diff --git a/.gitignore b/.gitignore index 162eec9ba7..f0f8c09495 100644 --- a/.gitignore +++ b/.gitignore @@ -204,6 +204,8 @@ agents.md # AI .claude/ WARP.md +**/memory-bank/ +**/projectBrief.md # Azurite storage emulator files */__azurite_db_blob__.json diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index bb7c4398fc..e0bc3ba51a 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -16,27 +16,29 @@ from agent_framework import AgentProtocol, get_logger from ._callbacks import AgentResponseCallbackProtocol +from ._constants import ( + DEFAULT_MAX_POLL_RETRIES, + DEFAULT_POLL_INTERVAL_SECONDS, + MIMETYPE_APPLICATION_JSON, + MIMETYPE_TEXT_PLAIN, + REQUEST_RESPONSE_FORMAT_JSON, + REQUEST_RESPONSE_FORMAT_TEXT, + THREAD_ID_FIELD, + THREAD_ID_HEADER, + WAIT_FOR_RESPONSE_FIELD, + WAIT_FOR_RESPONSE_HEADER, +) +from ._durable_agent_state import DurableAgentState from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._models import AgentSessionId, RunRequest from ._orchestration import AgentOrchestrationContextType, DurableAIAgent -from ._state import AgentState logger = get_logger("agent_framework.azurefunctions") -THREAD_ID_FIELD: str = "thread_id" -RESPONSE_FORMAT_JSON: str = "json" -RESPONSE_FORMAT_TEXT: str = "text" -WAIT_FOR_RESPONSE_FIELD: str = "wait_for_response" -WAIT_FOR_RESPONSE_HEADER: str = "x-ms-wait-for-response" - - EntityHandler = Callable[[df.DurableEntityContext], None] HandlerT = TypeVar("HandlerT", bound=Callable[..., Any]) -DEFAULT_MAX_POLL_RETRIES: int = 30 -DEFAULT_POLL_INTERVAL_SECONDS: float = 1.0 - if TYPE_CHECKING: class DFAppBase: @@ -317,11 +319,11 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien """ logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run") - response_format: str = RESPONSE_FORMAT_JSON + request_response_format: str = REQUEST_RESPONSE_FORMAT_JSON thread_id: str | None = None try: - req_body, message, response_format = self._parse_incoming_request(req) + req_body, message, request_response_format = self._parse_incoming_request(req) thread_id = self._resolve_thread_id(req=req, req_body=req_body) wait_for_response = self._should_wait_for_response(req=req, req_body=req_body) @@ -334,7 +336,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": "Message is required"}, status_code=400, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -351,6 +353,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien message, thread_id, correlation_id, + request_response_format, ) logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request) await client.signal_entity(entity_instance_id, "run_agent", run_request) @@ -370,7 +373,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload=result, status_code=200 if result.get("status") == "success" else 500, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -383,7 +386,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload=accepted_response, status_code=202, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -392,7 +395,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": str(exc)}, status_code=exc.status_code, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) except ValueError as exc: @@ -400,7 +403,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": "Invalid JSON"}, status_code=400, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) except Exception as exc: @@ -408,7 +411,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien return self._create_http_response( payload={"error": str(exc)}, status_code=500, - response_format=response_format, + request_response_format=request_response_format, thread_id=thread_id, ) @@ -466,7 +469,7 @@ def health_check(req: func.HttpRequest) -> func.HttpResponse: return func.HttpResponse( json.dumps({"status": "healthy", "agents": agent_info, "agent_count": len(self.agents)}), status_code=200, - mimetype="application/json", + mimetype=MIMETYPE_APPLICATION_JSON, ) _ = health_check @@ -491,7 +494,7 @@ async def _read_cached_state( self, client: df.DurableOrchestrationClient, entity_instance_id: df.EntityId, - ) -> AgentState | None: + ) -> DurableAgentState | None: state_response = await client.read_entity_state(entity_instance_id) if not state_response or not state_response.entity_exists: return None @@ -502,9 +505,7 @@ async def _read_cached_state( typed_state_payload = cast(dict[str, Any], state_payload) - agent_state = AgentState() - agent_state.restore_state(typed_state_payload) - return agent_state + return DurableAgentState.from_dict(typed_state_payload) async def _get_response_from_entity( self, @@ -580,31 +581,58 @@ async def _poll_entity_for_response( return result - async def _build_timeout_result(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: - """Create the timeout response.""" - return { - "response": "Agent is still processing or timed out...", + def _build_response_payload( + self, + *, + response: str | None, + message: str, + thread_id: str, + status: str, + correlation_id: str, + extra_fields: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Create a consistent response structure and allow optional extra fields.""" + payload = { + "response": response, "message": message, THREAD_ID_FIELD: thread_id, - "status": "timeout", + "status": status, "correlation_id": correlation_id, } + if extra_fields: + payload.update(extra_fields) + return payload + + async def _build_timeout_result(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: + """Create the timeout response.""" + return self._build_response_payload( + response="Agent is still processing or timed out...", + message=message, + thread_id=thread_id, + status="timeout", + correlation_id=correlation_id, + ) def _build_success_result( - self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: AgentState + self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState ) -> dict[str, Any]: """Build the success result returned to the HTTP caller.""" - return { - "response": response_data.get("content"), - "message": message, - THREAD_ID_FIELD: thread_id, - "status": "success", - "message_count": response_data.get("message_count", state.message_count), - "correlation_id": correlation_id, - } + return self._build_response_payload( + response=response_data.get("content"), + message=message, + thread_id=thread_id, + status="success", + correlation_id=correlation_id, + extra_fields={"message_count": response_data.get("message_count", state.message_count)}, + ) def _build_request_data( - self, req_body: dict[str, Any], message: str, thread_id: str, correlation_id: str + self, + req_body: dict[str, Any], + message: str, + thread_id: str, + correlation_id: str, + request_response_format: str, ) -> dict[str, Any]: """Create the durable entity request payload.""" enable_tool_calls_value = req_body.get("enable_tool_calls") @@ -613,6 +641,7 @@ def _build_request_data( return RunRequest( message=message, role=req_body.get("role"), + request_response_format=request_response_format, response_format=req_body.get("response_format"), enable_tool_calls=enable_tool_calls, thread_id=thread_id, @@ -621,23 +650,23 @@ def _build_request_data( def _build_accepted_response(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: """Build the response returned when not waiting for completion.""" - return { - "response": "Agent request accepted", - "message": message, - THREAD_ID_FIELD: thread_id, - "status": "accepted", - "correlation_id": correlation_id, - } + return self._build_response_payload( + response="Agent request accepted", + message=message, + thread_id=thread_id, + status="accepted", + correlation_id=correlation_id, + ) def _create_http_response( self, payload: dict[str, Any] | str, status_code: int, - response_format: str, + request_response_format: str, thread_id: str | None, ) -> func.HttpResponse: """Create the HTTP response using helper serializers for clarity.""" - if response_format == RESPONSE_FORMAT_TEXT: + if request_response_format == REQUEST_RESPONSE_FORMAT_TEXT: return self._build_plain_text_response(payload=payload, status_code=status_code, thread_id=thread_id) return self._build_json_response(payload=payload, status_code=status_code) @@ -650,13 +679,13 @@ def _build_plain_text_response( ) -> func.HttpResponse: """Return a plain-text response with optional thread identifier header.""" body_text = payload if isinstance(payload, str) else self._convert_payload_to_text(payload) - headers = {"x-ms-thread-id": thread_id} if thread_id is not None else None - return func.HttpResponse(body_text, status_code=status_code, mimetype="text/plain", headers=headers) + headers = {THREAD_ID_HEADER: thread_id} if thread_id is not None else None + return func.HttpResponse(body_text, status_code=status_code, mimetype=MIMETYPE_TEXT_PLAIN, headers=headers) def _build_json_response(self, payload: dict[str, Any] | str, status_code: int) -> func.HttpResponse: """Return the JSON response, serializing dictionaries as needed.""" body_json = payload if isinstance(payload, str) else json.dumps(payload) - return func.HttpResponse(body_json, status_code=status_code, mimetype="application/json") + return func.HttpResponse(body_json, status_code=status_code, mimetype=MIMETYPE_APPLICATION_JSON) def _convert_payload_to_text(self, payload: dict[str, Any]) -> str: """Convert a structured payload into a human-readable text response.""" @@ -702,18 +731,19 @@ def _parse_incoming_request(self, req: func.HttpRequest) -> tuple[dict[str, Any] normalized_content_type = self._extract_content_type(headers) body_parser, body_format = self._select_body_parser(normalized_content_type) prefers_json = self._accepts_json_response(headers) - response_format = self._select_response_format(body_format=body_format, prefers_json=prefers_json) + request_response_format = self._select_request_response_format( + body_format=body_format, prefers_json=prefers_json + ) req_body, message = body_parser(req) - return req_body, message, response_format + return req_body, message, request_response_format def _extract_normalized_headers(self, req: func.HttpRequest) -> dict[str, str]: """Create a lowercase header mapping from the incoming request.""" headers: dict[str, str] = {} raw_headers = req.headers if isinstance(raw_headers, Mapping): - header_mapping: Mapping[str, Any] = cast(Mapping[str, Any], raw_headers) - for key, value in header_mapping.items(): + for key, value in raw_headers.items(): if value is not None: headers[str(key).lower()] = str(value) return headers @@ -729,9 +759,9 @@ def _select_body_parser( normalized_content_type: str, ) -> tuple[Callable[[func.HttpRequest], tuple[dict[str, Any], str]], str]: """Choose the body parser and declared body format.""" - if normalized_content_type in {"application/json"} or normalized_content_type.endswith("+json"): - return self._parse_json_body, RESPONSE_FORMAT_JSON - return self._parse_text_body, RESPONSE_FORMAT_TEXT + if normalized_content_type in {MIMETYPE_APPLICATION_JSON} or normalized_content_type.endswith("+json"): + return self._parse_json_body, REQUEST_RESPONSE_FORMAT_JSON + return self._parse_text_body, REQUEST_RESPONSE_FORMAT_TEXT @staticmethod def _accepts_json_response(headers: dict[str, str]) -> bool: @@ -742,16 +772,16 @@ def _accepts_json_response(headers: dict[str, str]) -> bool: for value in accept_header.split(","): media_type = value.split(";")[0].strip().lower() - if media_type == "application/json": + if media_type == MIMETYPE_APPLICATION_JSON: return True return False @staticmethod - def _select_response_format(body_format: str, prefers_json: bool) -> str: + def _select_request_response_format(body_format: str, prefers_json: bool) -> str: """Combine body format and accept preference to determine response format.""" - if body_format == RESPONSE_FORMAT_JSON or prefers_json: - return RESPONSE_FORMAT_JSON - return RESPONSE_FORMAT_TEXT + if body_format == REQUEST_RESPONSE_FORMAT_JSON or prefers_json: + return REQUEST_RESPONSE_FORMAT_JSON + return REQUEST_RESPONSE_FORMAT_TEXT @staticmethod def _parse_json_body(req: func.HttpRequest) -> tuple[dict[str, Any], str]: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py new file mode 100644 index 0000000000..8c4cded196 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Constants for Azure Functions Agent Framework integration.""" + +# Supported request/response formats and MIME types +REQUEST_RESPONSE_FORMAT_JSON: str = "json" +REQUEST_RESPONSE_FORMAT_TEXT: str = "text" +MIMETYPE_APPLICATION_JSON: str = "application/json" +MIMETYPE_TEXT_PLAIN: str = "text/plain" + +# Field and header names +THREAD_ID_FIELD: str = "thread_id" +THREAD_ID_HEADER: str = "x-ms-thread-id" +WAIT_FOR_RESPONSE_FIELD: str = "wait_for_response" +WAIT_FOR_RESPONSE_HEADER: str = "x-ms-wait-for-response" + +# Polling configuration +DEFAULT_MAX_POLL_RETRIES: int = 30 +DEFAULT_POLL_INTERVAL_SECONDS: float = 1.0 diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py new file mode 100644 index 0000000000..73695e61f2 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -0,0 +1,1192 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Durable agent state management conforming to the durable-agent-entity-state.json schema. + +This module provides classes for managing conversation state in Azure Durable Functions agents. +It implements the versioned schema that defines how agent conversations are persisted and restored +across invocations, enabling stateful, long-running agent sessions. + +The module includes: +- DurableAgentState: Root state container with schema version and conversation history +- DurableAgentStateEntry and subclasses: Request and response entries in conversation history +- DurableAgentStateMessage: Individual messages with role, content items, and metadata +- Content type classes: Specialized types for text, function calls, errors, and other content +- Serialization/deserialization: Conversion between Python objects and JSON schema format + +The state structure follows this hierarchy: + DurableAgentState + └── DurableAgentStateData + └── conversationHistory: List[DurableAgentStateEntry] + ├── DurableAgentStateRequest (user/system messages) + └── DurableAgentStateResponse (assistant messages with usage stats) + └── messages: List[DurableAgentStateMessage] + └── contents: List[DurableAgentStateContent subclasses] + +All classes support bidirectional conversion between: +- Durable state format (JSON with camelCase, $type discriminators) +- Agent framework objects (Python objects with snake_case) +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from enum import Enum +from typing import Any + +from agent_framework import ( + AgentRunResponse, + BaseContent, + ChatMessage, + DataContent, + ErrorContent, + FunctionCallContent, + FunctionResultContent, + HostedFileContent, + HostedVectorStoreContent, + TextContent, + TextReasoningContent, + UriContent, + UsageContent, + UsageDetails, + get_logger, +) +from dateutil import parser as date_parser + +from ._models import RunRequest, _serialize_response_format + +logger = get_logger("agent_framework.azurefunctions.durable_agent_state") + + +def _parse_created_at(value: Any) -> datetime: + """Normalize created_at values coming from persisted durable state.""" + if isinstance(value, datetime): + return value + + if isinstance(value, str): + try: + parsed = date_parser.parse(value) + if isinstance(parsed, datetime): + return parsed + except (ValueError, TypeError): + pass + + return datetime.now(tz=timezone.utc) + + +class DurableAgentStateContent: + """Base class for all content types in durable agent state messages. + + This abstract base class defines the interface for content items that can be + stored in conversation history. Content types include text, function calls, + function results, errors, and other specialized content types defined by the + agent framework. + + Subclasses must implement to_dict() and to_ai_content() to handle conversion + between the durable state representation and the agent framework's content objects. + + Attributes: + extensionData: Optional additional metadata (not serialized per schema) + """ + + extensionData: dict[str, Any] | None = None + type: str = "" + + def to_dict(self) -> dict[str, Any]: + """Serialize this content to a dictionary for JSON storage. + + Returns: + Dictionary representation including $type discriminator and content-specific fields + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError + + def to_ai_content(self) -> Any: + """Convert this durable state content back to an agent framework content object. + + Returns: + An agent framework content object (TextContent, FunctionCallContent, etc.) + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError + + @staticmethod + def from_ai_content(content: Any) -> DurableAgentStateContent: + """Create a durable state content object from an agent framework content object. + + This factory method maps agent framework content types (TextContent, FunctionCallContent, + etc.) to their corresponding durable state representations. Unknown content types are + wrapped in DurableAgentStateUnknownContent. + + Args: + content: An agent framework content object (TextContent, FunctionCallContent, etc.) + + Returns: + The corresponding DurableAgentStateContent subclass instance + """ + # Map AI content type to appropriate DurableAgentStateContent subclass + if isinstance(content, DataContent): + return DurableAgentStateDataContent.from_data_content(content) + if isinstance(content, ErrorContent): + return DurableAgentStateErrorContent.from_error_content(content) + if isinstance(content, FunctionCallContent): + return DurableAgentStateFunctionCallContent.from_function_call_content(content) + if isinstance(content, FunctionResultContent): + return DurableAgentStateFunctionResultContent.from_function_result_content(content) + if isinstance(content, HostedFileContent): + return DurableAgentStateHostedFileContent.from_hosted_file_content(content) + if isinstance(content, HostedVectorStoreContent): + return DurableAgentStateHostedVectorStoreContent.from_hosted_vector_store_content(content) + if isinstance(content, TextContent): + return DurableAgentStateTextContent.from_text_content(content) + if isinstance(content, TextReasoningContent): + return DurableAgentStateTextReasoningContent.from_text_reasoning_content(content) + if isinstance(content, UriContent): + return DurableAgentStateUriContent.from_uri_content(content) + if isinstance(content, UsageContent): + return DurableAgentStateUsageContent.from_usage_content(content) + return DurableAgentStateUnknownContent.from_unknown_content(content) + + +# Core state classes + + +class DurableAgentStateData: + """Container for the core data within durable agent state. + + This class holds the primary data structures for agent conversation state, + including the conversation history (a sequence of request and response entries) + and optional extension data for custom metadata. + + The data structure is nested within DurableAgentState under the "data" property, + conforming to the durable-agent-entity-state.json schema structure. + + Attributes: + conversation_history: Ordered list of conversation entries (requests and responses) + extension_data: Optional dictionary for custom metadata (not part of core schema) + """ + + conversation_history: list[DurableAgentStateEntry] + extension_data: dict[str, Any] | None + + def __init__( + self, + conversation_history: list[DurableAgentStateEntry] | None = None, + extension_data: dict[str, Any] | None = None, + ) -> None: + """Initialize the data container. + + Args: + conversation_history: Initial conversation history (defaults to empty list) + extension_data: Optional custom metadata + """ + self.conversation_history = conversation_history or [] + self.extension_data = extension_data + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "conversationHistory": [entry.to_dict() for entry in self.conversation_history], + } + if self.extension_data is not None: + result["extensionData"] = self.extension_data + return result + + @classmethod + def from_dict(cls, data_dict: dict[str, Any]) -> DurableAgentStateData: + # Restore the conversation history - deserialize entries from dicts to objects + history_data = data_dict.get("conversationHistory", []) + deserialized_history: list[DurableAgentStateEntry] = [] + for entry_dict in history_data: + if isinstance(entry_dict, dict): + # Deserialize based on $type discriminator + entry_type = entry_dict.get("$type") or entry_dict.get("json_type") + if entry_type == DurableAgentStateEntryJsonType.RESPONSE: + deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) + elif entry_type == DurableAgentStateEntryJsonType.REQUEST: + deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict)) + else: + deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict)) + else: + # Already an object + deserialized_history.append(entry_dict) + + return cls( + conversation_history=deserialized_history, + extension_data=data_dict.get("extensionData"), + ) + + +class DurableAgentState: + """Manages durable agent state conforming to the durable-agent-entity-state.json schema. + + This class provides the root container for agent conversation state that can be persisted + in Azure Durable Entities. It maintains the conversation history as a sequence of request + and response entries, each with their messages, timestamps, and metadata. + + The state follows a versioned schema (currently 1.0.0) that defines the structure for: + - Request entries: User/system messages with optional response format specifications + - Response entries: Assistant messages with token usage information + - Messages: Individual chat messages with role, content items, and timestamps + - Content items: Text, function calls, function results, errors, and other content types + + State is serialized to JSON with this structure: + { + "schemaVersion": "1.0.0", + "data": { + "conversationHistory": [ + {"$type": "request", "correlationId": "...", "createdAt": "...", "messages": [...]}, + {"$type": "response", "correlationId": "...", "createdAt": "...", "messages": [...], "usage": {...}} + ] + } + } + + Attributes: + data: Container for conversation history and optional extension data + schema_version: Schema version string (defaults to "1.0.0") + """ + + data: DurableAgentStateData + schema_version: str = "1.0.0" + + def __init__(self, schema_version: str = "1.0.0"): + """Initialize a new durable agent state. + + Args: + schema_version: Schema version to use (defaults to "1.0.0") + """ + self.data = DurableAgentStateData() + self.schema_version = schema_version + + def to_dict(self) -> dict[str, Any]: + + return { + "schemaVersion": self.schema_version, + "data": self.data.to_dict(), + } + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @classmethod + def from_dict(cls, state: dict[str, Any]) -> DurableAgentState: + """Restore state from a dictionary. + + Args: + state: Dictionary containing schemaVersion and data (full state structure) + """ + schema_version = state.get("schemaVersion") + if schema_version is None: + logger.warning("Resetting state as it is incompatible with the current schema, all history will be lost") + return cls() + + instance = cls(schema_version=state.get("schemaVersion", "1.0.0")) + instance.data = DurableAgentStateData.from_dict(state.get("data", {})) + + return instance + + @classmethod + def from_json(cls, json_str: str) -> DurableAgentState: + try: + obj = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError("The durable agent state is not valid JSON.") from e + + return cls.from_dict(obj) + + @property + def message_count(self) -> int: + """Get the count of conversation entries (requests + responses).""" + return len(self.data.conversation_history) + + def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: + """Try to get an agent response by correlation ID. + + This method searches the conversation history for a response entry matching the given + correlation ID and returns a dictionary suitable for HTTP API responses. + + Note: The returned dictionary includes computed properties (message_count) that are + NOT part of the persisted state schema. These are derived values included for backward + compatibility with the HTTP API response format and should not be considered part of + the durable state structure. + + Args: + correlation_id: The correlation ID to search for + + Returns: + Response data dict with 'content', 'message_count', and 'correlationId' if found, + None otherwise + """ + # Search through conversation history for a response with this correlationId + for entry in self.data.conversation_history: + if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse): + # Found the entry, extract response data + # Get the text content from assistant messages only + content = "\n".join(message.text for message in entry.messages if message.text is not None) + + return {"content": content, "message_count": self.message_count, "correlationId": correlation_id} + return None + + +class DurableAgentStateEntryJsonType(str, Enum): + """Enum for conversation history entry types. + + Discriminator values for the $type field in DurableAgentStateEntry objects. + """ + + REQUEST = "request" + RESPONSE = "response" + + +class DurableAgentStateEntry: + """Base class for conversation history entries (requests and responses). + + This class represents a single entry in the conversation history. Each entry can be + either a request (user/system messages sent to the agent) or a response (assistant + messages from the agent). The $type discriminator field determines which type of entry + it represents. + + Entries are linked together using correlation IDs, allowing responses to be matched + with their originating requests. + + Common Attributes: + json_type: Discriminator for entry type ("request" or "response") + correlationId: Unique identifier linking requests and responses + created_at: Timestamp when the entry was created + messages: List of messages in this entry + extensionData: Optional additional metadata (not serialized per schema) + + Request-only Attributes: + responseType: Expected response type ("text" or "json") - only for request entries + responseSchema: JSON schema for structured responses - only for request entries + + Response-only Attributes: + usage: Token usage statistics - only for response entries + """ + + json_type: DurableAgentStateEntryJsonType + correlation_id: str | None + created_at: datetime + messages: list[DurableAgentStateMessage] + extension_data: dict[str, Any] | None + + def __init__( + self, + json_type: DurableAgentStateEntryJsonType, + correlation_id: str | None, + created_at: datetime, + messages: list[DurableAgentStateMessage], + extension_data: dict[str, Any] | None = None, + ) -> None: + self.json_type = json_type + self.correlation_id = correlation_id + self.created_at = created_at + self.messages = messages + self.extension_data = extension_data + + def to_dict(self) -> dict[str, Any]: + # Ensure createdAt is never null + created_at_value = self.created_at + if created_at_value is None: + created_at_value = datetime.now(tz=timezone.utc) + + return { + "$type": self.json_type, + "correlationId": self.correlation_id, + "createdAt": created_at_value.isoformat() if isinstance(created_at_value, datetime) else created_at_value, + "messages": [m.to_dict() for m in self.messages], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateEntry: + created_at = _parse_created_at(data.get("created_at")) + + messages = [] + for msg_dict in data.get("messages", []): + if isinstance(msg_dict, dict): + messages.append(DurableAgentStateMessage.from_dict(msg_dict)) + else: + messages.append(msg_dict) + + return cls( + json_type=DurableAgentStateEntryJsonType(data.get("$type", "entry")), + correlation_id=data.get("correlationId", ""), + created_at=created_at, + messages=messages, + extension_data=data.get("extensionData"), + ) + + +class DurableAgentStateRequest(DurableAgentStateEntry): + """Represents a request entry in the durable agent conversation history. + + A request entry captures a user or system message sent to the agent, along with + optional response format specifications. Each request is stored as a separate + entry in the conversation history with a unique correlation ID. + + Attributes: + response_type: Expected response type ("text" or "json") + response_schema: JSON schema for structured responses (when response_type is "json") + correlationId: Unique identifier linking this request to its response + created_at: Timestamp when the request was created + messages: List of messages included in this request + json_type: Always "request" for this class + """ + + response_type: str | None = None + response_schema: dict[str, Any] | None = None + + def __init__( + self, + correlation_id: str | None, + created_at: datetime, + messages: list[DurableAgentStateMessage], + extension_data: dict[str, Any] | None = None, + response_type: str | None = None, + response_schema: dict[str, Any] | None = None, + ) -> None: + super().__init__( + json_type=DurableAgentStateEntryJsonType.REQUEST, + correlation_id=correlation_id, + created_at=created_at, + messages=messages, + extension_data=extension_data, + ) + self.response_type = response_type + self.response_schema = response_schema + + def to_dict(self) -> dict[str, Any]: + data = super().to_dict() + if self.response_type is not None: + data["responseType"] = self.response_type + if self.response_schema is not None: + data["responseSchema"] = self.response_schema + return data + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: + created_at = _parse_created_at(data.get("created_at")) + + messages = [] + for msg_dict in data.get("messages", []): + if isinstance(msg_dict, dict): + messages.append(DurableAgentStateMessage.from_dict(msg_dict)) + else: + messages.append(msg_dict) + + return cls( + correlation_id=data.get("correlationId", ""), + created_at=created_at, + messages=messages, + extension_data=data.get("extensionData"), + response_type=data.get("responseType"), + response_schema=data.get("responseSchema"), + ) + + @staticmethod + def from_run_request(request: RunRequest) -> DurableAgentStateRequest: + # Determine response_type based on response_format + return DurableAgentStateRequest( + correlation_id=request.correlation_id, + messages=[DurableAgentStateMessage.from_run_request(request)], + created_at=datetime.now(tz=timezone.utc), + response_type=request.request_response_format, + response_schema=_serialize_response_format(request.response_format), + ) + + +class DurableAgentStateResponse(DurableAgentStateEntry): + """Represents a response entry in the durable agent conversation history. + + A response entry captures the agent's reply to a user request, including any + assistant messages, tool calls, and token usage information. Each response is + linked to its originating request via a correlation ID. + + Attributes: + usage: Token usage statistics for this response (input, output, and total tokens) + is_error: Flag indicating if this response represents an error (not persisted in schema) + correlation_id: Unique identifier linking this response to its request + created_at: Timestamp when the response was created + messages: List of assistant messages in this response + json_type: Always "response" for this class + """ + + usage: DurableAgentStateUsage | None = None + is_error: bool = False + + def __init__( + self, + correlation_id: str, + created_at: datetime, + messages: list[DurableAgentStateMessage], + extension_data: dict[str, Any] | None = None, + usage: DurableAgentStateUsage | None = None, + is_error: bool = False, + ) -> None: + super().__init__( + json_type=DurableAgentStateEntryJsonType.RESPONSE, + correlation_id=correlation_id, + created_at=created_at, + messages=messages, + extension_data=extension_data, + ) + self.usage = usage + self.is_error = is_error + + def to_dict(self) -> dict[str, Any]: + data = super().to_dict() + if self.usage is not None: + data["usage"] = self.usage.to_dict() + return data + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateResponse: + created_at = _parse_created_at(data.get("created_at")) + + messages = [] + for msg_dict in data.get("messages", []): + if isinstance(msg_dict, dict): + messages.append(DurableAgentStateMessage.from_dict(msg_dict)) + else: + messages.append(msg_dict) + + usage_dict = data.get("usage") + usage = None + if usage_dict and isinstance(usage_dict, dict): + usage = DurableAgentStateUsage.from_dict(usage_dict) + elif usage_dict: + usage = usage_dict + + return cls( + correlation_id=data.get("correlationId", ""), + created_at=created_at, + messages=messages, + extension_data=data.get("extensionData"), + usage=usage, + ) + + @staticmethod + def from_run_response(correlation_id: str, response: AgentRunResponse) -> DurableAgentStateResponse: + """Creates a DurableAgentStateResponse from an AgentRunResponse.""" + return DurableAgentStateResponse( + correlation_id=correlation_id, + created_at=_parse_created_at(response.created_at), + messages=[DurableAgentStateMessage.from_chat_message(m) for m in response.messages], + usage=DurableAgentStateUsage.from_usage(response.usage_details), + ) + + def to_run_response(self) -> Any: + """Converts this DurableAgentStateResponse back to an AgentRunResponse.""" + return AgentRunResponse( + created_at=self.created_at.isoformat() if self.created_at else None, + messages=[m.to_chat_message() for m in self.messages], + usage=self.usage.to_usage_details() if self.usage else None, + ) + + +class DurableAgentStateMessage: + """Represents a message within a conversation history entry. + + A message contains the role (user, assistant, system), content items (text, function calls, + tool results, etc.), and optional metadata. Messages are the building blocks of both + request and response entries in the conversation history. + + Attributes: + role: The sender role ("user", "assistant", or "system") + contents: List of content items (text, function calls, errors, etc.) + author_name: Optional name of the message author (typically set for assistant messages) + created_at: Optional timestamp when the message was created + extension_data: Optional additional metadata (not serialized per schema) + """ + + role: str + contents: list[DurableAgentStateContent] + author_name: str | None = None + created_at: datetime | None = None + extension_data: dict[str, Any] | None = None + + def __init__( + self, + role: str, + contents: list[DurableAgentStateContent], + author_name: str | None = None, + created_at: datetime | None = None, + extension_data: dict[str, Any] | None = None, + ) -> None: + self.role = role + self.contents = contents + self.author_name = author_name + self.created_at = created_at + self.extension_data = extension_data + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "role": self.role, + "contents": [ + {"$type": c.to_dict().get("type", "text"), **{k: v for k, v in c.to_dict().items() if k != "type"}} + for c in self.contents + ], + } + # Only include optional fields if they have values + if self.created_at is not None: + result["createdAt"] = self.created_at.isoformat() + if self.author_name is not None: + result["authorName"] = self.author_name + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateMessage: + contents: list[DurableAgentStateContent] = [] + for content_dict in data.get("contents", []): + if isinstance(content_dict, dict): + content_type = content_dict.get("$type") + if content_type == DurableAgentStateTextContent.type: + contents.append(DurableAgentStateTextContent(text=content_dict.get("text"))) + elif content_type == DurableAgentStateDataContent.type: + contents.append( + DurableAgentStateDataContent( + uri=content_dict.get("uri", ""), media_type=content_dict.get("mediaType") + ) + ) + elif content_type == DurableAgentStateErrorContent.type: + contents.append( + DurableAgentStateErrorContent( + message=content_dict.get("message"), + error_code=content_dict.get("errorCode"), + details=content_dict.get("details"), + ) + ) + elif content_type == DurableAgentStateFunctionCallContent.type: + contents.append( + DurableAgentStateFunctionCallContent( + call_id=content_dict.get("callId", ""), + name=content_dict.get("name", ""), + arguments=content_dict.get("arguments", {}), + ) + ) + elif content_type == DurableAgentStateFunctionResultContent.type: + contents.append( + DurableAgentStateFunctionResultContent( + call_id=content_dict.get("callId", ""), result=content_dict.get("result") + ) + ) + elif content_type == DurableAgentStateHostedFileContent.type: + contents.append(DurableAgentStateHostedFileContent(file_id=content_dict.get("fileId", ""))) + elif content_type == DurableAgentStateHostedVectorStoreContent.type: + contents.append( + DurableAgentStateHostedVectorStoreContent(vector_store_id=content_dict.get("vectorStoreId", "")) + ) + elif content_type == DurableAgentStateTextReasoningContent.type: + contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text"))) + elif content_type == DurableAgentStateUriContent.type: + contents.append( + DurableAgentStateUriContent( + uri=content_dict.get("uri", ""), media_type=content_dict.get("mediaType", "") + ) + ) + elif content_type == DurableAgentStateUsageContent.type: + usage_data = content_dict.get("usage") + if usage_data and isinstance(usage_data, dict): + contents.append( + DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_dict(usage_data)) + ) + elif content_type == DurableAgentStateUnknownContent.type: + contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content", {}))) + else: + contents.append(content_dict) # type: ignore + + return cls( + role=data.get("role", ""), + contents=contents, + author_name=data.get("authorName"), + created_at=_parse_created_at(data.get("createdAt")), + extension_data=data.get("extensionData"), + ) + + @property + def text(self) -> str: + """Extract text from the contents list.""" + text_parts = [] + for content in self.contents: + if isinstance(content, DurableAgentStateTextContent): + text_parts.append(content.text or "") + return "".join(text_parts) + + @staticmethod + def from_run_request(request: RunRequest) -> DurableAgentStateMessage: + """Converts a RunRequest from the agent framework to a DurableAgentStateMessage. + + Args: + request: RunRequest object with role, message/contents, and metadata + Returns: + DurableAgentStateMessage with converted content items and metadata + """ + return DurableAgentStateMessage( + role=request.role.value, + contents=[DurableAgentStateTextContent(text=request.message)], + created_at=_parse_created_at(request.created_at), + ) + + @staticmethod + def from_chat_message(chat_message: ChatMessage) -> DurableAgentStateMessage: + """Converts an Agent Framework chat message to a durable state message. + + Args: + chat_message: ChatMessage object with role, contents, and metadata to convert + + Returns: + DurableAgentStateMessage with converted content items and metadata + """ + contents_list: list[DurableAgentStateContent] = [ + DurableAgentStateContent.from_ai_content(c) for c in chat_message.contents + ] + + return DurableAgentStateMessage( + role=chat_message.role.value, + contents=contents_list, + author_name=chat_message.author_name, + extension_data=dict(chat_message.additional_properties) if chat_message.additional_properties else None, + ) + + def to_chat_message(self) -> Any: + """Converts this DurableAgentStateMessage back to an agent framework ChatMessage. + + Returns: + ChatMessage object with role, contents, and metadata converted back to agent framework types + """ + # Convert DurableAgentStateContent objects back to agent_framework content objects + ai_contents = [c.to_ai_content() for c in self.contents] + + # Build kwargs for ChatMessage + kwargs: dict[str, Any] = { + "role": self.role, + "contents": ai_contents, + } + + if self.author_name is not None: + kwargs["author_name"] = self.author_name + + if self.extension_data is not None: + kwargs["additional_properties"] = self.extension_data + + return ChatMessage(**kwargs) + + +class DurableAgentStateDataContent(DurableAgentStateContent): + """Represents data content with a URI reference. + + This content type is used to reference data stored at a specific URI location, + optionally with a media type specification. Common use cases include referencing + files, documents, or other data resources. + + Attributes: + uri: URI pointing to the data resource + media_type: Optional MIME type of the data (e.g., "application/json", "text/plain") + """ + + uri: str = "" + media_type: str | None = None + type: str = "data" + + def __init__(self, uri: str, media_type: str | None = None) -> None: + self.uri = uri + self.media_type = media_type + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "uri": self.uri, "mediaType": self.media_type} + + @staticmethod + def from_data_content(content: DataContent) -> DurableAgentStateDataContent: + return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type) + + def to_ai_content(self) -> DataContent: + return DataContent(uri=self.uri, media_type=self.media_type) + + +class DurableAgentStateErrorContent(DurableAgentStateContent): + """Represents error content in agent responses. + + This content type is used to communicate errors that occurred during agent execution, + including error messages, error codes, and additional details for debugging. + + Attributes: + message: Human-readable error message + error_code: Machine-readable error code or exception type + details: Additional error details or stack trace information + """ + + message: str | None = None + error_code: str | None = None + details: str | None = None + + type: str = "error" + + def __init__(self, message: str | None = None, error_code: str | None = None, details: str | None = None) -> None: + self.message = message + self.error_code = error_code + self.details = details + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "message": self.message, "errorCode": self.error_code, "details": self.details} + + @staticmethod + def from_error_content(content: ErrorContent) -> DurableAgentStateErrorContent: + return DurableAgentStateErrorContent( + message=content.message, error_code=content.error_code, details=content.details + ) + + def to_ai_content(self) -> ErrorContent: + return ErrorContent(message=self.message, error_code=self.error_code, details=self.details) + + +class DurableAgentStateFunctionCallContent(DurableAgentStateContent): + """Represents a function/tool call request from the agent. + + This content type is used when the agent requests execution of a function or tool, + including the function name, arguments, and a unique call identifier for tracking + the call-result pair. + + Attributes: + call_id: Unique identifier for this function call (used to match with results) + name: Name of the function/tool to execute + arguments: Dictionary of argument names to values for the function call + """ + + call_id: str + name: str + arguments: dict[str, Any] + + type: str = "functionCall" + + def __init__(self, call_id: str, name: str, arguments: dict[str, Any]) -> None: + self.call_id = call_id + self.name = name + self.arguments = arguments + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "callId": self.call_id, "name": self.name, "arguments": self.arguments} + + @staticmethod + def from_function_call_content(content: FunctionCallContent) -> DurableAgentStateFunctionCallContent: + # Ensure arguments is a dict; parse string if needed + arguments: dict[str, Any] = {} + if content.arguments: + if isinstance(content.arguments, dict): + arguments = content.arguments + elif isinstance(content.arguments, str): + # Parse JSON string to dict + try: + arguments = json.loads(content.arguments) + except json.JSONDecodeError: + arguments = {} + + return DurableAgentStateFunctionCallContent(call_id=content.call_id, name=content.name, arguments=arguments) + + def to_ai_content(self) -> FunctionCallContent: + return FunctionCallContent(call_id=self.call_id, name=self.name, arguments=self.arguments) + + +class DurableAgentStateFunctionResultContent(DurableAgentStateContent): + """Represents the result of a function/tool call execution. + + This content type is used to communicate the result of executing a function or tool + that was previously requested by the agent. The call_id links this result back to + the original function call request. + + Attributes: + call_id: Unique identifier matching the original function call + result: The return value from the function execution (can be any serializable type) + """ + + call_id: str + result: object | None = None + + type: str = "functionResult" + + def __init__(self, call_id: str, result: Any | None = None) -> None: + self.call_id = call_id + self.result = result + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "callId": self.call_id, "result": self.result} + + @staticmethod + def from_function_result_content(content: FunctionResultContent) -> DurableAgentStateFunctionResultContent: + return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result) + + def to_ai_content(self) -> FunctionResultContent: + return FunctionResultContent(call_id=self.call_id, result=self.result) + + +class DurableAgentStateHostedFileContent(DurableAgentStateContent): + """Represents a reference to a hosted file resource. + + This content type is used to reference files that are hosted by the agent platform + or a file storage service, identified by a unique file ID. + + Attributes: + file_id: Unique identifier for the hosted file + """ + + file_id: str + + type: str = "hostedFile" + + def __init__(self, file_id: str) -> None: + self.file_id = file_id + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "fileId": self.file_id} + + @staticmethod + def from_hosted_file_content(content: HostedFileContent) -> DurableAgentStateHostedFileContent: + return DurableAgentStateHostedFileContent(file_id=content.file_id) + + def to_ai_content(self) -> HostedFileContent: + return HostedFileContent(file_id=self.file_id) + + +class DurableAgentStateHostedVectorStoreContent(DurableAgentStateContent): + """Represents a reference to a hosted vector store resource. + + This content type is used to reference vector stores (used for semantic search + and retrieval-augmented generation) that are hosted by the agent platform, + identified by a unique vector store ID. + + Attributes: + vector_store_id: Unique identifier for the hosted vector store + """ + + vector_store_id: str + + type: str = "hostedVectorStore" + + def __init__(self, vector_store_id: str) -> None: + self.vector_store_id = vector_store_id + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "vectorStoreId": self.vector_store_id} + + @staticmethod + def from_hosted_vector_store_content( + content: HostedVectorStoreContent, + ) -> DurableAgentStateHostedVectorStoreContent: + return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id) + + def to_ai_content(self) -> HostedVectorStoreContent: + return HostedVectorStoreContent(vector_store_id=self.vector_store_id) + + +class DurableAgentStateTextContent(DurableAgentStateContent): + """Represents plain text content in messages. + + This is the most common content type, used for regular text messages from users + and text responses from the agent. + + Attributes: + text: The text content of the message + """ + + type: str = "text" + + def __init__(self, text: str | None) -> None: + self.text = text + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "text": self.text} + + @staticmethod + def from_text_content(content: TextContent) -> DurableAgentStateTextContent: + return DurableAgentStateTextContent(text=content.text) + + def to_ai_content(self) -> TextContent: + return TextContent(text=self.text or "") + + +class DurableAgentStateTextReasoningContent(DurableAgentStateContent): + """Represents reasoning or thought process text from the agent. + + This content type is used to capture the agent's internal reasoning, chain of thought, + or explanation of its decision-making process, separate from the final response text. + + Attributes: + text: The reasoning or thought process text + """ + + type: str = "reasoning" + + def __init__(self, text: str | None) -> None: + self.text = text + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "text": self.text} + + @staticmethod + def from_text_reasoning_content(content: TextReasoningContent) -> DurableAgentStateTextReasoningContent: + return DurableAgentStateTextReasoningContent(text=content.text) + + def to_ai_content(self) -> TextReasoningContent: + return TextReasoningContent(text=self.text or "") + + +class DurableAgentStateUriContent(DurableAgentStateContent): + """Represents content referenced by a URI with media type. + + This content type is used to reference external content via a URI, with an associated + media type to indicate how the content should be interpreted. + + Attributes: + uri: URI pointing to the content resource + media_type: MIME type of the content (e.g., "image/png", "application/pdf") + """ + + uri: str + media_type: str + + type: str = "uri" + + def __init__(self, uri: str, media_type: str) -> None: + self.uri = uri + self.media_type = media_type + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "uri": self.uri, "mediaType": self.media_type} + + @staticmethod + def from_uri_content(content: UriContent) -> DurableAgentStateUriContent: + return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type) + + def to_ai_content(self) -> UriContent: + return UriContent(uri=self.uri, media_type=self.media_type) + + +class DurableAgentStateUsage: + """Represents token usage statistics for agent responses. + + This class tracks the number of tokens consumed during agent execution, + including input tokens (from the request), output tokens (in the response), + and the total token count. + + Attributes: + input_token_count: Number of tokens in the input/request + output_token_count: Number of tokens in the output/response + total_token_count: Total number of tokens consumed (input + output) + extensionData: Optional additional metadata + """ + + input_token_count: int | None = None + output_token_count: int | None = None + total_token_count: int | None = None + extensionData: dict[str, Any] | None = None + + def __init__( + self, + input_token_count: int | None = None, + output_token_count: int | None = None, + total_token_count: int | None = None, + extensionData: dict[str, Any] | None = None, + ) -> None: + self.input_token_count = input_token_count + self.output_token_count = output_token_count + self.total_token_count = total_token_count + self.extensionData = extensionData + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "inputTokenCount": self.input_token_count, + "outputTokenCount": self.output_token_count, + "totalTokenCount": self.total_token_count, + } + if self.extensionData is not None: + result["extensionData"] = self.extensionData + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateUsage: + return cls( + input_token_count=data.get("inputTokenCount"), + output_token_count=data.get("outputTokenCount"), + total_token_count=data.get("totalTokenCount"), + extensionData=data.get("extensionData"), + ) + + @staticmethod + def from_usage(usage: UsageDetails | None) -> DurableAgentStateUsage | None: + if usage is None: + return None + return DurableAgentStateUsage( + input_token_count=usage.input_token_count, + output_token_count=usage.output_token_count, + total_token_count=usage.total_token_count, + ) + + def to_usage_details(self) -> UsageDetails: + # Convert back to AI SDK UsageDetails + return UsageDetails( + input_token_count=self.input_token_count, + output_token_count=self.output_token_count, + total_token_count=self.total_token_count, + ) + + +class DurableAgentStateUsageContent(DurableAgentStateContent): + """Represents token usage information as message content. + + This content type is used to communicate token usage statistics as part of + message content, allowing usage information to be tracked alongside other + content types in the conversation history. + + Attributes: + usage: DurableAgentStateUsage object containing token counts + """ + + usage: DurableAgentStateUsage = DurableAgentStateUsage() + + type: str = "usage" + + def __init__(self, usage: DurableAgentStateUsage) -> None: + self.usage = usage + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "usage": self.usage.to_dict() if hasattr(self.usage, "to_dict") else self.usage} + + @staticmethod + def from_usage_content(content: UsageContent) -> DurableAgentStateUsageContent: + return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details)) # type: ignore + + def to_ai_content(self) -> UsageContent: + return UsageContent(details=self.usage.to_usage_details()) + + +class DurableAgentStateUnknownContent(DurableAgentStateContent): + """Represents unknown or unrecognized content types. + + This content type serves as a fallback for content that doesn't match any of the + known content type classes. It preserves the original content object for later + inspection or processing. + + Attributes: + content: The unknown content object + """ + + content: Any + + type: str = "unknown" + + def __init__(self, content: Any) -> None: + self.content = content + + def to_dict(self) -> dict[str, Any]: + return {"$type": self.type, "content": self.content} + + @staticmethod + def from_unknown_content(content: Any) -> DurableAgentStateUnknownContent: + return DurableAgentStateUnknownContent(content=content) + + def to_ai_content(self) -> BaseContent: + if not self.content: + raise Exception("The content is missing and cannot be converted to valid AI content.") + return BaseContent(content=self.content) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index 8df8e3f335..a79269bd4d 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -11,14 +11,30 @@ import inspect import json from collections.abc import AsyncIterable, Callable +from datetime import datetime, timezone from typing import Any, cast import azure.durable_functions as df -from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, Role, get_logger +from agent_framework import ( + AgentProtocol, + AgentRunResponse, + AgentRunResponseUpdate, + ChatMessage, + ErrorContent, + Role, + get_logger, +) from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol +from ._durable_agent_state import ( + DurableAgentState, + DurableAgentStateData, + DurableAgentStateEntry, + DurableAgentStateMessage, + DurableAgentStateRequest, + DurableAgentStateResponse, +) from ._models import AgentResponse, RunRequest -from ._state import AgentState logger = get_logger("agent_framework.azurefunctions.entities") @@ -38,11 +54,11 @@ class AgentEntity: Attributes: agent: The AgentProtocol instance - state: The AgentState managing conversation history + state: The DurableAgentState managing conversation history """ agent: AgentProtocol - state: AgentState + state: DurableAgentState def __init__( self, @@ -56,11 +72,27 @@ def __init__( callback: Optional callback invoked during streaming updates and final responses """ self.agent = agent - self.state = AgentState() + self.state = DurableAgentState() self.callback = callback logger.debug(f"[AgentEntity] Initialized with agent type: {type(agent).__name__}") + def _is_error_response(self, entry: DurableAgentStateEntry) -> bool: + """Check if a conversation history entry is an error response. + + Error responses should be kept in history for tracking but not sent to the agent + since Azure OpenAI doesn't support 'error' content type. + + Args: + entry: A conversation history entry (DurableAgentStateEntry or dict) + + Returns: + True if the entry is a response containing error content, False otherwise + """ + if isinstance(entry, DurableAgentStateResponse): + return entry.is_error + return False + async def run_agent( self, context: df.DurableEntityContext, @@ -94,26 +126,27 @@ async def run_agent( raise ValueError("RunRequest must include a thread_id") if not correlation_id: raise ValueError("RunRequest must include a correlation_id") - role = run_request.role or Role.USER response_format = run_request.response_format enable_tool_calls = run_request.enable_tool_calls - logger.debug(f"[AgentEntity.run_agent] Received message: {message}") - logger.debug(f"[AgentEntity.run_agent] Thread ID: {thread_id}") - logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlation_id}") - logger.debug(f"[AgentEntity.run_agent] Role: {role.value}") - logger.debug(f"[AgentEntity.run_agent] Enable tool calls: {enable_tool_calls}") - logger.debug(f"[AgentEntity.run_agent] Response format: {'provided' if response_format else 'none'}") + state_request = DurableAgentStateRequest.from_run_request(run_request) + self.state.data.conversation_history.append(state_request) - # Store message in history with role - self.state.add_user_message(message, role=role, correlation_id=correlation_id) - - logger.debug("[AgentEntity.run_agent] Executing agent...") + logger.debug(f"[AgentEntity.run_agent] Received Message: {state_request}") try: logger.debug("[AgentEntity.run_agent] Starting agent invocation") - run_kwargs: dict[str, Any] = {"messages": self.state.get_chat_messages()} + # Build messages from conversation history, excluding error responses + # Error responses are kept in history for tracking but not sent to the agent + chat_messages: list[ChatMessage] = [ + m.to_chat_message() + for entry in self.state.data.conversation_history + if not self._is_error_response(entry) + for m in entry.messages + ] + + run_kwargs: dict[str, Any] = {"messages": chat_messages} if not enable_tool_calls: run_kwargs["tools"] = None if response_format: @@ -133,8 +166,8 @@ async def run_agent( response_text = None structured_response = None - response_str: str | None = None + try: if response_format: try: @@ -156,18 +189,19 @@ async def run_agent( ) response_text = "Error extracting response" + state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response) + self.state.data.conversation_history.append(state_response) + agent_response = AgentResponse( response=response_text, message=str(message), thread_id=str(thread_id), status="success", - message_count=self.state.message_count, + message_count=len(self.state.data.conversation_history), structured_response=structured_response, ) result = agent_response.to_dict() - content = json.dumps(structured_response) if structured_response else (response_text or "") - self.state.add_assistant_message(content, agent_run_response, correlation_id) logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history") return result @@ -181,12 +215,28 @@ async def run_agent( logger.error(f"Error type: {type(exc).__name__}") logger.error(f"Full traceback:\n{error_traceback}") + # Create error message + error_message = DurableAgentStateMessage.from_chat_message( + ChatMessage( + role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)] + ) + ) + + # Create and store error response in conversation history + error_state_response = DurableAgentStateResponse( + correlation_id=correlation_id, + created_at=datetime.now(tz=timezone.utc), + messages=[error_message], + is_error=True, + ) + self.state.data.conversation_history.append(error_state_response) + error_response = AgentResponse( response=f"Error: {exc!s}", message=str(message), thread_id=str(thread_id), status="error", - message_count=self.state.message_count, + message_count=len(self.state.data.conversation_history), error=str(exc), error_type=type(exc).__name__, ) @@ -333,7 +383,7 @@ def _build_callback_context( def reset(self, context: df.DurableEntityContext) -> None: """Reset the entity state (clear conversation history).""" logger.debug("[AgentEntity.reset] Resetting entity state") - self.state.reset() + self.state.data = DurableAgentStateData(conversation_history=[]) logger.debug("[AgentEntity.reset] State reset complete") @@ -362,7 +412,7 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: entity = AgentEntity(agent, callback) if current_state is not None: - entity.state.restore_state(current_state) + entity.state = DurableAgentState.from_dict(current_state) logger.debug( "[entity_function] Restored entity from state (message_count: %s)", entity.state.message_count ) @@ -392,8 +442,9 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: logger.error("[entity_function] Unknown operation: %s", operation) context.set_result({"error": f"Unknown operation: {operation}"}) + logger.debug("State dict: %s", entity.state.to_dict()) context.set_state(entity.state.to_dict()) - logger.debug(f"[entity_function] Operation {operation} completed successfully") + logger.info(f"[entity_function] Operation {operation} completed successfully") except Exception as exc: import traceback diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index 015ca40754..19f175a485 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -17,6 +17,8 @@ import azure.durable_functions as df from agent_framework import AgentThread, Role +from ._constants import REQUEST_RESPONSE_FORMAT_TEXT + if TYPE_CHECKING: # pragma: no cover - type checking imports only from pydantic import BaseModel @@ -278,35 +280,43 @@ class RunRequest: Attributes: message: The message to send to the agent + request_response_format: The desired response format (e.g., "text" or "json") role: The role of the message sender (user, system, or assistant) response_format: Optional Pydantic BaseModel type describing the structured response format enable_tool_calls: Whether to enable tool calls for this request thread_id: Optional thread ID for tracking correlation_id: Optional correlation ID for tracking the response to this specific request + created_at: Optional timestamp when the request was created """ message: str + request_response_format: str role: Role = Role.USER response_format: type[BaseModel] | None = None enable_tool_calls: bool = True thread_id: str | None = None correlation_id: str | None = None + created_at: str | None = None def __init__( self, message: str, + request_response_format: str = REQUEST_RESPONSE_FORMAT_TEXT, role: Role | str | None = Role.USER, response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True, thread_id: str | None = None, correlation_id: str | None = None, + created_at: str | None = None, ) -> None: self.message = message self.role = self.coerce_role(role) self.response_format = response_format + self.request_response_format = request_response_format self.enable_tool_calls = enable_tool_calls self.thread_id = thread_id self.correlation_id = correlation_id + self.created_at = created_at @staticmethod def coerce_role(value: Role | str | None) -> Role: @@ -326,13 +336,17 @@ def to_dict(self) -> dict[str, Any]: "message": self.message, "enable_tool_calls": self.enable_tool_calls, "role": self.role.value, + "request_response_format": self.request_response_format, } if self.response_format: result["response_format"] = _serialize_response_format(self.response_format) if self.thread_id: result["thread_id"] = self.thread_id if self.correlation_id: - result["correlation_id"] = self.correlation_id + result["correlationId"] = self.correlation_id + if self.created_at: + result["created_at"] = self.created_at + return result @classmethod @@ -340,11 +354,13 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: """Create RunRequest from dictionary.""" return cls( message=data.get("message", ""), + request_response_format=data.get("request_response_format", REQUEST_RESPONSE_FORMAT_TEXT), role=cls.coerce_role(data.get("role")), response_format=_deserialize_response_format(data.get("response_format")), enable_tool_calls=data.get("enable_tool_calls", True), thread_id=data.get("thread_id"), - correlation_id=data.get("correlation_id"), + correlation_id=data.get("correlationId"), + created_at=data.get("created_at"), ) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_state.py deleted file mode 100644 index c9d54b8333..0000000000 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_state.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Agent State Management. - -This module defines the AgentState class for managing conversation state and -serializing agent framework responses. -""" - -from collections.abc import MutableMapping -from datetime import datetime, timezone -from typing import Any, cast - -from agent_framework import AgentRunResponse, ChatMessage, Role, get_logger - -logger = get_logger("agent_framework.azurefunctions.state") - - -class AgentState: - """Manages agent conversation state using agent_framework types (ChatMessage, AgentRunResponse). - - This class handles: - - Conversation history tracking using ChatMessage objects - - Agent response storage using AgentRunResponse objects with correlation IDs - - State persistence and restoration - - Message counting - """ - - def __init__(self) -> None: - """Initialize empty agent state.""" - self.conversation_history: list[ChatMessage] = [] - self.last_response: str | None = None - self.message_count: int = 0 - - def _current_timestamp(self) -> str: - """Return an ISO 8601 UTC timestamp.""" - return datetime.now(timezone.utc).isoformat() - - def add_user_message( - self, - content: str, - role: Role = Role.USER, - correlation_id: str | None = None, - ) -> None: - """Add a user message to the conversation history as a ChatMessage object. - - Args: - content: The message content - role: The message role (user, system, etc.) - correlation_id: Optional correlation identifier associated with the user message - """ - self.message_count += 1 - timestamp = self._current_timestamp() - additional_props: MutableMapping[str, Any] = {"timestamp": timestamp} - if correlation_id is not None: - additional_props["correlation_id"] = correlation_id - chat_message = ChatMessage(role=role, text=content, additional_properties=additional_props) - self.conversation_history.append(chat_message) - logger.debug(f"Added {role} ChatMessage to history (message #{self.message_count})") - - def add_assistant_message( - self, content: str, agent_response: AgentRunResponse, correlation_id: str | None = None - ) -> None: - """Add an assistant message to the conversation history with full agent response. - - Args: - content: The text content of the response - agent_response: The AgentRunResponse object from the agent framework - correlation_id: Optional correlation ID for tracking this response - """ - self.last_response = content - timestamp = self._current_timestamp() - serialized_response = self.serialize_response(agent_response) - - # Create a ChatMessage for the assistant response - # The agent_response already contains messages, but we store it as a custom ChatMessage - # with the agent_response stored in additional_properties for full metadata preservation - additional_props: dict[str, Any] = { - "agent_response": serialized_response, - "correlation_id": correlation_id, - "timestamp": timestamp, - "message_count": self.message_count, - } - chat_message = ChatMessage(role="assistant", text=content, additional_properties=additional_props) - - self.conversation_history.append(chat_message) - - logger.debug( - f"Added assistant ChatMessage to history with AgentRunResponse metadata (correlation_id: {correlation_id})" - ) - - def get_chat_messages(self) -> list[ChatMessage]: - """Return a copy of the full conversation history.""" - return list(self.conversation_history) - - def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: - """Get an agent response by correlation ID. - - Args: - correlation_id: The correlation ID to look up - - Returns: - The agent response data if found, None otherwise - """ - for message in reversed(self.conversation_history): - metadata = getattr(message, "additional_properties", {}) or {} - if metadata.get("correlation_id") == correlation_id: - return self._build_agent_response_payload(message, metadata) - - return None - - def serialize_response(self, response: AgentRunResponse) -> dict[str, Any]: - """Serialize an ``AgentRunResponse`` to a dictionary. - - Args: - response: The agent framework response object - - Returns: - Dictionary containing all response fields - """ - try: - return response.to_dict() - except Exception as exc: # pragma: no cover - defensive logging path - logger.warning(f"Error serializing response: {exc}") - return {"response": str(response), "serialization_error": str(exc)} - - def to_dict(self) -> dict[str, Any]: - """Get the current state as a dictionary for persistence. - - Returns: - Dictionary containing conversation_history (as serialized ChatMessages), - last_response, and message_count - """ - return { - "conversation_history": [msg.to_dict() for msg in self.conversation_history], - "last_response": self.last_response, - "message_count": self.message_count, - } - - def restore_state(self, state: dict[str, Any]) -> None: - """Restore state from a dictionary, reconstructing ChatMessage objects. - - Args: - state: Dictionary containing conversation_history, last_response, and message_count - """ - # Restore conversation history as ChatMessage objects - history_data = state.get("conversation_history", []) - restored_history: list[ChatMessage] = [] - for raw_message in history_data: - if isinstance(raw_message, dict): - restored_history.append(ChatMessage.from_dict(cast(dict[str, Any], raw_message))) - else: - restored_history.append(cast(ChatMessage, raw_message)) - - self.conversation_history = restored_history - - self.last_response = state.get("last_response") - self.message_count = state.get("message_count", 0) - logger.debug("Restored state: %s ChatMessages in history", len(self.conversation_history)) - - def reset(self) -> None: - """Reset the state to empty.""" - self.conversation_history = [] - self.last_response = None - self.message_count = 0 - logger.debug("State reset to empty") - - def __repr__(self) -> str: - """String representation of the state.""" - return f"AgentState(messages={self.message_count}, history_length={len(self.conversation_history)})" - - def _build_agent_response_payload(self, message: ChatMessage, metadata: dict[str, Any]) -> dict[str, Any]: - """Construct the agent response payload returned to callers.""" - return { - "content": message.text, - "agent_response": metadata.get("agent_response"), - "message_count": metadata.get("message_count", self.message_count), - "timestamp": metadata.get("timestamp"), - "correlation_id": metadata.get("correlation_id"), - } diff --git a/python/packages/azurefunctions/pyproject.toml b/python/packages/azurefunctions/pyproject.toml index 07695c24d3..052431916f 100644 --- a/python/packages/azurefunctions/pyproject.toml +++ b/python/packages/azurefunctions/pyproject.toml @@ -27,6 +27,11 @@ dependencies = [ "azure-functions-durable", ] +[dependency-groups] +dev = [ + "types-python-dateutil>=2.9.0", +] + [tool.uv] prerelease = "if-necessary-or-explicit" environments = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py index f97916825c..cd93e6a352 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py @@ -16,6 +16,8 @@ import pytest +from agent_framework_azurefunctions._constants import THREAD_ID_HEADER + from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file @@ -67,7 +69,7 @@ def test_simple_message_plain_text(self) -> None: # Agent responded with plain text when the request body was text/plain. assert response.text.strip() - assert response.headers.get("x-ms-thread-id") is not None + assert response.headers.get(THREAD_ID_HEADER) is not None def test_thread_id_in_query(self) -> None: """Test using thread_id in query parameter.""" @@ -77,7 +79,7 @@ def test_thread_id_in_query(self) -> None: assert response.status_code in [200, 202] assert response.text.strip() - assert response.headers.get("x-ms-thread-id") == "test-query-thread" + assert response.headers.get(THREAD_ID_HEADER) == "test-query-thread" def test_conversation_continuity(self) -> None: """Test conversation context is maintained across requests.""" @@ -92,7 +94,7 @@ def test_conversation_continuity(self) -> None: if response1.status_code == 200: data1 = response1.json() - assert data1["message_count"] == 1 + assert data1["message_count"] == 2 # Initial + reply # Second message in same session response2 = SampleTestHelper.post_json( @@ -100,7 +102,7 @@ def test_conversation_continuity(self) -> None: ) assert response2.status_code == 200 data2 = response2.json() - assert data2["message_count"] == 2 + assert data2["message_count"] == 4 else: # In async mode, we can't easily test message count # Just verify we can make multiple calls diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index c11bb873d6..ebf6eef3e6 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -13,7 +13,13 @@ from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER -from agent_framework_azurefunctions._entities import AgentEntity, AgentState, create_agent_entity +from agent_framework_azurefunctions._constants import ( + MIMETYPE_APPLICATION_JSON, + MIMETYPE_TEXT_PLAIN, + THREAD_ID_HEADER, +) +from agent_framework_azurefunctions._durable_agent_state import DurableAgentState +from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity TFunc = TypeVar("TFunc", bound=Callable[..., Any]) @@ -333,14 +339,14 @@ async def test_entity_run_agent_operation(self) -> None: result = await entity.run_agent( mock_context, - {"message": "Test message", "thread_id": "test-conv-123", "correlation_id": "corr-app-entity-1"}, + {"message": "Test message", "thread_id": "test-conv-123", "correlationId": "corr-app-entity-1"}, ) assert result["status"] == "success" assert result["response"] == "Test response" assert result["message"] == "Test message" assert result["thread_id"] == "test-conv-123" - assert entity.state.message_count == 1 + assert entity.state.message_count == 2 async def test_entity_stores_conversation_history(self) -> None: """Test that the entity stores conversation history.""" @@ -354,18 +360,29 @@ async def test_entity_stores_conversation_history(self) -> None: # Send first message await entity.run_agent( - mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-app-entity-2"} + mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-app-entity-2"} + ) + + # Each conversation turn creates 2 entries: request and response + history = entity.state.data.conversation_history[0].messages # Request entry + assert len(history) == 1 # Just the user message + + # Send second message + await entity.run_agent( + mock_context, {"message": "Message 2", "thread_id": "conv-2", "correlationId": "corr-app-entity-2b"} ) - history = entity.state.conversation_history - assert len(history) == 2 # User + assistant + # Now we have 4 entries total (2 requests + 2 responses) + # Access the first request entry + history2 = entity.state.data.conversation_history[2].messages # Second request entry + assert len(history2) == 1 # Just the user message user_msg = history[0] user_role = getattr(user_msg.role, "value", user_msg.role) assert user_role == "user" assert user_msg.text == "Message 1" - assistant_msg = history[1] + assistant_msg = entity.state.data.conversation_history[1].messages[0] assistant_role = getattr(assistant_msg.role, "value", assistant_msg.role) assert assistant_role == "assistant" assert assistant_msg.text == "Response 1" @@ -380,17 +397,17 @@ async def test_entity_increments_message_count(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - assert entity.state.message_count == 0 + assert len(entity.state.data.conversation_history) == 0 await entity.run_agent( - mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-app-entity-3a"} + mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-app-entity-3a"} ) - assert entity.state.message_count == 1 + assert len(entity.state.data.conversation_history) == 2 await entity.run_agent( - mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlation_id": "corr-app-entity-3b"} + mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-app-entity-3b"} ) - assert entity.state.message_count == 2 + assert len(entity.state.data.conversation_history) == 4 def test_entity_reset(self) -> None: """Test that entity reset clears state.""" @@ -398,19 +415,13 @@ def test_entity_reset(self) -> None: entity = AgentEntity(mock_agent) # Set some state - entity.state.message_count = 10 - entity.state.last_response = "Some response" - entity.state.conversation_history = [ - ChatMessage(role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"}) - ] + entity.state = DurableAgentState() # Reset mock_context = Mock() entity.reset(mock_context) - assert entity.state.message_count == 0 - assert entity.state.last_response is None - assert len(entity.state.conversation_history) == 0 + assert len(entity.state.data.conversation_history) == 0 class TestAgentEntityFactory: @@ -438,7 +449,7 @@ def test_entity_function_handles_run_agent_operation(self) -> None: mock_context.get_input.return_value = { "message": "Test message", "thread_id": "conv-123", - "correlation_id": "corr-app-factory-1", + "correlationId": "corr-app-factory-1", } mock_context.get_state.return_value = None @@ -458,9 +469,27 @@ def test_entity_function_handles_reset_operation(self) -> None: mock_context = Mock() mock_context.operation_name = "reset" mock_context.get_state.return_value = { - "message_count": 5, - "conversation_history": [{"role": "user", "content": "test"}], - "last_response": "Test", + "schemaVersion": "1.0.0", + "data": { + "conversationHistory": [ + { + "$type": "request", + "correlationId": "corr-reset-test", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [ + { + "role": "user", + "contents": [ + { + "$type": "text", + "text": "test", + } + ], + } + ], + } + ], + }, } # Execute entity function @@ -497,19 +526,53 @@ def test_entity_function_restores_state(self) -> None: # Mock context with existing state existing_state = { - "message_count": 3, - "conversation_history": [{"role": "user", "content": "msg1"}, {"role": "assistant", "content": "resp1"}], - "last_response": "resp1", + "schemaVersion": "1.0.0", + "data": { + "conversationHistory": [ + { + "$type": "request", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [ + { + "role": "user", + "contents": [ + { + "$type": "text", + "text": "msg1", + } + ], + } + ], + }, + { + "$type": "response", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:05:00Z", + "messages": [ + { + "role": "assistant", + "contents": [ + { + "$type": "text", + "text": "resp1", + } + ], + } + ], + }, + ], + }, } mock_context = Mock() mock_context.operation_name = "reset" mock_context.get_state.return_value = existing_state - with patch.object(AgentState, "restore_state") as restore_state_mock: + with patch.object(DurableAgentState, "from_dict", wraps=DurableAgentState.from_dict) as from_dict_mock: entity_function(mock_context) - restore_state_mock.assert_called_once_with(existing_state) + from_dict_mock.assert_called_once_with(existing_state) class TestErrorHandling: @@ -524,7 +587,7 @@ async def test_entity_handles_agent_error(self) -> None: mock_context = Mock() result = await entity.run_agent( - mock_context, {"message": "Test message", "thread_id": "conv-1", "correlation_id": "corr-app-error-1"} + mock_context, {"message": "Test message", "thread_id": "conv-1", "correlationId": "corr-app-error-1"} ) assert result["status"] == "error" @@ -600,7 +663,7 @@ def test_accept_header_prefers_json(self) -> None: app = self._create_app() request = Mock() - request.headers = {"accept": "application/json"} + request.headers = {"accept": MIMETYPE_APPLICATION_JSON} request.params = {} request.get_json.side_effect = ValueError("Invalid JSON") request.get_body.return_value = b"Plain text message" @@ -674,8 +737,8 @@ async def test_http_run_accepts_plain_text(self) -> None: response = await handler(request, client) assert response.status_code == 202 - assert response.mimetype == "text/plain" - assert response.headers.get("x-ms-thread-id") is not None + assert response.mimetype == MIMETYPE_TEXT_PLAIN + assert response.headers.get(THREAD_ID_HEADER) is not None assert response.get_body().decode("utf-8") == "Agent request accepted" signal_args = client.signal_entity.call_args[0] @@ -693,7 +756,7 @@ async def test_http_run_accept_header_returns_json(self) -> None: handler = self._get_run_handler(mock_agent) request = Mock() - request.headers = {WAIT_FOR_RESPONSE_HEADER: "false", "Accept": "application/json"} + request.headers = {WAIT_FOR_RESPONSE_HEADER: "false", "Accept": MIMETYPE_APPLICATION_JSON} request.params = {} request.route_params = {} request.get_json.side_effect = ValueError("Invalid JSON") @@ -704,8 +767,8 @@ async def test_http_run_accept_header_returns_json(self) -> None: response = await handler(request, client) assert response.status_code == 202 - assert response.mimetype == "application/json" - assert response.headers.get("x-ms-thread-id") is None + assert response.mimetype == MIMETYPE_APPLICATION_JSON + assert response.headers.get(THREAD_ID_HEADER) is None body = response.get_body().decode("utf-8") assert '"status": "accepted"' in body @@ -728,8 +791,8 @@ async def test_http_run_rejects_empty_message(self) -> None: response = await handler(request, client) assert response.status_code == 400 - assert response.mimetype == "text/plain" - assert response.headers.get("x-ms-thread-id") is not None + assert response.mimetype == MIMETYPE_TEXT_PLAIN + assert response.headers.get(THREAD_ID_HEADER) is not None assert response.get_body().decode("utf-8") == "Message is required" client.signal_entity.assert_not_called() diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 5b053efbf3..2f73f1daa8 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -15,14 +15,20 @@ from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, Role from pydantic import BaseModel +from agent_framework_azurefunctions._durable_agent_state import ( + DurableAgentState, + DurableAgentStateData, + DurableAgentStateMessage, + DurableAgentStateRequest, + DurableAgentStateTextContent, +) from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity from agent_framework_azurefunctions._models import RunRequest -from agent_framework_azurefunctions._state import AgentState TFunc = TypeVar("TFunc", bound=Callable[..., Any]) -def _role_value(chat_message: ChatMessage) -> str: +def _role_value(chat_message: DurableAgentStateMessage) -> str: """Helper to extract the string role from a ChatMessage.""" role = getattr(chat_message, "role", None) role_value = getattr(role, "value", role) @@ -71,9 +77,9 @@ def test_init_creates_entity(self) -> None: entity = AgentEntity(mock_agent) assert entity.agent == mock_agent - assert entity.state.conversation_history == [] - assert entity.state.last_response is None - assert entity.state.message_count == 0 + assert len(entity.state.data.conversation_history) == 0 + assert entity.state.data.extension_data is None + assert entity.state.schema_version == "1.0.0" def test_init_stores_agent_reference(self) -> None: """Test that the agent reference is stored correctly.""" @@ -112,7 +118,7 @@ async def test_run_agent_executes_agent(self) -> None: mock_context = Mock() result = await entity.run_agent( - mock_context, {"message": "Test message", "thread_id": "conv-123", "correlation_id": "corr-entity-1"} + mock_context, {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-1"} ) # Verify agent.run was called @@ -123,8 +129,8 @@ async def test_run_agent_executes_agent(self) -> None: assert len(sent_messages) == 1 sent_message = sent_messages[0] assert isinstance(sent_message, ChatMessage) - assert sent_message.text == "Test message" - assert _role_value(sent_message) == "user" + assert getattr(sent_message, "text", None) == "Test message" + assert getattr(sent_message.role, "value", sent_message.role) == "user" # Verify result assert result["status"] == "success" @@ -158,7 +164,7 @@ async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]: { "message": "Tell me something", "thread_id": "session-1", - "correlation_id": "corr-stream-1", + "correlationId": "corr-stream-1", }, ) @@ -205,7 +211,7 @@ async def test_run_agent_final_callback_without_streaming(self) -> None: { "message": "Hi", "thread_id": "session-2", - "correlation_id": "corr-final-1", + "correlationId": "corr-final-1", }, ) @@ -233,19 +239,20 @@ async def test_run_agent_updates_conversation_history(self) -> None: mock_context = Mock() await entity.run_agent( - mock_context, {"message": "User message", "thread_id": "conv-1", "correlation_id": "corr-entity-2"} + mock_context, {"message": "User message", "thread_id": "conv-1", "correlationId": "corr-entity-2"} ) - # Should have 2 entries: user message + assistant response - history = entity.state.conversation_history + # Should have 1 entry: user message + assistant response + user_history = entity.state.data.conversation_history[0].messages + assistant_history = entity.state.data.conversation_history[1].messages - assert len(history) == 2 + assert len(user_history) == 1 - user_msg = history[0] + user_msg = user_history[0] assert _role_value(user_msg) == "user" assert user_msg.text == "User message" - assistant_msg = history[1] + assistant_msg = assistant_history[0] assert _role_value(assistant_msg) == "assistant" assert assistant_msg.text == "Agent response" @@ -257,41 +264,22 @@ async def test_run_agent_increments_message_count(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - assert entity.state.message_count == 0 + assert len(entity.state.data.conversation_history) == 0 await entity.run_agent( - mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-entity-3a"} + mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-3a"} ) - assert entity.state.message_count == 1 + assert len(entity.state.data.conversation_history) == 2 await entity.run_agent( - mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlation_id": "corr-entity-3b"} + mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-3b"} ) - assert entity.state.message_count == 2 + assert len(entity.state.data.conversation_history) == 4 await entity.run_agent( - mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlation_id": "corr-entity-3c"} + mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-3c"} ) - assert entity.state.message_count == 3 - - async def test_run_agent_stores_last_response(self) -> None: - """Test that run_agent stores the last response.""" - mock_agent = Mock() - mock_agent.run = AsyncMock(return_value=_agent_response("Response 1")) - - entity = AgentEntity(mock_agent) - mock_context = Mock() - - await entity.run_agent( - mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-entity-4a"} - ) - assert entity.state.last_response == "Response 1" - - mock_agent.run = AsyncMock(return_value=_agent_response("Response 2")) - await entity.run_agent( - mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlation_id": "corr-entity-4b"} - ) - assert entity.state.last_response == "Response 2" + assert len(entity.state.data.conversation_history) == 6 async def test_run_agent_with_none_thread_id(self) -> None: """Test run_agent with a None thread identifier.""" @@ -303,7 +291,7 @@ async def test_run_agent_with_none_thread_id(self) -> None: with pytest.raises(ValueError, match="thread_id"): await entity.run_agent( - mock_context, {"message": "Message", "thread_id": None, "correlation_id": "corr-entity-5"} + mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"} ) async def test_run_agent_handles_response_without_text_attribute(self) -> None: @@ -322,7 +310,7 @@ def text(self) -> str: # type: ignore[override] mock_context = Mock() result = await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlation_id": "corr-entity-6"} + mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-6"} ) # Should handle gracefully @@ -338,7 +326,7 @@ async def test_run_agent_handles_none_response_text(self) -> None: mock_context = Mock() result = await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlation_id": "corr-entity-7"} + mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-7"} ) assert result["status"] == "success" @@ -354,18 +342,18 @@ async def test_run_agent_multiple_conversations(self) -> None: # Send multiple messages await entity.run_agent( - mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-entity-8a"} + mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-8a"} ) await entity.run_agent( - mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlation_id": "corr-entity-8b"} + mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-8b"} ) await entity.run_agent( - mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlation_id": "corr-entity-8c"} + mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-8c"} ) - history = entity.state.conversation_history + history = entity.state.data.conversation_history assert len(history) == 6 - assert entity.state.message_count == 3 + assert entity.state.message_count == 6 class TestAgentEntityReset: @@ -376,40 +364,47 @@ def test_reset_clears_conversation_history(self) -> None: mock_agent = Mock() entity = AgentEntity(mock_agent) - # Add some history - entity.state.conversation_history = [ - ChatMessage(role="user", text="msg1"), - ChatMessage(role="assistant", text="resp1"), + # Add some history with proper DurableAgentStateEntry objects + entity.state.data.conversation_history = [ + DurableAgentStateRequest( + correlation_id="test-1", + created_at=datetime.now(), + messages=[ + DurableAgentStateMessage( + role="user", + contents=[DurableAgentStateTextContent(text="msg1")], + ) + ], + ), ] mock_context = Mock() entity.reset(mock_context) - assert entity.state.conversation_history == [] + assert entity.state.data.conversation_history == [] - def test_reset_clears_last_response(self) -> None: - """Test that reset clears the last response.""" + def test_reset_with_extension_data(self) -> None: + """Test that reset works when entity has extension data.""" mock_agent = Mock() entity = AgentEntity(mock_agent) - entity.state.last_response = "Some response" + # Set up some initial state with conversation history + entity.state.data = DurableAgentStateData(conversation_history=[], extension_data={"some_key": "some_value"}) mock_context = Mock() entity.reset(mock_context) - assert entity.state.last_response is None + assert len(entity.state.data.conversation_history) == 0 def test_reset_clears_message_count(self) -> None: """Test that reset clears the message count.""" mock_agent = Mock() entity = AgentEntity(mock_agent) - entity.state.message_count = 10 - mock_context = Mock() entity.reset(mock_context) - assert entity.state.message_count == 0 + assert len(entity.state.data.conversation_history) == 0 async def test_reset_after_conversation(self) -> None: """Test reset after a full conversation.""" @@ -421,23 +416,22 @@ async def test_reset_after_conversation(self) -> None: # Have a conversation await entity.run_agent( - mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-entity-10a"} + mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-10a"} ) await entity.run_agent( - mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlation_id": "corr-entity-10b"} + mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-10b"} ) # Verify state before reset - assert entity.state.message_count == 2 - assert len(entity.state.conversation_history) == 4 + assert entity.state.message_count == 4 + assert len(entity.state.data.conversation_history) == 4 # Reset entity.reset(mock_context) # Verify state after reset assert entity.state.message_count == 0 - assert len(entity.state.conversation_history) == 0 - assert entity.state.last_response is None + assert len(entity.state.data.conversation_history) == 0 class TestCreateAgentEntity: @@ -464,7 +458,7 @@ def test_entity_function_handles_run_agent(self) -> None: mock_context.get_input.return_value = { "message": "Test message", "thread_id": "conv-123", - "correlation_id": "corr-entity-factory", + "correlationId": "corr-entity-factory", } mock_context.get_state.return_value = None @@ -485,13 +479,22 @@ def test_entity_function_handles_reset(self) -> None: mock_context = Mock() mock_context.operation_name = "reset" mock_context.get_state.return_value = { - "message_count": 5, - "conversation_history": [ - ChatMessage( - role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"} - ).to_dict() - ], - "last_response": "Test", + "schemaVersion": "1.0.0", + "data": { + "conversationHistory": [ + { + "$type": "request", + "correlationId": "test-correlation-id", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [ + { + "role": "user", + "contents": [{"$type": "text", "text": "test"}], + } + ], + } + ] + }, } # Execute @@ -505,9 +508,7 @@ def test_entity_function_handles_reset(self) -> None: # Verify state was cleared assert mock_context.set_state.called state = mock_context.set_state.call_args[0][0] - assert state["message_count"] == 0 - assert state["conversation_history"] == [] - assert state["last_response"] is None + assert state["data"]["conversationHistory"] == [] def test_entity_function_handles_unknown_operation(self) -> None: """Test that the entity function handles unknown operations.""" @@ -547,8 +548,7 @@ def test_entity_function_creates_new_entity_on_first_call(self) -> None: assert result["status"] == "reset" assert mock_context.set_state.called state = mock_context.set_state.call_args[0][0] - assert state["message_count"] == 0 - assert state["conversation_history"] == [] + assert state["data"] == {"conversationHistory": []} def test_entity_function_restores_existing_state(self) -> None: """Test that the entity function restores existing state.""" @@ -557,26 +557,53 @@ def test_entity_function_restores_existing_state(self) -> None: entity_function = create_agent_entity(mock_agent) existing_state = { - "message_count": 5, - "conversation_history": [ - ChatMessage( - role="user", text="msg1", additional_properties={"timestamp": "2024-01-01T00:00:00Z"} - ).to_dict(), - ChatMessage( - role="assistant", text="resp1", additional_properties={"timestamp": "2024-01-01T00:05:00Z"} - ).to_dict(), - ], - "last_response": "resp1", + "schemaVersion": "1.0.0", + "data": { + "conversationHistory": [ + { + "$type": "request", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [ + { + "role": "user", + "contents": [ + { + "$type": "text", + "text": "msg1", + } + ], + } + ], + }, + { + "$type": "response", + "correlationId": "corr-existing-1", + "createdAt": "2024-01-01T00:05:00Z", + "messages": [ + { + "role": "assistant", + "contents": [ + { + "$type": "text", + "text": "resp1", + } + ], + } + ], + }, + ], + }, } mock_context = Mock() mock_context.operation_name = "reset" mock_context.get_state.return_value = existing_state - with patch.object(AgentState, "restore_state") as restore_state_mock: + with patch.object(DurableAgentState, "from_dict", wraps=DurableAgentState.from_dict) as from_dict_mock: entity_function(mock_context) - restore_state_mock.assert_called_once_with(existing_state) + from_dict_mock.assert_called_once_with(existing_state) class TestErrorHandling: @@ -591,7 +618,7 @@ async def test_run_agent_handles_agent_exception(self) -> None: mock_context = Mock() result = await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlation_id": "corr-entity-error-1"} + mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-1"} ) assert result["status"] == "error" @@ -608,7 +635,7 @@ async def test_run_agent_handles_value_error(self) -> None: mock_context = Mock() result = await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlation_id": "corr-entity-error-2"} + mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-2"} ) assert result["status"] == "error" @@ -624,7 +651,7 @@ async def test_run_agent_handles_timeout_error(self) -> None: mock_context = Mock() result = await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlation_id": "corr-entity-error-3"} + mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-3"} ) assert result["status"] == "error" @@ -659,7 +686,7 @@ async def test_run_agent_preserves_message_on_error(self) -> None: result = await entity.run_agent( mock_context, - {"message": "Test message", "thread_id": "conv-123", "correlation_id": "corr-entity-error-4"}, + {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-error-4"}, ) # Even on error, message info should be preserved @@ -680,15 +707,15 @@ async def test_conversation_history_has_timestamps(self) -> None: mock_context = Mock() await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlation_id": "corr-entity-history-1"} + mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-history-1"} ) # Check both user and assistant messages have timestamps - for entry in entity.state.conversation_history: - timestamp = entry.additional_properties.get("timestamp") + for entry in entity.state.data.conversation_history: + timestamp = entry.created_at assert timestamp is not None # Verify timestamp is in ISO format - datetime.fromisoformat(timestamp) + datetime.fromisoformat(str(timestamp)) async def test_conversation_history_ordering(self) -> None: """Test that conversation history maintains the correct order.""" @@ -701,29 +728,30 @@ async def test_conversation_history_ordering(self) -> None: mock_agent.run = AsyncMock(return_value=_agent_response("Response 1")) await entity.run_agent( mock_context, - {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-entity-history-2a"}, + {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-2a"}, ) mock_agent.run = AsyncMock(return_value=_agent_response("Response 2")) await entity.run_agent( mock_context, - {"message": "Message 2", "thread_id": "conv-1", "correlation_id": "corr-entity-history-2b"}, + {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-2b"}, ) mock_agent.run = AsyncMock(return_value=_agent_response("Response 3")) await entity.run_agent( mock_context, - {"message": "Message 3", "thread_id": "conv-1", "correlation_id": "corr-entity-history-2c"}, + {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-history-2c"}, ) # Verify order - history = entity.state.conversation_history - assert history[0].text == "Message 1" - assert history[1].text == "Response 1" - assert history[2].text == "Message 2" - assert history[3].text == "Response 2" - assert history[4].text == "Message 3" - assert history[5].text == "Response 3" + history = entity.state.data.conversation_history + # Each conversation turn creates 2 entries: request and response + assert history[0].messages[0].text == "Message 1" # Request 1 + assert history[1].messages[0].text == "Response 1" # Response 1 + assert history[2].messages[0].text == "Message 2" # Request 2 + assert history[3].messages[0].text == "Response 2" # Response 2 + assert history[4].messages[0].text == "Message 3" # Request 3 + assert history[5].messages[0].text == "Response 3" # Response 3 async def test_conversation_history_role_alternation(self) -> None: """Test that conversation history alternates between user and assistant roles.""" @@ -735,19 +763,20 @@ async def test_conversation_history_role_alternation(self) -> None: await entity.run_agent( mock_context, - {"message": "Message 1", "thread_id": "conv-1", "correlation_id": "corr-entity-history-3a"}, + {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-3a"}, ) await entity.run_agent( mock_context, - {"message": "Message 2", "thread_id": "conv-1", "correlation_id": "corr-entity-history-3b"}, + {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-3b"}, ) # Check role alternation - history = entity.state.conversation_history - assert _role_value(history[0]) == "user" - assert _role_value(history[1]) == "assistant" - assert _role_value(history[2]) == "user" - assert _role_value(history[3]) == "assistant" + history = entity.state.data.conversation_history + # Each conversation turn creates 2 entries: request and response + assert history[0].messages[0].role == "user" # Request 1 + assert history[1].messages[0].role == "assistant" # Response 1 + assert history[2].messages[0].role == "user" # Request 2 + assert history[3].messages[0].role == "assistant" # Response 2 class TestRunRequestSupport: @@ -789,7 +818,7 @@ async def test_run_agent_with_dict_request(self) -> None: "thread_id": "conv-456", "role": "system", "enable_tool_calls": False, - "correlation_id": "corr-runreq-2", + "correlationId": "corr-runreq-2", } result = await entity.run_agent(mock_context, request_dict) @@ -828,9 +857,9 @@ async def test_run_agent_stores_role_in_history(self) -> None: await entity.run_agent(mock_context, request) # Check that system role was stored - history = entity.state.conversation_history - assert _role_value(history[0]) == "system" - assert history[0].text == "System message" + history = entity.state.data.conversation_history + assert history[0].messages[0].role == "system" + assert history[0].messages[0].text == "System message" async def test_run_agent_with_response_format(self) -> None: """Test run_agent with a JSON response format.""" @@ -887,7 +916,7 @@ async def test_entity_function_with_run_request_dict(self) -> None: "thread_id": "conv-789", "role": "user", "enable_tool_calls": True, - "correlation_id": "corr-runreq-6", + "correlationId": "corr-runreq-6", } mock_context.get_state.return_value = None diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py index bb802956ff..5b803ead13 100644 --- a/python/packages/azurefunctions/tests/test_models.py +++ b/python/packages/azurefunctions/tests/test_models.py @@ -295,32 +295,32 @@ def test_round_trip_with_pydantic_response_format(self) -> None: restored = RunRequest.from_dict(data) assert restored.response_format is ModuleStructuredResponse - def test_init_with_correlation_id(self) -> None: - """Test RunRequest initialization with correlation_id.""" + def test_init_with_correlationId(self) -> None: + """Test RunRequest initialization with correlationId.""" request = RunRequest(message="Test message", thread_id="thread-corr-init", correlation_id="corr-123") assert request.message == "Test message" assert request.correlation_id == "corr-123" - def test_to_dict_with_correlation_id(self) -> None: - """Test to_dict includes correlation_id.""" + def test_to_dict_with_correlationId(self) -> None: + """Test to_dict includes correlationId.""" request = RunRequest(message="Test", thread_id="thread-corr-to-dict", correlation_id="corr-456") data = request.to_dict() assert data["message"] == "Test" - assert data["correlation_id"] == "corr-456" + assert data["correlationId"] == "corr-456" - def test_from_dict_with_correlation_id(self) -> None: - """Test from_dict with correlation_id.""" - data = {"message": "Test", "correlation_id": "corr-789", "thread_id": "thread-corr-from-dict"} + def test_from_dict_with_correlationId(self) -> None: + """Test from_dict with correlationId.""" + data = {"message": "Test", "correlationId": "corr-789", "thread_id": "thread-corr-from-dict"} request = RunRequest.from_dict(data) assert request.message == "Test" assert request.correlation_id == "corr-789" assert request.thread_id == "thread-corr-from-dict" - def test_round_trip_with_correlation_id(self) -> None: - """Test round-trip to_dict and from_dict with correlation_id.""" + def test_round_trip_with_correlationId(self) -> None: + """Test round-trip to_dict and from_dict with correlationId.""" original = RunRequest( message="Test message", thread_id="thread-123", diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index c30f9f0bec..93201a64e9 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -136,8 +136,8 @@ def test_run_creates_entity_call(self) -> None: assert operation == "run_agent" assert request["message"] == "Test message" assert request["enable_tool_calls"] is True - assert "correlation_id" in request - assert request["correlation_id"] == "correlation-guid" + assert "correlationId" in request + assert request["correlationId"] == "correlation-guid" assert "thread_id" in request assert request["thread_id"] == "thread-guid" @@ -145,7 +145,7 @@ def test_run_without_thread(self) -> None: """Test that run() works without explicit thread (creates unique session key).""" mock_context = Mock() mock_context.instance_id = "test-instance-002" - # Two calls to new_uuid: one for session_key, one for correlation_id + # Two calls to new_uuid: one for session_key, one for correlationId mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"]) mock_task = Mock() @@ -164,7 +164,7 @@ def test_run_without_thread(self) -> None: entity_id = call_args[0][0] assert entity_id.name == "dafx-TestAgent" assert entity_id.key == "auto-generated-guid" - # Should be called twice: once for session_key, once for correlation_id + # Should be called twice: once for session_key, once for correlationId assert mock_context.new_uuid.call_count == 2 def test_run_with_response_format(self) -> None: @@ -307,8 +307,8 @@ def test_sequential_agent_calls_simulation(self) -> None: mock_context.instance_id = "test-orchestration-001" # new_uuid will be called 3 times: # 1. thread creation - # 2. correlation_id for first call - # 3. correlation_id for second call + # 2. correlationId for first call + # 3. correlationId for second call mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"]) # Track entity calls diff --git a/python/packages/azurefunctions/tests/test_state.py b/python/packages/azurefunctions/tests/test_state.py deleted file mode 100644 index 52aa7458f0..0000000000 --- a/python/packages/azurefunctions/tests/test_state.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Unit tests for AgentState correlation ID tracking.""" - -from unittest.mock import Mock - -import pytest -from agent_framework import AgentRunResponse - -from agent_framework_azurefunctions._state import AgentState - - -class TestAgentStateCorrelationId: - """Test suite for AgentState correlation ID tracking.""" - - def _create_mock_response(self, text: str = "Response") -> Mock: - """Create a mock AgentRunResponse with the provided text.""" - mock_response = Mock(spec=AgentRunResponse) - mock_response.to_dict.return_value = {"text": text, "messages": []} - return mock_response - - def test_add_assistant_message_with_correlation_id(self) -> None: - state = AgentState() - state.add_user_message("Hello", correlation_id="corr-123-request") - state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") - message_metadata = state.conversation_history[-1].additional_properties or {} - assert message_metadata.get("correlation_id") == "corr-123" - - response_data = state.try_get_agent_response("corr-123") - assert response_data is not None - assert response_data["content"] == "Response" - assert response_data["agent_response"] == {"text": "Response", "messages": []} - - def test_try_get_agent_response_returns_response(self) -> None: - state = AgentState() - state.add_user_message("Hello", correlation_id="corr-200-request") - state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-456") - - response_data = state.try_get_agent_response("corr-456") - - assert response_data is not None - assert response_data["content"] == "Response" - - def test_try_get_agent_response_returns_none_for_missing_id(self) -> None: - state = AgentState() - state.add_user_message("Hello", correlation_id="corr-300-request") - state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") - - assert state.try_get_agent_response("non-existent") is None - - def test_multiple_responses_tracked_separately(self) -> None: - state = AgentState() - - for index in range(3): - state.add_user_message(f"Message {index}", correlation_id=f"corr-{index}-request") - state.add_assistant_message( - f"Response {index}", - self._create_mock_response(text=f"Response {index}"), - correlation_id=f"corr-{index}", - ) - - for index in range(3): - payload = state.try_get_agent_response(f"corr-{index}") - assert payload is not None - assert payload["content"] == f"Response {index}" - - def test_add_assistant_message_without_correlation_id(self) -> None: - state = AgentState() - state.add_user_message("Hello", correlation_id="corr-400-request") - state.add_assistant_message("Response", self._create_mock_response()) - - assert state.try_get_agent_response("missing") is None - assert state.last_response == "Response" - - def test_to_dict_does_not_duplicate_agent_responses(self) -> None: - state = AgentState() - state.add_user_message("Hello", correlation_id="corr-500-request") - state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") - - state_snapshot = state.to_dict() - - assert "agent_responses" not in state_snapshot - metadata = state_snapshot["conversation_history"][-1]["additional_properties"] - assert metadata["correlation_id"] == "corr-123" - - def test_restore_state_preserves_agent_response_lookup(self) -> None: - state = AgentState() - state.add_user_message("Hello", correlation_id="corr-600-request") - state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") - - restored_state = AgentState() - restored_state.restore_state(state.to_dict()) - - payload = restored_state.try_get_agent_response("corr-123") - assert payload is not None - assert payload["content"] == "Response" - - def test_reset_clears_conversation_history(self) -> None: - state = AgentState() - state.add_user_message("Hello", correlation_id="corr-700-request") - state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") - - state.reset() - - assert len(state.conversation_history) == 0 - assert state.try_get_agent_response("corr-123") is None - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/uv.lock b/python/uv.lock index 4104a624f9..cbbd28046a 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -237,6 +237,11 @@ dependencies = [ { name = "azure-functions-durable", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] +[package.dev-dependencies] +dev = [ + { name = "types-python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + [package.metadata] requires-dist = [ { name = "agent-framework-core", editable = "packages/core" }, @@ -244,6 +249,9 @@ requires-dist = [ { name = "azure-functions-durable" }, ] +[package.metadata.requires-dev] +dev = [{ name = "types-python-dateutil", specifier = ">=2.9.0" }] + [[package]] name = "agent-framework-chatkit" version = "1.0.0b251114" @@ -1749,7 +1757,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'win32')" }, + { name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -6168,6 +6176,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/dd/5cbf31f402f1cc0ab087c94d4669cfa55bd1e818688b910631e131d74e75/typer_slim-0.20.0-py3-none-any.whl", hash = "sha256:f42a9b7571a12b97dddf364745d29f12221865acef7a2680065f9bb29c7dc89d", size = 47087, upload-time = "2025-10-20T17:03:44.546Z" }, ] +[[package]] +name = "types-python-dateutil" +version = "2.9.0.20251115" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/36/06d01fb52c0d57e9ad0c237654990920fa41195e4b3d640830dabf9eeb2f/types_python_dateutil-2.9.0.20251115.tar.gz", hash = "sha256:8a47f2c3920f52a994056b8786309b43143faa5a64d4cbb2722d6addabdf1a58", size = 16363, upload-time = "2025-11-15T03:00:13.717Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/0b/56961d3ba517ed0df9b3a27bfda6514f3d01b28d499d1bce9068cfe4edd1/types_python_dateutil-2.9.0.20251115-py3-none-any.whl", hash = "sha256:9cf9c1c582019753b8639a081deefd7e044b9fa36bd8217f565c6c4e36ee0624", size = 18251, upload-time = "2025-11-15T03:00:12.317Z" }, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20250915"