From 97926c58b382b20ef58ba53d4ab8a5a683d04e2f Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Tue, 18 Nov 2025 17:36:16 -0800 Subject: [PATCH 1/3] Initial fixes --- .gitignore | 2 + .../agent_framework_azurefunctions/_app.py | 130 +++++++++----- .../_constants.py | 7 +- .../_durable_agent_state.py | 164 ++++++++---------- .../_entities.py | 24 ++- .../agent_framework_azurefunctions/_models.py | 8 + .../integration_tests/test_01_single_agent.py | 6 +- .../packages/azurefunctions/tests/test_app.py | 37 ++-- .../azurefunctions/tests/test_entities.py | 45 +++-- 9 files changed, 231 insertions(+), 192 deletions(-) diff --git a/.gitignore b/.gitignore index 162eec9ba7..f0f8c09495 100644 --- a/.gitignore +++ b/.gitignore @@ -204,6 +204,8 @@ agents.md # AI .claude/ WARP.md +**/memory-bank/ +**/projectBrief.md # Azurite storage emulator files */__azurite_db_blob__.json diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 7898cf6a78..96a68770cb 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -19,9 +19,12 @@ from ._constants import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, - RESPONSE_FORMAT_JSON, - RESPONSE_FORMAT_TEXT, + MIMETYPE_APPLICATION_JSON, + MIMETYPE_TEXT_PLAIN, + REQUEST_RESPONSE_FORMAT_JSON, + REQUEST_RESPONSE_FORMAT_TEXT, THREAD_ID_FIELD, + THREAD_ID_HEADER, WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER, ) @@ -316,11 +319,11 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien """ logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run") - response_format: str = RESPONSE_FORMAT_JSON + request_response_format: str = REQUEST_RESPONSE_FORMAT_JSON thread_id: str | None = None try: - req_body, message, response_format = self._parse_incoming_request(req) + req_body, message, request_response_format = self._parse_incoming_request(req) thread_id = self._resolve_thread_id(req=req, req_body=req_body) wait_for_response = self._should_wait_for_response(req=req, req_body=req_body) @@ -333,7 +336,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": "Message is required"}, status_code=400, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -350,6 +353,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien message, thread_id, correlation_id, + request_response_format, ) logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request) await client.signal_entity(entity_instance_id, "run_agent", run_request) @@ -369,7 +373,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload=result, status_code=200 if result.get("status") == "success" else 500, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -382,7 +386,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload=accepted_response, status_code=202, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -391,7 +395,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": str(exc)}, status_code=exc.status_code, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) except ValueError as exc: @@ -399,7 +403,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": "Invalid JSON"}, status_code=400, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) except Exception as exc: @@ -407,7 +411,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": str(exc)}, status_code=500, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -465,7 +469,7 @@ def health_check(req: func.HttpRequest) -> func.HttpResponse: return func.HttpResponse( json.dumps({"status": "healthy", "agents": agent_info, "agent_count": len(self.agents)}), status_code=200, - mimetype="application/json", + mimetype=MIMETYPE_APPLICATION_JSON, ) _ = health_check @@ -579,31 +583,58 @@ async def _poll_entity_for_response( return result - async def _build_timeout_result(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: - """Create the timeout response.""" - return { - "response": "Agent is still processing or timed out...", + def _build_response_payload( + self, + *, + response: str | None, + message: str, + thread_id: str, + status: str, + correlation_id: str, + extra_fields: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Create a consistent response structure and allow optional extra fields.""" + payload = { + "response": response, "message": message, THREAD_ID_FIELD: thread_id, - "status": "timeout", - "correlationId": correlation_id, + "status": status, + "correlation_id": correlation_id, } + if extra_fields: + payload.update(extra_fields) + return payload + + async def _build_timeout_result(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: + """Create the timeout response.""" + return self._build_response_payload( + response="Agent is still processing or timed out...", + message=message, + thread_id=thread_id, + status="timeout", + correlation_id=correlation_id, + ) def _build_success_result( self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState ) -> dict[str, Any]: """Build the success result returned to the HTTP caller.""" - return { - "response": response_data.get("content"), - "message": message, - THREAD_ID_FIELD: thread_id, - "status": "success", - "message_count": response_data.get("message_count", state.message_count), - "correlationId": correlation_id, - } + return self._build_response_payload( + response=response_data.get("content"), + message=message, + thread_id=thread_id, + status="success", + correlation_id=correlation_id, + extra_fields={"message_count": response_data.get("message_count", state.message_count)}, + ) def _build_request_data( - self, req_body: dict[str, Any], message: str, thread_id: str, correlation_id: str + self, + req_body: dict[str, Any], + message: str, + thread_id: str, + correlation_id: str, + request_response_format: str, ) -> dict[str, Any]: """Create the durable entity request payload.""" enable_tool_calls_value = req_body.get("enable_tool_calls") @@ -612,6 +643,7 @@ def _build_request_data( return RunRequest( message=message, role=req_body.get("role"), + request_response_format=request_response_format, response_format=req_body.get("response_format"), enable_tool_calls=enable_tool_calls, thread_id=thread_id, @@ -620,23 +652,23 @@ def _build_request_data( def _build_accepted_response(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: """Build the response returned when not waiting for completion.""" - return { - "response": "Agent request accepted", - "message": message, - THREAD_ID_FIELD: thread_id, - "status": "accepted", - "correlationId": correlation_id, - } + return self._build_response_payload( + response="Agent request accepted", + message=message, + thread_id=thread_id, + status="accepted", + correlation_id=correlation_id, + ) def _create_http_response( self, payload: dict[str, Any] | str, status_code: int, - response_format: str, + request_response_format: str, thread_id: str | None, ) -> func.HttpResponse: """Create the HTTP response using helper serializers for clarity.""" - if response_format == RESPONSE_FORMAT_TEXT: + if request_response_format == REQUEST_RESPONSE_FORMAT_TEXT: return self._build_plain_text_response(payload=payload, status_code=status_code, thread_id=thread_id) return self._build_json_response(payload=payload, status_code=status_code) @@ -649,13 +681,13 @@ def _build_plain_text_response( ) -> func.HttpResponse: """Return a plain-text response with optional thread identifier header.""" body_text = payload if isinstance(payload, str) else self._convert_payload_to_text(payload) - headers = {"x-ms-thread-id": thread_id} if thread_id is not None else None - return func.HttpResponse(body_text, status_code=status_code, mimetype="text/plain", headers=headers) + headers = {THREAD_ID_HEADER: thread_id} if thread_id is not None else None + return func.HttpResponse(body_text, status_code=status_code, mimetype=MIMETYPE_TEXT_PLAIN, headers=headers) def _build_json_response(self, payload: dict[str, Any] | str, status_code: int) -> func.HttpResponse: """Return the JSON response, serializing dictionaries as needed.""" body_json = payload if isinstance(payload, str) else json.dumps(payload) - return func.HttpResponse(body_json, status_code=status_code, mimetype="application/json") + return func.HttpResponse(body_json, status_code=status_code, mimetype=MIMETYPE_APPLICATION_JSON) def _convert_payload_to_text(self, payload: dict[str, Any]) -> str: """Convert a structured payload into a human-readable text response.""" @@ -701,10 +733,12 @@ def _parse_incoming_request(self, req: func.HttpRequest) -> tuple[dict[str, Any] normalized_content_type = self._extract_content_type(headers) body_parser, body_format = self._select_body_parser(normalized_content_type) prefers_json = self._accepts_json_response(headers) - response_format = self._select_response_format(body_format=body_format, prefers_json=prefers_json) + request_response_format = self._select_request_response_format( + body_format=body_format, prefers_json=prefers_json + ) req_body, message = body_parser(req) - return req_body, message, response_format + return req_body, message, request_response_format def _extract_normalized_headers(self, req: func.HttpRequest) -> dict[str, str]: """Create a lowercase header mapping from the incoming request.""" @@ -727,9 +761,9 @@ def _select_body_parser( normalized_content_type: str, ) -> tuple[Callable[[func.HttpRequest], tuple[dict[str, Any], str]], str]: """Choose the body parser and declared body format.""" - if normalized_content_type in {"application/json"} or normalized_content_type.endswith("+json"): - return self._parse_json_body, RESPONSE_FORMAT_JSON - return self._parse_text_body, RESPONSE_FORMAT_TEXT + if normalized_content_type in {MIMETYPE_APPLICATION_JSON} or normalized_content_type.endswith("+json"): + return self._parse_json_body, REQUEST_RESPONSE_FORMAT_JSON + return self._parse_text_body, REQUEST_RESPONSE_FORMAT_TEXT @staticmethod def _accepts_json_response(headers: dict[str, str]) -> bool: @@ -740,16 +774,16 @@ def _accepts_json_response(headers: dict[str, str]) -> bool: for value in accept_header.split(","): media_type = value.split(";")[0].strip().lower() - if media_type == "application/json": + if media_type == MIMETYPE_APPLICATION_JSON: return True return False @staticmethod - def _select_response_format(body_format: str, prefers_json: bool) -> str: + def _select_request_response_format(body_format: str, prefers_json: bool) -> str: """Combine body format and accept preference to determine response format.""" - if body_format == RESPONSE_FORMAT_JSON or prefers_json: - return RESPONSE_FORMAT_JSON - return RESPONSE_FORMAT_TEXT + if body_format == REQUEST_RESPONSE_FORMAT_JSON or prefers_json: + return REQUEST_RESPONSE_FORMAT_JSON + return REQUEST_RESPONSE_FORMAT_TEXT @staticmethod def _parse_json_body(req: func.HttpRequest) -> tuple[dict[str, Any], str]: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py index 3a8dfa50e4..291235b690 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py @@ -3,11 +3,14 @@ """Constants used across the Azure Functions agent framework.""" # Response format constants -RESPONSE_FORMAT_JSON: str = "json" -RESPONSE_FORMAT_TEXT: str = "text" +REQUEST_RESPONSE_FORMAT_JSON: str = "json" +REQUEST_RESPONSE_FORMAT_TEXT: str = "text" +MIMETYPE_APPLICATION_JSON: str = "application/json" +MIMETYPE_TEXT_PLAIN: str = "text/plain" # Field and header names THREAD_ID_FIELD: str = "thread_id" +THREAD_ID_HEADER: str = "x-ms-thread-id" WAIT_FOR_RESPONSE_FIELD: str = "wait_for_response" WAIT_FOR_RESPONSE_HEADER: str = "x-ms-wait-for-response" diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index b8fd14d2d8..5f48f260eb 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -31,6 +31,7 @@ import json from datetime import datetime, timezone +from enum import Enum from typing import Any from agent_framework import ( @@ -51,7 +52,6 @@ ) from dateutil import parser as date_parser # type: ignore -from ._constants import RESPONSE_FORMAT_JSON, RESPONSE_FORMAT_TEXT from ._models import RunRequest, _serialize_response_format @@ -146,26 +146,26 @@ class DurableAgentStateData: conforming to the durable-agent-entity-state.json schema structure. Attributes: - conversationHistory: Ordered list of conversation entries (requests and responses) - extensionData: Optional dictionary for custom metadata (not part of core schema) + conversation_history: Ordered list of conversation entries (requests and responses) + extension_data: Optional dictionary for custom metadata (not part of core schema) """ - conversationHistory: list[DurableAgentStateEntry] - extensionData: dict[str, Any] | None + conversation_history: list[DurableAgentStateEntry] + extension_data: dict[str, Any] | None def __init__( self, - conversationHistory: list[DurableAgentStateEntry] | None = None, - extensionData: dict[str, Any] | None = None, + conversation_history: list[DurableAgentStateEntry] | None = None, + extension_data: dict[str, Any] | None = None, ) -> None: """Initialize the data container. Args: - conversationHistory: Initial conversation history (defaults to empty list) - extensionData: Optional custom metadata + conversation_history: Initial conversation history (defaults to empty list) + extension_data: Optional custom metadata """ - self.conversationHistory = conversationHistory or [] - self.extensionData = extensionData + self.conversation_history = conversation_history or [] + self.extension_data = extension_data class DurableAgentState: @@ -212,7 +212,7 @@ def __init__(self, schema_version: str = "1.0.0"): def to_dict(self) -> dict[str, Any]: # Serialize conversationHistory serialized_history: list[dict[str, Any]] = [] - for entry in self.data.conversationHistory: + for entry in self.data.conversation_history: # Properly serialize each entry to a dictionary if hasattr(entry, "to_dict"): serialized_history.append(entry.to_dict()) @@ -244,8 +244,8 @@ def from_dict(cls, obj: dict[str, Any]) -> DurableAgentState: # Deserialize the data dict into DurableAgentStateData if isinstance(data_dict, dict): instance.data = DurableAgentStateData( - conversationHistory=data_dict.get("conversationHistory", []), - extensionData=data_dict.get("extensionData"), + conversation_history=data_dict.get("conversationHistory", []), + extension_data=data_dict.get("extensionData"), ) return instance @@ -284,13 +284,13 @@ def restore_state(self, state: dict[str, Any]) -> None: # Already an object deserialized_history.append(entry_dict) # type: ignore - self.data.conversationHistory = deserialized_history - self.data.extensionData = data_dict.get("extensionData") + self.data.conversation_history = deserialized_history + self.data.extension_data = data_dict.get("extensionData") @property def message_count(self) -> int: """Get the count of conversation entries (requests + responses).""" - return len(self.data.conversationHistory) + return len(self.data.conversation_history) def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: """Try to get an agent response by correlation ID. @@ -311,29 +311,26 @@ def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: None otherwise """ # Search through conversation history for a response with this correlationId - for entry in self.data.conversationHistory: - if ( - hasattr(entry, "correlation_id") - and entry.correlation_id == correlation_id - and isinstance(entry, DurableAgentStateResponse) - ): + for entry in self.data.conversation_history: + if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse): # Found the entry, extract response data # Get the text content from assistant messages only - content = "" - for message in entry.messages: - if hasattr(message, "role") and message.role == "assistant" and hasattr(message, "contents"): - # Extract text from message contents - for content_item in message.contents: - # Handle both object and dict forms - if hasattr(content_item, "text") and content_item.text: - content += content_item.text - elif isinstance(content_item, dict) and content_item.get("text"): - content += content_item.get("text", "") + content = "".join(message.text for message in entry.messages if message.text is not None) return {"content": content, "message_count": self.message_count, "correlationId": correlation_id} return None +class DurableAgentStateEntryJsonType(str, Enum): + """Enum for conversation history entry types. + + Discriminator values for the $type field in DurableAgentStateEntry objects. + """ + + REQUEST = "request" + RESPONSE = "response" + + class DurableAgentStateEntry: """Base class for conversation history entries (requests and responses). @@ -360,7 +357,7 @@ class DurableAgentStateEntry: usage: Token usage statistics - only for response entries """ - json_type: str + json_type: DurableAgentStateEntryJsonType correlation_id: str | None created_at: datetime messages: list[DurableAgentStateMessage] @@ -375,7 +372,7 @@ class DurableAgentStateEntry: def __init__( self, - json_type: str, + json_type: DurableAgentStateEntryJsonType, correlation_id: str | None, created_at: datetime, messages: list[DurableAgentStateMessage], @@ -405,13 +402,13 @@ def to_dict(self) -> dict[str, Any]: "createdAt": created_at_value.isoformat() if isinstance(created_at_value, datetime) else created_at_value, "messages": [m.to_dict() if hasattr(m, "to_dict") else m for m in self.messages], } - if self.json_type == "request": + if self.json_type == DurableAgentStateEntryJsonType.REQUEST: # Only include responseType and responseSchema if they're not None if self.response_type is not None: data["responseType"] = self.response_type if self.response_schema is not None: data["responseSchema"] = self.response_schema - elif self.json_type == "response": + elif self.json_type == DurableAgentStateEntryJsonType.RESPONSE: # Only include usage if it's not None if self.usage is not None: data["usage"] = self.usage.to_dict() @@ -457,21 +454,17 @@ class DurableAgentStateRequest(DurableAgentStateEntry): json_type: Always "request" for this class """ - response_type: str | None = None - response_schema: dict[str, Any] | None = None - def __init__( self, correlation_id: str | None, created_at: datetime, messages: list[DurableAgentStateMessage], - json_type: str, extension_data: dict[str, Any] | None = None, response_type: str | None = None, response_schema: dict[str, Any] | None = None, ) -> None: super().__init__( - json_type=json_type, + json_type=DurableAgentStateEntryJsonType.REQUEST, correlation_id=correlation_id, created_at=created_at, messages=messages, @@ -479,8 +472,6 @@ def __init__( response_type=response_type, response_schema=response_schema, ) - self.response_type = response_type - self.response_schema = response_schema @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: @@ -498,7 +489,6 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: messages.append(msg_dict) return cls( - json_type=data.get("$type", "request"), correlation_id=data.get("correlationId", ""), created_at=created_at, messages=messages, @@ -508,22 +498,15 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: ) @staticmethod - def from_run_request(content: RunRequest) -> DurableAgentStateRequest: + def from_run_request(request: RunRequest) -> DurableAgentStateRequest: # Determine response_type based on response_format - response_type = None - if content.response_format is not None: - response_type = ( - RESPONSE_FORMAT_TEXT if isinstance(content.response_format, TextContent) else RESPONSE_FORMAT_JSON - ) - return DurableAgentStateRequest( - correlation_id=content.correlation_id, - messages=[DurableAgentStateMessage.from_chat_message(content)], + correlation_id=request.correlation_id, + messages=[DurableAgentStateMessage.from_run_request(request)], created_at=datetime.now(tz=timezone.utc), - json_type="request", - extension_data=content.extension_data if hasattr(content, "extensionData") else None, - response_type=response_type, - response_schema=_serialize_response_format(content.response_format), + extension_data=request.extension_data, + response_type=request.request_response_format, + response_schema=_serialize_response_format(request.response_format), ) @@ -548,7 +531,6 @@ class DurableAgentStateResponse(DurableAgentStateEntry): def __init__( self, - json_type: str, correlation_id: str, created_at: datetime, messages: list[DurableAgentStateMessage], @@ -557,7 +539,7 @@ def __init__( is_error: bool = False, ) -> None: super().__init__( - json_type=json_type, + json_type=DurableAgentStateEntryJsonType.RESPONSE, correlation_id=correlation_id, created_at=created_at, messages=messages, @@ -590,7 +572,6 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateResponse: usage = usage_dict return cls( - json_type=data.get("$type", "response"), correlation_id=data.get("correlationId", ""), created_at=created_at, messages=messages, @@ -603,7 +584,7 @@ def from_run_response(correlation_id: str, response: AgentRunResponse) -> Durabl """Creates a DurableAgentStateResponse from an AgentRunResponse.""" # Use response.created_at if available, otherwise use current time created_at = datetime.now(tz=timezone.utc) - if hasattr(response, "created_at") and response.created_at is not None: + if response.created_at is not None: # AgentRunResponse.created_at is a string, parse it to datetime if isinstance(response.created_at, str): created_at = date_parser.parse(response.created_at) @@ -611,13 +592,10 @@ def from_run_response(correlation_id: str, response: AgentRunResponse) -> Durabl created_at = response.created_at return DurableAgentStateResponse( - json_type="response", correlation_id=correlation_id, created_at=created_at, messages=[DurableAgentStateMessage.from_chat_message(m) for m in response.messages], - usage=DurableAgentStateUsage.from_usage(response.usage) - if hasattr(response, "usage") and response.usage - else None, + usage=DurableAgentStateUsage.from_usage(response.usage_details), ) def to_run_response(self) -> Any: @@ -766,36 +744,46 @@ def text(self) -> str: return "".join(text_parts) @staticmethod - def from_chat_message(content: Any) -> DurableAgentStateMessage: - """Converts a ChatMessage or RunRequest from the agent framework to a DurableAgentStateMessage. + def from_run_request(request: RunRequest) -> DurableAgentStateMessage: + """Converts a RunRequest from the agent framework to a DurableAgentStateMessage. Args: - content: ChatMessage or RunRequest object with role, message/contents, and metadata - + request: RunRequest object with role, message/contents, and metadata Returns: DurableAgentStateMessage with converted content items and metadata """ - # Convert to a list of DurableAgentStateContent objects - contents_list: list[DurableAgentStateContent] = [] + # Parse created_at if it's a string + created_at = request.created_at + if isinstance(created_at, str): + created_at = date_parser.parse(created_at) - if hasattr(content, "message") and isinstance(content.message, str): - # RunRequest with 'message' attribute - contents_list = [DurableAgentStateTextContent(text=content.message)] - elif hasattr(content, "contents") and content.contents: - # ChatMessage with 'contents' attribute - convert each content object - for c in content.contents: - converted = DurableAgentStateContent.from_ai_content(c) - contents_list.append(converted) + return DurableAgentStateMessage( + role=request.role.value, + contents=[DurableAgentStateTextContent(text=request.message)], + author_name=request.author_name, + created_at=created_at, + extension_data=request.extension_data, + ) - # Convert role enum to string if needed - role_value = content.role.value if hasattr(content.role, "value") else str(content.role) + @staticmethod + def from_chat_message(chat_message: ChatMessage) -> DurableAgentStateMessage: + """Converts an Agent Framework chat message to a durable state message. + + Args: + chat_message: ChatMessage object with role, contents, and metadata to convert + + Returns: + DurableAgentStateMessage with converted content items and metadata + """ + contents_list: list[DurableAgentStateContent] = [ + DurableAgentStateContent.from_ai_content(c) for c in chat_message.contents + ] return DurableAgentStateMessage( - role=role_value, + role=chat_message.role.value, contents=contents_list, - author_name=content.author_name if hasattr(content, "author_name") else None, - created_at=content.created_at if hasattr(content, "created_at") else None, - extension_data=content.extension_data if hasattr(content, "extensionData") else None, + author_name=chat_message.author_name, + extension_data=dict(chat_message.additional_properties) if chat_message.additional_properties else None, ) def to_chat_message(self) -> Any: @@ -1149,7 +1137,7 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateUsage: ) @staticmethod - def from_usage(usage: Any) -> DurableAgentStateUsage | None: + def from_usage(usage: UsageDetails | None) -> DurableAgentStateUsage | None: if usage is None: return None return DurableAgentStateUsage( @@ -1158,7 +1146,7 @@ def from_usage(usage: Any) -> DurableAgentStateUsage | None: total_token_count=usage.total_token_count, ) - def to_usage_details(self) -> Any: + def to_usage_details(self) -> UsageDetails: # Convert back to AI SDK UsageDetails return UsageDetails( input_token_count=self.input_token_count, diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index a967f5d9e6..d29c8b42b7 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -130,9 +130,9 @@ async def run_agent( enable_tool_calls = run_request.enable_tool_calls state_request = DurableAgentStateRequest.from_run_request(run_request) - self.state.data.conversationHistory.append(state_request) + self.state.data.conversation_history.append(state_request) - logger.debug(f"[AgentEntity.run_agent] Saved state request: {state_request}") + logger.debug(f"[AgentEntity.run_agent] Received Message: {state_request}") try: logger.debug("[AgentEntity.run_agent] Starting agent invocation") @@ -141,7 +141,7 @@ async def run_agent( # Error responses are kept in history for tracking but not sent to the agent chat_messages: list[ChatMessage] = [ m.to_chat_message() - for entry in self.state.data.conversationHistory + for entry in self.state.data.conversation_history if not self._is_error_response(entry) for m in entry.messages ] @@ -196,14 +196,14 @@ async def run_agent( response_text = "Error extracting response" state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response) - self.state.data.conversationHistory.append(state_response) + self.state.data.conversation_history.append(state_response) agent_response = AgentResponse( response=response_text, message=str(message), thread_id=str(thread_id), status="success", - message_count=len(self.state.data.conversationHistory), + message_count=len(self.state.data.conversation_history), structured_response=structured_response, ) result = agent_response.to_dict() @@ -223,27 +223,25 @@ async def run_agent( # Create error message error_message = DurableAgentStateMessage.from_chat_message( - ChatMessage(role="assistant", contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]) + ChatMessage( + role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)] + ) ) # Create and store error response in conversation history error_state_response = DurableAgentStateResponse( - json_type="response", correlation_id=correlation_id, created_at=datetime.now(tz=timezone.utc), messages=[error_message], - extension_data=None, - usage=None, - is_error=True, ) - self.state.data.conversationHistory.append(error_state_response) + self.state.data.conversation_history.append(error_state_response) error_response = AgentResponse( response=f"Error: {exc!s}", message=str(message), thread_id=str(thread_id), status="error", - message_count=len(self.state.data.conversationHistory), + message_count=len(self.state.data.conversation_history), error=str(exc), error_type=type(exc).__name__, ) @@ -390,7 +388,7 @@ def _build_callback_context( def reset(self, context: df.DurableEntityContext) -> None: """Reset the entity state (clear conversation history).""" logger.debug("[AgentEntity.reset] Resetting entity state") - self.state.data = DurableAgentStateData(conversationHistory=[]) + self.state.data = DurableAgentStateData(conversation_history=[]) logger.debug("[AgentEntity.reset] State reset complete") diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index c6b66df20d..b0dfd13d35 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -17,6 +17,8 @@ import azure.durable_functions as df from agent_framework import AgentThread, Role +from ._constants import REQUEST_RESPONSE_FORMAT_TEXT + if TYPE_CHECKING: # pragma: no cover - type checking imports only from pydantic import BaseModel @@ -289,6 +291,7 @@ class RunRequest: """ message: str + request_response_format: str role: Role = Role.USER response_format: type[BaseModel] | None = None enable_tool_calls: bool = True @@ -301,6 +304,7 @@ class RunRequest: def __init__( self, message: str, + request_response_format: str = REQUEST_RESPONSE_FORMAT_TEXT, role: Role | str | None = Role.USER, response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True, @@ -313,6 +317,7 @@ def __init__( self.message = message self.role = self.coerce_role(role) self.response_format = response_format + self.request_response_format = request_response_format self.enable_tool_calls = enable_tool_calls self.thread_id = thread_id self.correlation_id = correlation_id @@ -338,6 +343,7 @@ def to_dict(self) -> dict[str, Any]: "message": self.message, "enable_tool_calls": self.enable_tool_calls, "role": self.role.value, + "request_response_format": self.request_response_format, } if self.response_format: result["response_format"] = _serialize_response_format(self.response_format) @@ -351,6 +357,7 @@ def to_dict(self) -> dict[str, Any]: result["created_at"] = self.created_at if self.extension_data: result["extension_data"] = self.extension_data + return result @classmethod @@ -358,6 +365,7 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: """Create RunRequest from dictionary.""" return cls( message=data.get("message", ""), + request_response_format=data.get("request_response_format", REQUEST_RESPONSE_FORMAT_TEXT), role=cls.coerce_role(data.get("role")), response_format=_deserialize_response_format(data.get("response_format")), enable_tool_calls=data.get("enable_tool_calls", True), diff --git a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py index f97916825c..4b1910ea7b 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py @@ -16,6 +16,8 @@ import pytest +from agent_framework_azurefunctions._constants import THREAD_ID_HEADER + from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file @@ -67,7 +69,7 @@ def test_simple_message_plain_text(self) -> None: # Agent responded with plain text when the request body was text/plain. assert response.text.strip() - assert response.headers.get("x-ms-thread-id") is not None + assert response.headers.get(THREAD_ID_HEADER) is not None def test_thread_id_in_query(self) -> None: """Test using thread_id in query parameter.""" @@ -77,7 +79,7 @@ def test_thread_id_in_query(self) -> None: assert response.status_code in [200, 202] assert response.text.strip() - assert response.headers.get("x-ms-thread-id") == "test-query-thread" + assert response.headers.get(THREAD_ID_HEADER) == "test-query-thread" def test_conversation_continuity(self) -> None: """Test conversation context is maintained across requests.""" diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 4caf3ed9d6..35ead2c4e8 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -13,6 +13,11 @@ from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER +from agent_framework_azurefunctions._constants import ( + MIMETYPE_APPLICATION_JSON, + MIMETYPE_TEXT_PLAIN, + THREAD_ID_HEADER, +) from agent_framework_azurefunctions._durable_agent_state import DurableAgentState from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity @@ -359,7 +364,7 @@ async def test_entity_stores_conversation_history(self) -> None: ) # Each conversation turn creates 2 entries: request and response - history = entity.state.data.conversationHistory[0].messages # Request entry + history = entity.state.data.conversation_history[0].messages # Request entry assert len(history) == 1 # Just the user message # Send second message @@ -369,7 +374,7 @@ async def test_entity_stores_conversation_history(self) -> None: # Now we have 4 entries total (2 requests + 2 responses) # Access the first request entry - history2 = entity.state.data.conversationHistory[2].messages # Second request entry + history2 = entity.state.data.conversation_history[2].messages # Second request entry assert len(history2) == 1 # Just the user message user_msg = history[0] @@ -377,7 +382,7 @@ async def test_entity_stores_conversation_history(self) -> None: assert user_role == "user" assert user_msg.text == "Message 1" - assistant_msg = entity.state.data.conversationHistory[1].messages[0] + assistant_msg = entity.state.data.conversation_history[1].messages[0] assistant_role = getattr(assistant_msg.role, "value", assistant_msg.role) assert assistant_role == "assistant" assert assistant_msg.text == "Response 1" @@ -392,17 +397,17 @@ async def test_entity_increments_message_count(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - assert len(entity.state.data.conversationHistory) == 0 + assert len(entity.state.data.conversation_history) == 0 await entity.run_agent( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-app-entity-3a"} ) - assert len(entity.state.data.conversationHistory) == 2 + assert len(entity.state.data.conversation_history) == 2 await entity.run_agent( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-app-entity-3b"} ) - assert len(entity.state.data.conversationHistory) == 4 + assert len(entity.state.data.conversation_history) == 4 def test_entity_reset(self) -> None: """Test that entity reset clears state.""" @@ -410,13 +415,13 @@ def test_entity_reset(self) -> None: entity = AgentEntity(mock_agent) # Set some state - entity.state.data = DurableAgentState() + entity.state = DurableAgentState() # Reset mock_context = Mock() entity.reset(mock_context) - assert len(entity.state.data.conversationHistory) == 0 + assert len(entity.state.data.conversation_history) == 0 class TestAgentEntityFactory: @@ -608,7 +613,7 @@ def test_accept_header_prefers_json(self) -> None: app = self._create_app() request = Mock() - request.headers = {"accept": "application/json"} + request.headers = {"accept": MIMETYPE_APPLICATION_JSON} request.params = {} request.get_json.side_effect = ValueError("Invalid JSON") request.get_body.return_value = b"Plain text message" @@ -682,8 +687,8 @@ async def test_http_run_accepts_plain_text(self) -> None: response = await handler(request, client) assert response.status_code == 202 - assert response.mimetype == "text/plain" - assert response.headers.get("x-ms-thread-id") is not None + assert response.mimetype == MIMETYPE_TEXT_PLAIN + assert response.headers.get(THREAD_ID_HEADER) is not None assert response.get_body().decode("utf-8") == "Agent request accepted" signal_args = client.signal_entity.call_args[0] @@ -701,7 +706,7 @@ async def test_http_run_accept_header_returns_json(self) -> None: handler = self._get_run_handler(mock_agent) request = Mock() - request.headers = {WAIT_FOR_RESPONSE_HEADER: "false", "Accept": "application/json"} + request.headers = {WAIT_FOR_RESPONSE_HEADER: "false", "Accept": MIMETYPE_APPLICATION_JSON} request.params = {} request.route_params = {} request.get_json.side_effect = ValueError("Invalid JSON") @@ -712,8 +717,8 @@ async def test_http_run_accept_header_returns_json(self) -> None: response = await handler(request, client) assert response.status_code == 202 - assert response.mimetype == "application/json" - assert response.headers.get("x-ms-thread-id") is None + assert response.mimetype == MIMETYPE_APPLICATION_JSON + assert response.headers.get(THREAD_ID_HEADER) is None body = response.get_body().decode("utf-8") assert '"status": "accepted"' in body @@ -736,8 +741,8 @@ async def test_http_run_rejects_empty_message(self) -> None: response = await handler(request, client) assert response.status_code == 400 - assert response.mimetype == "text/plain" - assert response.headers.get("x-ms-thread-id") is not None + assert response.mimetype == MIMETYPE_TEXT_PLAIN + assert response.headers.get(THREAD_ID_HEADER) is not None assert response.get_body().decode("utf-8") == "Message is required" client.signal_entity.assert_not_called() diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 200194c2c5..7d2b163e5c 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -28,7 +28,7 @@ TFunc = TypeVar("TFunc", bound=Callable[..., Any]) -def _role_value(chat_message: ChatMessage) -> str: +def _role_value(chat_message: DurableAgentStateMessage) -> str: """Helper to extract the string role from a ChatMessage.""" role = getattr(chat_message, "role", None) role_value = getattr(role, "value", role) @@ -77,8 +77,8 @@ def test_init_creates_entity(self) -> None: entity = AgentEntity(mock_agent) assert entity.agent == mock_agent - assert len(entity.state.data.conversationHistory) == 0 - assert entity.state.data.extensionData is None + assert len(entity.state.data.conversation_history) == 0 + assert entity.state.data.extension_data is None assert entity.state.schema_version == "1.0.0" def test_init_stores_agent_reference(self) -> None: @@ -128,7 +128,7 @@ async def test_run_agent_executes_agent(self) -> None: assert isinstance(sent_messages, list) assert len(sent_messages) == 1 sent_message = sent_messages[0] - assert isinstance(sent_message, ChatMessage) + assert isinstance(sent_message, DurableAgentStateMessage) assert sent_message.text == "Test message" assert _role_value(sent_message) == "user" @@ -243,8 +243,8 @@ async def test_run_agent_updates_conversation_history(self) -> None: ) # Should have 1 entry: user message + assistant response - user_history = entity.state.data.conversationHistory[0].messages - assistant_history = entity.state.data.conversationHistory[1].messages + user_history = entity.state.data.conversation_history[0].messages + assistant_history = entity.state.data.conversation_history[1].messages assert len(user_history) == 1 @@ -264,22 +264,22 @@ async def test_run_agent_increments_message_count(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - assert len(entity.state.data.conversationHistory) == 0 + assert len(entity.state.data.conversation_history) == 0 await entity.run_agent( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-3a"} ) - assert len(entity.state.data.conversationHistory) == 2 + assert len(entity.state.data.conversation_history) == 2 await entity.run_agent( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-3b"} ) - assert len(entity.state.data.conversationHistory) == 4 + assert len(entity.state.data.conversation_history) == 4 await entity.run_agent( mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-3c"} ) - assert len(entity.state.data.conversationHistory) == 6 + assert len(entity.state.data.conversation_history) == 6 async def test_run_agent_with_none_thread_id(self) -> None: """Test run_agent with a None thread identifier.""" @@ -351,7 +351,7 @@ async def test_run_agent_multiple_conversations(self) -> None: mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-8c"} ) - history = entity.state.data.conversationHistory + history = entity.state.data.conversation_history assert len(history) == 6 assert entity.state.message_count == 6 @@ -365,9 +365,8 @@ def test_reset_clears_conversation_history(self) -> None: entity = AgentEntity(mock_agent) # Add some history with proper DurableAgentStateEntry objects - entity.state.data.conversationHistory = [ + entity.state.data.conversation_history = [ DurableAgentStateRequest( - json_type="request", correlation_id="test-1", created_at=datetime.now(), messages=[ @@ -382,7 +381,7 @@ def test_reset_clears_conversation_history(self) -> None: mock_context = Mock() entity.reset(mock_context) - assert entity.state.data.conversationHistory == [] + assert entity.state.data.conversation_history == [] def test_reset_with_extension_data(self) -> None: """Test that reset works when entity has extension data.""" @@ -390,12 +389,12 @@ def test_reset_with_extension_data(self) -> None: entity = AgentEntity(mock_agent) # Set up some initial state with conversation history - entity.state.data = DurableAgentStateData(conversationHistory=[], extensionData={"some_key": "some_value"}) + entity.state.data = DurableAgentStateData(conversation_history=[], extension_data={"some_key": "some_value"}) mock_context = Mock() entity.reset(mock_context) - assert len(entity.state.data.conversationHistory) == 0 + assert len(entity.state.data.conversation_history) == 0 def test_reset_clears_message_count(self) -> None: """Test that reset clears the message count.""" @@ -405,7 +404,7 @@ def test_reset_clears_message_count(self) -> None: mock_context = Mock() entity.reset(mock_context) - assert len(entity.state.data.conversationHistory) == 0 + assert len(entity.state.data.conversation_history) == 0 async def test_reset_after_conversation(self) -> None: """Test reset after a full conversation.""" @@ -425,14 +424,14 @@ async def test_reset_after_conversation(self) -> None: # Verify state before reset assert entity.state.message_count == 4 - assert len(entity.state.data.conversationHistory) == 4 + assert len(entity.state.data.conversation_history) == 4 # Reset entity.reset(mock_context) # Verify state after reset assert entity.state.message_count == 0 - assert len(entity.state.data.conversationHistory) == 0 + assert len(entity.state.data.conversation_history) == 0 class TestCreateAgentEntity: @@ -686,7 +685,7 @@ async def test_conversation_history_has_timestamps(self) -> None: ) # Check both user and assistant messages have timestamps - for entry in entity.state.data.conversationHistory: + for entry in entity.state.data.conversation_history: timestamp = entry.created_at assert timestamp is not None # Verify timestamp is in ISO format @@ -719,7 +718,7 @@ async def test_conversation_history_ordering(self) -> None: ) # Verify order - history = entity.state.data.conversationHistory + history = entity.state.data.conversation_history # Each conversation turn creates 2 entries: request and response assert history[0].messages[0].text == "Message 1" # Request 1 assert history[1].messages[0].text == "Response 1" # Response 1 @@ -746,7 +745,7 @@ async def test_conversation_history_role_alternation(self) -> None: ) # Check role alternation - history = entity.state.data.conversationHistory + history = entity.state.data.conversation_history # Each conversation turn creates 2 entries: request and response assert history[0].messages[0].role == "user" # Request 1 assert history[1].messages[0].role == "assistant" # Response 1 @@ -832,7 +831,7 @@ async def test_run_agent_stores_role_in_history(self) -> None: await entity.run_agent(mock_context, request) # Check that system role was stored - history = entity.state.data.conversationHistory + history = entity.state.data.conversation_history assert history[0].messages[0].role == "system" assert history[0].messages[0].text == "System message" From 0022936da971c3ec5a71e93f53d316d780e3924c Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Tue, 18 Nov 2025 22:18:14 -0800 Subject: [PATCH 2/3] Address more comments --- .../agent_framework_azurefunctions/_app.py | 4 +- .../_durable_agent_state.py | 92 ++++++++----------- .../_entities.py | 10 +- .../integration_tests/test_01_single_agent.py | 4 +- .../packages/azurefunctions/tests/test_app.py | 58 +++++++++++- .../azurefunctions/tests/test_entities.py | 48 +++++++--- 6 files changed, 136 insertions(+), 80 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 96a68770cb..e0bc3ba51a 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -505,9 +505,7 @@ async def _read_cached_state( typed_state_payload = cast(dict[str, Any], state_payload) - agent_state = DurableAgentState() - agent_state.restore_state(typed_state_payload) - return agent_state + return DurableAgentState.from_dict(typed_state_payload) async def _get_response_from_entity( self, diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index 5f48f260eb..1384737cf7 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -49,11 +49,14 @@ UriContent, UsageContent, UsageDetails, + get_logger, ) from dateutil import parser as date_parser # type: ignore from ._models import RunRequest, _serialize_response_format +logger = get_logger("agent_framework.azurefunctions.durable_agent_state") + class DurableAgentStateContent: """Base class for all content types in durable agent state messages. @@ -229,41 +232,17 @@ def to_json(self) -> str: return json.dumps(self.to_dict()) @classmethod - def from_dict(cls, obj: dict[str, Any]) -> DurableAgentState: - # Default to "1.0.0" for backward compatibility with legacy state that doesn't have schemaVersion - schema_version = obj.get("schemaVersion", "1.0.0") - - if not schema_version.startswith("1."): - raise ValueError(f"The durable agent state schema version '{schema_version}' is not supported.") - - data_dict = obj.get("data") - if data_dict is None: - raise ValueError("The durable agent state is missing the 'data' property.") - - instance = cls(schema_version=schema_version) - # Deserialize the data dict into DurableAgentStateData - if isinstance(data_dict, dict): - instance.data = DurableAgentStateData( - conversation_history=data_dict.get("conversationHistory", []), - extension_data=data_dict.get("extensionData"), - ) - return instance - - @classmethod - def from_json(cls, json_str: str) -> DurableAgentState: - try: - obj = json.loads(json_str) - except json.JSONDecodeError as e: - raise ValueError("The durable agent state is not valid JSON.") from e - - return cls.from_dict(obj) - - def restore_state(self, state: dict[str, Any]) -> None: + def from_dict(cls, state: dict[str, Any]) -> DurableAgentState: """Restore state from a dictionary. Args: state: Dictionary containing schemaVersion and data (full state structure) """ + schema_version = state.get("schemaVersion") + if schema_version is None: + logger.warning("Resetting state as it is incompatible with the current schema, all history will be lost") + return cls() + # Extract the data portion from the state data_dict = state.get("data", {}) @@ -274,18 +253,30 @@ def restore_state(self, state: dict[str, Any]) -> None: if isinstance(entry_dict, dict): # Deserialize based on $type discriminator entry_type = entry_dict.get("$type") or entry_dict.get("json_type") - if entry_type == "response": + if entry_type == DurableAgentStateEntryJsonType.RESPONSE: deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) - elif entry_type == "request": + elif entry_type == DurableAgentStateEntryJsonType.REQUEST: deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict)) else: deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict)) else: # Already an object - deserialized_history.append(entry_dict) # type: ignore + deserialized_history.append(entry_dict) + + instance = cls(schema_version=state.get("schemaVersion", "1.0.0")) + instance.data.conversation_history = deserialized_history + instance.data.extension_data = data_dict.get("extensionData") + + return instance + + @classmethod + def from_json(cls, json_str: str) -> DurableAgentState: + try: + obj = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError("The durable agent state is not valid JSON.") from e - self.data.conversation_history = deserialized_history - self.data.extension_data = data_dict.get("extensionData") + return cls.from_dict(obj) @property def message_count(self) -> int: @@ -400,7 +391,7 @@ def to_dict(self) -> dict[str, Any]: "$type": self.json_type, "correlationId": self.correlation_id, "createdAt": created_at_value.isoformat() if isinstance(created_at_value, datetime) else created_at_value, - "messages": [m.to_dict() if hasattr(m, "to_dict") else m for m in self.messages], + "messages": [m.to_dict() for m in self.messages], } if self.json_type == DurableAgentStateEntryJsonType.REQUEST: # Only include responseType and responseSchema if they're not None @@ -430,7 +421,7 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateEntry: messages.append(msg_dict) return cls( - json_type=data.get("$type", "entry"), + json_type=DurableAgentStateEntryJsonType(data.get("$type", "entry")), correlation_id=data.get("correlationId", ""), created_at=created_at, messages=messages, @@ -807,10 +798,7 @@ def to_chat_message(self) -> Any: if self.extension_data is not None: kwargs["additional_properties"] = self.extension_data - return ChatMessage(**kwargs) # type: ignore[call-overload] - - -# Content subclasses + return ChatMessage(**kwargs) class DurableAgentStateDataContent(DurableAgentStateContent): @@ -839,7 +827,7 @@ def to_dict(self) -> dict[str, Any]: def from_data_content(content: DataContent) -> DurableAgentStateDataContent: return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> DataContent: return DataContent(uri=self.uri, media_type=self.media_type) @@ -873,7 +861,7 @@ def from_error_content(content: ErrorContent) -> DurableAgentStateErrorContent: message=content.message, error_code=content.error_code, details=content.details ) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> ErrorContent: return ErrorContent(message=self.message, error_code=self.error_code, details=self.details) @@ -918,7 +906,7 @@ def from_function_call_content(content: FunctionCallContent) -> DurableAgentStat return DurableAgentStateFunctionCallContent(call_id=content.call_id, name=content.name, arguments=arguments) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> FunctionCallContent: return FunctionCallContent(call_id=self.call_id, name=self.name, arguments=self.arguments) @@ -948,7 +936,7 @@ def to_dict(self) -> dict[str, Any]: def from_function_result_content(content: FunctionResultContent) -> DurableAgentStateFunctionResultContent: return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> FunctionResultContent: return FunctionResultContent(call_id=self.call_id, result=self.result) @@ -974,7 +962,7 @@ def to_dict(self) -> dict[str, Any]: def from_hosted_file_content(content: HostedFileContent) -> DurableAgentStateHostedFileContent: return DurableAgentStateHostedFileContent(file_id=content.file_id) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> HostedFileContent: return HostedFileContent(file_id=self.file_id) @@ -1003,7 +991,7 @@ def from_hosted_vector_store_content( ) -> DurableAgentStateHostedVectorStoreContent: return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> HostedVectorStoreContent: return HostedVectorStoreContent(vector_store_id=self.vector_store_id) @@ -1029,7 +1017,7 @@ def to_dict(self) -> dict[str, Any]: def from_text_content(content: TextContent) -> DurableAgentStateTextContent: return DurableAgentStateTextContent(text=content.text) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> TextContent: return TextContent(text=self.text or "") @@ -1055,7 +1043,7 @@ def to_dict(self) -> dict[str, Any]: def from_text_reasoning_content(content: TextReasoningContent) -> DurableAgentStateTextReasoningContent: return DurableAgentStateTextReasoningContent(text=content.text) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> TextReasoningContent: return TextReasoningContent(text=self.text or "") @@ -1084,7 +1072,7 @@ def to_dict(self) -> dict[str, Any]: def from_uri_content(content: UriContent) -> DurableAgentStateUriContent: return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> UriContent: return UriContent(uri=self.uri, media_type=self.media_type) @@ -1178,7 +1166,7 @@ def to_dict(self) -> dict[str, Any]: def from_usage_content(content: UsageContent) -> DurableAgentStateUsageContent: return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details)) # type: ignore - def to_ai_content(self) -> Any: + def to_ai_content(self) -> UsageContent: return UsageContent(details=self.usage.to_usage_details()) @@ -1205,7 +1193,7 @@ def to_dict(self) -> dict[str, Any]: def from_unknown_content(content: Any) -> DurableAgentStateUnknownContent: return DurableAgentStateUnknownContent(content=content) - def to_ai_content(self) -> Any: + def to_ai_content(self) -> BaseContent: if not self.content: raise Exception("The content is missing and cannot be converted to valid AI content.") return BaseContent(content=self.content) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index d29c8b42b7..2bea456ddc 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -77,7 +77,7 @@ def __init__( logger.debug(f"[AgentEntity] Initialized with agent type: {type(agent).__name__}") - def _is_error_response(self, entry: DurableAgentStateEntry | dict[str, Any]) -> bool: + def _is_error_response(self, entry: DurableAgentStateEntry) -> bool: """Check if a conversation history entry is an error response. Error responses should be kept in history for tracking but not sent to the agent @@ -146,12 +146,6 @@ async def run_agent( for m in entry.messages ] - # Strip additional_properties from all messages to avoid metadata being sent to Azure OpenAI - # Azure OpenAI doesn't support the 'metadata' field in messages - for msg in chat_messages: - if hasattr(msg, "additional_properties"): - msg.additional_properties = {} - run_kwargs: dict[str, Any] = {"messages": chat_messages} if not enable_tool_calls: run_kwargs["tools"] = None @@ -417,7 +411,7 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: entity = AgentEntity(agent, callback) if current_state is not None: - entity.state.restore_state(current_state) + entity.state = DurableAgentState.from_dict(current_state) logger.debug( "[entity_function] Restored entity from state (message_count: %s)", entity.state.message_count ) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py index 4b1910ea7b..cd93e6a352 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py @@ -94,7 +94,7 @@ def test_conversation_continuity(self) -> None: if response1.status_code == 200: data1 = response1.json() - assert data1["message_count"] == 1 + assert data1["message_count"] == 2 # Initial + reply # Second message in same session response2 = SampleTestHelper.post_json( @@ -102,7 +102,7 @@ def test_conversation_continuity(self) -> None: ) assert response2.status_code == 200 data2 = response2.json() - assert data2["message_count"] == 2 + assert data2["message_count"] == 4 else: # In async mode, we can't easily test message count # Just verify we can make multiple calls diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 35ead2c4e8..ebf6eef3e6 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -471,7 +471,24 @@ def test_entity_function_handles_reset_operation(self) -> None: mock_context.get_state.return_value = { "schemaVersion": "1.0.0", "data": { - "conversationHistory": [{"role": "user", "content": "test"}], + "conversationHistory": [ + { + "$type": "request", + "correlationId": "corr-reset-test", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [ + { + "role": "user", + "contents": [ + { + "$type": "text", + "text": "test", + } + ], + } + ], + } + ], }, } @@ -511,7 +528,40 @@ def test_entity_function_restores_state(self) -> None: existing_state = { "schemaVersion": "1.0.0", "data": { - "conversationHistory": [{"role": "user", "content": "msg1"}, {"role": "assistant", "content": "resp1"}], + "conversationHistory": [ + { + "$type": "request", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [ + { + "role": "user", + "contents": [ + { + "$type": "text", + "text": "msg1", + } + ], + } + ], + }, + { + "$type": "response", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:05:00Z", + "messages": [ + { + "role": "assistant", + "contents": [ + { + "$type": "text", + "text": "resp1", + } + ], + } + ], + }, + ], }, } @@ -519,10 +569,10 @@ def test_entity_function_restores_state(self) -> None: mock_context.operation_name = "reset" mock_context.get_state.return_value = existing_state - with patch.object(DurableAgentState, "restore_state") as restore_state_mock: + with patch.object(DurableAgentState, "from_dict", wraps=DurableAgentState.from_dict) as from_dict_mock: entity_function(mock_context) - restore_state_mock.assert_called_once_with(existing_state) + from_dict_mock.assert_called_once_with(existing_state) class TestErrorHandling: diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 7d2b163e5c..2f73f1daa8 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -128,9 +128,9 @@ async def test_run_agent_executes_agent(self) -> None: assert isinstance(sent_messages, list) assert len(sent_messages) == 1 sent_message = sent_messages[0] - assert isinstance(sent_message, DurableAgentStateMessage) - assert sent_message.text == "Test message" - assert _role_value(sent_message) == "user" + assert isinstance(sent_message, ChatMessage) + assert getattr(sent_message, "text", None) == "Test message" + assert getattr(sent_message.role, "value", sent_message.role) == "user" # Verify result assert result["status"] == "success" @@ -560,12 +560,38 @@ def test_entity_function_restores_existing_state(self) -> None: "schemaVersion": "1.0.0", "data": { "conversationHistory": [ - ChatMessage( - role="user", text="msg1", additional_properties={"timestamp": "2024-01-01T00:00:00Z"} - ).to_dict(), - ChatMessage( - role="assistant", text="resp1", additional_properties={"timestamp": "2024-01-01T00:05:00Z"} - ).to_dict(), + { + "$type": "request", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [ + { + "role": "user", + "contents": [ + { + "$type": "text", + "text": "msg1", + } + ], + } + ], + }, + { + "$type": "response", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:05:00Z", + "messages": [ + { + "role": "assistant", + "contents": [ + { + "$type": "text", + "text": "resp1", + } + ], + } + ], + }, ], }, } @@ -574,10 +600,10 @@ def test_entity_function_restores_existing_state(self) -> None: mock_context.operation_name = "reset" mock_context.get_state.return_value = existing_state - with patch.object(DurableAgentState, "restore_state") as restore_state_mock: + with patch.object(DurableAgentState, "from_dict", wraps=DurableAgentState.from_dict) as from_dict_mock: entity_function(mock_context) - restore_state_mock.assert_called_once_with(existing_state) + from_dict_mock.assert_called_once_with(existing_state) class TestErrorHandling: From f54d7ef4535b717a7e17c231314a319e8af22129 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Wed, 19 Nov 2025 11:30:54 -0800 Subject: [PATCH 3/3] Address remaining comments --- .../_constants.py | 4 +- .../_durable_agent_state.py | 72 +++++++++---------- .../agent_framework_azurefunctions/_models.py | 12 ---- 3 files changed, 38 insertions(+), 50 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py index 291235b690..8c4cded196 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. -"""Constants used across the Azure Functions agent framework.""" +"""Constants for Azure Functions Agent Framework integration.""" -# Response format constants +# Supported request/response formats and MIME types REQUEST_RESPONSE_FORMAT_JSON: str = "json" REQUEST_RESPONSE_FORMAT_TEXT: str = "text" MIMETYPE_APPLICATION_JSON: str = "application/json" diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index 1384737cf7..63e6b24eb7 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -170,6 +170,36 @@ def __init__( self.conversation_history = conversation_history or [] self.extension_data = extension_data + def to_dict(self) -> dict[str, Any]: + return { + "conversationHistory": [entry.to_dict() for entry in self.conversation_history], + "extensionData": self.extension_data, + } + + @classmethod + def from_dict(cls, data_dict: dict[str, Any]) -> DurableAgentStateData: + # Restore the conversation history - deserialize entries from dicts to objects + history_data = data_dict.get("conversationHistory", []) + deserialized_history: list[DurableAgentStateEntry] = [] + for entry_dict in history_data: + if isinstance(entry_dict, dict): + # Deserialize based on $type discriminator + entry_type = entry_dict.get("$type") or entry_dict.get("json_type") + if entry_type == DurableAgentStateEntryJsonType.RESPONSE: + deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) + elif entry_type == DurableAgentStateEntryJsonType.REQUEST: + deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict)) + else: + deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict)) + else: + # Already an object + deserialized_history.append(entry_dict) + + return cls( + conversation_history=deserialized_history, + extension_data=data_dict.get("extensionData"), + ) + class DurableAgentState: """Manages durable agent state conforming to the durable-agent-entity-state.json schema. @@ -213,19 +243,10 @@ def __init__(self, schema_version: str = "1.0.0"): self.schema_version = schema_version def to_dict(self) -> dict[str, Any]: - # Serialize conversationHistory - serialized_history: list[dict[str, Any]] = [] - for entry in self.data.conversation_history: - # Properly serialize each entry to a dictionary - if hasattr(entry, "to_dict"): - serialized_history.append(entry.to_dict()) - else: - # Fallback for already-serialized entries - serialized_history.append(entry) # type: ignore return { "schemaVersion": self.schema_version, - "data": {"conversationHistory": serialized_history}, + "data": self.data.to_dict(), } def to_json(self) -> str: @@ -243,29 +264,8 @@ def from_dict(cls, state: dict[str, Any]) -> DurableAgentState: logger.warning("Resetting state as it is incompatible with the current schema, all history will be lost") return cls() - # Extract the data portion from the state - data_dict = state.get("data", {}) - - # Restore the conversation history - deserialize entries from dicts to objects - history_data = data_dict.get("conversationHistory", []) - deserialized_history: list[DurableAgentStateEntry] = [] - for entry_dict in history_data: - if isinstance(entry_dict, dict): - # Deserialize based on $type discriminator - entry_type = entry_dict.get("$type") or entry_dict.get("json_type") - if entry_type == DurableAgentStateEntryJsonType.RESPONSE: - deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) - elif entry_type == DurableAgentStateEntryJsonType.REQUEST: - deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict)) - else: - deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict)) - else: - # Already an object - deserialized_history.append(entry_dict) - instance = cls(schema_version=state.get("schemaVersion", "1.0.0")) - instance.data.conversation_history = deserialized_history - instance.data.extension_data = data_dict.get("extensionData") + instance.data = DurableAgentStateData.from_dict(state.get("data", {})) return instance @@ -306,7 +306,7 @@ def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse): # Found the entry, extract response data # Get the text content from assistant messages only - content = "".join(message.text for message in entry.messages if message.text is not None) + content = "\n".join(message.text for message in entry.messages if message.text is not None) return {"content": content, "message_count": self.message_count, "correlationId": correlation_id} return None @@ -464,6 +464,9 @@ def __init__( response_schema=response_schema, ) + def to_dict(self) -> dict[str, Any]: + return super().to_dict() + @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: created_at = data.get("created_at") @@ -495,7 +498,6 @@ def from_run_request(request: RunRequest) -> DurableAgentStateRequest: correlation_id=request.correlation_id, messages=[DurableAgentStateMessage.from_run_request(request)], created_at=datetime.now(tz=timezone.utc), - extension_data=request.extension_data, response_type=request.request_response_format, response_schema=_serialize_response_format(request.response_format), ) @@ -751,9 +753,7 @@ def from_run_request(request: RunRequest) -> DurableAgentStateMessage: return DurableAgentStateMessage( role=request.role.value, contents=[DurableAgentStateTextContent(text=request.message)], - author_name=request.author_name, created_at=created_at, - extension_data=request.extension_data, ) @staticmethod diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index b0dfd13d35..a1b97bfc3d 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -297,9 +297,7 @@ class RunRequest: enable_tool_calls: bool = True thread_id: str | None = None correlation_id: str | None = None - author_name: str | None = None created_at: str | None = None - extension_data: dict[str, Any] | None = None def __init__( self, @@ -310,9 +308,7 @@ def __init__( enable_tool_calls: bool = True, thread_id: str | None = None, correlation_id: str | None = None, - author_name: str | None = None, created_at: str | None = None, - extension_data: dict[str, Any] | None = None, ) -> None: self.message = message self.role = self.coerce_role(role) @@ -321,9 +317,7 @@ def __init__( self.enable_tool_calls = enable_tool_calls self.thread_id = thread_id self.correlation_id = correlation_id - self.author_name = author_name self.created_at = created_at - self.extension_data = extension_data @staticmethod def coerce_role(value: Role | str | None) -> Role: @@ -351,12 +345,8 @@ def to_dict(self) -> dict[str, Any]: result["thread_id"] = self.thread_id if self.correlation_id: result["correlationId"] = self.correlation_id - if self.author_name: - result["author_name"] = self.author_name if self.created_at: result["created_at"] = self.created_at - if self.extension_data: - result["extension_data"] = self.extension_data return result @@ -371,9 +361,7 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: enable_tool_calls=data.get("enable_tool_calls", True), thread_id=data.get("thread_id"), correlation_id=data.get("correlationId"), - author_name=data.get("author_name"), created_at=data.get("created_at"), - extension_data=data.get("extension_data"), )