diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index 73695e61f2..0d9166373f 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -53,7 +53,7 @@ ) from dateutil import parser as date_parser -from ._models import RunRequest, _serialize_response_format +from ._models import RunRequest, serialize_response_format logger = get_logger("agent_framework.azurefunctions.durable_agent_state") @@ -494,7 +494,7 @@ def from_run_request(request: RunRequest) -> DurableAgentStateRequest: messages=[DurableAgentStateMessage.from_run_request(request)], created_at=datetime.now(tz=timezone.utc), response_type=request.request_response_format, - response_schema=_serialize_response_format(request.response_format), + response_schema=serialize_response_format(request.response_format), ) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index a79269bd4d..45872ce1a1 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -9,9 +9,7 @@ import asyncio import inspect -import json from collections.abc import AsyncIterable, Callable -from datetime import datetime, timezone from typing import Any, cast import azure.durable_functions as df @@ -30,11 +28,10 @@ DurableAgentState, DurableAgentStateData, DurableAgentStateEntry, - DurableAgentStateMessage, DurableAgentStateRequest, DurableAgentStateResponse, ) -from ._models import AgentResponse, RunRequest +from ._models import RunRequest logger = get_logger("agent_framework.azurefunctions.entities") @@ -97,7 +94,7 @@ async def run_agent( self, context: df.DurableEntityContext, request: RunRequest | dict[str, Any] | str, - ) -> dict[str, Any]: + ) -> AgentRunResponse: """Execute the agent with a message directly in the entity. Args: @@ -105,13 +102,8 @@ async def run_agent( request: RunRequest object, dict, or string message (for backward compatibility) Returns: - Dict with status information and response (serialized AgentResponse) - - Note: - The agent returns an AgentRunResponse object which is stored in state. - This method extracts the text/structured response and returns an AgentResponse dict. + AgentRunResponse enriched with execution metadata. """ - # Convert string or dict to RunRequest if isinstance(request, str): run_request = RunRequest(message=request, role=Role.USER) elif isinstance(request, dict): @@ -135,8 +127,6 @@ async def run_agent( logger.debug(f"[AgentEntity.run_agent] Received Message: {state_request}") try: - logger.debug("[AgentEntity.run_agent] Starting agent invocation") - # Build messages from conversation history, excluding error responses # Error responses are kept in history for tracking but not sent to the agent chat_messages: list[ChatMessage] = [ @@ -164,83 +154,39 @@ async def run_agent( type(agent_run_response).__name__, ) - response_text = None - structured_response = None - response_str: str | None = None - try: - if response_format: - try: - response_str = agent_run_response.text - structured_response = json.loads(response_str) - logger.debug("Parsed structured JSON response") - except json.JSONDecodeError as decode_error: - logger.warning(f"Failed to parse JSON response: {decode_error}") - response_text = response_str - else: - raw_text = agent_run_response.text - response_text = raw_text if raw_text else "No response" - preview = response_text - logger.debug(f"Response: {preview[:100]}..." if len(preview) > 100 else f"Response: {preview}") + response_text = agent_run_response.text if agent_run_response.text else "No response" + logger.debug(f"Response: {response_text[:100]}...") except Exception as extraction_error: logger.error( - f"Error extracting response: {extraction_error}", + "Error extracting response text: %s", + extraction_error, exc_info=True, ) - response_text = "Error extracting response" state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response) self.state.data.conversation_history.append(state_response) - agent_response = AgentResponse( - response=response_text, - message=str(message), - thread_id=str(thread_id), - status="success", - message_count=len(self.state.data.conversation_history), - structured_response=structured_response, - ) - result = agent_response.to_dict() - logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history") - return result + return agent_run_response except Exception as exc: - import traceback - - error_traceback = traceback.format_exc() - logger.error("[AgentEntity.run_agent] Agent execution failed") - logger.error(f"Error: {exc!s}") - logger.error(f"Error type: {type(exc).__name__}") - logger.error(f"Full traceback:\n{error_traceback}") + logger.exception("[AgentEntity.run_agent] Agent execution failed.") # Create error message - error_message = DurableAgentStateMessage.from_chat_message( - ChatMessage( - role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)] - ) + error_message = ChatMessage( + role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)] ) + error_response = AgentRunResponse(messages=[error_message]) + # Create and store error response in conversation history - error_state_response = DurableAgentStateResponse( - correlation_id=correlation_id, - created_at=datetime.now(tz=timezone.utc), - messages=[error_message], - is_error=True, - ) + error_state_response = DurableAgentStateResponse.from_run_response(correlation_id, error_response) + error_state_response.is_error = True self.state.data.conversation_history.append(error_state_response) - error_response = AgentResponse( - response=f"Error: {exc!s}", - message=str(message), - thread_id=str(thread_id), - status="error", - message_count=len(self.state.data.conversation_history), - error=str(exc), - error_type=type(exc).__name__, - ) - return error_response.to_dict() + return error_response async def _invoke_agent( self, @@ -432,7 +378,7 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: request = "" if input_data is None else str(cast(object, input_data)) result = await entity.run_agent(context, request) - context.set_result(result) + context.set_result(result.to_dict()) elif operation == "reset": entity.reset(context) @@ -442,15 +388,13 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: logger.error("[entity_function] Unknown operation: %s", operation) context.set_result({"error": f"Unknown operation: {operation}"}) - logger.debug("State dict: %s", entity.state.to_dict()) - context.set_state(entity.state.to_dict()) + serialized_state = entity.state.to_dict() + logger.debug("State dict: %s", serialized_state) + context.set_state(serialized_state) logger.info(f"[entity_function] Operation {operation} completed successfully") except Exception as exc: - import traceback - - logger.error("[entity_function] Error in entity: %s", exc) - logger.error(f"[entity_function] Traceback:\n{traceback.format_exc()}") + logger.exception("[entity_function] Error executing entity operation %s", exc) context.set_result({"error": str(exc), "status": "error"}) def entity_function(context: df.DurableEntityContext) -> None: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index 19f175a485..e9ed6f7cad 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -213,7 +213,7 @@ async def deserialize( return thread -def _serialize_response_format(response_format: type[BaseModel] | None) -> Any: +def serialize_response_format(response_format: type[BaseModel] | None) -> Any: """Serialize response format for transport across durable function boundaries.""" if response_format is None: return None @@ -339,7 +339,7 @@ def to_dict(self) -> dict[str, Any]: "request_response_format": self.request_response_format, } if self.response_format: - result["response_format"] = _serialize_response_format(self.response_format) + result["response_format"] = serialize_response_format(self.response_format) if self.thread_id: result["thread_id"] = self.thread_id if self.correlation_id: @@ -362,50 +362,3 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: correlation_id=data.get("correlationId"), created_at=data.get("created_at"), ) - - -@dataclass -class AgentResponse: - """Response from agent execution. - - Attributes: - response: The agent's text response (or None for structured responses) - message: The original message sent to the agent - thread_id: The thread identifier - status: Status of the execution (success, error, etc.) - message_count: Number of messages in the conversation - error: Error message if status is error - error_type: Type of error if status is error - structured_response: Structured response if response_format was provided - """ - - response: str | None - message: str - thread_id: str | None - status: str - message_count: int = 0 - error: str | None = None - error_type: str | None = None - structured_response: dict[str, Any] | None = None - - def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for JSON serialization.""" - result: dict[str, Any] = { - "message": self.message, - "thread_id": self.thread_id, - "status": self.status, - "message_count": self.message_count, - } - - # Add response or structured_response based on what's available - if self.structured_response is not None: - result["structured_response"] = self.structured_response - elif self.response is not None: - result["response"] = self.response - - if self.error: - result["error"] = self.error - if self.error_type: - result["error_type"] = self.error_type - - return result diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index 2fd4522964..fb6613b85b 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -6,21 +6,148 @@ """ import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from typing import TYPE_CHECKING, Any, TypeAlias, cast -from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage, get_logger +from agent_framework import ( + AgentProtocol, + AgentRunResponse, + AgentRunResponseUpdate, + AgentThread, + ChatMessage, + get_logger, +) +from azure.durable_functions.models import TaskBase +from azure.durable_functions.models.Task import CompoundTask, TaskState +from pydantic import BaseModel from ._models import AgentSessionId, DurableAgentThread, RunRequest logger = get_logger("agent_framework.azurefunctions.orchestration") +CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None + if TYPE_CHECKING: - from azure.durable_functions import DurableOrchestrationContext as _DurableOrchestrationContext + from azure.durable_functions import DurableOrchestrationContext + + class _TypedCompoundTask(CompoundTask): # type: ignore[misc] + _first_error: Any + + def __init__( + self, + tasks: list[TaskBase], + compound_action_constructor: CompoundActionConstructor = None, + ) -> None: ... - AgentOrchestrationContextType: TypeAlias = _DurableOrchestrationContext + AgentOrchestrationContextType: TypeAlias = DurableOrchestrationContext else: AgentOrchestrationContextType = Any + _TypedCompoundTask = CompoundTask + + +class AgentTask(_TypedCompoundTask): + """A custom Task that wraps entity calls and provides typed AgentRunResponse results. + + This task wraps the underlying entity call task and intercepts its completion + to convert the raw result into a typed AgentRunResponse object. + """ + + def __init__( + self, + entity_task: TaskBase, + response_format: type[BaseModel] | None, + correlation_id: str, + ): + """Initialize the AgentTask. + + Args: + entity_task: The underlying entity call task + response_format: Optional Pydantic model for response parsing + correlation_id: Correlation ID for logging + """ + super().__init__([entity_task]) + self._response_format = response_format + self._correlation_id = correlation_id + + # Override action_repr to expose the inner task's action directly + # This ensures compatibility with ReplaySchema V3 which expects Action objects. + self.action_repr = entity_task.action_repr + + # Also copy the task ID to match the entity task's identity + self.id = entity_task.id + + def try_set_value(self, child: TaskBase) -> None: + """Transition the AgentTask to a terminal state and set its value to `AgentRunResponse`. + + Parameters + ---------- + child : TaskBase + The entity call task that just completed + """ + if child.state is TaskState.SUCCEEDED: + # Delegate to parent class for standard completion logic + if len(self.pending_tasks) == 0: + # Transform the raw result before setting it + raw_result = child.result + logger.debug( + "[AgentTask] Converting raw result for correlation_id %s", + self._correlation_id, + ) + + try: + response = self._load_agent_response(raw_result) + + if self._response_format is not None: + self._ensure_response_format( + self._response_format, + self._correlation_id, + response, + ) + + # Set the typed AgentRunResponse as this task's result + self.set_value(is_error=False, value=response) + except Exception as e: + logger.exception( + "[AgentTask] Failed to convert result for correlation_id: %s", + self._correlation_id, + ) + self.set_value(is_error=True, value=e) + else: + # If error not handled by the parent, set it explicitly. + if self._first_error is None: + self._first_error = child.result + self.set_value(is_error=True, value=self._first_error) + + def _load_agent_response(self, agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse: + """Convert raw payloads into AgentRunResponse instance.""" + if agent_response is None: + raise ValueError("agent_response cannot be None") + + logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response)) + + if isinstance(agent_response, AgentRunResponse): + return agent_response + if isinstance(agent_response, dict): + logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict") + return AgentRunResponse.from_dict(agent_response) + + raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}") + + def _ensure_response_format( + self, + response_format: type[BaseModel] | None, + correlation_id: str, + response: AgentRunResponse, + ) -> None: + """Ensure the AgentRunResponse value is parsed into the expected response_format.""" + if response_format is not None and not isinstance(response.value, response_format): + response.try_parse_value(response_format) + + logger.debug( + "[DurableAIAgent] Loaded AgentRunResponse.value for correlation_id %s with type: %s", + correlation_id, + type(response.value).__name__, + ) class DurableAIAgent(AgentProtocol): @@ -59,7 +186,7 @@ def __init__(self, context: AgentOrchestrationContextType, agent_name: str): self._name = agent_name self._display_name = agent_name self._description = f"Durable agent proxy for {agent_name}" - logger.debug(f"[DurableAIAgent] Initialized for agent: {agent_name}") + logger.debug("[DurableAIAgent] Initialized for agent: %s", agent_name) @property def id(self) -> str: @@ -81,38 +208,45 @@ def description(self) -> str | None: """Get the description of the agent.""" return self._description - def run( + # We return an AgentTask here which is a TaskBase subclass. + # This is an intentional deviation from AgentProtocol which defines run() as async. + # The AgentTask can be yielded in Durable Functions orchestrations and will provide + # a typed AgentRunResponse result. + def run( # type: ignore[override] self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, **kwargs: Any, - ) -> Any: # TODO(msft-team): Add a wrapper to respond correctly with `AgentRunResponse` - """Execute the agent with messages and return a Task for orchestrations. + ) -> AgentTask: + """Execute the agent with messages and return an AgentTask for orchestrations. - This method implements AgentProtocol and returns a Task that can be yielded - in Durable Functions orchestrations. + This method implements AgentProtocol and returns an AgentTask (subclass of TaskBase) + that can be yielded in Durable Functions orchestrations. The task's result will be + a typed AgentRunResponse. Args: messages: The message(s) to send to the agent thread: Optional agent thread for conversation context - **kwargs: Additional arguments (enable_tool_calls, response_format, etc.) + response_format: Optional Pydantic model for response parsing + **kwargs: Additional arguments (enable_tool_calls) Returns: - Task that will resolve to the agent response + An AgentTask that resolves to an AgentRunResponse when yielded Example: @app.orchestration_trigger(context_name="context") def my_orchestration(context): agent = app.get_agent(context, "MyAgent") thread = agent.get_new_thread() - result = yield agent.run("Hello", thread=thread) + response = yield agent.run("Hello", thread=thread) + # response is typed as AgentRunResponse """ message_str = self._normalize_messages(messages) # Extract optional parameters from kwargs enable_tool_calls = kwargs.get("enable_tool_calls", True) - response_format = kwargs.get("response_format") # Get the session ID for the entity if isinstance(thread, DurableAgentThread) and thread.session_id is not None: @@ -122,7 +256,7 @@ def my_orchestration(context): # This ensures each call gets its own conversation context session_key = str(self.context.new_uuid()) session_id = AgentSessionId(name=self.agent_name, key=session_key) - logger.warning(f"[DurableAIAgent] No thread provided, created unique session_id: {session_id}") + logger.warning("[DurableAIAgent] No thread provided, created unique session_id: %s", session_id) # Create entity ID from session ID entity_id = session_id.to_entity_id() @@ -130,6 +264,12 @@ def my_orchestration(context): # Generate a deterministic correlation ID for this call # This is required by the entity and must be unique per call correlation_id = str(self.context.new_uuid()) + logger.debug( + "[DurableAIAgent] Using correlation_id: %s for entity_id: %s for session_id: %s", + correlation_id, + entity_id, + session_id, + ) # Prepare the request using RunRequest model run_request = RunRequest( @@ -140,11 +280,24 @@ def my_orchestration(context): response_format=response_format, ) - logger.debug(f"[DurableAIAgent] Calling entity {entity_id} with message: {message_str[:100]}...") + logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100]) + + # Call the entity to get the underlying task + entity_task = self.context.call_entity(entity_id, "run_agent", run_request.to_dict()) + + # Wrap it in an AgentTask that will convert the result to AgentRunResponse + agent_task = AgentTask( + entity_task=entity_task, + response_format=response_format, + correlation_id=correlation_id, + ) + + logger.debug( + "[DurableAIAgent] Created AgentTask for correlation_id %s", + correlation_id, + ) - # Call the entity and return the Task directly - # The orchestration will yield this Task - return self.context.call_entity(entity_id, "run_agent", run_request.to_dict()) + return agent_task def run_stream( self, @@ -179,7 +332,7 @@ def get_new_thread(self, **kwargs: Any) -> AgentThread: thread = DurableAgentThread.from_session_id(session_id, **kwargs) - logger.debug(f"[DurableAIAgent] Created new thread with session_id: {session_id}") + logger.debug("[DurableAIAgent] Created new thread with session_id: %s", session_id) return thread def _messages_to_string(self, messages: list[ChatMessage]) -> str: diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index ebf6eef3e6..a6961195a1 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -9,7 +9,7 @@ import azure.durable_functions as df import azure.functions as func import pytest -from agent_framework import AgentRunResponse, ChatMessage +from agent_framework import AgentRunResponse, ChatMessage, ErrorContent from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER @@ -342,10 +342,8 @@ async def test_entity_run_agent_operation(self) -> None: {"message": "Test message", "thread_id": "test-conv-123", "correlationId": "corr-app-entity-1"}, ) - assert result["status"] == "success" - assert result["response"] == "Test response" - assert result["message"] == "Test message" - assert result["thread_id"] == "test-conv-123" + assert isinstance(result, AgentRunResponse) + assert result.text == "Test response" assert entity.state.message_count == 2 async def test_entity_stores_conversation_history(self) -> None: @@ -590,10 +588,12 @@ async def test_entity_handles_agent_error(self) -> None: mock_context, {"message": "Test message", "thread_id": "conv-1", "correlationId": "corr-app-error-1"} ) - assert result["status"] == "error" - assert "error" in result - assert "Agent error" in result["error"] - assert result["error_type"] == "Exception" + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 1 + content = result.messages[0].contents[0] + assert isinstance(content, ErrorContent) + assert "Agent error" in (content.message or "") + assert content.error_code == "Exception" def test_entity_function_handles_exception(self) -> None: """Test that the entity function handles exceptions gracefully.""" diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 2f73f1daa8..d47fae4a1f 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, Role +from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ErrorContent, Role from pydantic import BaseModel from agent_framework_azurefunctions._durable_agent_state import ( @@ -133,10 +133,8 @@ async def test_run_agent_executes_agent(self) -> None: assert getattr(sent_message.role, "value", sent_message.role) == "user" # Verify result - assert result["status"] == "success" - assert result["response"] == "Test response" - assert result["message"] == "Test message" - assert result["thread_id"] == "conv-123" + assert isinstance(result, AgentRunResponse) + assert result.text == "Test response" async def test_run_agent_streaming_callbacks_invoked(self) -> None: """Ensure streaming updates trigger callbacks and run() is not used.""" @@ -168,8 +166,8 @@ async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]: }, ) - assert result["status"] == "success" - assert "Hello" in result.get("response", "") + assert isinstance(result, AgentRunResponse) + assert "Hello" in result.text assert callback.stream_mock.await_count == len(updates) assert callback.response_mock.await_count == 1 mock_agent.run.assert_not_called() @@ -215,8 +213,8 @@ async def test_run_agent_final_callback_without_streaming(self) -> None: }, ) - assert result["status"] == "success" - assert result.get("response") == "Final response" + assert isinstance(result, AgentRunResponse) + assert result.text == "Final response" assert callback.stream_mock.await_count == 0 assert callback.response_mock.await_count == 1 @@ -294,44 +292,6 @@ async def test_run_agent_with_none_thread_id(self) -> None: mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"} ) - async def test_run_agent_handles_response_without_text_attribute(self) -> None: - """Test that run_agent handles responses without a text attribute.""" - mock_agent = Mock() - - class NoTextResponse(AgentRunResponse): - @property - def text(self) -> str: # type: ignore[override] - raise AttributeError("text attribute missing") - - mock_response = NoTextResponse(messages=[ChatMessage(role="assistant", text="ignored")]) - mock_agent.run = AsyncMock(return_value=mock_response) - - entity = AgentEntity(mock_agent) - mock_context = Mock() - - result = await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-6"} - ) - - # Should handle gracefully - assert result["status"] == "success" - assert result["response"] == "Error extracting response" - - async def test_run_agent_handles_none_response_text(self) -> None: - """Test that run_agent handles responses with None text.""" - mock_agent = Mock() - mock_agent.run = AsyncMock(return_value=_agent_response(None)) - - entity = AgentEntity(mock_agent) - mock_context = Mock() - - result = await entity.run_agent( - mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-7"} - ) - - assert result["status"] == "success" - assert result["response"] == "No response" - async def test_run_agent_multiple_conversations(self) -> None: """Test that run_agent maintains history across multiple messages.""" mock_agent = Mock() @@ -621,10 +581,12 @@ async def test_run_agent_handles_agent_exception(self) -> None: mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-1"} ) - assert result["status"] == "error" - assert "error" in result - assert "Agent failed" in result["error"] - assert result["error_type"] == "Exception" + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 1 + content = result.messages[0].contents[0] + assert isinstance(content, ErrorContent) + assert "Agent failed" in (content.message or "") + assert content.error_code == "Exception" async def test_run_agent_handles_value_error(self) -> None: """Test that run_agent handles ValueError instances.""" @@ -638,9 +600,12 @@ async def test_run_agent_handles_value_error(self) -> None: mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-2"} ) - assert result["status"] == "error" - assert result["error_type"] == "ValueError" - assert "Invalid input" in result["error"] + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 1 + content = result.messages[0].contents[0] + assert isinstance(content, ErrorContent) + assert content.error_code == "ValueError" + assert "Invalid input" in str(content.message) async def test_run_agent_handles_timeout_error(self) -> None: """Test that run_agent handles TimeoutError instances.""" @@ -654,8 +619,11 @@ async def test_run_agent_handles_timeout_error(self) -> None: mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-3"} ) - assert result["status"] == "error" - assert result["error_type"] == "TimeoutError" + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 1 + content = result.messages[0].contents[0] + assert isinstance(content, ErrorContent) + assert content.error_code == "TimeoutError" def test_entity_function_handles_exception_in_operation(self) -> None: """Test that the entity function handles exceptions gracefully.""" @@ -690,9 +658,10 @@ async def test_run_agent_preserves_message_on_error(self) -> None: ) # Even on error, message info should be preserved - assert result["message"] == "Test message" - assert result["thread_id"] == "conv-123" - assert result["status"] == "error" + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 1 + content = result.messages[0].contents[0] + assert isinstance(content, ErrorContent) class TestConversationHistory: @@ -800,10 +769,8 @@ async def test_run_agent_with_run_request_object(self) -> None: result = await entity.run_agent(mock_context, request) - assert result["status"] == "success" - assert result["response"] == "Response" - assert result["message"] == "Test message" - assert result["thread_id"] == "conv-123" + assert isinstance(result, AgentRunResponse) + assert result.text == "Response" async def test_run_agent_with_dict_request(self) -> None: """Test run_agent with a dictionary request.""" @@ -823,9 +790,8 @@ async def test_run_agent_with_dict_request(self) -> None: result = await entity.run_agent(mock_context, request_dict) - assert result["status"] == "success" - assert result["message"] == "Test message" - assert result["thread_id"] == "conv-456" + assert isinstance(result, AgentRunResponse) + assert result.text == "Response" async def test_run_agent_with_string_raises_without_correlation(self) -> None: """Test that run_agent rejects legacy string input without correlation ID.""" @@ -879,10 +845,9 @@ async def test_run_agent_with_response_format(self) -> None: result = await entity.run_agent(mock_context, request) - assert result["status"] == "success" - # Should have structured_response - if "structured_response" in result: - assert result["structured_response"]["answer"] == 42 + assert isinstance(result, AgentRunResponse) + assert result.text == '{"answer": 42}' + assert result.value is None async def test_run_agent_disable_tool_calls(self) -> None: """Test run_agent with tool calls disabled.""" @@ -898,7 +863,7 @@ async def test_run_agent_disable_tool_calls(self) -> None: result = await entity.run_agent(mock_context, request) - assert result["status"] == "success" + assert isinstance(result, AgentRunResponse) # Agent should have been called (tool disabling is framework-dependent) mock_agent.run.assert_called_once() @@ -925,8 +890,24 @@ async def test_entity_function_with_run_request_dict(self) -> None: # Verify result was set assert mock_context.set_result.called result = mock_context.set_result.call_args[0][0] - assert result["status"] == "success" - assert result["message"] == "Test message" + assert isinstance(result, dict) + + # Check if messages are present + assert "messages" in result + assert len(result["messages"]) > 0 + message = result["messages"][0] + + # Check for text in various possible locations + text_found = False + if "text" in message and message["text"] == "Response": + text_found = True + elif "contents" in message: + for content in message["contents"]: + if isinstance(content, dict) and content.get("text") == "Response": + text_found = True + break + + assert text_found, f"Response text not found in message: {message}" if __name__ == "__main__": diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py index 5b803ead13..b4341e8c45 100644 --- a/python/packages/azurefunctions/tests/test_models.py +++ b/python/packages/azurefunctions/tests/test_models.py @@ -7,7 +7,7 @@ from agent_framework import Role from pydantic import BaseModel -from agent_framework_azurefunctions._models import AgentResponse, AgentSessionId, RunRequest +from agent_framework_azurefunctions._models import AgentSessionId, RunRequest class ModuleStructuredResponse(BaseModel): @@ -337,107 +337,6 @@ def test_round_trip_with_correlationId(self) -> None: assert restored.thread_id == original.thread_id -class TestAgentResponse: - """Test suite for AgentResponse.""" - - def test_init_with_required_fields(self) -> None: - """Test AgentResponse initialization with required fields.""" - response = AgentResponse( - response="Test response", message="Test message", thread_id="thread-123", status="success" - ) - - assert response.response == "Test response" - assert response.message == "Test message" - assert response.thread_id == "thread-123" - assert response.status == "success" - assert response.message_count == 0 - assert response.error is None - assert response.error_type is None - assert response.structured_response is None - - def test_init_with_all_fields(self) -> None: - """Test AgentResponse initialization with all fields.""" - structured = {"answer": "42"} - response = AgentResponse( - response=None, - message="What is the answer?", - thread_id="thread-456", - status="success", - message_count=5, - error=None, - error_type=None, - structured_response=structured, - ) - - assert response.response is None - assert response.structured_response == structured - assert response.message_count == 5 - - def test_to_dict_with_text_response(self) -> None: - """Test to_dict with text response.""" - response = AgentResponse( - response="Text response", message="Message", thread_id="thread-1", status="success", message_count=3 - ) - data = response.to_dict() - - assert data["response"] == "Text response" - assert data["message"] == "Message" - assert data["thread_id"] == "thread-1" - assert data["status"] == "success" - assert data["message_count"] == 3 - assert "structured_response" not in data - assert "error" not in data - assert "error_type" not in data - - def test_to_dict_with_structured_response(self) -> None: - """Test to_dict with structured response.""" - structured = {"answer": 42, "confidence": 0.95} - response = AgentResponse( - response=None, - message="Question", - thread_id="thread-2", - status="success", - structured_response=structured, - ) - data = response.to_dict() - - assert data["structured_response"] == structured - assert "response" not in data - - def test_to_dict_with_error(self) -> None: - """Test to_dict with error.""" - response = AgentResponse( - response=None, - message="Failed message", - thread_id="thread-3", - status="error", - error="Something went wrong", - error_type="ValueError", - ) - data = response.to_dict() - - assert data["status"] == "error" - assert data["error"] == "Something went wrong" - assert data["error_type"] == "ValueError" - - def test_to_dict_prefers_structured_over_text(self) -> None: - """Test to_dict prefers structured_response over response.""" - structured = {"result": "structured"} - response = AgentResponse( - response="Text response", - message="Message", - thread_id="thread-4", - status="success", - structured_response=structured, - ) - data = response.to_dict() - - assert "structured_response" in data - assert data["structured_response"] == structured - # Text response should not be included when structured is present - assert "response" not in data - - class TestModelIntegration: """Test suite for integration between models.""" @@ -450,21 +349,6 @@ def test_run_request_with_session_id(self) -> None: assert request.thread_id == str(session_id) assert request.thread_id.startswith("@AgentEntity@") - def test_response_from_run_request(self) -> None: - """Test creating AgentResponse from RunRequest.""" - request = RunRequest(message="What is 2+2?", thread_id="thread-123", role=Role.USER) - - response = AgentResponse( - response="4", - message=request.message, - thread_id=request.thread_id, - status="success", - message_count=1, - ) - - assert response.message == request.message - assert response.thread_id == request.thread_id - if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index 93201a64e9..c65724c160 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -6,10 +6,12 @@ from unittest.mock import Mock import pytest -from agent_framework import AgentThread +from agent_framework import AgentRunResponse, AgentThread, ChatMessage +from azure.durable_functions.models.Task import TaskBase, TaskState from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread +from agent_framework_azurefunctions._orchestration import AgentTask def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp: @@ -21,6 +23,169 @@ def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp: return app +class _FakeTask(TaskBase): + """Concrete TaskBase for testing AgentTask wiring.""" + + def __init__(self, task_id: int = 1): + super().__init__(task_id, []) + self._set_is_scheduled(False) + self.action_repr = [] + self.state = TaskState.RUNNING + + +def _create_entity_task(task_id: int = 1) -> TaskBase: + """Create a minimal TaskBase instance for AgentTask tests.""" + return _FakeTask(task_id) + + +class TestAgentResponseHelpers: + """Tests for helper utilities that prepare AgentRunResponse values.""" + + @staticmethod + def _create_agent_task() -> AgentTask: + entity_task = _create_entity_task() + return AgentTask(entity_task, None, "correlation-id") + + def test_load_agent_response_from_instance(self) -> None: + task = self._create_agent_task() + response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')]) + + loaded = task._load_agent_response(response) + + assert loaded is response + assert loaded.value is None + + def test_load_agent_response_from_serialized(self) -> None: + task = self._create_agent_task() + serialized = AgentRunResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict() + serialized["value"] = {"answer": 42} + + loaded = task._load_agent_response(serialized) + + assert loaded is not None + assert loaded.value == {"answer": 42} + loaded_dict = loaded.to_dict() + assert loaded_dict["type"] == "agent_run_response" + + def test_load_agent_response_rejects_none(self) -> None: + task = self._create_agent_task() + + with pytest.raises(ValueError): + task._load_agent_response(None) + + def test_load_agent_response_rejects_unsupported_type(self) -> None: + task = self._create_agent_task() + + with pytest.raises(TypeError, match="Unsupported type"): + task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type] + + def test_try_set_value_success(self) -> None: + """Test try_set_value correctly processes successful task completion.""" + entity_task = _create_entity_task() + task = AgentTask(entity_task, None, "correlation-id") + + # Simulate successful entity task completion + entity_task.state = TaskState.SUCCEEDED + entity_task.result = AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")]).to_dict() + + # Clear pending_tasks to simulate that parent has processed the child + task.pending_tasks.clear() + + # Call try_set_value + task.try_set_value(entity_task) + + # Verify task completed successfully with AgentRunResponse + assert task.state == TaskState.SUCCEEDED + assert isinstance(task.result, AgentRunResponse) + assert task.result.text == "Test response" + + def test_try_set_value_failure(self) -> None: + """Test try_set_value correctly handles failed task completion.""" + entity_task = _create_entity_task() + task = AgentTask(entity_task, None, "correlation-id") + + # Simulate failed entity task + entity_task.state = TaskState.FAILED + entity_task.result = Exception("Entity call failed") + + # Call try_set_value + task.try_set_value(entity_task) + + # Verify task failed with the error + assert task.state == TaskState.FAILED + assert isinstance(task.result, Exception) + assert str(task.result) == "Entity call failed" + + def test_try_set_value_with_response_format(self) -> None: + """Test try_set_value parses structured output when response_format is provided.""" + from pydantic import BaseModel + + class TestSchema(BaseModel): + answer: str + + entity_task = _create_entity_task() + task = AgentTask(entity_task, TestSchema, "correlation-id") + + # Simulate successful entity task with JSON response + entity_task.state = TaskState.SUCCEEDED + entity_task.result = AgentRunResponse( + messages=[ChatMessage(role="assistant", text='{"answer": "42"}')] + ).to_dict() + + # Clear pending_tasks to simulate that parent has processed the child + task.pending_tasks.clear() + + # Call try_set_value + task.try_set_value(entity_task) + + # Verify task completed and value was parsed + assert task.state == TaskState.SUCCEEDED + assert isinstance(task.result, AgentRunResponse) + assert isinstance(task.result.value, TestSchema) + assert task.result.value.answer == "42" + + def test_ensure_response_format_parses_value(self) -> None: + """Test _ensure_response_format correctly parses response value.""" + from pydantic import BaseModel + + class SampleSchema(BaseModel): + name: str + + task = self._create_agent_task() + response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')]) + + # Value should be None initially + assert response.value is None + + # Parse the value + task._ensure_response_format(SampleSchema, "test-correlation", response) + + # Value should now be parsed + assert isinstance(response.value, SampleSchema) + assert response.value.name == "test" + + def test_ensure_response_format_skips_if_already_parsed(self) -> None: + """Test _ensure_response_format does not re-parse if value already matches format.""" + from pydantic import BaseModel + + class SampleSchema(BaseModel): + name: str + + task = self._create_agent_task() + existing_value = SampleSchema(name="existing") + response = AgentRunResponse( + messages=[ChatMessage(role="assistant", text='{"name": "new"}')], + value=existing_value, + ) + + # Call _ensure_response_format + task._ensure_response_format(SampleSchema, "test-correlation", response) + + # Value should remain unchanged (not re-parsed) + assert response.value is existing_value + assert response.value.name == "existing" + + class TestDurableAIAgent: """Test suite for DurableAIAgent wrapper.""" @@ -111,22 +276,19 @@ def test_run_creates_entity_call(self) -> None: mock_context.instance_id = "test-instance-001" mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) - # Mock call_entity to return a Task-like object - mock_task = Mock() - mock_task._is_scheduled = False # Task attribute that orchestration checks - - mock_context.call_entity = Mock(return_value=mock_task) + entity_task = _create_entity_task() + mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") # Create thread thread = agent.get_new_thread() - # Call run() - it should return the Task directly + # Call run() - returns AgentTask directly task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True) - # Verify run() returns the Task from call_entity - assert task == mock_task + assert isinstance(task, AgentTask) + assert task.children[0] == entity_task # Verify call_entity was called with correct parameters assert mock_context.call_entity.called @@ -145,19 +307,18 @@ def test_run_without_thread(self) -> None: """Test that run() works without explicit thread (creates unique session key).""" mock_context = Mock() mock_context.instance_id = "test-instance-002" - # Two calls to new_uuid: one for session_key, one for correlationId mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"]) - mock_task = Mock() - mock_task._is_scheduled = False - mock_context.call_entity = Mock(return_value=mock_task) + entity_task = _create_entity_task() + mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") # Call without thread task = agent.run(messages="Test message") - assert task == mock_task + assert isinstance(task, AgentTask) + assert task.children[0] == entity_task # Verify the entity ID uses the auto-generated GUID with dafx- prefix call_args = mock_context.call_entity.call_args @@ -172,9 +333,8 @@ def test_run_with_response_format(self) -> None: mock_context = Mock() mock_context.instance_id = "test-instance-003" - mock_task = Mock() - mock_task._is_scheduled = False - mock_context.call_entity = Mock(return_value=mock_task) + entity_task = _create_entity_task() + mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") @@ -188,7 +348,8 @@ class SampleSchema(BaseModel): task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema) - assert task == mock_task + assert isinstance(task, AgentTask) + assert task.children[0] == entity_task # Verify schema was passed in the call_entity arguments call_args = mock_context.call_entity.call_args @@ -221,8 +382,8 @@ def test_run_with_chat_message(self) -> None: mock_context = Mock() mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) - mock_task = Mock() - mock_context.call_entity = Mock(return_value=mock_task) + entity_task = _create_entity_task() + mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") thread = agent.get_new_thread() @@ -231,7 +392,8 @@ def test_run_with_chat_message(self) -> None: msg = ChatMessage(role="user", text="Hello") task = agent.run(messages=msg, thread=thread) - assert task == mock_task + assert isinstance(task, AgentTask) + assert task.children[0] == entity_task # Verify message was converted to string call_args = mock_context.call_entity.call_args @@ -255,7 +417,7 @@ def test_entity_id_format(self) -> None: mock_context = Mock() mock_context.new_uuid = Mock(return_value="test-guid-789") - mock_context.call_entity = Mock(return_value=Mock()) + mock_context.call_entity = Mock(return_value=_create_entity_task()) agent = DurableAIAgent(mock_context, "WriterAgent") thread = agent.get_new_thread() @@ -314,13 +476,9 @@ def test_sequential_agent_calls_simulation(self) -> None: # Track entity calls entity_calls: list[dict[str, Any]] = [] - def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock: + def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase: entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data}) - - # Return a mock Task - mock_task = Mock() - mock_task._is_scheduled = False - return mock_task + return _create_entity_task() mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) @@ -330,13 +488,13 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic # Create thread thread = agent.get_new_thread() - # First call - returns Task + # First call - returns AgentTask task1 = agent.run("Write something", thread=thread) - assert hasattr(task1, "_is_scheduled") + assert isinstance(task1, AgentTask) - # Second call - returns Task + # Second call - returns AgentTask task2 = agent.run("Improve: something", thread=thread) - assert hasattr(task2, "_is_scheduled") + assert isinstance(task2, AgentTask) # Verify both calls used the same entity (same session key) assert len(entity_calls) == 2 @@ -356,11 +514,9 @@ def test_multiple_agents_in_orchestration(self) -> None: entity_calls: list[str] = [] - def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock: + def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase: entity_calls.append(str(entity_id)) - mock_task = Mock() - mock_task._is_scheduled = False - return mock_task + return _create_entity_task() mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) @@ -371,12 +527,12 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic writer_thread = writer.get_new_thread() editor_thread = editor.get_new_thread() - # Call both agents - returns Tasks + # Call both agents - returns AgentTasks writer_task = writer.run("Write", thread=writer_thread) editor_task = editor.run("Edit", thread=editor_thread) - assert hasattr(writer_task, "_is_scheduled") - assert hasattr(editor_task, "_is_scheduled") + assert isinstance(writer_task, AgentTask) + assert isinstance(editor_task, AgentTask) # Verify different entity IDs were used assert len(entity_calls) == 2 diff --git a/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py b/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py index 377af2763a..cc05a323f3 100644 --- a/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py +++ b/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py @@ -57,7 +57,7 @@ def single_agent_orchestration(context: DurableOrchestrationContext): improved_prompt = ( "Improve this further while keeping it under 25 words: " - f"{initial.get('response', '').strip()}" + f"{initial.text}" ) refined = yield writer.run( @@ -65,7 +65,7 @@ def single_agent_orchestration(context: DurableOrchestrationContext): thread=writer_thread, ) - return refined.get("response", "") + return refined.text # 5. HTTP endpoint to kick off the orchestration and return the status query URI. diff --git a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py index 59f59a4f40..69ea8816b2 100644 --- a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py +++ b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py @@ -10,8 +10,9 @@ import json import logging -from typing import Any +from typing import Any, cast +from agent_framework import AgentRunResponse import azure.functions as func from agent_framework.azure import AgentFunctionApp, AzureOpenAIChatClient from azure.durable_functions import DurableOrchestrationClient, DurableOrchestrationContext @@ -63,14 +64,19 @@ def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext): physicist_thread = physicist.get_new_thread() chemist_thread = chemist.get_new_thread() + # Create tasks from agent.run() calls physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread) chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread) - results = yield context.task_all([physicist_task, chemist_task]) + # Execute both tasks concurrently using task_all + task_results = yield context.task_all([physicist_task, chemist_task]) + + physicist_result = cast(AgentRunResponse, task_results[0]) + chemist_result = cast(AgentRunResponse, task_results[1]) return { - "physicist": results[0].get("response", ""), - "chemist": results[1].get("response", ""), + "physicist": physicist_result.text, + "chemist": chemist_result.text, } diff --git a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py index d32884014f..2ee445d423 100644 --- a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py +++ b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py @@ -102,7 +102,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext): response_format=SpamDetectionResult, ) - spam_result = cast(SpamDetectionResult, _coerce_structured(spam_result_raw, SpamDetectionResult)) + spam_result = cast(SpamDetectionResult, spam_result_raw.value) if spam_result.is_spam: result = yield context.call_activity("handle_spam_email", spam_result.reason) @@ -123,7 +123,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext): response_format=EmailResponse, ) - email_result = cast(EmailResponse, _coerce_structured(email_result_raw, EmailResponse)) + email_result = cast(EmailResponse, email_result_raw.value) result = yield context.call_activity("send_email", email_result.response) return result @@ -231,24 +231,6 @@ def _build_status_url(request_url: str, instance_id: str, *, route: str) -> str: return f"{base_url}/api/{route}/status/{instance_id}" -def _coerce_structured(result: Mapping[str, Any], model: type[BaseModel]) -> BaseModel: - structured = result.get("structured_response") if isinstance(result, Mapping) else None - if structured is not None: - return model.model_validate(structured) - - response_text = result.get("response") if isinstance(result, Mapping) else None - if isinstance(response_text, str) and response_text.strip(): - try: - parsed = json.loads(response_text) - if isinstance(parsed, Mapping): - return model.model_validate(parsed) - except json.JSONDecodeError: - logger.warning("[ConditionalOrchestration] Failed to parse agent JSON response; raising error.") - - # If parsing failed, raise to surface the issue to the caller. - raise ValueError(f"Agent response could not be parsed as {model.__name__}.") - - """ Expected response from `POST /api/spamdetection/run`: diff --git a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py index 69b2e3fae9..c8e2bbaa9c 100644 --- a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -100,7 +100,12 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): thread=writer_thread, response_format=GeneratedContent, ) - content = _coerce_generated_content(initial_raw) + + content = initial_raw.value + logger.info("Type of content after extraction: %s", type(content)) + + if content is None or not isinstance(content, GeneratedContent): + raise ValueError("Agent returned no content after extraction.") attempt = 0 while attempt < payload.max_review_attempts: @@ -142,7 +147,12 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): thread=writer_thread, response_format=GeneratedContent, ) - content = _coerce_generated_content(rewritten_raw) + + rewritten_value = rewritten_raw.value + if rewritten_value is None or not isinstance(rewritten_value, GeneratedContent): + raise ValueError("Agent returned no content after rewrite.") + + content = rewritten_value else: context.set_custom_status( f"Human approval timed out after {payload.approval_timeout_hours} hour(s). Treating as rejection." @@ -317,23 +327,6 @@ def _build_status_url(request_url: str, instance_id: str, *, route: str) -> str: return f"{base_url}/api/{route}/status/{instance_id}" -def _coerce_generated_content(result: Mapping[str, Any]) -> GeneratedContent: - structured = result.get("structured_response") if isinstance(result, Mapping) else None - if structured is not None: - return GeneratedContent.model_validate(structured) - - response_text = result.get("response") if isinstance(result, Mapping) else None - if isinstance(response_text, str) and response_text.strip(): - try: - parsed = json.loads(response_text) - if isinstance(parsed, Mapping): - return GeneratedContent.model_validate(parsed) - except json.JSONDecodeError: - logger.warning("[HITL] Failed to parse agent JSON response; falling back to defaults.") - - raise ValueError("Agent response could not be parsed as GeneratedContent.") - - def _parse_human_approval(raw: Any) -> HumanApproval: if isinstance(raw, Mapping): return HumanApproval.model_validate(raw)