diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index e83a9244a7..7d8ebe0264 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -562,6 +562,7 @@ async def mcp_tool_handler(context: str, client: df.DurableOrchestrationClient) logger.debug("[MCP Tool Trigger] Received invocation for agent: %s", agent_name) return await self._handle_mcp_tool_invocation(agent_name=agent_name, context=context, client=client) + _ = mcp_tool_handler logger.debug("[AgentFunctionApp] Registered MCP tool trigger for agent: %s", agent_name) async def _handle_mcp_tool_invocation( @@ -587,15 +588,17 @@ async def _handle_mcp_tool_invocation( # Parse JSON context string try: - parsed_context = json.loads(context) + parsed_context: Any = json.loads(context) except json.JSONDecodeError as e: raise ValueError(f"Invalid MCP context format: {e}") from e + parsed_context = cast(Mapping[str, Any], parsed_context) if isinstance(parsed_context, dict) else {} + # Extract arguments from MCP context - arguments = parsed_context.get("arguments", {}) if isinstance(parsed_context, dict) else {} + arguments: dict[str, Any] = parsed_context.get("arguments", {}) # Validate required 'query' argument - query = arguments.get("query") + query: Any = arguments.get("query") if not query or not isinstance(query, str): raise ValueError("MCP Tool invocation is missing required 'query' argument of type string.") @@ -951,10 +954,9 @@ def _extract_normalized_headers(self, req: func.HttpRequest) -> dict[str, str]: """Create a lowercase header mapping from the incoming request.""" headers: dict[str, str] = {} raw_headers = req.headers - if isinstance(raw_headers, Mapping): - for key, value in raw_headers.items(): - if value is not None: - headers[str(key).lower()] = str(value) + for key, value in cast(Mapping[str, str], raw_headers).items(): + headers[key.lower()] = value + return headers @staticmethod diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index 0d9166373f..ffb71d2367 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -32,7 +32,7 @@ import json from datetime import datetime, timezone from enum import Enum -from typing import Any +from typing import Any, cast from agent_framework import ( AgentRunResponse, @@ -74,6 +74,130 @@ def _parse_created_at(value: Any) -> datetime: return datetime.now(tz=timezone.utc) +def _parse_messages(data: dict[str, Any]) -> list[DurableAgentStateMessage]: + """Parse messages from a dictionary, converting dicts to DurableAgentStateMessage objects. + + Args: + data: Dictionary containing a 'messages' key with a list of message data + + Returns: + List of DurableAgentStateMessage objects + """ + messages: list[DurableAgentStateMessage] = [] + raw_messages: list[Any] = data.get("messages", []) + for raw_msg in raw_messages: + if isinstance(raw_msg, dict): + messages.append(DurableAgentStateMessage.from_dict(cast(dict[str, Any], raw_msg))) + elif isinstance(raw_msg, DurableAgentStateMessage): + messages.append(raw_msg) + return messages + + +def _parse_history_entries(data_dict: dict[str, Any]) -> list[DurableAgentStateEntry]: + """Parse conversation history entries from a dictionary. + + Args: + data_dict: Dictionary containing a 'conversationHistory' key with a list of entry data + + Returns: + List of DurableAgentStateEntry objects (requests and responses) + """ + history_data: list[Any] = data_dict.get("conversationHistory", []) + deserialized_history: list[DurableAgentStateEntry] = [] + for raw_entry in history_data: + if isinstance(raw_entry, dict): + entry_dict = cast(dict[str, Any], raw_entry) + entry_type = entry_dict.get("$type") or entry_dict.get("json_type") + if entry_type == DurableAgentStateEntryJsonType.RESPONSE: + deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) + elif entry_type == DurableAgentStateEntryJsonType.REQUEST: + deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict)) + else: + deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict)) + elif isinstance(raw_entry, DurableAgentStateEntry): + deserialized_history.append(raw_entry) + return deserialized_history + + +def _parse_contents(data: dict[str, Any]) -> list[DurableAgentStateContent]: + """Parse content items from a dictionary. + + Args: + data: Dictionary containing a 'contents' key with a list of content data + + Returns: + List of DurableAgentStateContent objects + """ + contents: list[DurableAgentStateContent] = [] + raw_contents: list[Any] = data.get("contents", []) + for raw_content in raw_contents: + if isinstance(raw_content, dict): + content_dict = cast(dict[str, Any], raw_content) + content_type: str | None = content_dict.get("$type") + if content_type == DurableAgentStateTextContent.type: + contents.append(DurableAgentStateTextContent(text=content_dict.get("text"))) + elif content_type == DurableAgentStateDataContent.type: + contents.append( + DurableAgentStateDataContent( + uri=str(content_dict.get("uri", "")), + media_type=content_dict.get("mediaType"), + ) + ) + elif content_type == DurableAgentStateErrorContent.type: + contents.append( + DurableAgentStateErrorContent( + message=content_dict.get("message"), + error_code=content_dict.get("errorCode"), + details=content_dict.get("details"), + ) + ) + elif content_type == DurableAgentStateFunctionCallContent.type: + contents.append( + DurableAgentStateFunctionCallContent( + call_id=str(content_dict.get("callId", "")), + name=str(content_dict.get("name", "")), + arguments=content_dict.get("arguments", {}), + ) + ) + elif content_type == DurableAgentStateFunctionResultContent.type: + contents.append( + DurableAgentStateFunctionResultContent( + call_id=str(content_dict.get("callId", "")), + result=content_dict.get("result"), + ) + ) + elif content_type == DurableAgentStateHostedFileContent.type: + contents.append(DurableAgentStateHostedFileContent(file_id=str(content_dict.get("fileId", "")))) + elif content_type == DurableAgentStateHostedVectorStoreContent.type: + contents.append( + DurableAgentStateHostedVectorStoreContent( + vector_store_id=str(content_dict.get("vectorStoreId", "")) + ) + ) + elif content_type == DurableAgentStateTextReasoningContent.type: + contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text"))) + elif content_type == DurableAgentStateUriContent.type: + contents.append( + DurableAgentStateUriContent( + uri=str(content_dict.get("uri", "")), + media_type=str(content_dict.get("mediaType", "")), + ) + ) + elif content_type == DurableAgentStateUsageContent.type: + usage_data = content_dict.get("usage") + if usage_data and isinstance(usage_data, dict): + contents.append( + DurableAgentStateUsageContent( + usage=DurableAgentStateUsage.from_dict(cast(dict[str, Any], usage_data)) + ) + ) + elif content_type == DurableAgentStateUnknownContent.type: + contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content", {}))) + elif isinstance(raw_content, DurableAgentStateContent): + contents.append(raw_content) + return contents + + class DurableAgentStateContent: """Base class for all content types in durable agent state messages. @@ -197,25 +321,8 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data_dict: dict[str, Any]) -> DurableAgentStateData: - # Restore the conversation history - deserialize entries from dicts to objects - history_data = data_dict.get("conversationHistory", []) - deserialized_history: list[DurableAgentStateEntry] = [] - for entry_dict in history_data: - if isinstance(entry_dict, dict): - # Deserialize based on $type discriminator - entry_type = entry_dict.get("$type") or entry_dict.get("json_type") - if entry_type == DurableAgentStateEntryJsonType.RESPONSE: - deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) - elif entry_type == DurableAgentStateEntryJsonType.REQUEST: - deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict)) - else: - deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict)) - else: - # Already an object - deserialized_history.append(entry_dict) - return cls( - conversation_history=deserialized_history, + conversation_history=_parse_history_entries(data_dict), extension_data=data_dict.get("extensionData"), ) @@ -227,7 +334,7 @@ class DurableAgentState: in Azure Durable Entities. It maintains the conversation history as a sequence of request and response entries, each with their messages, timestamps, and metadata. - The state follows a versioned schema (currently 1.0.0) that defines the structure for: + The state follows a versioned schema (see SCHEMA_VERSION class constant) that defines the structure for: - Request entries: User/system messages with optional response format specifications - Response entries: Assistant messages with token usage information - Messages: Individual chat messages with role, content items, and timestamps @@ -235,7 +342,7 @@ class DurableAgentState: State is serialized to JSON with this structure: { - "schemaVersion": "1.0.0", + "schemaVersion": "", "data": { "conversationHistory": [ {"$type": "request", "correlationId": "...", "createdAt": "...", "messages": [...]}, @@ -246,17 +353,20 @@ class DurableAgentState: Attributes: data: Container for conversation history and optional extension data - schema_version: Schema version string (defaults to "1.0.0") + schema_version: Schema version string (defaults to SCHEMA_VERSION) """ + # Durable Agent Schema version + SCHEMA_VERSION: str = "1.1.0" + data: DurableAgentStateData - schema_version: str = "1.0.0" + schema_version: str = SCHEMA_VERSION - def __init__(self, schema_version: str = "1.0.0"): + def __init__(self, schema_version: str = SCHEMA_VERSION): """Initialize a new durable agent state. Args: - schema_version: Schema version to use (defaults to "1.0.0") + schema_version: Schema version to use (defaults to SCHEMA_VERSION) """ self.data = DurableAgentStateData() self.schema_version = schema_version @@ -283,7 +393,7 @@ def from_dict(cls, state: dict[str, Any]) -> DurableAgentState: logger.warning("Resetting state as it is incompatible with the current schema, all history will be lost") return cls() - instance = cls(schema_version=state.get("schemaVersion", "1.0.0")) + instance = cls(schema_version=state.get("schemaVersion", DurableAgentState.SCHEMA_VERSION)) instance.data = DurableAgentStateData.from_dict(state.get("data", {})) return instance @@ -325,7 +435,7 @@ def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse): # Found the entry, extract response data # Get the text content from assistant messages only - content = "\n".join(message.text for message in entry.messages if message.text is not None) + content = "\n".join(message.text for message in entry.messages if message.text) return {"content": content, "message_count": self.message_count, "correlationId": correlation_id} return None @@ -388,28 +498,17 @@ def __init__( self.extension_data = extension_data def to_dict(self) -> dict[str, Any]: - # Ensure createdAt is never null - created_at_value = self.created_at - if created_at_value is None: - created_at_value = datetime.now(tz=timezone.utc) - return { "$type": self.json_type, "correlationId": self.correlation_id, - "createdAt": created_at_value.isoformat() if isinstance(created_at_value, datetime) else created_at_value, + "createdAt": self.created_at.isoformat(), "messages": [m.to_dict() for m in self.messages], } @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateEntry: created_at = _parse_created_at(data.get("created_at")) - - messages = [] - for msg_dict in data.get("messages", []): - if isinstance(msg_dict, dict): - messages.append(DurableAgentStateMessage.from_dict(msg_dict)) - else: - messages.append(msg_dict) + messages = _parse_messages(data) return cls( json_type=DurableAgentStateEntryJsonType(data.get("$type", "entry")), @@ -430,6 +529,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry): Attributes: response_type: Expected response type ("text" or "json") response_schema: JSON schema for structured responses (when response_type is "json") + orchestration_id: ID of the orchestration that initiated this request (if any) correlationId: Unique identifier linking this request to its response created_at: Timestamp when the request was created messages: List of messages included in this request @@ -438,6 +538,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry): response_type: str | None = None response_schema: dict[str, Any] | None = None + orchestration_id: str | None = None def __init__( self, @@ -447,6 +548,7 @@ def __init__( extension_data: dict[str, Any] | None = None, response_type: str | None = None, response_schema: dict[str, Any] | None = None, + orchestration_id: str | None = None, ) -> None: super().__init__( json_type=DurableAgentStateEntryJsonType.REQUEST, @@ -457,9 +559,12 @@ def __init__( ) self.response_type = response_type self.response_schema = response_schema + self.orchestration_id = orchestration_id def to_dict(self) -> dict[str, Any]: data = super().to_dict() + if self.orchestration_id is not None: + data["orchestrationId"] = self.orchestration_id if self.response_type is not None: data["responseType"] = self.response_type if self.response_schema is not None: @@ -469,13 +574,7 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: created_at = _parse_created_at(data.get("created_at")) - - messages = [] - for msg_dict in data.get("messages", []): - if isinstance(msg_dict, dict): - messages.append(DurableAgentStateMessage.from_dict(msg_dict)) - else: - messages.append(msg_dict) + messages = _parse_messages(data) return cls( correlation_id=data.get("correlationId", ""), @@ -484,6 +583,7 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: extension_data=data.get("extensionData"), response_type=data.get("responseType"), response_schema=data.get("responseSchema"), + orchestration_id=data.get("orchestrationId"), ) @staticmethod @@ -495,6 +595,7 @@ def from_run_request(request: RunRequest) -> DurableAgentStateRequest: created_at=datetime.now(tz=timezone.utc), response_type=request.request_response_format, response_schema=serialize_response_format(request.response_format), + orchestration_id=request.orchestration_id, ) @@ -545,20 +646,12 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateResponse: created_at = _parse_created_at(data.get("created_at")) - - messages = [] - for msg_dict in data.get("messages", []): - if isinstance(msg_dict, dict): - messages.append(DurableAgentStateMessage.from_dict(msg_dict)) - else: - messages.append(msg_dict) + messages = _parse_messages(data) usage_dict = data.get("usage") - usage = None + usage: DurableAgentStateUsage | None = None if usage_dict and isinstance(usage_dict, dict): - usage = DurableAgentStateUsage.from_dict(usage_dict) - elif usage_dict: - usage = usage_dict + usage = DurableAgentStateUsage.from_dict(cast(dict[str, Any], usage_dict)) return cls( correlation_id=data.get("correlationId", ""), @@ -639,68 +732,9 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateMessage: - contents: list[DurableAgentStateContent] = [] - for content_dict in data.get("contents", []): - if isinstance(content_dict, dict): - content_type = content_dict.get("$type") - if content_type == DurableAgentStateTextContent.type: - contents.append(DurableAgentStateTextContent(text=content_dict.get("text"))) - elif content_type == DurableAgentStateDataContent.type: - contents.append( - DurableAgentStateDataContent( - uri=content_dict.get("uri", ""), media_type=content_dict.get("mediaType") - ) - ) - elif content_type == DurableAgentStateErrorContent.type: - contents.append( - DurableAgentStateErrorContent( - message=content_dict.get("message"), - error_code=content_dict.get("errorCode"), - details=content_dict.get("details"), - ) - ) - elif content_type == DurableAgentStateFunctionCallContent.type: - contents.append( - DurableAgentStateFunctionCallContent( - call_id=content_dict.get("callId", ""), - name=content_dict.get("name", ""), - arguments=content_dict.get("arguments", {}), - ) - ) - elif content_type == DurableAgentStateFunctionResultContent.type: - contents.append( - DurableAgentStateFunctionResultContent( - call_id=content_dict.get("callId", ""), result=content_dict.get("result") - ) - ) - elif content_type == DurableAgentStateHostedFileContent.type: - contents.append(DurableAgentStateHostedFileContent(file_id=content_dict.get("fileId", ""))) - elif content_type == DurableAgentStateHostedVectorStoreContent.type: - contents.append( - DurableAgentStateHostedVectorStoreContent(vector_store_id=content_dict.get("vectorStoreId", "")) - ) - elif content_type == DurableAgentStateTextReasoningContent.type: - contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text"))) - elif content_type == DurableAgentStateUriContent.type: - contents.append( - DurableAgentStateUriContent( - uri=content_dict.get("uri", ""), media_type=content_dict.get("mediaType", "") - ) - ) - elif content_type == DurableAgentStateUsageContent.type: - usage_data = content_dict.get("usage") - if usage_data and isinstance(usage_data, dict): - contents.append( - DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_dict(usage_data)) - ) - elif content_type == DurableAgentStateUnknownContent.type: - contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content", {}))) - else: - contents.append(content_dict) # type: ignore - return cls( role=data.get("role", ""), - contents=contents, + contents=_parse_contents(data), author_name=data.get("authorName"), created_at=_parse_created_at(data.get("createdAt")), extension_data=data.get("extensionData"), @@ -709,7 +743,7 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateMessage: @property def text(self) -> str: """Extract text from the contents list.""" - text_parts = [] + text_parts: list[str] = [] for content in self.contents: if isinstance(content, DurableAgentStateTextContent): text_parts.append(content.text or "") diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index e9ed6f7cad..2ab9667575 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -287,6 +287,7 @@ class RunRequest: thread_id: Optional thread ID for tracking correlation_id: Optional correlation ID for tracking the response to this specific request created_at: Optional timestamp when the request was created + orchestration_id: Optional ID of the orchestration that initiated this request """ message: str @@ -297,6 +298,7 @@ class RunRequest: thread_id: str | None = None correlation_id: str | None = None created_at: str | None = None + orchestration_id: str | None = None def __init__( self, @@ -308,6 +310,7 @@ def __init__( thread_id: str | None = None, correlation_id: str | None = None, created_at: str | None = None, + orchestration_id: str | None = None, ) -> None: self.message = message self.role = self.coerce_role(role) @@ -317,6 +320,7 @@ def __init__( self.thread_id = thread_id self.correlation_id = correlation_id self.created_at = created_at + self.orchestration_id = orchestration_id @staticmethod def coerce_role(value: Role | str | None) -> Role: @@ -346,6 +350,8 @@ def to_dict(self) -> dict[str, Any]: result["correlationId"] = self.correlation_id if self.created_at: result["created_at"] = self.created_at + if self.orchestration_id: + result["orchestrationId"] = self.orchestration_id return result @@ -361,4 +367,5 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: thread_id=data.get("thread_id"), correlation_id=data.get("correlationId"), created_at=data.get("created_at"), + orchestration_id=data.get("orchestrationId"), ) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index fb6613b85b..0f7e786778 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -272,12 +272,14 @@ def my_orchestration(context): ) # Prepare the request using RunRequest model + # Include the orchestration's instance_id so it can be stored in the agent's entity state run_request = RunRequest( message=message_str, enable_tool_calls=enable_tool_calls, correlation_id=correlation_id, thread_id=session_id.key, response_format=response_format, + orchestration_id=self.context.instance_id, ) logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100]) diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index d47fae4a1f..1c3f5168a5 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -79,7 +79,7 @@ def test_init_creates_entity(self) -> None: assert entity.agent == mock_agent assert len(entity.state.data.conversation_history) == 0 assert entity.state.data.extension_data is None - assert entity.state.schema_version == "1.0.0" + assert entity.state.schema_version == DurableAgentState.SCHEMA_VERSION def test_init_stores_agent_reference(self) -> None: """Test that the agent reference is stored correctly.""" @@ -124,8 +124,7 @@ async def test_run_agent_executes_agent(self) -> None: # Verify agent.run was called mock_agent.run.assert_called_once() _, kwargs = mock_agent.run.call_args - sent_messages = kwargs.get("messages") - assert isinstance(sent_messages, list) + sent_messages: list[Any] = kwargs.get("messages") assert len(sent_messages) == 1 sent_message = sent_messages[0] assert isinstance(sent_message, ChatMessage) @@ -910,5 +909,98 @@ async def test_entity_function_with_run_request_dict(self) -> None: assert text_found, f"Response text not found in message: {message}" +class TestDurableAgentStateRequestOrchestrationId: + """Test suite for DurableAgentStateRequest orchestration_id field.""" + + def test_request_with_orchestration_id(self) -> None: + """Test creating a request with an orchestration_id.""" + request = DurableAgentStateRequest( + correlation_id="corr-123", + created_at=datetime.now(), + messages=[ + DurableAgentStateMessage( + role="user", + contents=[DurableAgentStateTextContent(text="test")], + ) + ], + orchestration_id="orch-456", + ) + + assert request.orchestration_id == "orch-456" + + def test_request_to_dict_includes_orchestration_id(self) -> None: + """Test that to_dict includes orchestrationId when set.""" + request = DurableAgentStateRequest( + correlation_id="corr-123", + created_at=datetime.now(), + messages=[ + DurableAgentStateMessage( + role="user", + contents=[DurableAgentStateTextContent(text="test")], + ) + ], + orchestration_id="orch-789", + ) + + data = request.to_dict() + + assert "orchestrationId" in data + assert data["orchestrationId"] == "orch-789" + + def test_request_to_dict_excludes_orchestration_id_when_none(self) -> None: + """Test that to_dict excludes orchestrationId when not set.""" + request = DurableAgentStateRequest( + correlation_id="corr-123", + created_at=datetime.now(), + messages=[ + DurableAgentStateMessage( + role="user", + contents=[DurableAgentStateTextContent(text="test")], + ) + ], + ) + + data = request.to_dict() + + assert "orchestrationId" not in data + + def test_request_from_dict_with_orchestration_id(self) -> None: + """Test from_dict correctly parses orchestrationId.""" + data = { + "$type": "request", + "correlationId": "corr-123", + "createdAt": "2024-01-01T00:00:00Z", + "messages": [{"role": "user", "contents": [{"$type": "text", "text": "test"}]}], + "orchestrationId": "orch-from-dict", + } + + request = DurableAgentStateRequest.from_dict(data) + + assert request.orchestration_id == "orch-from-dict" + + def test_request_from_run_request_with_orchestration_id(self) -> None: + """Test from_run_request correctly transfers orchestration_id.""" + run_request = RunRequest( + message="test message", + correlation_id="corr-run", + orchestration_id="orch-from-run-request", + ) + + durable_request = DurableAgentStateRequest.from_run_request(run_request) + + assert durable_request.orchestration_id == "orch-from-run-request" + + def test_request_from_run_request_without_orchestration_id(self) -> None: + """Test from_run_request correctly handles missing orchestration_id.""" + run_request = RunRequest( + message="test message", + correlation_id="corr-run", + ) + + durable_request = DurableAgentStateRequest.from_run_request(run_request) + + assert durable_request.orchestration_id is None + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py index b4341e8c45..74efa9c166 100644 --- a/python/packages/azurefunctions/tests/test_models.py +++ b/python/packages/azurefunctions/tests/test_models.py @@ -336,6 +336,71 @@ def test_round_trip_with_correlationId(self) -> None: assert restored.correlation_id == original.correlation_id assert restored.thread_id == original.thread_id + def test_init_with_orchestration_id(self) -> None: + """Test RunRequest initialization with orchestration_id.""" + request = RunRequest( + message="Test message", + thread_id="thread-orch-init", + orchestration_id="orch-123", + ) + + assert request.message == "Test message" + assert request.orchestration_id == "orch-123" + + def test_to_dict_with_orchestration_id(self) -> None: + """Test to_dict includes orchestrationId.""" + request = RunRequest( + message="Test", + thread_id="thread-orch-to-dict", + orchestration_id="orch-456", + ) + data = request.to_dict() + + assert data["message"] == "Test" + assert data["orchestrationId"] == "orch-456" + + def test_to_dict_excludes_orchestration_id_when_none(self) -> None: + """Test to_dict excludes orchestrationId when not set.""" + request = RunRequest( + message="Test", + thread_id="thread-orch-none", + ) + data = request.to_dict() + + assert "orchestrationId" not in data + + def test_from_dict_with_orchestration_id(self) -> None: + """Test from_dict with orchestrationId.""" + data = { + "message": "Test", + "orchestrationId": "orch-789", + "thread_id": "thread-orch-from-dict", + } + request = RunRequest.from_dict(data) + + assert request.message == "Test" + assert request.orchestration_id == "orch-789" + assert request.thread_id == "thread-orch-from-dict" + + def test_round_trip_with_orchestration_id(self) -> None: + """Test round-trip to_dict and from_dict with orchestration_id.""" + original = RunRequest( + message="Test message", + thread_id="thread-123", + role=Role.SYSTEM, + correlation_id="corr-123", + orchestration_id="orch-123", + ) + + data = original.to_dict() + restored = RunRequest.from_dict(data) + + assert restored.message == original.message + assert restored.role == original.role + assert restored.correlation_id == original.correlation_id + assert restored.orchestration_id == original.orchestration_id + assert restored.thread_id == original.thread_id + class TestModelIntegration: """Test suite for integration between models.""" diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index c65724c160..0f845d4105 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -302,6 +302,28 @@ def test_run_creates_entity_call(self) -> None: assert request["correlationId"] == "correlation-guid" assert "thread_id" in request assert request["thread_id"] == "thread-guid" + # Verify orchestration ID is set from context.instance_id + assert "orchestrationId" in request + assert request["orchestrationId"] == "test-instance-001" + + def test_run_sets_orchestration_id(self) -> None: + """Test that run() sets the orchestration_id from context.instance_id.""" + mock_context = Mock() + mock_context.instance_id = "my-orchestration-123" + mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) + + entity_task = _create_entity_task() + mock_context.call_entity = Mock(return_value=entity_task) + + agent = DurableAIAgent(mock_context, "TestAgent") + thread = agent.get_new_thread() + + agent.run(messages="Test", thread=thread) + + call_args = mock_context.call_entity.call_args + request = call_args[0][2] + + assert request["orchestrationId"] == "my-orchestration-123" def test_run_without_thread(self) -> None: """Test that run() works without explicit thread (creates unique session key).""" diff --git a/python/uv.lock b/python/uv.lock index 3f269ea78a..6308986fc2 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1806,7 +1806,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" }, + { name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [