From b3c8c876bd0b2163777381ca4cc401a4ca31bb3d Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 30 Jan 2026 18:17:57 -0600 Subject: [PATCH 01/29] Add workflow support for Azure Functions --- .../agent_framework_azurefunctions/_app.py | 318 ++++++- .../agent_framework_azurefunctions/_utils.py | 640 +++++++++++++ .../_workflow.py | 854 ++++++++++++++++++ .../packages/azurefunctions/tests/test_app.py | 97 ++ .../azurefunctions/tests/test_utils.py | 463 ++++++++++ .../azurefunctions/tests/test_workflow.py | 379 ++++++++ .../_workflows/_agent_executor.py | 5 + .../demo.http | 2 +- .../function_app.py | 4 +- .../09_workflow_shared_state/.gitignore | 18 + .../09_workflow_shared_state/README.md | 99 ++ .../09_workflow_shared_state/demo.http | 31 + .../09_workflow_shared_state/function_app.py | 294 ++++++ .../09_workflow_shared_state/host.json | 16 + .../local.settings.json.sample | 11 + .../09_workflow_shared_state/requirements.txt | 3 + .../10_workflow_no_shared_state/.env.sample | 4 + .../10_workflow_no_shared_state/.gitignore | 2 + .../10_workflow_no_shared_state/README.md | 159 ++++ .../10_workflow_no_shared_state/demo.http | 32 + .../function_app.py | 244 +++++ .../10_workflow_no_shared_state/host.json | 16 + .../local.settings.json.sample | 12 + .../requirements.txt | 3 + .../11_workflow_parallel/.env.template | 14 + .../11_workflow_parallel/.gitignore | 4 + .../11_workflow_parallel/README.md | 193 ++++ .../11_workflow_parallel/demo.http | 29 + .../11_workflow_parallel/function_app.py | 538 +++++++++++ .../11_workflow_parallel/host.json | 16 + .../local.settings.json.sample | 12 + .../11_workflow_parallel/requirements.txt | 3 + .../12_workflow_hitl/.gitignore | 5 + .../12_workflow_hitl/README.md | 141 +++ .../12_workflow_hitl/demo.http | 123 +++ .../12_workflow_hitl/function_app.py | 468 ++++++++++ .../12_workflow_hitl/host.json | 16 + .../local.settings.json.sample | 11 + .../12_workflow_hitl/requirements.txt | 3 + 39 files changed, 5278 insertions(+), 4 deletions(-) create mode 100644 python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py create mode 100644 python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py create mode 100644 python/packages/azurefunctions/tests/test_utils.py create mode 100644 python/packages/azurefunctions/tests/test_workflow.py create mode 100644 python/samples/getting_started/azure_functions/09_workflow_shared_state/.gitignore create mode 100644 python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md create mode 100644 python/samples/getting_started/azure_functions/09_workflow_shared_state/demo.http create mode 100644 python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py create mode 100644 python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json create mode 100644 python/samples/getting_started/azure_functions/09_workflow_shared_state/local.settings.json.sample create mode 100644 python/samples/getting_started/azure_functions/09_workflow_shared_state/requirements.txt create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.env.sample create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.gitignore create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/README.md create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/demo.http create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/local.settings.json.sample create mode 100644 python/samples/getting_started/azure_functions/10_workflow_no_shared_state/requirements.txt create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/.env.template create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/.gitignore create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/README.md create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/demo.http create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/host.json create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/local.settings.json.sample create mode 100644 python/samples/getting_started/azure_functions/11_workflow_parallel/requirements.txt create mode 100644 python/samples/getting_started/azure_functions/12_workflow_hitl/.gitignore create mode 100644 python/samples/getting_started/azure_functions/12_workflow_hitl/README.md create mode 100644 python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http create mode 100644 python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py create mode 100644 python/samples/getting_started/azure_functions/12_workflow_hitl/host.json create mode 100644 python/samples/getting_started/azure_functions/12_workflow_hitl/local.settings.json.sample create mode 100644 python/samples/getting_started/azure_functions/12_workflow_hitl/requirements.txt diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 1c77d22e4d..7b3863eeda 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -8,6 +8,7 @@ from __future__ import annotations +import asyncio import json import logging import re @@ -19,7 +20,7 @@ import azure.durable_functions as df import azure.functions as func -from agent_framework import SupportsAgentRun +from agent_framework import AgentExecutor, SupportsAgentRun, Workflow, WorkflowOutputEvent, get_logger from agent_framework_durabletask import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, @@ -42,6 +43,14 @@ from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor +from ._utils import ( + CapturingRunnerContext, + _execute_hitl_response_handler, + deserialize_value, + reconstruct_message_for_handler, + serialize_message, +) +from ._workflow import run_workflow_orchestrator logger = logging.getLogger("agent_framework.azurefunctions") @@ -154,16 +163,19 @@ def my_orchestration(context): enable_mcp_tool_trigger: Whether MCP tool triggers are created for agents max_poll_retries: Maximum polling attempts when waiting for responses poll_interval_seconds: Delay (seconds) between polling attempts + workflow: Optional Workflow instance for workflow orchestration """ _agent_metadata: dict[str, AgentMetadata] enable_health_check: bool enable_http_endpoints: bool enable_mcp_tool_trigger: bool + workflow: Workflow | None def __init__( self, agents: list[SupportsAgentRun] | None = None, + workflow: Workflow | None = None, http_auth_level: func.AuthLevel = func.AuthLevel.FUNCTION, enable_health_check: bool = True, enable_http_endpoints: bool = True, @@ -175,6 +187,7 @@ def __init__( """Initialize the AgentFunctionApp. :param agents: List of agent instances to register. + :param workflow: Optional Workflow instance to extract agents from and set up orchestration. :param http_auth_level: HTTP authentication level (default: ``func.AuthLevel.FUNCTION``). :param enable_health_check: Enable the built-in health check endpoint (default: ``True``). :param enable_http_endpoints: Enable HTTP endpoints for agents (default: ``True``). @@ -199,6 +212,7 @@ def __init__( self.enable_http_endpoints = enable_http_endpoints self.enable_mcp_tool_trigger = enable_mcp_tool_trigger self.default_callback = default_callback + self.workflow = workflow try: retries = int(max_poll_retries) @@ -212,6 +226,20 @@ def __init__( interval = DEFAULT_POLL_INTERVAL_SECONDS self.poll_interval_seconds = interval if interval > 0 else DEFAULT_POLL_INTERVAL_SECONDS + # If workflow is provided, extract agents and set up orchestration + if workflow: + if agents is None: + agents = [] + logger.debug("[AgentFunctionApp] Extracting agents from workflow") + for executor in workflow.executors.values(): + if isinstance(executor, AgentExecutor): + agents.append(executor.agent) + else: + # Setup individual activity for each non-agent executor + self._setup_executor_activity(executor.id) + + self._setup_workflow_orchestration() + if agents: # Register all provided agents logger.debug(f"[AgentFunctionApp] Registering {len(agents)} agent(s)") @@ -224,6 +252,294 @@ def __init__( logger.debug("[AgentFunctionApp] Initialization complete") + def _setup_executor_activity(self, executor_id: str) -> None: + """Register an activity for executing a specific non-agent executor. + + Args: + executor_id: The ID of the executor to create an activity for. + """ + activity_name = f"dafx-{executor_id}" + logger.debug(f"[AgentFunctionApp] Registering activity '{activity_name}' for executor '{executor_id}'") + + # Capture executor_id in closure + captured_executor_id = executor_id + + @self.function_name(activity_name) + @self.activity_trigger(input_name="inputData") + def executor_activity(inputData: str) -> str: + """Activity to execute a specific non-agent executor. + + Note: We use str type annotations instead of dict to work around + Azure Functions worker type validation issues with dict[str, Any]. + """ + import json as json_module + + from agent_framework import SharedState + + data = json_module.loads(inputData) + message_data = data["message"] + shared_state_snapshot = data.get("shared_state_snapshot", {}) + source_executor_ids = data.get("source_executor_ids", ["__orchestrator__"]) + + if not self.workflow: + raise RuntimeError("Workflow not initialized in AgentFunctionApp") + + executor = self.workflow.executors.get(captured_executor_id) + if not executor: + raise ValueError(f"Unknown executor: {captured_executor_id}") + + # Reconstruct message - try to match handler's expected types using public input_types + message = reconstruct_message_for_handler(message_data, executor.input_types) + + # Check if this is a HITL response message + is_hitl_response = isinstance(message_data, dict) and message_data.get("__hitl_response__") + + async def run() -> dict[str, Any]: + # Create runner context and shared state + runner_context = CapturingRunnerContext() + shared_state = SharedState() + + # Deserialize shared state values to reconstruct dataclasses/Pydantic models + deserialized_state = {k: deserialize_value(v) for k, v in (shared_state_snapshot or {}).items()} + original_snapshot = dict(deserialized_state) + await shared_state.import_state(deserialized_state) + + if is_hitl_response: + # Handle HITL response by calling the executor's @response_handler + await _execute_hitl_response_handler( + executor=executor, + hitl_message=message_data, + shared_state=shared_state, + runner_context=runner_context, + ) + else: + # Execute using the public execute() method + await executor.execute( + message=message, + source_executor_ids=source_executor_ids, + shared_state=shared_state, + runner_context=runner_context, + ) + + # Export current state and compute changes + current_state = await shared_state.export_state() + original_keys = set(original_snapshot.keys()) + current_keys = set(current_state.keys()) + + # Deleted = was in original, not in current + deletes = original_keys - current_keys + + # Updates = keys in current that are new or have different values + updates = { + k: v for k, v in current_state.items() if k not in original_snapshot or original_snapshot[k] != v + } + + # Drain messages and events from runner context + sent_messages = await runner_context.drain_messages() + events = await runner_context.drain_events() + + # Extract outputs from WorkflowOutputEvent instances + outputs: list[Any] = [] + for event in events: + if isinstance(event, WorkflowOutputEvent): + outputs.append(serialize_message(event.data)) + + # Get pending request info events for HITL + pending_request_info_events = await runner_context.get_pending_request_info_events() + + # Serialize pending request info events for orchestrator + serialized_pending_requests = [] + for _request_id, event in pending_request_info_events.items(): + serialized_pending_requests.append({ + "request_id": event.request_id, + "source_executor_id": event.source_executor_id, + "data": serialize_message(event.data), + "request_type": f"{type(event.data).__module__}:{type(event.data).__name__}", + "response_type": f"{event.response_type.__module__}:{event.response_type.__name__}" + if event.response_type + else None, + }) + + # Serialize messages for JSON compatibility + serialized_sent_messages = [] + for _source_id, msg_list in sent_messages.items(): + for msg in msg_list: + serialized_sent_messages.append({ + "message": serialize_message(msg.data), + "target_id": msg.target_id, + "source_id": msg.source_id, + }) + + serialized_updates = {k: serialize_message(v) for k, v in updates.items()} + + return { + "sent_messages": serialized_sent_messages, + "outputs": outputs, + "shared_state_updates": serialized_updates, + "shared_state_deletes": list(deletes), + "pending_request_info_events": serialized_pending_requests, + } + + result = asyncio.run(run()) + return json_module.dumps(result) + + # Ensure the function is registered (prevents garbage collection) + _ = executor_activity + + def _setup_workflow_orchestration(self) -> None: + """Register the workflow orchestration and related HTTP endpoints.""" + + @self.orchestration_trigger(context_name="context") + def workflow_orchestrator(context: df.DurableOrchestrationContext): # type: ignore[type-arg] + """Generic orchestrator for running the configured workflow.""" + input_data = context.get_input() + + # Ensure input is a string for the agent + initial_message = json.dumps(input_data) if isinstance(input_data, (dict, list)) else str(input_data) + + # Create local shared state dict for cross-executor state sharing + shared_state: dict[str, Any] = {} + + outputs = yield from run_workflow_orchestrator(context, self.workflow, initial_message, shared_state) + # Durable Functions runtime extracts return value from StopIteration + return outputs # noqa: B901 + + @self.route(route="workflow/run", methods=["POST"]) + @self.durable_client_input(client_name="client") + async def start_workflow_orchestration( + req: func.HttpRequest, client: df.DurableOrchestrationClient + ) -> func.HttpResponse: + """HTTP endpoint to start the workflow.""" + try: + req_body = req.get_json() + except ValueError: + return func.HttpResponse( + json.dumps({"error": "Invalid JSON body"}), + status_code=400, + mimetype="application/json", + ) + + instance_id = await client.start_new("workflow_orchestrator", client_input=req_body) + + base_url = self._build_base_url(req.url) + status_url = f"{base_url}/api/workflow/status/{instance_id}" + + return func.HttpResponse( + json.dumps({ + "instanceId": instance_id, + "statusQueryGetUri": status_url, + "respondUri": f"{base_url}/api/workflow/respond/{instance_id}/{{requestId}}", + "message": "Workflow started", + }), + status_code=202, + mimetype="application/json", + ) + + @self.route(route="workflow/status/{instanceId}", methods=["GET"]) + @self.durable_client_input(client_name="client") + async def get_workflow_status( + req: func.HttpRequest, client: df.DurableOrchestrationClient + ) -> func.HttpResponse: + """HTTP endpoint to get workflow status.""" + instance_id = req.route_params.get("instanceId") + status = await client.get_status(instance_id) + + if not status: + return func.HttpResponse( + json.dumps({"error": "Instance not found"}), + status_code=404, + mimetype="application/json", + ) + + response = { + "instanceId": status.instance_id, + "runtimeStatus": status.runtime_status.name if status.runtime_status else None, + "customStatus": status.custom_status, + "output": status.output, + "error": status.output if status.runtime_status == df.OrchestrationRuntimeStatus.Failed else None, + "createdTime": status.created_time.isoformat() if status.created_time else None, + "lastUpdatedTime": status.last_updated_time.isoformat() if status.last_updated_time else None, + } + + # Add pending HITL requests info if available + custom_status = status.custom_status or {} + if isinstance(custom_status, dict) and custom_status.get("pending_requests"): + base_url = self._build_base_url(req.url) + pending_requests = [] + for req_id, req_data in custom_status["pending_requests"].items(): + pending_requests.append({ + "requestId": req_id, + "sourceExecutor": req_data.get("source_executor_id"), + "requestData": req_data.get("data"), + "requestType": req_data.get("request_type"), + "responseType": req_data.get("response_type"), + "respondUrl": f"{base_url}/api/workflow/respond/{instance_id}/{req_id}", + }) + response["pendingHumanInputRequests"] = pending_requests + + return func.HttpResponse( + json.dumps(response, default=str), + status_code=200, + mimetype="application/json", + ) + + @self.route(route="workflow/respond/{instanceId}/{requestId}", methods=["POST"]) + @self.durable_client_input(client_name="client") + async def send_hitl_response(req: func.HttpRequest, client: df.DurableOrchestrationClient) -> func.HttpResponse: + """HTTP endpoint to send a response to a pending HITL request. + + The requestId in the URL corresponds to the request_id from the RequestInfoEvent. + The request body should contain the response data matching the expected response_type. + """ + instance_id = req.route_params.get("instanceId") + request_id = req.route_params.get("requestId") + + if not instance_id or not request_id: + return func.HttpResponse( + json.dumps({"error": "Instance ID and Request ID are required."}), + status_code=400, + mimetype="application/json", + ) + + try: + response_data = req.get_json() + except ValueError: + return func.HttpResponse( + json.dumps({"error": "Request body must be valid JSON."}), + status_code=400, + mimetype="application/json", + ) + + # Send the response as an external event + # The request_id is used as the event name for correlation + await client.raise_event( + instance_id=instance_id, + event_name=request_id, + event_data=response_data, + ) + + return func.HttpResponse( + json.dumps({ + "message": "Response delivered successfully", + "instanceId": instance_id, + "requestId": request_id, + }), + status_code=200, + mimetype="application/json", + ) + + def _build_status_url(self, request_url: str, instance_id: str) -> str: + """Build the status URL for a workflow instance.""" + base_url = self._build_base_url(request_url) + return f"{base_url}/api/workflow/status/{instance_id}" + + def _build_base_url(self, request_url: str) -> str: + """Extract the base URL from a request URL.""" + base_url, _, _ = request_url.partition("/api/") + if not base_url: + base_url = request_url.rstrip("/") + return base_url + @property def agents(self) -> dict[str, SupportsAgentRun]: """Returns dict of agent names to agent instances. diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py new file mode 100644 index 0000000000..3b25f5db85 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py @@ -0,0 +1,640 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Utility functions for workflow execution. + +This module provides helper functions for serialization, deserialization, and +context management used by the workflow orchestrator and executors. +""" + +from __future__ import annotations + +import asyncio +import logging +import types +from dataclasses import asdict, fields, is_dataclass +from typing import Any, Union, get_args, get_origin + +from agent_framework import ( + AgentExecutorRequest, + AgentExecutorResponse, + AgentRunResponse, + ChatMessage, + CheckpointStorage, + Message, + RequestInfoEvent, + RunnerContext, + SharedState, + WorkflowCheckpoint, + WorkflowEvent, +) +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class CapturingRunnerContext(RunnerContext): + """A RunnerContext implementation that captures messages and events for Azure Functions activities. + + This context is designed for executing standard Executors within Azure Functions activities. + It captures all messages and events produced during execution without requiring durable + entity storage, allowing the results to be returned to the orchestrator. + + Unlike the full InProcRunnerContext, this implementation: + - Does NOT support checkpointing (always returns False for has_checkpointing) + - Does NOT support streaming (always returns False for is_streaming) + - Captures messages and events in memory for later retrieval + + The orchestrator manages state coordination; this context just captures execution output. + """ + + def __init__(self) -> None: + """Initialize the capturing runner context.""" + self._messages: dict[str, list[Message]] = {} + self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() + self._pending_request_info_events: dict[str, RequestInfoEvent] = {} + self._workflow_id: str | None = None + self._streaming: bool = False + + # region Messaging + + async def send_message(self, message: Message) -> None: + """Capture a message sent by an executor.""" + self._messages.setdefault(message.source_id, []) + self._messages[message.source_id].append(message) + + async def drain_messages(self) -> dict[str, list[Message]]: + """Drain and return all captured messages.""" + from copy import copy + + messages = copy(self._messages) + self._messages.clear() + return messages + + async def has_messages(self) -> bool: + """Check if there are any captured messages.""" + return bool(self._messages) + + # endregion Messaging + + # region Events + + async def add_event(self, event: WorkflowEvent) -> None: + """Capture an event produced during execution.""" + await self._event_queue.put(event) + + async def drain_events(self) -> list[WorkflowEvent]: + """Drain all currently queued events without blocking.""" + events: list[WorkflowEvent] = [] + while True: + try: + events.append(self._event_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return events + + async def has_events(self) -> bool: + """Check if there are any queued events.""" + return not self._event_queue.empty() + + async def next_event(self) -> WorkflowEvent: + """Wait for and return the next event.""" + return await self._event_queue.get() + + # endregion Events + + # region Checkpointing (not supported in activity context) + + def has_checkpointing(self) -> bool: + """Checkpointing is not supported in activity context.""" + return False + + def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None: + """No-op: checkpointing not supported in activity context.""" + pass + + def clear_runtime_checkpoint_storage(self) -> None: + """No-op: checkpointing not supported in activity context.""" + pass + + async def create_checkpoint( + self, + shared_state: SharedState, + iteration_count: int, + metadata: dict[str, Any] | None = None, + ) -> str: + """Checkpointing not supported in activity context.""" + raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") + + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: + """Checkpointing not supported in activity context.""" + raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") + + async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: + """Checkpointing not supported in activity context.""" + raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") + + # endregion Checkpointing + + # region Workflow Configuration + + def set_workflow_id(self, workflow_id: str) -> None: + """Set the workflow ID.""" + self._workflow_id = workflow_id + + def reset_for_new_run(self) -> None: + """Reset the context for a new run.""" + self._messages.clear() + self._event_queue = asyncio.Queue() + self._pending_request_info_events.clear() + self._streaming = False + + def set_streaming(self, streaming: bool) -> None: + """Set streaming mode (not used in activity context).""" + self._streaming = streaming + + def is_streaming(self) -> bool: + """Check if streaming mode is enabled (always False in activity context).""" + return self._streaming + + # endregion Workflow Configuration + + # region Request Info Events + + async def add_request_info_event(self, event: RequestInfoEvent) -> None: + """Add a RequestInfoEvent and track it for correlation.""" + self._pending_request_info_events[event.request_id] = event + await self.add_event(event) + + async def send_request_info_response(self, request_id: str, response: Any) -> None: + """Send a response correlated to a pending request. + + Note: This is not supported in activity context since human-in-the-loop + scenarios require orchestrator-level coordination. + """ + raise NotImplementedError( + "send_request_info_response is not supported in Azure Functions activity context. " + "Human-in-the-loop scenarios should be handled at the orchestrator level." + ) + + async def get_pending_request_info_events(self) -> dict[str, RequestInfoEvent]: + """Get the mapping of request IDs to their corresponding RequestInfoEvent.""" + return dict(self._pending_request_info_events) + + # endregion Request Info Events + + +def _serialize_value(value: Any) -> Any: + """Recursively serialize a value for JSON compatibility.""" + # Handle None + if value is None: + return None + + # Handle objects with to_dict() method (like ChatMessage) + if hasattr(value, "to_dict") and callable(value.to_dict): + return value.to_dict() + + # Handle dataclasses + if is_dataclass(value) and not isinstance(value, type): + d: dict[str, Any] = {} + for k, v in asdict(value).items(): + d[k] = _serialize_value(v) + d["__type__"] = type(value).__name__ + d["__module__"] = type(value).__module__ + return d + + # Handle Pydantic models + if isinstance(value, BaseModel): + d = value.model_dump() + d["__type__"] = type(value).__name__ + d["__module__"] = type(value).__module__ + return d + + # Handle lists + if isinstance(value, list): + return [_serialize_value(item) for item in value] + + # Handle dicts + if isinstance(value, dict): + return {k: _serialize_value(v) for k, v in value.items()} + + # Handle primitives and other types + return value + + +def serialize_message(message: Any) -> Any: + """Helper to serialize messages for activity input. + + Adds type metadata (__type__, __module__) to dataclasses and Pydantic models + to enable reconstruction on the receiving end. Handles nested ChatMessage + and other objects with to_dict() methods. + """ + return _serialize_value(message) + + +def deserialize_value(data: Any, type_registry: dict[str, type] | None = None) -> Any: + """Attempt to deserialize a value using embedded type metadata. + + Args: + data: The serialized data (could be dict with __type__ metadata) + type_registry: Optional dict mapping type names to types for reconstruction + + Returns: + Reconstructed object if type metadata found and type available, otherwise original data + """ + if not isinstance(data, dict): + return data + + type_name = data.get("__type__") + module_name = data.get("__module__") + + # Special handling for MAF types with nested objects + if type_name == "AgentExecutorRequest" or ("messages" in data and "should_respond" in data): + try: + return reconstruct_agent_executor_request(data) + except Exception: + logger.debug("Could not reconstruct as AgentExecutorRequest, trying next strategy") + + if type_name == "AgentExecutorResponse" or ("executor_id" in data and "agent_run_response" in data): + try: + return reconstruct_agent_executor_response(data) + except Exception: + logger.debug("Could not reconstruct as AgentExecutorResponse, trying next strategy") + + if not type_name: + return data + + # Try to find the type + target_type = None + + # First check the registry + if type_registry and type_name in type_registry: + target_type = type_registry[type_name] + else: + # Try to import from module + if module_name: + try: + import importlib + + module = importlib.import_module(module_name) + target_type = getattr(module, type_name, None) + except Exception: + logger.debug("Could not import module %s for type %s", module_name, type_name) + + if target_type: + # Remove metadata before reconstruction + clean_data = {k: v for k, v in data.items() if not k.startswith("__")} + try: + if is_dataclass(target_type): + # Recursively reconstruct nested fields for dataclasses + reconstructed_data = _reconstruct_dataclass_fields(target_type, clean_data) + return target_type(**reconstructed_data) + if issubclass(target_type, BaseModel): + # Pydantic handles nested model validation automatically + return target_type.model_validate(clean_data) + except Exception: + logger.debug("Could not reconstruct type %s from data", type_name) + + return data + + +def _reconstruct_dataclass_fields(dataclass_type: type, data: dict[str, Any]) -> dict[str, Any]: + """Recursively reconstruct nested dataclass and Pydantic fields. + + This function processes each field of a dataclass, looking up the expected type + from type hints and reconstructing nested objects (dataclasses, Pydantic models, lists). + + Args: + dataclass_type: The dataclass type being constructed + data: The dict of field values + + Returns: + Dict with nested objects properly reconstructed + """ + if not is_dataclass(dataclass_type): + return data + + result = {} + type_hints = {} + + # Get type hints for the dataclass + try: + import typing + + type_hints = typing.get_type_hints(dataclass_type) + except Exception: + # Fall back to field annotations if get_type_hints fails + for f in fields(dataclass_type): + type_hints[f.name] = f.type + + for key, value in data.items(): + if key not in type_hints: + result[key] = value + continue + + field_type = type_hints[key] + + # Handle Optional types (Union with None) + origin = get_origin(field_type) + if origin is Union or isinstance(field_type, types.UnionType): + args = get_args(field_type) + # Filter out NoneType to get the actual type + non_none_types = [t for t in args if t is not type(None)] + if len(non_none_types) == 1: + field_type = non_none_types[0] + + # Recursively reconstruct the value + result[key] = _reconstruct_typed_value(value, field_type) + + return result + + +def _reconstruct_typed_value(value: Any, target_type: type) -> Any: + """Reconstruct a single value to the target type. + + Handles dataclasses, Pydantic models, and lists with typed elements. + + Args: + value: The value to reconstruct + target_type: The expected type + + Returns: + The reconstructed value + """ + if value is None: + return None + + # If already the correct type, return as-is + try: + if isinstance(value, target_type): + return value + except TypeError: + # target_type might not be a valid type for isinstance + pass + + # Handle dict values that need reconstruction + if isinstance(value, dict): + # First try deserialize_value which uses embedded type metadata + if "__type__" in value: + deserialized = deserialize_value(value) + if deserialized is not value: + return deserialized + + # Handle Pydantic models + if hasattr(target_type, "model_validate"): + try: + return target_type.model_validate(value) + except Exception: + logger.debug("Could not validate Pydantic model %s", target_type) + + # Handle dataclasses + if is_dataclass(target_type) and isinstance(target_type, type): + try: + # Recursively reconstruct nested fields + reconstructed = _reconstruct_dataclass_fields(target_type, value) + return target_type(**reconstructed) + except Exception: + logger.debug("Could not construct dataclass %s", target_type) + + # Handle list values + if isinstance(value, list): + origin = get_origin(target_type) + if origin is list: + args = get_args(target_type) + if args: + element_type = args[0] + return [_reconstruct_typed_value(item, element_type) for item in value] + + return value + + +def reconstruct_agent_executor_request(data: dict[str, Any]) -> AgentExecutorRequest: + """Helper to reconstruct AgentExecutorRequest from dict.""" + # Reconstruct ChatMessage objects in messages + messages_data = data.get("messages", []) + messages = [ChatMessage.from_dict(m) if isinstance(m, dict) else m for m in messages_data] + + return AgentExecutorRequest(messages=messages, should_respond=data.get("should_respond", True)) + + +def reconstruct_agent_executor_response(data: dict[str, Any]) -> AgentExecutorResponse: + """Helper to reconstruct AgentExecutorResponse from dict.""" + # Reconstruct AgentRunResponse + arr_data = data.get("agent_run_response", {}) + agent_run_response = AgentRunResponse.from_dict(arr_data) if isinstance(arr_data, dict) else arr_data + + # Reconstruct full_conversation + fc_data = data.get("full_conversation", []) + full_conversation = None + if fc_data: + full_conversation = [ChatMessage.from_dict(m) if isinstance(m, dict) else m for m in fc_data] + + return AgentExecutorResponse( + executor_id=data["executor_id"], agent_run_response=agent_run_response, full_conversation=full_conversation + ) + + +def reconstruct_message_for_handler(data: Any, input_types: list[type[Any]]) -> Any: + """Attempt to reconstruct a message to match one of the handler's expected types. + + Handles: + - Dicts with __type__ metadata -> reconstructs to original dataclass/Pydantic model + - Lists (from fan-in) -> recursively reconstructs each item + - Union types (T | U) -> tries each type in the union + - AgentExecutorRequest/Response -> special handling for nested ChatMessage objects + + Args: + data: The serialized message data (could be dict, str, list, etc.) + input_types: List of message types the executor can accept + + Returns: + Reconstructed message if possible, otherwise the original data + """ + # Flatten union types in input_types (e.g., T | U becomes [T, U]) + flattened_types: list[type[Any]] = [] + for input_type in input_types: + origin = get_origin(input_type) + # Handle both typing.Union and types.UnionType (Python 3.10+ | syntax) + if origin is Union or isinstance(input_type, types.UnionType): + # This is a Union type (T | U), extract the component types + flattened_types.extend(get_args(input_type)) + else: + flattened_types.append(input_type) + + # Handle lists (fan-in aggregation) - recursively reconstruct each item + if isinstance(data, list): + # Extract element types from list[T] annotations in input_types if possible + element_types: list[type[Any]] = [] + for input_type in input_types: + origin = get_origin(input_type) + if origin is list: + args = get_args(input_type) + if args: + # Handle union types inside list[T | U] + for arg in args: + arg_origin = get_origin(arg) + if arg_origin is Union or isinstance(arg, types.UnionType): + element_types.extend(get_args(arg)) + else: + element_types.append(arg) + + # Recursively reconstruct each item in the list + return [reconstruct_message_for_handler(item, element_types or flattened_types) for item in data] + + if not isinstance(data, dict): + return data + + # Try AgentExecutorResponse first - it needs special handling for nested objects + if "executor_id" in data and "agent_run_response" in data: + try: + return reconstruct_agent_executor_response(data) + except Exception: + logger.debug("Could not reconstruct as AgentExecutorResponse in handler context") + + # Try AgentExecutorRequest - also needs special handling for nested ChatMessage objects + if "messages" in data and "should_respond" in data: + try: + return reconstruct_agent_executor_request(data) + except Exception: + logger.debug("Could not reconstruct as AgentExecutorRequest in handler context") + + # Try deserialize_value which uses embedded type metadata (__type__, __module__) + if "__type__" in data: + deserialized = deserialize_value(data) + if deserialized is not data: + return deserialized + + # Try to match against input types by checking dict keys vs dataclass fields + # Filter out metadata keys when comparing + data_keys = {k for k in data if not k.startswith("__")} + for msg_type in flattened_types: + if is_dataclass(msg_type): + # Check if the dict keys match the dataclass fields + field_names = {f.name for f in fields(msg_type)} + if field_names == data_keys or field_names.issubset(data_keys): + try: + # Remove metadata before constructing + clean_data = {k: v for k, v in data.items() if not k.startswith("__")} + # Recursively reconstruct nested objects based on field types + reconstructed_data = _reconstruct_dataclass_fields(msg_type, clean_data) + return msg_type(**reconstructed_data) + except Exception: + logger.debug("Could not construct %s from matching fields", msg_type.__name__) + + return data + + +# ============================================================================ +# HITL Response Handler Execution +# ============================================================================ + + +async def _execute_hitl_response_handler( + executor: Any, + hitl_message: dict[str, Any], + shared_state: SharedState, + runner_context: CapturingRunnerContext, +) -> None: + """Execute a HITL response handler on an executor. + + This function handles the delivery of a HITL response to the executor's + @response_handler method. It: + 1. Deserializes the original request and response + 2. Finds the matching response handler based on types + 3. Creates a WorkflowContext and invokes the handler + + Args: + executor: The executor instance that has a @response_handler + hitl_message: The HITL response message containing original_request and response + shared_state: The shared state for the workflow context + runner_context: The runner context for capturing outputs + """ + from agent_framework._workflows._workflow_context import WorkflowContext + + # Extract the response data + original_request_data = hitl_message.get("original_request") + response_data = hitl_message.get("response") + response_type_str = hitl_message.get("response_type") + + # Deserialize the original request + original_request = deserialize_value(original_request_data) + + # Deserialize the response - try to match expected type + response = _deserialize_hitl_response(response_data, response_type_str) + + # Find the matching response handler + handler = executor._find_response_handler(original_request, response) + + if handler is None: + logger.warning( + "No response handler found for HITL response in executor %s. Request type: %s, Response type: %s", + executor.id, + type(original_request).__name__, + type(response).__name__, + ) + return + + # Create a WorkflowContext for the handler + # Use a special source ID to indicate this is a HITL response + ctx = WorkflowContext( + executor=executor, + source_executor_ids=["__hitl_response__"], + runner_context=runner_context, + shared_state=shared_state, + ) + + # Call the response handler + # Note: handler is already a partial with original_request bound + logger.debug( + "Invoking response handler for HITL request in executor %s", + executor.id, + ) + await handler(response, ctx) + + +def _deserialize_hitl_response(response_data: Any, response_type_str: str | None) -> Any: + """Deserialize a HITL response to its expected type. + + Args: + response_data: The raw response data (typically a dict from JSON) + response_type_str: The fully qualified type name (module:classname) + + Returns: + The deserialized response, or the original data if deserialization fails + """ + logger.debug( + "Deserializing HITL response. response_type_str=%s, response_data type=%s", + response_type_str, + type(response_data).__name__, + ) + + if response_data is None: + return None + + # If already a primitive, return as-is + if not isinstance(response_data, dict): + logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) + return response_data + + # Try to deserialize using the type hint + if response_type_str: + try: + module_name, class_name = response_type_str.rsplit(":", 1) + import importlib + + module = importlib.import_module(module_name) + response_type = getattr(module, class_name, None) + + if response_type: + logger.debug("Found response type %s, attempting reconstruction", response_type) + # Use the shared reconstruction logic which handles nested objects + result = _reconstruct_typed_value(response_data, response_type) + logger.debug("Reconstructed response type: %s", type(result).__name__) + return result + logger.warning("Could not find class %s in module %s", class_name, module_name) + + except Exception as e: + logger.warning("Could not deserialize HITL response to %s: %s", response_type_str, e) + + # Fall back to generic deserialization + logger.debug("Falling back to generic deserialization") + return deserialize_value(response_data) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py new file mode 100644 index 0000000000..20bb12db63 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -0,0 +1,854 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Workflow Execution for Durable Functions. + +This module provides the workflow orchestration engine that executes MAF Workflows +using Azure Durable Functions. It reuses MAF's edge group routing logic while +adapting execution to the DF generator-based model (yield instead of await). + +Key components: +- run_workflow_orchestrator: Main orchestration function for workflow execution +- route_message_through_edge_groups: Routing helper using MAF edge group APIs +- build_agent_executor_response: Helper to construct AgentExecutorResponse + +HITL (Human-in-the-Loop) Support: +- Detects pending RequestInfoEvents from executor activities +- Uses wait_for_external_event to pause for human input +- Routes responses back to executor's @response_handler methods +""" + +from __future__ import annotations + +import json +import logging +from collections import defaultdict +from dataclasses import dataclass +from datetime import timedelta +from enum import Enum +from typing import Any + +from agent_framework import ( + AgentExecutor, + AgentExecutorRequest, + AgentExecutorResponse, + AgentRunResponse, + ChatMessage, + Workflow, +) +from agent_framework._workflows._edge import ( + EdgeGroup, + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) +from agent_framework_durabletask import AgentSessionId, DurableAgentThread, DurableAIAgent +from azure.durable_functions import DurableOrchestrationContext + +from ._orchestration import AzureFunctionsAgentExecutor +from ._utils import deserialize_value, serialize_message + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Task Types and Data Structures +# ============================================================================ + + +class TaskType(Enum): + """Type of executor task.""" + + AGENT = "agent" + ACTIVITY = "activity" + + +@dataclass +class TaskMetadata: + """Metadata for a pending task.""" + + executor_id: str + message: Any + source_executor_id: str + task_type: TaskType + remaining_messages: list[tuple[str, Any, str]] | None = None # For agents with multiple messages + + +@dataclass +class ExecutorResult: + """Result from executing an agent or activity.""" + + executor_id: str + output_message: AgentExecutorResponse | None + activity_result: dict[str, Any] | None + task_type: TaskType + + +@dataclass +class PendingHITLRequest: + """Tracks a pending Human-in-the-Loop request in the orchestrator. + + Attributes: + request_id: Unique identifier for correlation with external events + source_executor_id: The executor that called ctx.request_info() + request_data: The serialized request payload + request_type: Fully qualified type name of the request data + response_type: Fully qualified type name of expected response + """ + + request_id: str + source_executor_id: str + request_data: Any + request_type: str | None + response_type: str | None + + +# Default timeout for HITL requests (72 hours) +DEFAULT_HITL_TIMEOUT_HOURS = 72.0 + + +# ============================================================================ +# Routing Functions +# ============================================================================ + + +def route_message_through_edge_groups( + edge_groups: list[EdgeGroup], + source_id: str, + message: Any, +) -> list[str]: + """Route a message through edge groups to find target executor IDs. + + Delegates to MAF's edge group routing logic instead of manual inspection. + + Args: + edge_groups: List of EdgeGroup instances from the workflow + source_id: The ID of the source executor + message: The message to route + + Returns: + List of target executor IDs that should receive the message + """ + targets: list[str] = [] + + for group in edge_groups: + if source_id not in group.source_executor_ids: + continue + + # SwitchCaseEdgeGroup and FanOutEdgeGroup use selection_func + if isinstance(group, (SwitchCaseEdgeGroup, FanOutEdgeGroup)): + if group.selection_func is not None: + selected = group.selection_func(message, group.target_executor_ids) + targets.extend(selected) + else: + # No selection func means broadcast to all targets + targets.extend(group.target_executor_ids) + + elif isinstance(group, SingleEdgeGroup): + # SingleEdgeGroup has exactly one edge + edge = group.edges[0] + if edge.should_route(message): + targets.append(edge.target_id) + + elif isinstance(group, FanInEdgeGroup): + # FanIn is handled separately in the orchestrator loop + # since it requires aggregation + pass + + else: + # Generic EdgeGroup: check each edge's condition + for edge in group.edges: + if edge.source_id == source_id and edge.should_route(message): + targets.append(edge.target_id) + + return targets + + +def build_agent_executor_response( + executor_id: str, + response_text: str | None, + structured_response: dict[str, Any] | None, + previous_message: Any, +) -> AgentExecutorResponse: + """Build an AgentExecutorResponse from entity response data. + + Shared helper to construct the response object consistently. + + Args: + executor_id: The ID of the executor that produced the response + response_text: Plain text response from the agent (if any) + structured_response: Structured JSON response (if any) + previous_message: The input message that triggered this response + + Returns: + AgentExecutorResponse with reconstructed conversation + """ + final_text = response_text + if structured_response: + final_text = json.dumps(structured_response) + + assistant_message = ChatMessage(role="assistant", text=final_text) + + agent_run_response = AgentRunResponse( + messages=[assistant_message], + ) + + # Build conversation history + full_conversation: list[ChatMessage] = [] + if isinstance(previous_message, AgentExecutorResponse) and previous_message.full_conversation: + full_conversation.extend(previous_message.full_conversation) + elif isinstance(previous_message, str): + full_conversation.append(ChatMessage(role="user", text=previous_message)) + + full_conversation.append(assistant_message) + + return AgentExecutorResponse( + executor_id=executor_id, + agent_run_response=agent_run_response, + full_conversation=full_conversation, + ) + + +# ============================================================================ +# Task Preparation Helpers +# ============================================================================ + + +def _prepare_agent_task( + context: DurableOrchestrationContext, + executor_id: str, + message: Any, +) -> Any: + """Prepare an agent task for execution. + + Args: + context: The Durable Functions orchestration context + executor_id: The agent executor ID (agent name) + message: The input message for the agent + + Returns: + A task that can be yielded to execute the agent + """ + message_content = _extract_message_content(message) + session_id = AgentSessionId(name=executor_id, key=context.instance_id) + thread = DurableAgentThread(session_id=session_id) + + az_executor = AzureFunctionsAgentExecutor(context) + agent = DurableAIAgent(az_executor, executor_id) + return agent.run(message_content, thread=thread) + + +def _prepare_activity_task( + context: DurableOrchestrationContext, + executor_id: str, + message: Any, + source_executor_id: str, + shared_state_snapshot: dict[str, Any] | None, +) -> Any: + """Prepare an activity task for execution. + + Args: + context: The Durable Functions orchestration context + executor_id: The activity executor ID + message: The input message for the activity + source_executor_id: The ID of the executor that sent the message + shared_state_snapshot: Current shared state snapshot + + Returns: + A task that can be yielded to execute the activity + """ + activity_input = { + "executor_id": executor_id, + "message": serialize_message(message), + "shared_state_snapshot": shared_state_snapshot, + "source_executor_ids": [source_executor_id], + } + activity_input_json = json.dumps(activity_input) + # Use the prefixed activity name that matches the registered function + activity_name = f"dafx-{executor_id}" + return context.call_activity(activity_name, activity_input_json) + + +# ============================================================================ +# Result Processing Helpers +# ============================================================================ + + +def _process_agent_response( + agent_response: AgentRunResponse, + executor_id: str, + message: Any, +) -> ExecutorResult: + """Process an agent response into an ExecutorResult. + + Args: + agent_response: The response from the agent + executor_id: The agent executor ID + message: The original input message + + Returns: + ExecutorResult containing the processed response + """ + response_text = agent_response.text if agent_response else None + structured_response = None + + if agent_response and agent_response.value is not None: + if hasattr(agent_response.value, "model_dump"): + structured_response = agent_response.value.model_dump() + elif isinstance(agent_response.value, dict): + structured_response = agent_response.value + + output_message = build_agent_executor_response( + executor_id=executor_id, + response_text=response_text, + structured_response=structured_response, + previous_message=message, + ) + + return ExecutorResult( + executor_id=executor_id, + output_message=output_message, + activity_result=None, + task_type=TaskType.AGENT, + ) + + +def _process_activity_result( + result_json: str | None, + executor_id: str, + shared_state: dict[str, Any] | None, + workflow_outputs: list[Any], +) -> ExecutorResult: + """Process an activity result and apply shared state updates. + + Args: + result_json: The JSON result from the activity + executor_id: The activity executor ID + shared_state: The shared state dict to update (mutated in place) + workflow_outputs: List to append outputs to (mutated in place) + + Returns: + ExecutorResult containing the processed result + """ + result = json.loads(result_json) if result_json else None + + # Apply shared state updates + if shared_state is not None and result: + if result.get("shared_state_updates"): + updates = result["shared_state_updates"] + logger.debug("[workflow] Applying SharedState updates from %s: %s", executor_id, updates) + shared_state.update(updates) + if result.get("shared_state_deletes"): + deletes = result["shared_state_deletes"] + logger.debug("[workflow] Applying SharedState deletes from %s: %s", executor_id, deletes) + for key in deletes: + shared_state.pop(key, None) + + # Collect outputs + if result and result.get("outputs"): + workflow_outputs.extend(result["outputs"]) + + return ExecutorResult( + executor_id=executor_id, + output_message=None, + activity_result=result, + task_type=TaskType.ACTIVITY, + ) + + +# ============================================================================ +# Routing Helpers +# ============================================================================ + + +def _route_result_messages( + result: ExecutorResult, + workflow: Workflow, + next_pending_messages: dict[str, list[tuple[Any, str]]], + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], +) -> None: + """Route messages from an executor result to their targets. + + Args: + result: The executor result containing messages to route + workflow: The workflow definition + next_pending_messages: Dict to accumulate next iteration's messages (mutated) + fan_in_pending: Dict tracking fan-in state (mutated) + """ + executor_id = result.executor_id + messages_to_route: list[tuple[Any, str | None]] = [] + + # Collect messages from agent response + if result.output_message: + messages_to_route.append((result.output_message, None)) + + # Collect sent_messages from activity results + if result.activity_result and result.activity_result.get("sent_messages"): + for msg_data in result.activity_result["sent_messages"]: + sent_msg = msg_data.get("message") + target_id = msg_data.get("target_id") + if sent_msg: + sent_msg = deserialize_value(sent_msg) + messages_to_route.append((sent_msg, target_id)) + + # Route each message + for msg_to_route, explicit_target in messages_to_route: + logger.debug("Routing output from %s", executor_id) + + # If explicit target specified, route directly + if explicit_target: + if explicit_target not in next_pending_messages: + next_pending_messages[explicit_target] = [] + next_pending_messages[explicit_target].append((msg_to_route, executor_id)) + logger.debug("Routed message from %s to explicit target %s", executor_id, explicit_target) + continue + + # Check for FanInEdgeGroup sources + for group in workflow.edge_groups: + if isinstance(group, FanInEdgeGroup) and executor_id in group.source_executor_ids: + fan_in_pending[group.id][executor_id].append((msg_to_route, executor_id)) + logger.debug("Accumulated message for FanIn group %s from %s", group.id, executor_id) + + # Use MAF's edge group routing for other edge types + targets = route_message_through_edge_groups(workflow.edge_groups, executor_id, msg_to_route) + + for target_id in targets: + logger.debug("Routing to %s", target_id) + if target_id not in next_pending_messages: + next_pending_messages[target_id] = [] + next_pending_messages[target_id].append((msg_to_route, executor_id)) + + +def _check_fan_in_ready( + workflow: Workflow, + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], + next_pending_messages: dict[str, list[tuple[Any, str]]], +) -> None: + """Check if any FanInEdgeGroups are ready and deliver their messages. + + Args: + workflow: The workflow definition + fan_in_pending: Dict tracking fan-in state (mutated - cleared when delivered) + next_pending_messages: Dict to add aggregated messages to (mutated) + """ + for group in workflow.edge_groups: + if not isinstance(group, FanInEdgeGroup): + continue + + pending_sources = fan_in_pending.get(group.id, {}) + + # Check if all sources have contributed at least one message + if not all(src in pending_sources and pending_sources[src] for src in group.source_executor_ids): + continue + + # Aggregate all messages into a single list + aggregated: list[Any] = [] + aggregated_sources: list[str] = [] + for src in group.source_executor_ids: + for msg, msg_source in pending_sources[src]: + aggregated.append(msg) + aggregated_sources.append(msg_source) + + target_id = group.target_executor_ids[0] + logger.debug("FanIn group %s ready, delivering %d messages to %s", group.id, len(aggregated), target_id) + + if target_id not in next_pending_messages: + next_pending_messages[target_id] = [] + + first_source = aggregated_sources[0] if aggregated_sources else "__fan_in__" + next_pending_messages[target_id].append((aggregated, first_source)) + + # Clear the pending sources for this group + fan_in_pending[group.id] = defaultdict(list) + + +# ============================================================================ +# HITL (Human-in-the-Loop) Helpers +# ============================================================================ + + +def _collect_hitl_requests( + result: ExecutorResult, + pending_hitl_requests: dict[str, PendingHITLRequest], +) -> None: + """Collect pending HITL requests from an activity result. + + Args: + result: The executor result that may contain pending request info events + pending_hitl_requests: Dict to accumulate pending requests (mutated) + """ + if result.activity_result and result.activity_result.get("pending_request_info_events"): + for req_data in result.activity_result["pending_request_info_events"]: + request_id = req_data.get("request_id") + if request_id: + pending_hitl_requests[request_id] = PendingHITLRequest( + request_id=request_id, + source_executor_id=req_data.get("source_executor_id", result.executor_id), + request_data=req_data.get("data"), + request_type=req_data.get("request_type"), + response_type=req_data.get("response_type"), + ) + logger.debug( + "Collected HITL request %s from executor %s", + request_id, + result.executor_id, + ) + + +def _route_hitl_response( + hitl_request: PendingHITLRequest, + raw_response: Any, + pending_messages: dict[str, list[tuple[Any, str]]], +) -> None: + """Route a HITL response back to the source executor's @response_handler. + + The response is packaged as a special HITL response message that the executor + activity can recognize and route to the appropriate @response_handler method. + + Args: + hitl_request: The original HITL request + raw_response: The raw response data from the external event + pending_messages: Dict to add the response message to (mutated) + """ + # Create a message structure that the executor can recognize + # This mimics what the InProcRunnerContext does for request_info responses + response_message = { + "__hitl_response__": True, + "request_id": hitl_request.request_id, + "original_request": hitl_request.request_data, + "response": raw_response, + "response_type": hitl_request.response_type, + } + + target_id = hitl_request.source_executor_id + if target_id not in pending_messages: + pending_messages[target_id] = [] + + # Use a special source ID to indicate this is a HITL response + source_id = f"__hitl_response__{hitl_request.request_id}" + pending_messages[target_id].append((response_message, source_id)) + + logger.debug( + "Routed HITL response for request %s to executor %s", + hitl_request.request_id, + target_id, + ) + + +# ============================================================================ +# Main Orchestrator +# ============================================================================ + + +def run_workflow_orchestrator( + context: DurableOrchestrationContext, + workflow: Workflow, + initial_message: Any, + shared_state: dict[str, Any] | None = None, + hitl_timeout_hours: float = DEFAULT_HITL_TIMEOUT_HOURS, +): + """Traverse and execute the workflow graph using Durable Functions. + + This orchestrator reuses MAF's edge group routing logic while adapting + execution to the DF generator-based model (yield instead of await). + + Supports: + - SingleEdgeGroup: Direct 1:1 routing with optional condition + - SwitchCaseEdgeGroup: First matching condition wins + - FanOutEdgeGroup: Broadcast to multiple targets - **executed in parallel** + - FanInEdgeGroup: Aggregates messages from multiple sources before delivery + - SharedState: Local shared state accessible to all executors + - HITL: Human-in-the-loop via request_info / @response_handler pattern + + Execution model: + - All pending executors (agents AND activities) run in parallel via single task_all() + - Multiple messages to the SAME agent are processed sequentially for conversation coherence + - SharedState updates are applied in order after parallel tasks complete + - HITL requests pause the orchestration until external events are received + + Args: + context: The Durable Functions orchestration context + workflow: The MAF Workflow instance to execute + initial_message: The initial message to send to the start executor + shared_state: Optional dict for cross-executor state sharing (local to orchestration) + hitl_timeout_hours: Timeout in hours for HITL requests (default: 72 hours) + + Returns: + List of workflow outputs collected from executor activities + """ + pending_messages: dict[str, list[tuple[Any, str]]] = { + workflow.start_executor_id: [(initial_message, "__workflow_start__")] + } + workflow_outputs: list[Any] = [] + iteration = 0 + + # Track pending sources for FanInEdgeGroups using defaultdict for cleaner access + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]] = { + group.id: defaultdict(list) for group in workflow.edge_groups if isinstance(group, FanInEdgeGroup) + } + + # Track pending HITL requests + pending_hitl_requests: dict[str, PendingHITLRequest] = {} + + while pending_messages and iteration < workflow.max_iterations: + logger.debug("Orchestrator iteration %d", iteration) + next_pending_messages: dict[str, list[tuple[Any, str]]] = {} + + # Phase 1: Prepare all tasks (agents and activities unified) + all_tasks, task_metadata_list, remaining_agent_messages = _prepare_all_tasks( + context, workflow, pending_messages, shared_state + ) + + # Phase 2: Execute all tasks in parallel (single task_all for true parallelism) + all_results: list[ExecutorResult] = [] + if all_tasks: + logger.debug("Executing %d tasks in parallel (agents + activities)", len(all_tasks)) + raw_results = yield context.task_all(all_tasks) + logger.debug("All %d tasks completed", len(all_tasks)) + + # Process results based on task type + for idx, raw_result in enumerate(raw_results): + metadata = task_metadata_list[idx] + if metadata.task_type == TaskType.AGENT: + result = _process_agent_response(raw_result, metadata.executor_id, metadata.message) + else: + result = _process_activity_result(raw_result, metadata.executor_id, shared_state, workflow_outputs) + all_results.append(result) + + # Phase 3: Process sequential agent messages (for same-agent conversation coherence) + for executor_id, message, _source_executor_id in remaining_agent_messages: + logger.debug("Processing sequential message for agent: %s", executor_id) + task = _prepare_agent_task(context, executor_id, message) + agent_response: AgentRunResponse = yield task + logger.debug("Agent %s sequential response completed", executor_id) + + result = _process_agent_response(agent_response, executor_id, message) + all_results.append(result) + + # Phase 4: Collect pending HITL requests from activity results + for result in all_results: + _collect_hitl_requests(result, pending_hitl_requests) + + # Phase 5: Route all results to next iteration + for result in all_results: + _route_result_messages(result, workflow, next_pending_messages, fan_in_pending) + + # Phase 6: Check if any FanInEdgeGroups are ready to deliver + _check_fan_in_ready(workflow, fan_in_pending, next_pending_messages) + + pending_messages = next_pending_messages + + # Phase 7: Handle HITL - if no pending work but HITL requests exist, wait for responses + if not pending_messages and pending_hitl_requests: + logger.debug("Workflow paused for HITL - %d pending requests", len(pending_hitl_requests)) + + # Update custom status to expose pending requests + context.set_custom_status({ + "state": "waiting_for_human_input", + "pending_requests": { + req_id: { + "request_id": req.request_id, + "source_executor_id": req.source_executor_id, + "data": req.request_data, + "request_type": req.request_type, + "response_type": req.response_type, + } + for req_id, req in pending_hitl_requests.items() + }, + }) + + # Wait for external events for each pending request + # Process responses one at a time to maintain ordering + for request_id, hitl_request in list(pending_hitl_requests.items()): + logger.debug("Waiting for HITL response for request: %s", request_id) + + # Create tasks for approval and timeout + approval_task = context.wait_for_external_event(request_id) + timeout_task = context.create_timer(context.current_utc_datetime + timedelta(hours=hitl_timeout_hours)) + + winner = yield context.task_any([approval_task, timeout_task]) + + if winner == approval_task: + # Cancel the timeout + timeout_task.cancel() + + # Get the response + raw_response = approval_task.result + logger.debug( + "Received HITL response for request %s. Type: %s, Value: %s", + request_id, + type(raw_response).__name__, + raw_response, + ) + + # Durable Functions may return a JSON string; parse it if so + if isinstance(raw_response, str): + try: + import json + + raw_response = json.loads(raw_response) + logger.debug("Parsed JSON string response to: %s", type(raw_response).__name__) + except (json.JSONDecodeError, TypeError): + logger.debug("Response is not JSON, keeping as string") + + # Remove from pending + del pending_hitl_requests[request_id] + + # Route the response back to the source executor's @response_handler + _route_hitl_response( + hitl_request, + raw_response, + pending_messages, + ) + else: + # Timeout occurred + logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) + raise TimeoutError( + f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." + ) + + # Clear custom status after HITL is resolved + context.set_custom_status({"state": "running"}) + + iteration += 1 + + # Durable Functions runtime extracts return value from StopIteration + return workflow_outputs # noqa: B901 + + +def _prepare_all_tasks( + context: DurableOrchestrationContext, + workflow: Workflow, + pending_messages: dict[str, list[tuple[Any, str]]], + shared_state: dict[str, Any] | None, +) -> tuple[list[Any], list[TaskMetadata], list[tuple[str, Any, str]]]: + """Prepare all pending tasks for parallel execution. + + Groups agent messages by executor ID so that only the first message per agent + runs in the parallel batch. Additional messages to the same agent are returned + for sequential processing. + + Args: + context: The Durable Functions orchestration context + workflow: The workflow definition + pending_messages: Messages pending for each executor + shared_state: Current shared state snapshot + + Returns: + Tuple of (tasks, metadata, remaining_agent_messages): + - tasks: List of tasks ready for task_all() + - metadata: TaskMetadata for each task (same order as tasks) + - remaining_agent_messages: Agent messages requiring sequential processing + """ + all_tasks: list[Any] = [] + task_metadata_list: list[TaskMetadata] = [] + remaining_agent_messages: list[tuple[str, Any, str]] = [] + + # Group agent messages by executor_id for sequential handling of same-agent messages + agent_messages_by_executor: dict[str, list[tuple[str, Any, str]]] = defaultdict(list) + + # Categorize all pending messages + for executor_id, messages_with_sources in pending_messages.items(): + executor = workflow.executors[executor_id] + is_agent = isinstance(executor, AgentExecutor) + + for message, source_executor_id in messages_with_sources: + if is_agent: + agent_messages_by_executor[executor_id].append((executor_id, message, source_executor_id)) + else: + # Activity tasks can all run in parallel + logger.debug("Preparing activity task: %s", executor_id) + task = _prepare_activity_task(context, executor_id, message, source_executor_id, shared_state) + all_tasks.append(task) + task_metadata_list.append( + TaskMetadata( + executor_id=executor_id, + message=message, + source_executor_id=source_executor_id, + task_type=TaskType.ACTIVITY, + ) + ) + + # Process agent messages: first message per agent goes to parallel batch + for executor_id, messages_list in agent_messages_by_executor.items(): + first_msg = messages_list[0] + remaining = messages_list[1:] + + logger.debug("Preparing agent task: %s", executor_id) + task = _prepare_agent_task(context, first_msg[0], first_msg[1]) + all_tasks.append(task) + task_metadata_list.append( + TaskMetadata( + executor_id=first_msg[0], + message=first_msg[1], + source_executor_id=first_msg[2], + task_type=TaskType.AGENT, + ) + ) + + # Queue remaining messages for sequential processing + remaining_agent_messages.extend(remaining) + + return all_tasks, task_metadata_list, remaining_agent_messages + + +# ============================================================================ +# Message Content Extraction +# ============================================================================ + + +def _extract_message_content(message: Any) -> str: + """Extract text content from various message types.""" + message_content = "" + if isinstance(message, AgentExecutorResponse) and message.agent_run_response: + if message.agent_run_response.text: + message_content = message.agent_run_response.text + elif message.agent_run_response.messages: + message_content = message.agent_run_response.messages[-1].text or "" + elif isinstance(message, AgentExecutorRequest) and message.messages: + # Extract text from the last message in the request + message_content = message.messages[-1].text or "" + elif isinstance(message, dict): + message_content = _extract_message_content_from_dict(message) + elif isinstance(message, str): + message_content = message + + return message_content + + +def _extract_message_content_from_dict(message: dict[str, Any]) -> str: + """Extract text content from serialized message dictionaries.""" + message_content = "" + + if message.get("messages"): + # AgentExecutorRequest dict - messages is a list of ChatMessage dicts + last_msg = message["messages"][-1] + if isinstance(last_msg, dict): + # ChatMessage serialized via to_dict() has structure: + # {"type": "chat_message", "contents": [{"type": "text", "text": "..."}], ...} + if last_msg.get("contents"): + first_content = last_msg["contents"][0] + if isinstance(first_content, dict): + message_content = first_content.get("text") or "" + # Fallback to direct text field if not in contents structure + if not message_content: + message_content = last_msg.get("text") or last_msg.get("_text") or "" + elif hasattr(last_msg, "text"): + message_content = last_msg.text or "" + elif "agent_run_response" in message: + # AgentExecutorResponse dict + arr = message.get("agent_run_response", {}) + if isinstance(arr, dict): + message_content = arr.get("text") or "" + if not message_content and arr.get("messages"): + last_msg = arr["messages"][-1] + if isinstance(last_msg, dict): + # Check for contents structure first + if last_msg.get("contents"): + first_content = last_msg["contents"][0] + if isinstance(first_content, dict): + message_content = first_content.get("text") or "" + if not message_content: + message_content = last_msg.get("text") or last_msg.get("_text") or "" + + return message_content diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 2729d73e7e..68bcb3b9b0 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1317,5 +1317,102 @@ def test_coerce_to_bool_with_none(self) -> None: assert app._coerce_to_bool([]) is False +class TestAgentFunctionAppWorkflow: + """Test suite for AgentFunctionApp workflow support.""" + + def test_init_with_workflow_stores_workflow(self) -> None: + """Test that workflow is stored when provided.""" + mock_workflow = Mock() + mock_workflow.executors = {} + + with ( + patch.object(AgentFunctionApp, "_setup_executor_activity"), + patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), + ): + app = AgentFunctionApp(workflow=mock_workflow) + + assert app.workflow is mock_workflow + + def test_init_with_workflow_extracts_agents(self) -> None: + """Test that agents are extracted from workflow executors.""" + from agent_framework import AgentExecutor + + mock_agent = Mock() + mock_agent.name = "WorkflowAgent" + + mock_executor = Mock(spec=AgentExecutor) + mock_executor.agent = mock_agent + + mock_workflow = Mock() + mock_workflow.executors = {"WorkflowAgent": mock_executor} + + with ( + patch.object(AgentFunctionApp, "_setup_executor_activity"), + patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), + patch.object(AgentFunctionApp, "_setup_agent_functions"), + ): + app = AgentFunctionApp(workflow=mock_workflow) + + assert "WorkflowAgent" in app.agents + + def test_init_with_workflow_calls_setup_methods(self) -> None: + """Test that workflow setup methods are called.""" + mock_workflow = Mock() + mock_workflow.executors = {} + + with ( + patch.object(AgentFunctionApp, "_setup_executor_activity") as setup_exec, + patch.object(AgentFunctionApp, "_setup_workflow_orchestration") as setup_orch, + ): + AgentFunctionApp(workflow=mock_workflow) + + setup_exec.assert_called_once() + setup_orch.assert_called_once() + + def test_init_without_workflow_does_not_call_workflow_setup(self) -> None: + """Test that workflow setup is not called when no workflow provided.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + with ( + patch.object(AgentFunctionApp, "_setup_executor_activity") as setup_exec, + patch.object(AgentFunctionApp, "_setup_workflow_orchestration") as setup_orch, + ): + AgentFunctionApp(agents=[mock_agent]) + + setup_exec.assert_not_called() + setup_orch.assert_not_called() + + def test_build_status_url(self) -> None: + """Test _build_status_url constructs correct URL.""" + mock_workflow = Mock() + mock_workflow.executors = {} + + with ( + patch.object(AgentFunctionApp, "_setup_executor_activity"), + patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), + ): + app = AgentFunctionApp(workflow=mock_workflow) + + url = app._build_status_url("http://localhost:7071/api/workflow/run", "instance-123") + + assert url == "http://localhost:7071/api/workflow/status/instance-123" + + def test_build_status_url_handles_trailing_slash(self) -> None: + """Test _build_status_url handles URLs without /api/ correctly.""" + mock_workflow = Mock() + mock_workflow.executors = {} + + with ( + patch.object(AgentFunctionApp, "_setup_executor_activity"), + patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), + ): + app = AgentFunctionApp(workflow=mock_workflow) + + url = app._build_status_url("http://localhost:7071/", "instance-456") + + assert "instance-456" in url + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py new file mode 100644 index 0000000000..c95b663161 --- /dev/null +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -0,0 +1,463 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for workflow utility functions.""" + +from dataclasses import dataclass +from unittest.mock import Mock + +import pytest +from agent_framework import ( + AgentExecutorRequest, + AgentExecutorResponse, + AgentRunResponse, + ChatMessage, + Message, + WorkflowOutputEvent, +) +from pydantic import BaseModel + +from agent_framework_azurefunctions._utils import ( + CapturingRunnerContext, + deserialize_value, + reconstruct_agent_executor_request, + reconstruct_agent_executor_response, + reconstruct_message_for_handler, + serialize_message, +) + + +class TestCapturingRunnerContext: + """Test suite for CapturingRunnerContext.""" + + @pytest.fixture + def context(self) -> CapturingRunnerContext: + """Create a fresh CapturingRunnerContext for each test.""" + return CapturingRunnerContext() + + @pytest.mark.asyncio + async def test_send_message_captures_message(self, context: CapturingRunnerContext) -> None: + """Test that send_message captures messages correctly.""" + message = Message(data="test data", target_id="target_1", source_id="source_1") + + await context.send_message(message) + + messages = await context.drain_messages() + assert "source_1" in messages + assert len(messages["source_1"]) == 1 + assert messages["source_1"][0].data == "test data" + + @pytest.mark.asyncio + async def test_send_multiple_messages_groups_by_source(self, context: CapturingRunnerContext) -> None: + """Test that messages are grouped by source_id.""" + msg1 = Message(data="msg1", target_id="target", source_id="source_a") + msg2 = Message(data="msg2", target_id="target", source_id="source_a") + msg3 = Message(data="msg3", target_id="target", source_id="source_b") + + await context.send_message(msg1) + await context.send_message(msg2) + await context.send_message(msg3) + + messages = await context.drain_messages() + assert len(messages["source_a"]) == 2 + assert len(messages["source_b"]) == 1 + + @pytest.mark.asyncio + async def test_drain_messages_clears_messages(self, context: CapturingRunnerContext) -> None: + """Test that drain_messages clears the message store.""" + message = Message(data="test", target_id="t", source_id="s") + await context.send_message(message) + + await context.drain_messages() # First drain + messages = await context.drain_messages() # Second drain + + assert messages == {} + + @pytest.mark.asyncio + async def test_has_messages_returns_correct_status(self, context: CapturingRunnerContext) -> None: + """Test has_messages returns correct boolean.""" + assert await context.has_messages() is False + + await context.send_message(Message(data="test", target_id="t", source_id="s")) + + assert await context.has_messages() is True + + @pytest.mark.asyncio + async def test_add_event_queues_event(self, context: CapturingRunnerContext) -> None: + """Test that add_event queues events correctly.""" + event = WorkflowOutputEvent(data="output", source_executor_id="exec_1") + + await context.add_event(event) + + events = await context.drain_events() + assert len(events) == 1 + assert isinstance(events[0], WorkflowOutputEvent) + assert events[0].data == "output" + + @pytest.mark.asyncio + async def test_drain_events_clears_queue(self, context: CapturingRunnerContext) -> None: + """Test that drain_events clears the event queue.""" + await context.add_event(WorkflowOutputEvent(data="test", source_executor_id="e")) + + await context.drain_events() # First drain + events = await context.drain_events() # Second drain + + assert events == [] + + @pytest.mark.asyncio + async def test_has_events_returns_correct_status(self, context: CapturingRunnerContext) -> None: + """Test has_events returns correct boolean.""" + assert await context.has_events() is False + + await context.add_event(WorkflowOutputEvent(data="test", source_executor_id="e")) + + assert await context.has_events() is True + + @pytest.mark.asyncio + async def test_next_event_waits_for_event(self, context: CapturingRunnerContext) -> None: + """Test that next_event returns queued events.""" + event = WorkflowOutputEvent(data="waited", source_executor_id="e") + await context.add_event(event) + + result = await context.next_event() + + assert result.data == "waited" + + def test_has_checkpointing_returns_false(self, context: CapturingRunnerContext) -> None: + """Test that checkpointing is not supported.""" + assert context.has_checkpointing() is False + + def test_is_streaming_returns_false_by_default(self, context: CapturingRunnerContext) -> None: + """Test streaming is disabled by default.""" + assert context.is_streaming() is False + + def test_set_streaming(self, context: CapturingRunnerContext) -> None: + """Test setting streaming mode.""" + context.set_streaming(True) + assert context.is_streaming() is True + + context.set_streaming(False) + assert context.is_streaming() is False + + def test_set_workflow_id(self, context: CapturingRunnerContext) -> None: + """Test setting workflow ID.""" + context.set_workflow_id("workflow-123") + assert context._workflow_id == "workflow-123" + + @pytest.mark.asyncio + async def test_reset_for_new_run_clears_state(self, context: CapturingRunnerContext) -> None: + """Test that reset_for_new_run clears all state.""" + await context.send_message(Message(data="test", target_id="t", source_id="s")) + await context.add_event(WorkflowOutputEvent(data="event", source_executor_id="e")) + context.set_streaming(True) + + context.reset_for_new_run() + + assert await context.has_messages() is False + assert await context.has_events() is False + assert context.is_streaming() is False + + @pytest.mark.asyncio + async def test_create_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None: + """Test that checkpointing methods raise NotImplementedError.""" + from agent_framework import SharedState + + with pytest.raises(NotImplementedError): + await context.create_checkpoint(SharedState(), 1) + + @pytest.mark.asyncio + async def test_load_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None: + """Test that load_checkpoint raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + await context.load_checkpoint("some-id") + + @pytest.mark.asyncio + async def test_apply_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None: + """Test that apply_checkpoint raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + await context.apply_checkpoint(Mock()) + + +class TestSerializeMessage: + """Test suite for serialize_message function.""" + + def test_serialize_none(self) -> None: + """Test serializing None.""" + assert serialize_message(None) is None + + def test_serialize_primitive_types(self) -> None: + """Test serializing primitive types.""" + assert serialize_message("hello") == "hello" + assert serialize_message(42) == 42 + assert serialize_message(3.14) == 3.14 + assert serialize_message(True) is True + + def test_serialize_list(self) -> None: + """Test serializing lists.""" + result = serialize_message([1, 2, 3]) + assert result == [1, 2, 3] + + def test_serialize_dict(self) -> None: + """Test serializing dicts.""" + result = serialize_message({"key": "value", "num": 42}) + assert result == {"key": "value", "num": 42} + + def test_serialize_dataclass(self) -> None: + """Test serializing dataclasses with type metadata.""" + + @dataclass + class TestData: + name: str + value: int + + data = TestData(name="test", value=123) + result = serialize_message(data) + + assert result["name"] == "test" + assert result["value"] == 123 + assert result["__type__"] == "TestData" + assert "__module__" in result + + def test_serialize_pydantic_model(self) -> None: + """Test serializing Pydantic models with type metadata.""" + + class TestModel(BaseModel): + title: str + count: int + + model = TestModel(title="Hello", count=5) + result = serialize_message(model) + + assert result["title"] == "Hello" + assert result["count"] == 5 + assert result["__type__"] == "TestModel" + assert "__module__" in result + + def test_serialize_nested_structures(self) -> None: + """Test serializing nested structures.""" + + @dataclass + class Inner: + x: int + + @dataclass + class Outer: + inner: Inner + items: list[int] + + outer = Outer(inner=Inner(x=10), items=[1, 2, 3]) + result = serialize_message(outer) + + assert result["__type__"] == "Outer" + # Nested dataclass is serialized via asdict, which doesn't add __type__ recursively + assert result["inner"]["x"] == 10 + assert result["items"] == [1, 2, 3] + + def test_serialize_object_with_to_dict(self) -> None: + """Test serializing objects with to_dict method.""" + message = ChatMessage(role="user", text="Hello") + result = serialize_message(message) + + # ChatMessage has to_dict() method which returns a specific structure + assert isinstance(result, dict) + assert "contents" in result # ChatMessage uses contents structure + + +class TestDeserializeValue: + """Test suite for deserialize_value function.""" + + def test_deserialize_non_dict_returns_original(self) -> None: + """Test that non-dict values are returned as-is.""" + assert deserialize_value("string") == "string" + assert deserialize_value(42) == 42 + assert deserialize_value([1, 2, 3]) == [1, 2, 3] + + def test_deserialize_dict_without_type_returns_original(self) -> None: + """Test that dicts without type metadata are returned as-is.""" + data = {"key": "value", "num": 42} + result = deserialize_value(data) + assert result == data + + def test_deserialize_agent_executor_request(self) -> None: + """Test deserializing AgentExecutorRequest.""" + data = { + "messages": [{"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Hello"}]}], + "should_respond": True, + } + + result = deserialize_value(data) + + assert isinstance(result, AgentExecutorRequest) + assert len(result.messages) == 1 + assert result.should_respond is True + + def test_deserialize_agent_executor_response(self) -> None: + """Test deserializing AgentExecutorResponse.""" + data = { + "executor_id": "test_exec", + "agent_run_response": { + "type": "agent_run_response", + "messages": [ + {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Hi there"}]} + ], + }, + } + + result = deserialize_value(data) + + assert isinstance(result, AgentExecutorResponse) + assert result.executor_id == "test_exec" + + def test_deserialize_with_type_registry(self) -> None: + """Test deserializing with type registry.""" + + @dataclass + class CustomType: + name: str + + data = {"name": "test", "__type__": "CustomType"} + result = deserialize_value(data, type_registry={"CustomType": CustomType}) + + assert isinstance(result, CustomType) + assert result.name == "test" + + +class TestReconstructAgentExecutorRequest: + """Test suite for reconstruct_agent_executor_request function.""" + + def test_reconstruct_with_chat_messages(self) -> None: + """Test reconstructing request with ChatMessage dicts.""" + data = { + "messages": [ + {"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Hello"}]}, + {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Hi"}]}, + ], + "should_respond": True, + } + + result = reconstruct_agent_executor_request(data) + + assert isinstance(result, AgentExecutorRequest) + assert len(result.messages) == 2 + assert result.should_respond is True + + def test_reconstruct_defaults_should_respond_to_true(self) -> None: + """Test that should_respond defaults to True.""" + data = {"messages": []} + + result = reconstruct_agent_executor_request(data) + + assert result.should_respond is True + + +class TestReconstructAgentExecutorResponse: + """Test suite for reconstruct_agent_executor_response function.""" + + def test_reconstruct_with_agent_run_response(self) -> None: + """Test reconstructing response with agent_run_response.""" + data = { + "executor_id": "my_executor", + "agent_run_response": { + "type": "agent_run_response", + "messages": [ + {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Response"}]} + ], + }, + "full_conversation": [], + } + + result = reconstruct_agent_executor_response(data) + + assert isinstance(result, AgentExecutorResponse) + assert result.executor_id == "my_executor" + assert isinstance(result.agent_run_response, AgentRunResponse) + + def test_reconstruct_with_full_conversation(self) -> None: + """Test reconstructing response with full_conversation.""" + data = { + "executor_id": "exec", + "agent_run_response": {"type": "agent_run_response", "messages": []}, + "full_conversation": [ + {"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Q"}]}, + {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "A"}]}, + ], + } + + result = reconstruct_agent_executor_response(data) + + assert result.full_conversation is not None + assert len(result.full_conversation) == 2 + + +class TestReconstructMessageForHandler: + """Test suite for reconstruct_message_for_handler function.""" + + def test_reconstruct_non_dict_returns_original(self) -> None: + """Test that non-dict messages are returned as-is.""" + assert reconstruct_message_for_handler("string", []) == "string" + assert reconstruct_message_for_handler(42, []) == 42 + + def test_reconstruct_agent_executor_response(self) -> None: + """Test reconstructing AgentExecutorResponse.""" + data = { + "executor_id": "exec", + "agent_run_response": {"type": "agent_run_response", "messages": []}, + } + + result = reconstruct_message_for_handler(data, [AgentExecutorResponse]) + + assert isinstance(result, AgentExecutorResponse) + + def test_reconstruct_agent_executor_request(self) -> None: + """Test reconstructing AgentExecutorRequest.""" + data = { + "messages": [{"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Hi"}]}], + "should_respond": True, + } + + result = reconstruct_message_for_handler(data, [AgentExecutorRequest]) + + assert isinstance(result, AgentExecutorRequest) + + def test_reconstruct_with_type_metadata(self) -> None: + """Test reconstructing using __type__ metadata.""" + + @dataclass + class CustomMsg: + content: str + + # Serialize includes type metadata + serialized = serialize_message(CustomMsg(content="test")) + + result = reconstruct_message_for_handler(serialized, [CustomMsg]) + + assert isinstance(result, CustomMsg) + assert result.content == "test" + + def test_reconstruct_matches_dataclass_fields(self) -> None: + """Test reconstruction by matching dataclass field names.""" + + @dataclass + class MyData: + field_a: str + field_b: int + + data = {"field_a": "hello", "field_b": 42} + + result = reconstruct_message_for_handler(data, [MyData]) + + assert isinstance(result, MyData) + assert result.field_a == "hello" + assert result.field_b == 42 + + def test_reconstruct_returns_original_if_no_match(self) -> None: + """Test that original dict is returned if no type matches.""" + + @dataclass + class UnrelatedType: + completely_different_field: str + + data = {"some_key": "some_value"} + + result = reconstruct_message_for_handler(data, [UnrelatedType]) + + assert result == data diff --git a/python/packages/azurefunctions/tests/test_workflow.py b/python/packages/azurefunctions/tests/test_workflow.py new file mode 100644 index 0000000000..f401fc96c2 --- /dev/null +++ b/python/packages/azurefunctions/tests/test_workflow.py @@ -0,0 +1,379 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for workflow orchestration functions.""" + +import json +from dataclasses import dataclass +from typing import Any + +from agent_framework import ( + AgentExecutorRequest, + AgentExecutorResponse, + AgentRunResponse, + ChatMessage, +) +from agent_framework._workflows._edge import ( + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, + SwitchCaseEdgeGroupCase, + SwitchCaseEdgeGroupDefault, +) + +from agent_framework_azurefunctions._workflow import ( + _extract_message_content, + _extract_message_content_from_dict, + build_agent_executor_response, + route_message_through_edge_groups, +) + + +class TestRouteMessageThroughEdgeGroups: + """Test suite for route_message_through_edge_groups function.""" + + def test_single_edge_group_routes_when_condition_matches(self) -> None: + """Test SingleEdgeGroup routes when condition is satisfied.""" + group = SingleEdgeGroup(source_id="src", target_id="tgt", condition=lambda m: True) + + targets = route_message_through_edge_groups([group], "src", "any message") + + assert targets == ["tgt"] + + def test_single_edge_group_does_not_route_when_condition_fails(self) -> None: + """Test SingleEdgeGroup does not route when condition fails.""" + group = SingleEdgeGroup(source_id="src", target_id="tgt", condition=lambda m: False) + + targets = route_message_through_edge_groups([group], "src", "any message") + + assert targets == [] + + def test_single_edge_group_ignores_different_source(self) -> None: + """Test SingleEdgeGroup ignores messages from different sources.""" + group = SingleEdgeGroup(source_id="src", target_id="tgt", condition=lambda m: True) + + targets = route_message_through_edge_groups([group], "other_src", "any message") + + assert targets == [] + + def test_switch_case_with_selection_func(self) -> None: + """Test SwitchCaseEdgeGroup uses selection_func.""" + + def select_first_target(msg: Any, targets: list[str]) -> list[str]: + return [targets[0]] + + group = SwitchCaseEdgeGroup( + source_id="src", + cases=[ + SwitchCaseEdgeGroupCase(condition=lambda m: True, target_id="target_a"), + SwitchCaseEdgeGroupDefault(target_id="target_b"), + ], + ) + # Manually set the selection function + group._selection_func = select_first_target + + targets = route_message_through_edge_groups([group], "src", "test") + + assert targets == ["target_a"] + + def test_switch_case_without_selection_func_broadcasts(self) -> None: + """Test SwitchCaseEdgeGroup without selection_func broadcasts to all.""" + group = SwitchCaseEdgeGroup( + source_id="src", + cases=[ + SwitchCaseEdgeGroupCase(condition=lambda m: True, target_id="target_a"), + SwitchCaseEdgeGroupDefault(target_id="target_b"), + ], + ) + group._selection_func = None + + targets = route_message_through_edge_groups([group], "src", "test") + + assert set(targets) == {"target_a", "target_b"} + + def test_fan_out_with_selection_func(self) -> None: + """Test FanOutEdgeGroup uses selection_func.""" + + def select_all(msg: Any, targets: list[str]) -> list[str]: + return targets + + group = FanOutEdgeGroup( + source_id="src", + target_ids=["fan_a", "fan_b", "fan_c"], + selection_func=select_all, + ) + + targets = route_message_through_edge_groups([group], "src", "broadcast") + + assert set(targets) == {"fan_a", "fan_b", "fan_c"} + + def test_fan_in_is_not_routed_directly(self) -> None: + """Test FanInEdgeGroup is handled separately (not routed here).""" + group = FanInEdgeGroup( + source_ids=["src_a", "src_b"], + target_id="aggregator", + ) + + # Fan-in should not add targets through this function + targets = route_message_through_edge_groups([group], "src_a", "message") + + assert targets == [] + + def test_multiple_edge_groups_aggregated(self) -> None: + """Test that targets from multiple edge groups are aggregated.""" + group1 = SingleEdgeGroup(source_id="src", target_id="t1", condition=lambda m: True) + group2 = SingleEdgeGroup(source_id="src", target_id="t2", condition=lambda m: True) + + targets = route_message_through_edge_groups([group1, group2], "src", "msg") + + assert set(targets) == {"t1", "t2"} + + +class TestBuildAgentExecutorResponse: + """Test suite for build_agent_executor_response function.""" + + def test_builds_response_with_text(self) -> None: + """Test building response with plain text.""" + response = build_agent_executor_response( + executor_id="my_executor", + response_text="Hello, world!", + structured_response=None, + previous_message="User input", + ) + + assert response.executor_id == "my_executor" + assert response.agent_run_response.text == "Hello, world!" + assert len(response.full_conversation) == 2 # User + Assistant + + def test_builds_response_with_structured_response(self) -> None: + """Test building response with structured JSON response.""" + structured = {"answer": 42, "reason": "because"} + + response = build_agent_executor_response( + executor_id="calc", + response_text="Original text", + structured_response=structured, + previous_message="Calculate", + ) + + # Structured response overrides text + assert response.agent_run_response.text == json.dumps(structured) + + def test_conversation_includes_previous_string_message(self) -> None: + """Test that string previous_message is included in conversation.""" + response = build_agent_executor_response( + executor_id="exec", + response_text="Response", + structured_response=None, + previous_message="User said this", + ) + + assert len(response.full_conversation) == 2 + assert response.full_conversation[0].role.value == "user" + assert response.full_conversation[0].text == "User said this" + assert response.full_conversation[1].role.value == "assistant" + + def test_conversation_extends_previous_agent_executor_response(self) -> None: + """Test that previous AgentExecutorResponse's conversation is extended.""" + # Create a previous response with conversation history + previous = AgentExecutorResponse( + executor_id="prev", + agent_run_response=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Previous")]), + full_conversation=[ + ChatMessage(role="user", text="First"), + ChatMessage(role="assistant", text="Previous"), + ], + ) + + response = build_agent_executor_response( + executor_id="current", + response_text="Current response", + structured_response=None, + previous_message=previous, + ) + + # Should have 3 messages: First + Previous + Current + assert len(response.full_conversation) == 3 + assert response.full_conversation[0].text == "First" + assert response.full_conversation[1].text == "Previous" + assert response.full_conversation[2].text == "Current response" + + +class TestExtractMessageContent: + """Test suite for _extract_message_content function.""" + + def test_extract_from_string(self) -> None: + """Test extracting content from plain string.""" + result = _extract_message_content("Hello, world!") + + assert result == "Hello, world!" + + def test_extract_from_agent_executor_response_with_text(self) -> None: + """Test extracting from AgentExecutorResponse with text.""" + response = AgentExecutorResponse( + executor_id="exec", + agent_run_response=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response text")]), + ) + + result = _extract_message_content(response) + + assert result == "Response text" + + def test_extract_from_agent_executor_response_with_messages(self) -> None: + """Test extracting from AgentExecutorResponse with messages.""" + response = AgentExecutorResponse( + executor_id="exec", + agent_run_response=AgentRunResponse( + messages=[ + ChatMessage(role="user", text="First"), + ChatMessage(role="assistant", text="Last message"), + ] + ), + ) + + result = _extract_message_content(response) + + # AgentRunResponse.text concatenates all message texts + assert result == "FirstLast message" + + def test_extract_from_agent_executor_request(self) -> None: + """Test extracting from AgentExecutorRequest.""" + request = AgentExecutorRequest( + messages=[ + ChatMessage(role="user", text="First"), + ChatMessage(role="user", text="Last request"), + ] + ) + + result = _extract_message_content(request) + + assert result == "Last request" + + def test_extract_from_dict_agent_executor_request(self) -> None: + """Test extracting from serialized AgentExecutorRequest dict.""" + msg_dict = { + "messages": [ + { + "type": "chat_message", + "contents": [{"type": "text", "text": "Hello from dict"}], + } + ] + } + + result = _extract_message_content(msg_dict) + + assert result == "Hello from dict" + + def test_extract_returns_empty_for_unknown_type(self) -> None: + """Test that unknown types return empty string.""" + result = _extract_message_content(12345) + + assert result == "" + + +class TestExtractMessageContentFromDict: + """Test suite for _extract_message_content_from_dict function.""" + + def test_extract_from_messages_with_contents(self) -> None: + """Test extracting from messages with contents structure.""" + msg_dict = {"messages": [{"contents": [{"type": "text", "text": "Content text"}]}]} + + result = _extract_message_content_from_dict(msg_dict) + + assert result == "Content text" + + def test_extract_from_messages_with_direct_text(self) -> None: + """Test extracting from messages with direct text field.""" + msg_dict = {"messages": [{"text": "Direct text"}]} + + result = _extract_message_content_from_dict(msg_dict) + + assert result == "Direct text" + + def test_extract_from_agent_run_response(self) -> None: + """Test extracting from agent_run_response dict.""" + msg_dict = {"agent_run_response": {"text": "Response text"}} + + result = _extract_message_content_from_dict(msg_dict) + + assert result == "Response text" + + def test_extract_from_agent_run_response_with_messages(self) -> None: + """Test extracting from agent_run_response with messages.""" + msg_dict = {"agent_run_response": {"messages": [{"contents": [{"type": "text", "text": "Nested content"}]}]}} + + result = _extract_message_content_from_dict(msg_dict) + + assert result == "Nested content" + + def test_extract_returns_empty_for_empty_dict(self) -> None: + """Test that empty dict returns empty string.""" + result = _extract_message_content_from_dict({}) + + assert result == "" + + def test_extract_returns_empty_for_empty_messages(self) -> None: + """Test that empty messages list returns empty string.""" + result = _extract_message_content_from_dict({"messages": []}) + + assert result == "" + + +class TestEdgeGroupIntegration: + """Integration tests for edge group routing with realistic scenarios.""" + + def test_conditional_routing_by_message_type(self) -> None: + """Test routing based on message content/type.""" + + @dataclass + class SpamResult: + is_spam: bool + reason: str + + def is_spam_condition(msg: Any) -> bool: + if isinstance(msg, SpamResult): + return msg.is_spam + return False + + def is_not_spam_condition(msg: Any) -> bool: + if isinstance(msg, SpamResult): + return not msg.is_spam + return False + + spam_group = SingleEdgeGroup( + source_id="detector", + target_id="spam_handler", + condition=is_spam_condition, + ) + legit_group = SingleEdgeGroup( + source_id="detector", + target_id="email_handler", + condition=is_not_spam_condition, + ) + + # Test spam message + spam_msg = SpamResult(is_spam=True, reason="Suspicious content") + targets = route_message_through_edge_groups([spam_group, legit_group], "detector", spam_msg) + assert targets == ["spam_handler"] + + # Test legitimate message + legit_msg = SpamResult(is_spam=False, reason="Clean") + targets = route_message_through_edge_groups([spam_group, legit_group], "detector", legit_msg) + assert targets == ["email_handler"] + + def test_fan_out_to_multiple_workers(self) -> None: + """Test fan-out to multiple parallel workers.""" + + def select_all_workers(msg: Any, targets: list[str]) -> list[str]: + return targets + + group = FanOutEdgeGroup( + source_id="coordinator", + target_ids=["worker_1", "worker_2", "worker_3"], + selection_func=select_all_workers, + ) + + targets = route_message_through_edge_groups([group], "coordinator", {"task": "process"}) + + assert len(targets) == 3 + assert set(targets) == {"worker_1", "worker_2", "worker_3"} diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 3b10579055..d78ee1cfbc 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -107,6 +107,11 @@ def __init__( # This tracks the full conversation after each run self._full_conversation: list[Message] = [] + @property + def agent(self) -> SupportsAgentRun: + """Get the underlying agent wrapped by this executor.""" + return self._agent + @property def description(self) -> str | None: """Get the description of the underlying agent.""" diff --git a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http index 42f93b8543..28231a08a8 100644 --- a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http +++ b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http @@ -20,7 +20,7 @@ Content-Type: application/json ### Replace INSTANCE_ID_GOES_HERE below with the value returned from the POST call -@instanceId= +@instanceId=ccf3950407b5496893df93d1357a5afa ### Check the status of the orchestration GET http://localhost:7071/api/hitl/status/{{instanceId}} diff --git a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py index 644ed9ed23..69651e45cf 100644 --- a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -64,7 +64,7 @@ def _create_writer_agent() -> Any: # 3. Activities encapsulate external work for review notifications and publishing. @app.activity_trigger(input_name="content") -def notify_user_for_approval(content: dict[str, str]) -> None: +def notify_user_for_approval(content: dict) -> None: model = GeneratedContent.model_validate(content) logger.info("NOTIFICATION: Please review the following content for approval:") logger.info("Title: %s", model.title or "(untitled)") @@ -73,7 +73,7 @@ def notify_user_for_approval(content: dict[str, str]) -> None: @app.activity_trigger(input_name="content") -def publish_content(content: dict[str, str]) -> None: +def publish_content(content: dict) -> None: model = GeneratedContent.model_validate(content) logger.info("PUBLISHING: Content has been published successfully:") logger.info("Title: %s", model.title or "(untitled)") diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/.gitignore b/python/samples/getting_started/azure_functions/09_workflow_shared_state/.gitignore new file mode 100644 index 0000000000..560ff95106 --- /dev/null +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/.gitignore @@ -0,0 +1,18 @@ +# Local settings +local.settings.json +.env + +# Python +__pycache__/ +*.py[cod] +.venv/ +venv/ + +# Azure Functions +bin/ +obj/ +.python_packages/ + +# IDE +.vscode/ +.idea/ diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md b/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md new file mode 100644 index 0000000000..bd6e33c916 --- /dev/null +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md @@ -0,0 +1,99 @@ +# Workflow with SharedState Sample + +This sample demonstrates running **Agent Framework workflows with SharedState** in Azure Durable Functions. + +## Overview + +This sample shows how to use `AgentFunctionApp` to execute a `WorkflowBuilder` workflow that uses SharedState to pass data between executors. SharedState is a local dictionary maintained by the orchestration that allows executors to share data across workflow steps. + +## What This Sample Demonstrates + +1. **Workflow Execution** - Running `WorkflowBuilder` workflows in Azure Durable Functions +2. **SharedState APIs** - Using `ctx.set_shared_state()` and `ctx.get_shared_state()` to share data +3. **Conditional Routing** - Routing messages based on spam detection results +4. **Agent + Executor Composition** - Combining AI agents with non-AI function executors + +## Workflow Architecture + +``` +store_email → spam_detector (agent) → to_detection_result → [branch]: + ├── If spam: handle_spam → yield "Email marked as spam: {reason}" + └── If not spam: submit_to_email_assistant → email_assistant (agent) → finalize_and_send → yield "Email sent: {response}" +``` + +### SharedState Usage by Executor + +| Executor | SharedState Operations | +|----------|----------------------| +| `store_email` | `set_shared_state("email:{id}", email)`, `set_shared_state("current_email_id", id)` | +| `to_detection_result` | `get_shared_state("current_email_id")` | +| `submit_to_email_assistant` | `get_shared_state("email:{id}")` | + +SharedState allows executors to pass large payloads (like email content) by reference rather than through message routing. + +## Prerequisites + +1. **Azure OpenAI** - Endpoint and deployment configured +2. **Azurite** - For local storage emulation + +## Setup + +1. Copy `local.settings.json.sample` to `local.settings.json` and configure: + ```json + { + "Values": { + "AZURE_OPENAI_ENDPOINT": "https://your-resource.openai.azure.com/", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "gpt-4o" + } + } + ``` + +2. Install dependencies: + ```bash + pip install -r requirements.txt + ``` + +3. Start Azurite: + ```bash + azurite --silent + ``` + +4. Run the function app: + ```bash + func start + ``` + +## Testing + +Use the `demo.http` file with REST Client extension or curl: + +### Test Spam Email +```bash +curl -X POST http://localhost:7071/api/workflow/run \ + -H "Content-Type: application/json" \ + -d '"URGENT! You have won $1,000,000! Click here to claim!"' +``` + +### Test Legitimate Email +```bash +curl -X POST http://localhost:7071/api/workflow/run \ + -H "Content-Type: application/json" \ + -d '"Hi team, reminder about our meeting tomorrow at 10 AM."' +``` + +## Expected Output + +**Spam email:** +``` +Email marked as spam: This email exhibits spam characteristics including urgent language, unrealistic claims of monetary winnings, and requests to click suspicious links. +``` + +**Legitimate email:** +``` +Email sent: Hi, Thank you for the reminder about the sprint planning meeting tomorrow at 10 AM. I will review the agenda and come prepared with my updates. See you then! +``` + +## Related Samples + +- `10_workflow_no_shared_state` - Workflow execution without SharedState usage +- `06_multi_agent_orchestration_conditionals` - Manual Durable Functions orchestration with agents diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/demo.http b/python/samples/getting_started/azure_functions/09_workflow_shared_state/demo.http new file mode 100644 index 0000000000..48b6a73f72 --- /dev/null +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/demo.http @@ -0,0 +1,31 @@ +@endpoint = http://localhost:7071 + +### Start the workflow with a spam email +POST {{endpoint}}/api/workflow/run +Content-Type: application/json + +"URGENT! You have won $1,000,000! Click here to claim your prize now before it expires!" + +### Start the workflow with a legitimate email +POST {{endpoint}}/api/workflow/run +Content-Type: application/json + +"Hi team, just a reminder about the sprint planning meeting tomorrow at 10 AM. Please review the agenda items in Jira before the call." + +### Start the workflow with another legitimate email +POST {{endpoint}}/api/workflow/run +Content-Type: application/json + +"Hello, I wanted to follow up on our conversation from last week regarding the project timeline. Could we schedule a brief call this afternoon to discuss the next steps?" + +### Start the workflow with a phishing attempt +POST {{endpoint}}/api/workflow/run +Content-Type: application/json + +"Dear Customer, Your account has been compromised! Click this link immediately to secure your account: http://totallylegit.suspicious.com/secure" + +### Check workflow status (replace {instanceId} with actual instance ID from response) +GET {{endpoint}}/runtime/webhooks/durabletask/instances/{instanceId} + +### Purge all orchestration instances (use for cleanup) +POST {{endpoint}}/runtime/webhooks/durabletask/instances/purge?createdTimeFrom=2020-01-01T00:00:00Z&createdTimeTo=2030-12-31T23:59:59Z diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py new file mode 100644 index 0000000000..bf38dfc72b --- /dev/null +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py @@ -0,0 +1,294 @@ +# Copyright (c) Microsoft. All rights reserved. +""" +Sample: Shared state with agents and conditional routing. + +Store an email once by id, classify it with a detector agent, then either draft a reply with an assistant +agent or finish with a spam notice. Stream events as the workflow runs. + +Purpose: +Show how to: +- Use shared state to decouple large payloads from messages and pass around lightweight references. +- Enforce structured agent outputs with Pydantic models via response_format for robust parsing. +- Route using conditional edges based on a typed intermediate DetectionResult. +- Compose agent backed executors with function style executors and yield the final output when the workflow completes. + +Prerequisites: +- Azure OpenAI configured for AzureOpenAIChatClient with required environment variables. +- Authentication via azure-identity. Use DefaultAzureCredential and run az login before executing the sample. +- Familiarity with WorkflowBuilder, executors, conditional edges, and streaming runs. +""" + +import logging +import os +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +from agent_framework import ( + AgentExecutorRequest, + AgentExecutorResponse, + ChatMessage, + Role, + Workflow, + WorkflowBuilder, + WorkflowContext, + executor, +) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential +from pydantic import BaseModel +from typing_extensions import Never + +from agent_framework_azurefunctions import AgentFunctionApp + +logger = logging.getLogger(__name__) + +# Environment variable names +AZURE_OPENAI_ENDPOINT_ENV = "AZURE_OPENAI_ENDPOINT" +AZURE_OPENAI_DEPLOYMENT_ENV = "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME" +AZURE_OPENAI_API_KEY_ENV = "AZURE_OPENAI_API_KEY" + +EMAIL_STATE_PREFIX = "email:" +CURRENT_EMAIL_ID_KEY = "current_email_id" + + +class DetectionResultAgent(BaseModel): + """Structured output returned by the spam detection agent.""" + + is_spam: bool + reason: str + + +class EmailResponse(BaseModel): + """Structured output returned by the email assistant agent.""" + + response: str + + +@dataclass +class DetectionResult: + """Internal detection result enriched with the shared state email_id for later lookups.""" + + is_spam: bool + reason: str + email_id: str + + +@dataclass +class Email: + """In memory record stored in shared state to avoid re-sending large bodies on edges.""" + + email_id: str + email_content: str + + +def get_condition(expected_result: bool): + """Create a condition predicate for DetectionResult.is_spam. + + Contract: + - If the message is not a DetectionResult, allow it to pass to avoid accidental dead ends. + - Otherwise, return True only when is_spam matches expected_result. + """ + + def condition(message: Any) -> bool: + if not isinstance(message, DetectionResult): + return True + return message.is_spam == expected_result + + return condition + + +@executor(id="store_email") +async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: + """Persist the raw email content in shared state and trigger spam detection. + + Responsibilities: + - Generate a unique email_id (UUID) for downstream retrieval. + - Store the Email object under a namespaced key and set the current id pointer. + - Emit an AgentExecutorRequest asking the detector to respond. + """ + new_email = Email(email_id=str(uuid4()), email_content=email_text) + await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) + await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) + + await ctx.send_message( + AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=new_email.email_content)], should_respond=True) + ) + + +@executor(id="to_detection_result") +async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowContext[DetectionResult]) -> None: + """Parse spam detection JSON into a structured model and enrich with email_id. + + Steps: + 1) Validate the agent's JSON output into DetectionResultAgent. + 2) Retrieve the current email_id from shared state. + 3) Send a typed DetectionResult for conditional routing. + """ + parsed = DetectionResultAgent.model_validate_json(response.agent_run_response.text) + email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) + await ctx.send_message(DetectionResult(is_spam=parsed.is_spam, reason=parsed.reason, email_id=email_id)) + + +@executor(id="submit_to_email_assistant") +async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowContext[AgentExecutorRequest]) -> None: + """Forward non spam email content to the drafting agent. + + Guard: + - This path should only receive non spam. Raise if misrouted. + """ + if detection.is_spam: + raise RuntimeError("This executor should only handle non-spam messages.") + + # Load the original content by id from shared state and forward it to the assistant. + email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") + await ctx.send_message( + AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=email.email_content)], should_respond=True) + ) + + +@executor(id="finalize_and_send") +async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: + """Validate the drafted reply and yield the final output.""" + parsed = EmailResponse.model_validate_json(response.agent_run_response.text) + await ctx.yield_output(f"Email sent: {parsed.response}") + + +@executor(id="handle_spam") +async def handle_spam(detection: DetectionResult, ctx: WorkflowContext[Never, str]) -> None: + """Yield output describing why the email was marked as spam.""" + if detection.is_spam: + await ctx.yield_output(f"Email marked as spam: {detection.reason}") + else: + raise RuntimeError("This executor should only handle spam messages.") + + +# ============================================================================ +# Workflow Creation +# ============================================================================ + + +def _build_client_kwargs() -> dict[str, Any]: + """Build Azure OpenAI client configuration from environment variables.""" + endpoint = os.getenv(AZURE_OPENAI_ENDPOINT_ENV) + if not endpoint: + raise RuntimeError(f"{AZURE_OPENAI_ENDPOINT_ENV} environment variable is required.") + + deployment = os.getenv(AZURE_OPENAI_DEPLOYMENT_ENV) + if not deployment: + raise RuntimeError(f"{AZURE_OPENAI_DEPLOYMENT_ENV} environment variable is required.") + + client_kwargs: dict[str, Any] = { + "endpoint": endpoint, + "deployment_name": deployment, + } + + api_key = os.getenv(AZURE_OPENAI_API_KEY_ENV) + if api_key: + client_kwargs["api_key"] = api_key + else: + client_kwargs["credential"] = AzureCliCredential() + + return client_kwargs + + +def _create_workflow() -> Workflow: + """Create the email classification workflow with conditional routing.""" + client_kwargs = _build_client_kwargs() + chat_client = AzureOpenAIChatClient(**client_kwargs) + + spam_detection_agent = chat_client.create_agent( + instructions=( + "You are a spam detection assistant that identifies spam emails. " + "Always return JSON with fields is_spam (bool) and reason (string)." + ), + response_format=DetectionResultAgent, + name="spam_detection_agent", + ) + + email_assistant_agent = chat_client.create_agent( + instructions=( + "You are an email assistant that helps users draft responses to emails with professionalism. " + "Return JSON with a single field 'response' containing the drafted reply." + ), + response_format=EmailResponse, + name="email_assistant_agent", + ) + + # Build the workflow graph with conditional edges. + # Flow: + # store_email -> spam_detection_agent -> to_detection_result -> branch: + # False -> submit_to_email_assistant -> email_assistant_agent -> finalize_and_send + # True -> handle_spam + workflow = ( + WorkflowBuilder() + .set_start_executor(store_email) + .add_edge(store_email, spam_detection_agent) + .add_edge(spam_detection_agent, to_detection_result) + .add_edge(to_detection_result, submit_to_email_assistant, condition=get_condition(False)) + .add_edge(to_detection_result, handle_spam, condition=get_condition(True)) + .add_edge(submit_to_email_assistant, email_assistant_agent) + .add_edge(email_assistant_agent, finalize_and_send) + .build() + ) + + return workflow + + +# ============================================================================ +# Application Entry Point +# ============================================================================ + + +def launch(durable: bool = True) -> AgentFunctionApp | None: + """Launch the function app or DevUI. + + Args: + durable: If True, returns AgentFunctionApp for Azure Functions. + If False, launches DevUI for local MAF development. + """ + if durable: + # Azure Functions mode with Durable Functions + # SharedState is enabled by default, which this sample requires for storing emails + workflow = _create_workflow() + app = AgentFunctionApp(workflow=workflow, enable_health_check=True) + return app + else: + # Pure MAF mode with DevUI for local development + from pathlib import Path + + from agent_framework.devui import serve + from dotenv import load_dotenv + + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) + + logger.info("Starting Workflow Shared State Sample in MAF mode") + logger.info("Available at: http://localhost:8096") + logger.info("\nThis workflow demonstrates:") + logger.info("- Shared state to decouple large payloads from messages") + logger.info("- Structured agent outputs with Pydantic models") + logger.info("- Conditional routing based on detection results") + logger.info("\nFlow: store_email -> spam_detection -> branch (spam/not spam)") + + workflow = _create_workflow() + serve(entities=[workflow], port=8096, auto_open=True) + + return None + + +# Default: Azure Functions mode +# Run with `python function_app.py --maf` for pure MAF mode with DevUI +app = launch(durable=True) + + +if __name__ == "__main__": + import sys + + if "--maf" in sys.argv: + # Run in pure MAF mode with DevUI + launch(durable=False) + else: + print("Usage: python function_app.py --maf") + print(" --maf Run in pure MAF mode with DevUI (http://localhost:8096)") + print("\nFor Azure Functions mode, use: func start") diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json b/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json new file mode 100644 index 0000000000..292562af8e --- /dev/null +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json @@ -0,0 +1,16 @@ +{ + "version": "2.0", + "extensionBundle": { + "id": "Microsoft.Azure.Functions.ExtensionBundle", + "version": "[4.*, 5.0.0)" + }, + "extensions": { + "durableTask": { + "hubName": "%TASKHUB_NAME%", + "storageProvider": { + "type": "AzureManaged", + "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" + } + } + } +} diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/local.settings.json.sample b/python/samples/getting_started/azure_functions/09_workflow_shared_state/local.settings.json.sample new file mode 100644 index 0000000000..69c08a3386 --- /dev/null +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/local.settings.json.sample @@ -0,0 +1,11 @@ +{ + "IsEncrypted": false, + "Values": { + "AzureWebJobsStorage": "UseDevelopmentStorage=true", + "DURABLE_TASK_SCHEDULER_CONNECTION_STRING": "Endpoint=http://localhost:8080;TaskHub=default;Authentication=None", + "TASKHUB_NAME": "default", + "FUNCTIONS_WORKER_RUNTIME": "python", + "AZURE_OPENAI_ENDPOINT": "", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "" + } +} diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/requirements.txt b/python/samples/getting_started/azure_functions/09_workflow_shared_state/requirements.txt new file mode 100644 index 0000000000..5739f93aa3 --- /dev/null +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/requirements.txt @@ -0,0 +1,3 @@ +agent-framework-azurefunctions +azure-identity +agents-maf \ No newline at end of file diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.env.sample b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.env.sample new file mode 100644 index 0000000000..cf8fe3d05c --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.env.sample @@ -0,0 +1,4 @@ +# Azure OpenAI Configuration +AZURE_OPENAI_ENDPOINT=https://.openai.azure.com/ +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= +AZURE_OPENAI_API_KEY= diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.gitignore b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.gitignore new file mode 100644 index 0000000000..1d5b48c35f --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.gitignore @@ -0,0 +1,2 @@ +.env +local.settings.json diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/README.md b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/README.md new file mode 100644 index 0000000000..f5f77f3c91 --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/README.md @@ -0,0 +1,159 @@ +# Workflow Execution Sample + +This sample demonstrates running **Agent Framework workflows** in Azure Durable Functions without using SharedState. + +## Overview + +This sample shows how to use `AgentFunctionApp` with a `WorkflowBuilder` workflow. The workflow is passed directly to `AgentFunctionApp`, which orchestrates execution using Durable Functions: + +```python +workflow = _create_workflow() # Build the workflow graph +app = AgentFunctionApp(workflow=workflow) +``` + +This approach provides durable, fault-tolerant workflow execution with minimal code. + +## What This Sample Demonstrates + +1. **Workflow Registration** - Pass a `Workflow` directly to `AgentFunctionApp` +2. **Durable Execution** - Workflow executes with Durable Functions durability and scalability +3. **Conditional Routing** - Route messages based on spam detection (is_spam → spam handler, not spam → email assistant) +4. **Agent + Executor Composition** - Combine AI agents with non-AI executor classes + +## Workflow Architecture + +``` +SpamDetectionAgent → [branch based on is_spam]: + ├── If spam: SpamHandlerExecutor → yield "Email marked as spam: {reason}" + └── If not spam: EmailAssistantAgent → EmailSenderExecutor → yield "Email sent: {response}" +``` + +### Components + +| Component | Type | Description | +|-----------|------|-------------| +| `SpamDetectionAgent` | AI Agent | Analyzes emails for spam indicators | +| `EmailAssistantAgent` | AI Agent | Drafts professional email responses | +| `SpamHandlerExecutor` | Executor | Handles spam emails (non-AI) | +| `EmailSenderExecutor` | Executor | Sends email responses (non-AI) | + +## Prerequisites + +1. **Azure OpenAI** - Endpoint and deployment configured +2. **Azurite** - For local storage emulation + +## Setup + +1. Copy configuration files: + ```bash + cp local.settings.json.sample local.settings.json + ``` + +2. Configure `local.settings.json`: + +3. Install dependencies: + ```bash + pip install -r requirements.txt + ``` + +4. Start Azurite: + ```bash + azurite --silent + ``` + +5. Run the function app: + ```bash + func start + ``` + +## Testing + +Use the `demo.http` file with REST Client extension or curl: + +### Test Spam Email +```bash +curl -X POST http://localhost:7071/api/workflow/run \ + -H "Content-Type: application/json" \ + -d '{"email_id": "test-001", "email_content": "URGENT! You have won $1,000,000! Click here!"}' +``` + +### Test Legitimate Email +```bash +curl -X POST http://localhost:7071/api/workflow/run \ + -H "Content-Type: application/json" \ + -d '{"email_id": "test-002", "email_content": "Hi team, reminder about our meeting tomorrow at 10 AM."}' +``` + +### Check Status +```bash +curl http://localhost:7071/api/workflow/status/{instanceId} +``` + +## Expected Output + +**Spam email:** +``` +Email marked as spam: This email exhibits spam characteristics including urgent language, unrealistic claims of monetary winnings, and requests to click suspicious links. +``` + +**Legitimate email:** +``` +Email sent: Hi, Thank you for the reminder about the sprint planning meeting tomorrow at 10 AM. I will be there. +``` + +## Code Highlights + +### Creating the Workflow + +```python +workflow = ( + WorkflowBuilder() + .set_start_executor(spam_agent) + .add_switch_case_edge_group( + spam_agent, + [ + Case(condition=is_spam_detected, target=spam_handler), + Default(target=email_agent), + ], + ) + .add_edge(email_agent, email_sender) + .build() +) +``` + +### Registering with AgentFunctionApp + +```python +app = AgentFunctionApp(workflow=workflow, enable_health_check=True) +``` + +### Executor Classes + +```python +class SpamHandlerExecutor(Executor): + @handler + async def handle_spam_result( + self, + agent_response: AgentExecutorResponse, + ctx: WorkflowContext[Never, str], + ) -> None: + spam_result = SpamDetectionResult.model_validate_json(agent_response.agent_run_response.text) + await ctx.yield_output(f"Email marked as spam: {spam_result.reason}") +``` + +## Standalone Mode (DevUI) + +This sample also supports running standalone for local development: + +```python +# Change launch(durable=True) to launch(durable=False) in function_app.py +# Then run: +python function_app.py +``` + +This starts the DevUI at `http://localhost:8094` for interactive testing. + +## Related Samples + +- `09_workflow_shared_state` - Workflow with SharedState for passing data between executors +- `06_multi_agent_orchestration_conditionals` - Manual Durable Functions orchestration with agents diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/demo.http b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/demo.http new file mode 100644 index 0000000000..2c81ddc9bc --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/demo.http @@ -0,0 +1,32 @@ +### Start Workflow Orchestration - Spam Email +POST http://localhost:7071/api/workflow/run +Content-Type: application/json + +{ + "email_id": "email-001", + "email_content": "URGENT! You've won $1,000,000! Click here immediately to claim your prize! Limited time offer - act now!" +} + +### + +### Start Workflow Orchestration - Legitimate Email +POST http://localhost:7071/api/workflow/run +Content-Type: application/json + +{ + "email_id": "email-002", + "email_content": "Hi team, just a reminder about our sprint planning meeting tomorrow at 10 AM. Please review the agenda in Jira." +} + +### + +### Get Workflow Status +# Replace {instanceId} with the actual instance ID from the start response +GET http://localhost:7071/api/workflow/status/{instanceId} + +### + +### Health Check +GET http://localhost:7071/api/health + +### diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py new file mode 100644 index 0000000000..b55fef58b8 --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py @@ -0,0 +1,244 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Workflow Execution within Durable Functions Orchestrator. + +This sample demonstrates running agent framework WorkflowBuilder workflows inside +a Durable Functions orchestrator by manually traversing the workflow graph and +delegating execution to Durable Entities (for agents) and Activities (for other logic). + +Key architectural points: +- AgentFunctionApp registers agents as DurableAIAgents. +- WorkflowBuilder uses `DurableAgentDefinition` (a placeholder) to define the graph. +- The orchestrator (`workflow_orchestration`) iterates through the workflow graph. +- When an agent node is encountered, it calls the corresponding `DurableAIAgent` entity. +- When a standard executor node is encountered, it calls an Activity (`ExecuteExecutor`). + +This approach allows using the rich structure of `WorkflowBuilder` while leveraging +the statefulness and durability of `DurableAIAgent`s. +""" + +import logging +import os +from typing import Any, Dict + +from pathlib import Path +from agent_framework import ( + AgentExecutorResponse, + Case, + Default, + Executor, + Workflow, + WorkflowBuilder, + WorkflowContext, + handler, +) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential +from pydantic import BaseModel +from agent_framework_azurefunctions import AgentFunctionApp +from typing_extensions import Never + +logger = logging.getLogger(__name__) + +AZURE_OPENAI_ENDPOINT_ENV = "AZURE_OPENAI_ENDPOINT" +AZURE_OPENAI_DEPLOYMENT_ENV = "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME" +AZURE_OPENAI_API_KEY_ENV = "AZURE_OPENAI_API_KEY" +SPAM_AGENT_NAME = "SpamDetectionAgent" +EMAIL_AGENT_NAME = "EmailAssistantAgent" + +SPAM_DETECTION_INSTRUCTIONS = ( + "You are a spam detection assistant that identifies spam emails.\n\n" + "Analyze the email content for spam indicators including:\n" + "1. Suspicious language (urgent, limited time, act now, free money, etc.)\n" + "2. Suspicious links or requests for personal information\n" + "3. Poor grammar or spelling\n" + "4. Requests for money or financial information\n" + "5. Impersonation attempts\n\n" + "Return a JSON response with:\n" + "- is_spam: boolean indicating if it's spam\n" + "- confidence: float between 0.0 and 1.0\n" + "- reason: detailed explanation of your classification" +) + +EMAIL_ASSISTANT_INSTRUCTIONS = ( + "You are an email assistant that helps users draft responses to legitimate emails.\n\n" + "When you receive an email that has been verified as legitimate:\n" + "1. Draft a professional and appropriate response\n" + "2. Match the tone and formality of the original email\n" + "3. Be helpful and courteous\n" + "4. Keep the response concise but complete\n\n" + "Return a JSON response with:\n" + "- response: the drafted email response" +) + + +class SpamDetectionResult(BaseModel): + is_spam: bool + confidence: float + reason: str + + +class EmailResponse(BaseModel): + response: str + + +class EmailPayload(BaseModel): + email_id: str + email_content: str + + +def _build_client_kwargs() -> dict[str, Any]: + endpoint = os.getenv(AZURE_OPENAI_ENDPOINT_ENV) + if not endpoint: + raise RuntimeError(f"{AZURE_OPENAI_ENDPOINT_ENV} environment variable is required.") + + deployment = os.getenv(AZURE_OPENAI_DEPLOYMENT_ENV) + if not deployment: + raise RuntimeError(f"{AZURE_OPENAI_DEPLOYMENT_ENV} environment variable is required.") + + client_kwargs: dict[str, Any] = { + "endpoint": endpoint, + "deployment_name": deployment, + } + + api_key = os.getenv(AZURE_OPENAI_API_KEY_ENV) + if api_key: + client_kwargs["api_key"] = api_key + else: + client_kwargs["credential"] = AzureCliCredential() + + return client_kwargs + + +# Executors for non-AI activities (defined at module level) +class SpamHandlerExecutor(Executor): + """Executor that handles spam emails (non-AI activity).""" + + @handler + async def handle_spam_result( + self, + agent_response: AgentExecutorResponse, + ctx: WorkflowContext[Never, str], + ) -> None: + """Mark email as spam and log the reason.""" + text = agent_response.agent_run_response.text + spam_result = SpamDetectionResult.model_validate_json(text) + message = f"Email marked as spam: {spam_result.reason}" + await ctx.yield_output(message) + + +class EmailSenderExecutor(Executor): + """Executor that sends email responses (non-AI activity).""" + + @handler + async def handle_email_response( + self, + agent_response: AgentExecutorResponse, + ctx: WorkflowContext[Never, str], + ) -> None: + """Send the drafted email response.""" + text = agent_response.agent_run_response.text + email_response = EmailResponse.model_validate_json(text) + message = f"Email sent: {email_response.response}" + await ctx.yield_output(message) + + +# Condition function for routing +def is_spam_detected(message: Any) -> bool: + """Check if spam was detected in the email.""" + if not isinstance(message, AgentExecutorResponse): + return False + try: + result = SpamDetectionResult.model_validate_json(message.agent_run_response.text) + return result.is_spam + except Exception: + return False + + +def _create_workflow() -> Workflow: + """Create the workflow definition.""" + client_kwargs = _build_client_kwargs() + chat_client = AzureOpenAIChatClient(**client_kwargs) + + spam_agent = chat_client.create_agent( + name=SPAM_AGENT_NAME, + instructions=SPAM_DETECTION_INSTRUCTIONS, + response_format=SpamDetectionResult, + ) + + email_agent = chat_client.create_agent( + name=EMAIL_AGENT_NAME, + instructions=EMAIL_ASSISTANT_INSTRUCTIONS, + response_format=EmailResponse, + ) + + # Executors + spam_handler = SpamHandlerExecutor(id="spam_handler") + email_sender = EmailSenderExecutor(id="email_sender") + + # Build workflow + workflow = ( + WorkflowBuilder() + .set_start_executor(spam_agent) + .add_switch_case_edge_group( + spam_agent, + [ + Case(condition=is_spam_detected, target=spam_handler), + Default(target=email_agent), + ], + ) + .add_edge(email_agent, email_sender) + .build() + ) + return workflow + + +def launch(durable: bool = True) -> AgentFunctionApp | None: + workflow: Workflow | None = None + + if durable: + # Initialize app + workflow = _create_workflow() + + + app = AgentFunctionApp(workflow=workflow) + + + return app + else: + # Launch the spam detection workflow in DevUI + from agent_framework.devui import serve + from dotenv import load_dotenv + + # Load environment variables from .env file + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) + + logger.info("Starting Multi-Agent Spam Detection Workflow") + logger.info("Available at: http://localhost:8094") + logger.info("\nThis workflow demonstrates:") + logger.info("- Conditional routing based on spam detection") + logger.info("- Mixing AI agents with non-AI executors (like activity functions)") + logger.info("- Path 1 (spam): SpamDetector Agent → SpamHandler Executor") + logger.info("- Path 2 (legitimate): SpamDetector Agent → EmailAssistant Agent → EmailSender Executor") + + workflow = _create_workflow() + serve(entities=[workflow], port=8094, auto_open=True) + + return None + + +# Default: Azure Functions mode +# Run with `python function_app.py --maf` for pure MAF mode with DevUI +app = launch(durable=True) + + +if __name__ == "__main__": + import sys + + if "--maf" in sys.argv: + # Run in pure MAF mode with DevUI + launch(durable=False) + else: + print("Usage: python function_app.py --maf") + print(" --maf Run in pure MAF mode with DevUI (http://localhost:8096)") + print("\nFor Azure Functions mode, use: func start") diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json new file mode 100644 index 0000000000..292562af8e --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json @@ -0,0 +1,16 @@ +{ + "version": "2.0", + "extensionBundle": { + "id": "Microsoft.Azure.Functions.ExtensionBundle", + "version": "[4.*, 5.0.0)" + }, + "extensions": { + "durableTask": { + "hubName": "%TASKHUB_NAME%", + "storageProvider": { + "type": "AzureManaged", + "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" + } + } + } +} diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/local.settings.json.sample b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/local.settings.json.sample new file mode 100644 index 0000000000..30edea6c08 --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/local.settings.json.sample @@ -0,0 +1,12 @@ +{ + "IsEncrypted": false, + "Values": { + "FUNCTIONS_WORKER_RUNTIME": "python", + "AzureWebJobsStorage": "UseDevelopmentStorage=true", + "DURABLE_TASK_SCHEDULER_CONNECTION_STRING": "Endpoint=http://localhost:8080;TaskHub=default;Authentication=None", + "TASKHUB_NAME": "default", + "AZURE_OPENAI_ENDPOINT": "https://.openai.azure.com/", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "", + "AZURE_OPENAI_API_KEY": "" + } +} diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/requirements.txt b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/requirements.txt new file mode 100644 index 0000000000..792ae4864e --- /dev/null +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/requirements.txt @@ -0,0 +1,3 @@ +agent-framework-azurefunctions +agent-framework +azure-identity diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/.env.template b/python/samples/getting_started/azure_functions/11_workflow_parallel/.env.template new file mode 100644 index 0000000000..1ef634f442 --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/.env.template @@ -0,0 +1,14 @@ +# Azure Functions Runtime Configuration +FUNCTIONS_WORKER_RUNTIME=python +AzureWebJobsStorage=UseDevelopmentStorage=true + +# Durable Task Scheduler Configuration +# For local development with DTS emulator: Endpoint=http://localhost:8080;TaskHub=default;Authentication=None +# For Azure: Get connection string from Azure portal +DURABLE_TASK_SCHEDULER_CONNECTION_STRING=Endpoint=http://localhost:8080;TaskHub=default;Authentication=None +TASKHUB_NAME=default + +# Azure OpenAI Configuration +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=your-deployment-name +AZURE_OPENAI_API_KEY=your-api-key diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/.gitignore b/python/samples/getting_started/azure_functions/11_workflow_parallel/.gitignore new file mode 100644 index 0000000000..41f350a67c --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/.gitignore @@ -0,0 +1,4 @@ +.venv/ +__pycache__/ +local.settings.json +.env diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md b/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md new file mode 100644 index 0000000000..07c48b73e6 --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md @@ -0,0 +1,193 @@ +# Parallel Workflow Execution Sample + +This sample demonstrates **parallel execution** of executors and agents in Azure Durable Functions workflows. + +## Overview + +This sample showcases three different parallel execution patterns: + +1. **Two Executors in Parallel** - Fan-out to multiple activities +2. **Two Agents in Parallel** - Fan-out to multiple entities +3. **Mixed Execution** - Agents and executors can run concurrently + +## Workflow Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ PARALLEL WORKFLOW │ +├─────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Pattern 1: Two Executors in Parallel (Activities) │ +│ ───────────────────────────────────────────────── │ +│ │ +│ input_router ──┬──> [word_count_processor] ────┐ │ +│ │ │ │ +│ └──> [format_analyzer_processor]┴──> [aggregator] │ +│ │ +│ Pattern 2: Two Agents in Parallel (Entities) │ +│ ───────────────────────────────────────────── │ +│ │ +│ [prepare_for_agents] ──┬──> [SentimentAgent] ──────┐ │ +│ │ │ │ +│ └──> [KeywordAgent] ────────┴──> [prepare_for_│ +│ mixed] │ +│ │ +│ Pattern 3: Mixed Agent + Executor in Parallel │ +│ ──────────────────────────────────────────────── │ +│ │ +│ [prepare_for_mixed] ──┬──> [SummaryAgent] ─────────┐ │ +│ │ │ │ +│ └──> [statistics_processor] ─┴──> [final_report│ +│ _executor] │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +## How Parallel Execution Works + +### Activities (Executors) +When multiple executors are pending in the same iteration (e.g., after a fan-out edge), they are batched and executed using `task_all()`: + +```python +# In _workflow.py - activities execute in parallel +activity_tasks = [context.call_activity("ExecuteExecutor", input) for ...] +results = yield context.task_all(activity_tasks) # All run concurrently! +``` + +### Agents (Entities) +Different agents can also run in parallel when they're pending in the same iteration: + +```python +# Different agents run in parallel +agent_tasks = [agent_a.run(...), agent_b.run(...)] +responses = yield context.task_all(agent_tasks) # Both agents run concurrently! +``` + +**Note:** Multiple messages to the *same* agent are processed sequentially to maintain conversation coherence. + +## Components + +| Component | Type | Description | +|-----------|------|-------------| +| `input_router` | Executor | Routes input JSON to parallel processors | +| `word_count_processor` | Executor | Counts words and characters | +| `format_analyzer_processor` | Executor | Analyzes document format | +| `aggregator` | Executor | Combines results from parallel processors | +| `prepare_for_agents` | Executor | Prepares content for agent analysis | +| `SentimentAnalysisAgent` | AI Agent | Analyzes text sentiment | +| `KeywordExtractionAgent` | AI Agent | Extracts keywords and categories | +| `prepare_for_mixed` | Executor | Prepares content for mixed parallel execution | +| `SummaryAgent` | AI Agent | Summarizes the document | +| `statistics_processor` | Executor | Computes document statistics | +| `FinalReportExecutor` | Executor | Compiles final report from all analyses | + +## Prerequisites + +1. **Azure OpenAI** - Endpoint and deployment configured +2. **DTS Emulator** - For durable task scheduling (recommended) +3. **Azurite** - For Azure Functions internal storage + +## Setup + +### Option 1: DevUI Mode (Local Development - No Durable Functions) + +The sample can run locally without Azure Functions infrastructure using DevUI: + +1. Copy the environment template: + ```bash + cp .env.template .env + ``` + +2. Configure `.env` with your Azure OpenAI credentials + +3. Install dependencies: + ```bash + pip install -r requirements.txt + ``` + +4. Run in DevUI mode (set `durable=False` in `function_app.py`): + ```bash + python function_app.py + ``` + +5. Open `http://localhost:8095` and provide input: + ```json + { + "document_id": "doc-001", + "content": "Your document text here..." + } + ``` + +### Option 2: Durable Functions Mode (Full Azure Functions) + +1. Copy configuration files: + ```bash + cp .env.template .env + cp local.settings.json.sample local.settings.json + ``` + +2. Configure `local.settings.json` with your Azure OpenAI credentials + +3. Install dependencies: + ```bash + pip install -r requirements.txt + ``` + +4. Start DTS Emulator: + ```bash + docker run -d --name dts-emulator -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest + ``` + +5. Start Azurite (or use VS Code extension): + ```bash + azurite --silent + ``` + +6. Run the function app (ensure `durable=True` in `function_app.py`): + ```bash + func start + ``` + +## Testing + +Use the `demo.http` file with REST Client extension or curl: + +### Analyze a Document +```bash +curl -X POST http://localhost:7071/api/workflow/run \ + -H "Content-Type: application/json" \ + -d '{ + "document_id": "doc-001", + "content": "The quarterly earnings report shows strong growth in cloud services. Revenue increased by 25%." + }' +``` + +### Check Status +```bash +curl http://localhost:7071/api/workflow/status/{instanceId} +``` + +## Observing Parallel Execution + +Open the DTS Dashboard at `http://localhost:8082` to observe: + +1. **Activity Execution Timeline** - You'll see `word_count_processor` and `format_analyzer_processor` starting at approximately the same time +2. **Agent Execution Timeline** - `SentimentAnalysisAgent` and `KeywordExtractionAgent` also start concurrently +3. **Sequential vs Parallel** - Compare with non-parallel samples to see the time savings + +## Expected Output + +```json +{ + "output": [ + "=== Document Analysis Report ===\n\n--- SentimentAnalysisAgent ---\n{\"sentiment\": \"positive\", \"confidence\": 0.85, \"explanation\": \"...\"}\n\n--- KeywordExtractionAgent ---\n{\"keywords\": [\"earnings\", \"growth\", \"cloud\"], \"categories\": [\"finance\", \"technology\"]}" + ] +} +``` + +## Key Takeaways + +1. **Parallel execution is automatic** - When multiple executors/agents are pending in the same iteration, they run in parallel +2. **Workflow graph determines parallelism** - Fan-out edges create parallel execution opportunities +3. **Mixed parallelism** - Agents and executors can run concurrently if they're in the same iteration +4. **Same-agent messages are sequential** - To maintain conversation coherence diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/demo.http b/python/samples/getting_started/azure_functions/11_workflow_parallel/demo.http new file mode 100644 index 0000000000..a8ae96e452 --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/demo.http @@ -0,0 +1,29 @@ +### Analyze a document (triggers parallel workflow) +POST http://localhost:7071/api/workflow/run +Content-Type: application/json + +{ + "document_id": "doc-001", + "content": "The quarterly earnings report shows strong growth in our cloud services division. Revenue increased by 25% compared to last year, driven by enterprise adoption. Customer satisfaction remains high at 92%. However, we face challenges in the mobile segment where competition is intense. Overall, the outlook is positive with expected continued growth in the coming quarters." +} + +### + +### Short document test +POST http://localhost:7071/api/workflow/run +Content-Type: application/json + +{ + "document_id": "doc-002", + "content": "Quick update: Project completed successfully. Team performance exceeded expectations." +} + +### + +### Check workflow status +GET http://localhost:7071/api/workflow/status/{{instanceId}} + +### + +### Health check +GET http://localhost:7071/api/health diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py new file mode 100644 index 0000000000..a51a1b6a04 --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py @@ -0,0 +1,538 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Parallel Workflow Execution Sample. + +This sample demonstrates parallel execution of executors and agents in Azure Durable Functions. +It showcases three different parallel execution patterns: + +1. Two executors running concurrently (fan-out to activities) +2. Two agents running concurrently (fan-out to entities) +3. One executor and one agent running concurrently (mixed fan-out) + +The workflow simulates a document processing pipeline where: +- A document is analyzed by multiple processors in parallel +- Results are aggregated and then processed by agents +- A summary agent and statistics executor run in parallel +- Finally, combined into a single output + +Key architectural points: +- FanOut edges enable parallel execution +- Different agents run in parallel when they're in the same iteration +- Activities (executors) also run in parallel when pending together +- Mixed agent/executor fan-outs execute concurrently +""" + +import json +import logging +import os +from dataclasses import dataclass +from typing import Any + +from agent_framework import ( + AgentExecutorResponse, + Executor, + Workflow, + WorkflowBuilder, + WorkflowContext, + executor, + handler, +) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential +from pydantic import BaseModel +from typing_extensions import Never + +from agent_framework_azurefunctions import AgentFunctionApp + +logger = logging.getLogger(__name__) + +AZURE_OPENAI_ENDPOINT_ENV = "AZURE_OPENAI_ENDPOINT" +AZURE_OPENAI_DEPLOYMENT_ENV = "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME" +AZURE_OPENAI_API_KEY_ENV = "AZURE_OPENAI_API_KEY" + +# Agent names +SENTIMENT_AGENT_NAME = "SentimentAnalysisAgent" +KEYWORD_AGENT_NAME = "KeywordExtractionAgent" +SUMMARY_AGENT_NAME = "SummaryAgent" +RECOMMENDATION_AGENT_NAME = "RecommendationAgent" + + +# ============================================================================ +# Pydantic Models for structured outputs +# ============================================================================ + + +class SentimentResult(BaseModel): + """Result from sentiment analysis.""" + sentiment: str # positive, negative, neutral + confidence: float + explanation: str + + +class KeywordResult(BaseModel): + """Result from keyword extraction.""" + keywords: list[str] + categories: list[str] + + +class SummaryResult(BaseModel): + """Result from summarization.""" + summary: str + key_points: list[str] + + +class RecommendationResult(BaseModel): + """Result from recommendation engine.""" + recommendations: list[str] + priority: str + + +@dataclass +class DocumentInput: + """Input document to be processed.""" + document_id: str + content: str + + +@dataclass +class ProcessorResult: + """Result from a document processor (executor).""" + processor_name: str + document_id: str + content: str + word_count: int + char_count: int + has_numbers: bool + + +@dataclass +class AggregatedResults: + """Aggregated results from parallel processors.""" + document_id: str + content: str + processor_results: list[ProcessorResult] + + +@dataclass +class AgentAnalysis: + """Analysis result from an agent.""" + agent_name: str + result: str + + +@dataclass +class FinalReport: + """Final combined report.""" + document_id: str + analyses: list[AgentAnalysis] + + +# ============================================================================ +# Executor Definitions (Activities - run in parallel when pending together) +# ============================================================================ + + +@executor(id="input_router") +async def input_router( + doc: str, + ctx: WorkflowContext[DocumentInput] +) -> None: + """Route input document to parallel processors. + + Accepts a JSON string from the HTTP request and converts to DocumentInput. + """ + # Parse the JSON string input + data = json.loads(doc) if isinstance(doc, str) else doc + document = DocumentInput( + document_id=data.get("document_id", "unknown"), + content=data.get("content", ""), + ) + logger.info("[input_router] Routing document: %s", document.document_id) + await ctx.send_message(document) + + +@executor(id="word_count_processor") +async def word_count_processor( + doc: DocumentInput, + ctx: WorkflowContext[ProcessorResult] +) -> None: + """Process document and count words - runs as an activity.""" + logger.info("[word_count_processor] Processing document: %s", doc.document_id) + + word_count = len(doc.content.split()) + char_count = len(doc.content) + has_numbers = any(c.isdigit() for c in doc.content) + + result = ProcessorResult( + processor_name="word_count", + document_id=doc.document_id, + content=doc.content, + word_count=word_count, + char_count=char_count, + has_numbers=has_numbers, + ) + + await ctx.send_message(result) + + +@executor(id="format_analyzer_processor") +async def format_analyzer_processor( + doc: DocumentInput, + ctx: WorkflowContext[ProcessorResult] +) -> None: + """Analyze document format - runs as an activity in parallel with word_count.""" + logger.info("[format_analyzer_processor] Processing document: %s", doc.document_id) + + # Simple format analysis + lines = doc.content.split('\n') + word_count = len(lines) # Using line count as "word count" for this processor + char_count = sum(len(line) for line in lines) + has_numbers = doc.content.count('.') > 0 # Check for sentences + + result = ProcessorResult( + processor_name="format_analyzer", + document_id=doc.document_id, + content=doc.content, + word_count=word_count, + char_count=char_count, + has_numbers=has_numbers, + ) + + await ctx.send_message(result) + + +@executor(id="aggregator") +async def aggregator( + results: list[ProcessorResult], + ctx: WorkflowContext[AggregatedResults] +) -> None: + """Aggregate results from parallel processors - receives fan-in input.""" + logger.info("[aggregator] Aggregating %d results", len(results)) + + # Extract document info from the first result (all have the same content) + document_id = results[0].document_id if results else "unknown" + content = results[0].content if results else "" + + aggregated = AggregatedResults( + document_id=document_id, + content=content, + processor_results=results, + ) + + await ctx.send_message(aggregated) + + +@executor(id="prepare_for_agents") +async def prepare_for_agents( + aggregated: AggregatedResults, + ctx: WorkflowContext[str] +) -> None: + """Prepare content for agent analysis - broadcasts to multiple agents.""" + logger.info("[prepare_for_agents] Preparing content for agents") + + # Send the original content to agents for analysis + await ctx.send_message(aggregated.content) + + +@executor(id="prepare_for_mixed") +async def prepare_for_mixed( + analyses: list[AgentExecutorResponse], + ctx: WorkflowContext[str] +) -> None: + """Prepare results for mixed agent+executor parallel processing. + + Combines agent analysis results into a string that can be consumed by + both the SummaryAgent and the statistics_processor in parallel. + """ + logger.info("[prepare_for_mixed] Preparing for mixed parallel pattern") + + sentiment_text = "" + keyword_text = "" + + for analysis in analyses: + executor_id = analysis.executor_id + text = analysis.agent_run_response.text if analysis.agent_run_response else "" + + if executor_id == SENTIMENT_AGENT_NAME: + sentiment_text = text + elif executor_id == KEYWORD_AGENT_NAME: + keyword_text = text + + # Combine into a string that both agent and executor can process + combined = f"Sentiment Analysis: {sentiment_text}\n\nKeyword Extraction: {keyword_text}" + await ctx.send_message(combined) + + +@executor(id="statistics_processor") +async def statistics_processor( + analysis_text: str, + ctx: WorkflowContext[ProcessorResult] +) -> None: + """Calculate statistics from the analysis - runs in parallel with SummaryAgent.""" + logger.info("[statistics_processor] Calculating statistics") + + # Calculate some statistics from the combined analysis + word_count = len(analysis_text.split()) + char_count = len(analysis_text) + has_numbers = any(c.isdigit() for c in analysis_text) + + result = ProcessorResult( + processor_name="statistics", + document_id="analysis", + content=analysis_text, + word_count=word_count, + char_count=char_count, + has_numbers=has_numbers, + ) + await ctx.send_message(result) + + +class FinalReportExecutor(Executor): + """Executor that compiles the final report from agent analyses.""" + + @handler + async def compile_report( + self, + analyses: list[AgentExecutorResponse | ProcessorResult], + ctx: WorkflowContext[Never, str], + ) -> None: + """Compile final report from mixed agent + processor results.""" + logger.info("[final_report] Compiling report from %d analyses", len(analyses)) + + report_parts = ["=== Document Analysis Report ===\n"] + + for analysis in analyses: + if isinstance(analysis, AgentExecutorResponse): + agent_name = analysis.executor_id + text = analysis.agent_run_response.text if analysis.agent_run_response else "No response" + elif isinstance(analysis, ProcessorResult): + agent_name = f"Processor: {analysis.processor_name}" + text = f"Words: {analysis.word_count}, Chars: {analysis.char_count}" + else: + continue + + report_parts.append(f"\n--- {agent_name} ---") + report_parts.append(text) + + final_report = "\n".join(report_parts) + await ctx.yield_output(final_report) + + +class MixedResultCollector(Executor): + """Collector for mixed agent/executor results.""" + + @handler + async def collect_mixed_results( + self, + results: list[Any], + ctx: WorkflowContext[Never, str], + ) -> None: + """Collect and format results from mixed parallel execution.""" + logger.info("[mixed_collector] Collecting %d mixed results", len(results)) + + output_parts = ["=== Mixed Parallel Execution Results ===\n"] + + for result in results: + if isinstance(result, AgentExecutorResponse): + output_parts.append(f"[Agent: {result.executor_id}]") + output_parts.append(result.agent_run_response.text if result.agent_run_response else "No response") + elif isinstance(result, ProcessorResult): + output_parts.append(f"[Processor: {result.processor_name}]") + output_parts.append(f" Words: {result.word_count}, Chars: {result.char_count}") + + await ctx.yield_output("\n".join(output_parts)) + + +# ============================================================================ +# Workflow Construction +# ============================================================================ + + +def _build_client_kwargs() -> dict[str, Any]: + """Build Azure OpenAI client kwargs from environment variables.""" + endpoint = os.getenv(AZURE_OPENAI_ENDPOINT_ENV) + if not endpoint: + raise RuntimeError(f"{AZURE_OPENAI_ENDPOINT_ENV} environment variable is required.") + + deployment = os.getenv(AZURE_OPENAI_DEPLOYMENT_ENV) + if not deployment: + raise RuntimeError(f"{AZURE_OPENAI_DEPLOYMENT_ENV} environment variable is required.") + + client_kwargs: dict[str, Any] = { + "endpoint": endpoint, + "deployment_name": deployment, + } + + api_key = os.getenv(AZURE_OPENAI_API_KEY_ENV) + if api_key: + client_kwargs["api_key"] = api_key + else: + client_kwargs["credential"] = AzureCliCredential() + + return client_kwargs + + +def _create_workflow() -> Workflow: + """Create the parallel workflow definition. + + Workflow structure demonstrating three parallel patterns: + + Pattern 1: Two Executors in Parallel (Fan-out/Fan-in to activities) + ──────────────────────────────────────────────────────────────────── + ┌─> word_count_processor ─────┐ + input_router ──┤ ├──> aggregator + └─> format_analyzer_processor ─┘ + + Pattern 2: Two Agents in Parallel (Fan-out to entities) + ──────────────────────────────────────────────────────── + prepare_for_agents ─┬─> SentimentAgent ──┐ + └─> KeywordAgent ────┤ + └──> prepare_for_mixed + + Pattern 3: Mixed Agent + Executor in Parallel + ────────────────────────────────────────────── + prepare_for_mixed ─┬─> SummaryAgent ────────┐ + └─> statistics_processor ─┤ + └──> final_report + """ + client_kwargs = _build_client_kwargs() + chat_client = AzureOpenAIChatClient(**client_kwargs) + + # Create agents for parallel analysis + sentiment_agent = chat_client.create_agent( + name=SENTIMENT_AGENT_NAME, + instructions=( + "You are a sentiment analysis expert. Analyze the sentiment of the given text. " + "Return JSON with fields: sentiment (positive/negative/neutral), " + "confidence (0.0-1.0), and explanation (brief reasoning)." + ), + response_format=SentimentResult, + ) + + keyword_agent = chat_client.create_agent( + name=KEYWORD_AGENT_NAME, + instructions=( + "You are a keyword extraction expert. Extract important keywords and categories " + "from the given text. Return JSON with fields: keywords (list of strings), " + "and categories (list of topic categories)." + ), + response_format=KeywordResult, + ) + + # Create summary agent for Pattern 3 (mixed parallel) + summary_agent = chat_client.create_agent( + name=SUMMARY_AGENT_NAME, + instructions=( + "You are a summarization expert. Given analysis results (sentiment and keywords), " + "provide a concise summary. Return JSON with fields: summary (brief text), " + "and key_points (list of main takeaways)." + ), + response_format=SummaryResult, + ) + + # Create executor instances + final_report_executor = FinalReportExecutor(id="final_report") + + # Build workflow with parallel patterns + workflow = ( + WorkflowBuilder() + # Start: Route input to parallel processors + .set_start_executor(input_router) + + # Pattern 1: Fan-out to two executors (run in parallel) + .add_fan_out_edges( + source=input_router, + targets=[word_count_processor, format_analyzer_processor], + ) + + # Fan-in: Both processors send results to aggregator + .add_fan_in_edges( + sources=[word_count_processor, format_analyzer_processor], + target=aggregator, + ) + + # Prepare content for agent analysis + .add_edge(aggregator, prepare_for_agents) + + # Pattern 2: Fan-out to two agents (run in parallel) + .add_fan_out_edges( + source=prepare_for_agents, + targets=[sentiment_agent, keyword_agent], + ) + + # Fan-in: Collect agent results into prepare_for_mixed + .add_fan_in_edges( + sources=[sentiment_agent, keyword_agent], + target=prepare_for_mixed, + ) + + # Pattern 3: Fan-out to one agent + one executor (mixed parallel) + .add_fan_out_edges( + source=prepare_for_mixed, + targets=[summary_agent, statistics_processor], + ) + + # Final fan-in: Collect mixed results + .add_fan_in_edges( + sources=[summary_agent, statistics_processor], + target=final_report_executor, + ) + + .build() + ) + + return workflow + + +# ============================================================================ +# Application Entry Point +# ============================================================================ + + +def launch(durable: bool = True) -> AgentFunctionApp | None: + """Launch the function app or DevUI.""" + workflow: Workflow | None = None + + if durable: + workflow = _create_workflow() + app = AgentFunctionApp( + workflow=workflow, + enable_health_check=True, + ) + return app + else: + from pathlib import Path + from agent_framework.devui import serve + from dotenv import load_dotenv + + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) + + logger.info("Starting Parallel Workflow Sample") + logger.info("Available at: http://localhost:8095") + logger.info("\nThis workflow demonstrates:") + logger.info("- Pattern 1: Two executors running in parallel") + logger.info("- Pattern 2: Two agents running in parallel") + logger.info("- Pattern 3: Mixed agent + executor running in parallel") + logger.info("- Fan-in aggregation of parallel results") + + workflow = _create_workflow() + serve(entities=[workflow], port=8095, auto_open=True) + + return None + + +# Default: Azure Functions mode +# Run with `python function_app.py --maf` for pure MAF mode with DevUI +app = launch(durable=True) + + +if __name__ == "__main__": + import sys + + if "--maf" in sys.argv: + # Run in pure MAF mode with DevUI + launch(durable=False) + else: + print("Usage: python function_app.py --maf") + print(" --maf Run in pure MAF mode with DevUI (http://localhost:8095)") + print("\nFor Azure Functions mode, use: func start") diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json b/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json new file mode 100644 index 0000000000..292562af8e --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json @@ -0,0 +1,16 @@ +{ + "version": "2.0", + "extensionBundle": { + "id": "Microsoft.Azure.Functions.ExtensionBundle", + "version": "[4.*, 5.0.0)" + }, + "extensions": { + "durableTask": { + "hubName": "%TASKHUB_NAME%", + "storageProvider": { + "type": "AzureManaged", + "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" + } + } + } +} diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/local.settings.json.sample b/python/samples/getting_started/azure_functions/11_workflow_parallel/local.settings.json.sample new file mode 100644 index 0000000000..30edea6c08 --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/local.settings.json.sample @@ -0,0 +1,12 @@ +{ + "IsEncrypted": false, + "Values": { + "FUNCTIONS_WORKER_RUNTIME": "python", + "AzureWebJobsStorage": "UseDevelopmentStorage=true", + "DURABLE_TASK_SCHEDULER_CONNECTION_STRING": "Endpoint=http://localhost:8080;TaskHub=default;Authentication=None", + "TASKHUB_NAME": "default", + "AZURE_OPENAI_ENDPOINT": "https://.openai.azure.com/", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "", + "AZURE_OPENAI_API_KEY": "" + } +} diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/requirements.txt b/python/samples/getting_started/azure_functions/11_workflow_parallel/requirements.txt new file mode 100644 index 0000000000..792ae4864e --- /dev/null +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/requirements.txt @@ -0,0 +1,3 @@ +agent-framework-azurefunctions +agent-framework +azure-identity diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/.gitignore b/python/samples/getting_started/azure_functions/12_workflow_hitl/.gitignore new file mode 100644 index 0000000000..7097fe0170 --- /dev/null +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/.gitignore @@ -0,0 +1,5 @@ +# Local settings - copy from local.settings.json.sample and fill in your values +local.settings.json +__pycache__/ +*.pyc +.venv/ diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/README.md b/python/samples/getting_started/azure_functions/12_workflow_hitl/README.md new file mode 100644 index 0000000000..2bb84f16dc --- /dev/null +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/README.md @@ -0,0 +1,141 @@ +# 12. Workflow with Human-in-the-Loop (HITL) + +This sample demonstrates how to integrate human approval into a MAF workflow running on Azure Durable Functions using the MAF `request_info` and `@response_handler` pattern. + +## Overview + +The sample implements a content moderation pipeline: + +1. **User starts workflow** with content for publication via HTTP endpoint +2. **AI Agent analyzes** the content for policy compliance +3. **Workflow pauses** and requests human reviewer approval +4. **Human responds** via HTTP endpoint with approval/rejection +5. **Workflow resumes** and publishes or rejects the content + +## Key Concepts + +### MAF HITL Pattern + +This sample uses MAF's built-in human-in-the-loop pattern: + +```python +# In an executor, request human input +await ctx.request_info( + request_data=HumanApprovalRequest(...), + response_type=HumanApprovalResponse, +) + +# Handle the response in a separate method +@response_handler +async def handle_approval_response( + self, + original_request: HumanApprovalRequest, + response: HumanApprovalResponse, + ctx: WorkflowContext, +) -> None: + # Process the human's decision + ... +``` + +### Automatic HITL Endpoints + +`AgentFunctionApp` automatically provides all the HTTP endpoints needed for HITL: + +| Endpoint | Description | +|----------|-------------| +| `POST /api/workflow/run` | Start the workflow | +| `GET /api/workflow/status/{instanceId}` | Check status and pending HITL requests | +| `POST /api/workflow/respond/{instanceId}/{requestId}` | Send human response | +| `GET /api/health` | Health check | + +### Durable Functions Integration + +When running on Durable Functions, the HITL pattern maps to: + +| MAF Concept | Durable Functions | +|-------------|-------------------| +| `ctx.request_info()` | Workflow pauses, custom status updated | +| `RequestInfoEvent` | Exposed via status endpoint | +| HTTP response | `client.raise_event(instance_id, request_id, data)` | +| `@response_handler` | Workflow resumes, handler invoked | + +## Workflow Architecture + +``` +┌─────────────────┐ ┌──────────────────────┐ ┌────────────────────────┐ +│ Input Router │ ──► │ Content Analyzer │ ──► │ Content Analyzer │ +│ Executor │ │ Agent (AI) │ │ Executor (Parse JSON) │ +└─────────────────┘ └──────────────────────┘ └────────────────────────┘ + │ + ▼ +┌─────────────────┐ ┌──────────────────────┐ +│ Publish │ ◄── │ Human Review │ ◄── HITL PAUSE +│ Executor │ │ Executor │ (wait for external event) +└─────────────────┘ └──────────────────────┘ +``` + +## Prerequisites + +1. **Azure OpenAI** - Access to Azure OpenAI with a deployed chat model +2. **Durable Task Scheduler** - Local emulator or Azure deployment +3. **Azurite** - Local Azure Storage emulator +4. **Azure CLI** - For authentication (`az login`) + +## Setup + +1. Copy the sample settings file: + ```bash + cp local.settings.json.sample local.settings.json + ``` + +2. Update `local.settings.json` with your Azure OpenAI credentials: + ```json + { + "Values": { + "AZURE_OPENAI_ENDPOINT": "https://your-resource.openai.azure.com/", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "gpt-4o" + } + } + ``` + +3. Start the local emulators: + ```bash + # Terminal 1: Start Azurite + azurite --silent --location . + + # Terminal 2: Start Durable Task Scheduler (if using local emulator) + # Follow Durable Task Scheduler setup instructions + ``` + +4. Start the Function App: + ```bash + func start + ``` + +## Running in Pure MAF Mode + +You can also run this sample in pure MAF mode (without Durable Functions) using the DevUI: + +```bash +python function_app.py --maf +``` + +This launches the DevUI at http://localhost:8096 where you can interact with the workflow directly. This is useful for: +- Local development and debugging +- Testing the HITL pattern without Durable Functions infrastructure +- Comparing behavior between MAF and Durable modes + +## Testing + +Use the `demo.http` file with the VS Code REST Client extension: + +1. **Start workflow** - `POST /api/workflow/run` with content payload +2. **Check status** - `GET /api/workflow/status/{instanceId}` to see pending HITL requests +3. **Send response** - `POST /api/workflow/respond/{instanceId}/{requestId}` with approval +4. **Check result** - `GET /api/workflow/status/{instanceId}` to see final output + +## Related Samples + +- [07_single_agent_orchestration_hitl](../07_single_agent_orchestration_hitl/) - HITL at orchestrator level (not using MAF pattern) +- [09_workflow_shared_state](../09_workflow_shared_state/) - Workflow with shared state +- [guessing_game_with_human_input](../../workflows/human-in-the-loop/guessing_game_with_human_input.py) - MAF HITL pattern (non-durable) diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http b/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http new file mode 100644 index 0000000000..b59ae8b61c --- /dev/null +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http @@ -0,0 +1,123 @@ +### ============================================================================ +### Workflow HITL Sample - Content Moderation with Human Approval +### ============================================================================ +### This sample demonstrates MAF workflows with human-in-the-loop using the +### request_info / @response_handler pattern on Azure Durable Functions. +### +### The AgentFunctionApp automatically provides all HITL endpoints. +### +### Prerequisites: +### 1. Start Azurite: azurite --silent --location . +### 2. Start Durable Task Scheduler emulator +### 3. Configure local.settings.json with Azure OpenAI credentials +### 4. Run: func start +### ============================================================================ + + +### ============================================================================ +### 1. Start the Workflow with Content for Moderation +### ============================================================================ +### This starts the workflow. The AI will analyze the content, then the workflow +### will pause waiting for human approval. + +POST http://localhost:7071/api/workflow/run +Content-Type: application/json + +{ + "content_id": "article-001", + "title": "Introduction to AI in Healthcare", + "body": "Artificial intelligence is revolutionizing healthcare by enabling faster diagnosis, personalized treatment plans, and improved patient outcomes. Machine learning algorithms can analyze medical images with remarkable accuracy, often detecting issues that human radiologists might miss.", + "author": "Dr. Jane Smith" +} + + +### ============================================================================ +### 2. Start Workflow with Potentially Problematic Content +### ============================================================================ +### This content should trigger higher risk assessment from the AI analyzer. + +POST http://localhost:7071/api/workflow/run +Content-Type: application/json + +{ + "content_id": "article-002", + "title": "Get Rich Quick Scheme", + "body": "Click here NOW to make $10,000 overnight! This SECRET method is GUARANTEED to work! Limited time offer - act NOW before it's too late! Send your bank details immediately!", + "author": "Definitely Not Spam" +} + + +### ============================================================================ +### 3. Check Workflow Status +### ============================================================================ +### Replace INSTANCE_ID with the value returned from the run call. +### The status will show pending HITL requests if waiting for human approval. + +@instanceId = 3130c486c9374e4e87125cbd9a238dfc + +GET http://localhost:7071/api/workflow/status/{{instanceId}} + + +### ============================================================================ +### 4. Send Human Approval +### ============================================================================ +### Approve the content for publication. +### Replace INSTANCE_ID and REQUEST_ID with values from the status response. + +@requestId = 1682e5f8-0917-4b68-aa04-d4688cfa2e69 + +POST http://localhost:7071/api/workflow/respond/{{instanceId}}/{{requestId}} +Content-Type: application/json + +{ + "approved": true, + "reviewer_notes": "Content is appropriate and well-written. Approved for publication." +} + + +### ============================================================================ +### 5. Send Human Rejection +### ============================================================================ +### Reject the content with feedback. + +POST http://localhost:7071/api/workflow/respond/{{instanceId}}/{{requestId}} +Content-Type: application/json + +{ + "approved": false, + "reviewer_notes": "Content appears to be spam. Contains multiple spam indicators including urgency language, promises of easy money, and requests for personal information." +} + + +### ============================================================================ +### Example Workflow - Complete Happy Path +### ============================================================================ +### +### Step 1: Start workflow with content +### POST http://localhost:7071/api/workflow/run +### -> Returns instanceId: "abc123..." +### +### Step 2: Check status (workflow is waiting for human input) +### GET http://localhost:7071/api/workflow/status/abc123 +### -> Returns pendingHumanInputRequests with requestId: "req-456..." +### +### Step 3: Approve content +### POST http://localhost:7071/api/workflow/respond/abc123/req-456 +### { +### "approved": true, +### "reviewer_notes": "Looks good!" +### } +### -> Returns success +### +### Step 4: Check final status +### GET http://localhost:7071/api/workflow/status/abc123 +### -> Returns runtimeStatus: "Completed", output: "✅ Content approved..." +### +### ============================================================================ + + +### ============================================================================ +### Health Check +### ============================================================================ + +GET http://localhost:7071/api/health diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py new file mode 100644 index 0000000000..bb36832b17 --- /dev/null +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py @@ -0,0 +1,468 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Workflow with Human-in-the-Loop (HITL) using MAF request_info Pattern. + +This sample demonstrates how to integrate human approval into a MAF workflow +running on Azure Durable Functions. It uses the MAF `request_info` and +`@response_handler` pattern for structured HITL interactions. + +The workflow simulates a content moderation pipeline: +1. User submits content for publication +2. An AI agent analyzes the content for policy compliance +3. A human reviewer is prompted to approve/reject the content +4. Based on approval, content is either published or rejected + +Key architectural points: +- Uses MAF's `ctx.request_info()` to pause workflow and request human input +- Uses `@response_handler` decorator to handle the human's response +- AgentFunctionApp automatically provides HITL endpoints for status and response +- Durable Functions provides durability while waiting for human input + +Prerequisites: +- Azure OpenAI configured with required environment variables +- Durable Task Scheduler connection string +- Authentication via Azure CLI (az login) +""" + +import json +import logging +import os +from dataclasses import dataclass +from typing import Any + +from agent_framework import ( + AgentExecutorRequest, + AgentExecutorResponse, + ChatMessage, + Executor, + Role, + Workflow, + WorkflowBuilder, + WorkflowContext, + handler, + response_handler, +) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential +from pydantic import BaseModel +from typing_extensions import Never + +from agent_framework_azurefunctions import AgentFunctionApp + +logger = logging.getLogger(__name__) + +# Environment variable names +AZURE_OPENAI_ENDPOINT_ENV = "AZURE_OPENAI_ENDPOINT" +AZURE_OPENAI_DEPLOYMENT_ENV = "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME" +AZURE_OPENAI_API_KEY_ENV = "AZURE_OPENAI_API_KEY" + +# Agent names +CONTENT_ANALYZER_AGENT_NAME = "ContentAnalyzerAgent" + + +# ============================================================================ +# Data Models +# ============================================================================ + + +class ContentAnalysisResult(BaseModel): + """Structured output from the content analysis agent.""" + + is_appropriate: bool + risk_level: str # low, medium, high + concerns: list[str] + recommendation: str + + +@dataclass +class ContentSubmission: + """Content submitted for moderation.""" + + content_id: str + title: str + body: str + author: str + + +@dataclass +class HumanApprovalRequest: + """Request sent to human reviewer for approval. + + This is the payload passed to ctx.request_info() and will be + exposed via the orchestration status for external systems to retrieve. + """ + + content_id: str + title: str + body: str + author: str + ai_analysis: ContentAnalysisResult + prompt: str + + +class HumanApprovalResponse(BaseModel): + """Response from human reviewer. + + This is what the external system must send back via the HITL response endpoint. + """ + + approved: bool + reviewer_notes: str = "" + + +@dataclass +class ModerationResult: + """Final result of the moderation workflow.""" + + content_id: str + status: str # "approved", "rejected" + ai_analysis: ContentAnalysisResult | None + reviewer_notes: str + + +# ============================================================================ +# Agent Instructions +# ============================================================================ + +CONTENT_ANALYZER_INSTRUCTIONS = """You are a content moderation assistant that analyzes user-submitted content +for policy compliance. Evaluate the content for: + +1. Appropriateness - Is the content suitable for a general audience? +2. Risk level - Rate as 'low', 'medium', or 'high' based on potential issues +3. Concerns - List any specific issues found (empty list if none) +4. Recommendation - Provide a brief recommendation for human reviewers + +Return a JSON response with: +- is_appropriate: boolean +- risk_level: string ('low', 'medium', 'high') +- concerns: list of strings +- recommendation: string + +Be thorough but fair in your analysis.""" + + +# ============================================================================ +# Executors +# ============================================================================ + + +@dataclass +class AnalysisWithSubmission: + """Combines the AI analysis with the original submission for downstream processing.""" + + submission: ContentSubmission + analysis: ContentAnalysisResult + + +class ContentAnalyzerExecutor(Executor): + """Parses the AI agent's response and prepares for human review.""" + + def __init__(self): + super().__init__(id="content_analyzer_executor") + + @handler + async def handle_analysis( + self, + response: AgentExecutorResponse, + ctx: WorkflowContext[AnalysisWithSubmission], + ) -> None: + """Parse the AI analysis and forward with submission context.""" + analysis = ContentAnalysisResult.model_validate_json(response.agent_run_response.text) + + # Retrieve the original submission from shared state + submission: ContentSubmission = await ctx.get_shared_state("current_submission") + + await ctx.send_message(AnalysisWithSubmission(submission=submission, analysis=analysis)) + + +class HumanReviewExecutor(Executor): + """Requests human approval using MAF's request_info pattern. + + This executor demonstrates the core HITL pattern: + 1. Receives the AI analysis result + 2. Calls ctx.request_info() to pause and request human input + 3. The @response_handler method processes the human's response + """ + + def __init__(self): + super().__init__(id="human_review_executor") + + @handler + async def request_review( + self, + data: AnalysisWithSubmission, + ctx: WorkflowContext, + ) -> None: + """Request human review for the content. + + This method: + 1. Constructs the approval request with all context + 2. Calls request_info to pause the workflow + 3. The workflow will resume when a response is provided via the HITL endpoint + """ + submission = data.submission + analysis = data.analysis + + # Construct the human-readable prompt + prompt = ( + f"Please review the following content for publication:\n\n" + f"Title: {submission.title}\n" + f"Author: {submission.author}\n" + f"Content: {submission.body}\n\n" + f"AI Analysis:\n" + f"- Appropriate: {analysis.is_appropriate}\n" + f"- Risk Level: {analysis.risk_level}\n" + f"- Concerns: {', '.join(analysis.concerns) if analysis.concerns else 'None'}\n" + f"- Recommendation: {analysis.recommendation}\n\n" + f"Please approve or reject this content." + ) + + approval_request = HumanApprovalRequest( + content_id=submission.content_id, + title=submission.title, + body=submission.body, + author=submission.author, + ai_analysis=analysis, + prompt=prompt, + ) + + # Store analysis in shared state for the response handler + await ctx.set_shared_state("pending_analysis", data) + + # Request human input - workflow will pause here + # The response_type specifies what we expect back + await ctx.request_info( + request_data=approval_request, + response_type=HumanApprovalResponse, + ) + + @response_handler + async def handle_approval_response( + self, + original_request: HumanApprovalRequest, + response: HumanApprovalResponse, + ctx: WorkflowContext[ModerationResult], + ) -> None: + """Process the human reviewer's decision. + + This method is called automatically when a response to request_info is received. + The original_request contains the HumanApprovalRequest we sent. + The response contains the HumanApprovalResponse from the reviewer. + """ + logger.info( + "Human review received for content %s: approved=%s, notes=%s", + original_request.content_id, + response.approved, + response.reviewer_notes, + ) + + # Create the final moderation result + status = "approved" if response.approved else "rejected" + result = ModerationResult( + content_id=original_request.content_id, + status=status, + ai_analysis=original_request.ai_analysis, + reviewer_notes=response.reviewer_notes, + ) + + await ctx.send_message(result) + + +class PublishExecutor(Executor): + """Handles the final publication or rejection of content.""" + + def __init__(self): + super().__init__(id="publish_executor") + + @handler + async def handle_result( + self, + result: ModerationResult, + ctx: WorkflowContext[Never, str], + ) -> None: + """Finalize the moderation and yield output.""" + if result.status == "approved": + message = ( + f"✅ Content '{result.content_id}' has been APPROVED and published.\n" + f"Reviewer notes: {result.reviewer_notes or 'None'}" + ) + else: + message = ( + f"❌ Content '{result.content_id}' has been REJECTED.\n" + f"Reviewer notes: {result.reviewer_notes or 'None'}" + ) + + logger.info(message) + await ctx.yield_output(message) + + +# ============================================================================ +# Input Router Executor +# ============================================================================ + + +def _build_client_kwargs() -> dict[str, Any]: + """Build Azure OpenAI client configuration from environment variables.""" + endpoint = os.getenv(AZURE_OPENAI_ENDPOINT_ENV) + if not endpoint: + raise RuntimeError(f"{AZURE_OPENAI_ENDPOINT_ENV} environment variable is required.") + + deployment = os.getenv(AZURE_OPENAI_DEPLOYMENT_ENV) + if not deployment: + raise RuntimeError(f"{AZURE_OPENAI_DEPLOYMENT_ENV} environment variable is required.") + + client_kwargs: dict[str, Any] = { + "endpoint": endpoint, + "deployment_name": deployment, + } + + api_key = os.getenv(AZURE_OPENAI_API_KEY_ENV) + if api_key: + client_kwargs["api_key"] = api_key + else: + client_kwargs["credential"] = AzureCliCredential() + + return client_kwargs + + +class InputRouterExecutor(Executor): + """Routes incoming content submission to the analysis agent.""" + + def __init__(self): + super().__init__(id="input_router") + + @handler + async def route_input( + self, + input_json: str, + ctx: WorkflowContext[AgentExecutorRequest], + ) -> None: + """Parse input and create agent request.""" + data = json.loads(input_json) if isinstance(input_json, str) else input_json + + submission = ContentSubmission( + content_id=data.get("content_id", "unknown"), + title=data.get("title", "Untitled"), + body=data.get("body", ""), + author=data.get("author", "Anonymous"), + ) + + # Store submission in shared state for later retrieval + await ctx.set_shared_state("current_submission", submission) + + # Create the agent request + message = ( + f"Please analyze the following content for policy compliance:\n\n" + f"Title: {submission.title}\n" + f"Author: {submission.author}\n" + f"Content:\n{submission.body}" + ) + + await ctx.send_message( + AgentExecutorRequest( + messages=[ChatMessage(Role.USER, text=message)], + should_respond=True, + ) + ) + + +# ============================================================================ +# Workflow Creation +# ============================================================================ + + +def _create_workflow() -> Workflow: + """Create the content moderation workflow with HITL.""" + client_kwargs = _build_client_kwargs() + chat_client = AzureOpenAIChatClient(**client_kwargs) + + # Create the content analysis agent + content_analyzer_agent = chat_client.create_agent( + name=CONTENT_ANALYZER_AGENT_NAME, + instructions=CONTENT_ANALYZER_INSTRUCTIONS, + response_format=ContentAnalysisResult, + ) + + # Create executors + input_router = InputRouterExecutor() + content_analyzer_executor = ContentAnalyzerExecutor() + human_review_executor = HumanReviewExecutor() + publish_executor = PublishExecutor() + + # Build the workflow graph + # Flow: + # input_router -> content_analyzer_agent -> content_analyzer_executor + # -> human_review_executor (HITL pause here) -> publish_executor + workflow = ( + WorkflowBuilder() + .set_start_executor(input_router) + .add_edge(input_router, content_analyzer_agent) + .add_edge(content_analyzer_agent, content_analyzer_executor) + .add_edge(content_analyzer_executor, human_review_executor) + .add_edge(human_review_executor, publish_executor) + .build() + ) + + return workflow + + +# ============================================================================ +# Application Entry Point +# ============================================================================ + + +def launch(durable: bool = True) -> AgentFunctionApp | None: + """Launch the function app or DevUI. + + Args: + durable: If True, returns AgentFunctionApp for Azure Functions. + If False, launches DevUI for local MAF development. + """ + if durable: + # Azure Functions mode with Durable Functions + # The app automatically provides HITL endpoints: + # - POST /api/workflow/run - Start the workflow + # - GET /api/workflow/status/{instanceId} - Check status and pending HITL requests + # - POST /api/workflow/respond/{instanceId}/{requestId} - Send HITL response + # - GET /api/health - Health check + workflow = _create_workflow() + app = AgentFunctionApp(workflow=workflow, enable_health_check=True) + return app + else: + # Pure MAF mode with DevUI for local development + from pathlib import Path + + from agent_framework.devui import serve + from dotenv import load_dotenv + + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) + + logger.info("Starting Workflow HITL Sample in MAF mode") + logger.info("Available at: http://localhost:8096") + logger.info("\nThis workflow demonstrates:") + logger.info("- Human-in-the-loop using request_info / @response_handler pattern") + logger.info("- AI content analysis with structured output") + logger.info("- Human approval workflow integration") + logger.info("\nFlow: InputRouter -> ContentAnalyzer Agent -> HumanReview -> Publish") + + workflow = _create_workflow() + serve(entities=[workflow], port=8096, auto_open=True) + + return None + + +# Default: Azure Functions mode +# Run with `python function_app.py --maf` for pure MAF mode with DevUI +app = launch(durable=True) + + +if __name__ == "__main__": + import sys + + if "--maf" in sys.argv: + # Run in pure MAF mode with DevUI + launch(durable=False) + else: + print("Usage: python function_app.py --maf") + print(" --maf Run in pure MAF mode with DevUI (http://localhost:8096)") + print("\nFor Azure Functions mode, use: func start") diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json b/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json new file mode 100644 index 0000000000..292562af8e --- /dev/null +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json @@ -0,0 +1,16 @@ +{ + "version": "2.0", + "extensionBundle": { + "id": "Microsoft.Azure.Functions.ExtensionBundle", + "version": "[4.*, 5.0.0)" + }, + "extensions": { + "durableTask": { + "hubName": "%TASKHUB_NAME%", + "storageProvider": { + "type": "AzureManaged", + "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" + } + } + } +} diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/local.settings.json.sample b/python/samples/getting_started/azure_functions/12_workflow_hitl/local.settings.json.sample new file mode 100644 index 0000000000..69c08a3386 --- /dev/null +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/local.settings.json.sample @@ -0,0 +1,11 @@ +{ + "IsEncrypted": false, + "Values": { + "AzureWebJobsStorage": "UseDevelopmentStorage=true", + "DURABLE_TASK_SCHEDULER_CONNECTION_STRING": "Endpoint=http://localhost:8080;TaskHub=default;Authentication=None", + "TASKHUB_NAME": "default", + "FUNCTIONS_WORKER_RUNTIME": "python", + "AZURE_OPENAI_ENDPOINT": "", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "" + } +} diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/requirements.txt b/python/samples/getting_started/azure_functions/12_workflow_hitl/requirements.txt new file mode 100644 index 0000000000..85e158b8d4 --- /dev/null +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/requirements.txt @@ -0,0 +1,3 @@ +agent-framework-azurefunctions +azure-identity +agents-maf From fa1cf84324ee9d664641301232cc627d253f89e1 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Mon, 2 Feb 2026 15:50:35 -0600 Subject: [PATCH 02/29] fix compatability with latest framework changes and add integration tests --- .../agent_framework_azurefunctions/_utils.py | 14 +- .../_workflow.py | 62 +++-- python/packages/azurefunctions/pyproject.toml | 2 +- .../test_09_workflow_shared_state.py | 91 ++++++++ .../test_10_workflow_no_shared_state.py | 107 +++++++++ .../test_11_workflow_parallel.py | 134 +++++++++++ .../test_12_workflow_hitl.py | 215 ++++++++++++++++++ .../packages/azurefunctions/tests/test_app.py | 6 +- .../azurefunctions/tests/test_utils.py | 30 +-- .../azurefunctions/tests/test_workflow.py | 26 +-- .../09_workflow_shared_state/function_app.py | 19 +- .../09_workflow_shared_state/host.json | 6 +- .../function_app.py | 28 ++- .../10_workflow_no_shared_state/host.json | 6 +- .../11_workflow_parallel/function_app.py | 18 +- .../11_workflow_parallel/host.json | 6 +- .../12_workflow_hitl/function_app.py | 16 +- .../12_workflow_hitl/host.json | 6 +- 18 files changed, 691 insertions(+), 101 deletions(-) create mode 100644 python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py create mode 100644 python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py create mode 100644 python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py create mode 100644 python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py index 3b25f5db85..3cb10a1ae0 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py @@ -17,7 +17,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - AgentRunResponse, + AgentResponse, ChatMessage, CheckpointStorage, Message, @@ -254,7 +254,7 @@ def deserialize_value(data: Any, type_registry: dict[str, type] | None = None) - except Exception: logger.debug("Could not reconstruct as AgentExecutorRequest, trying next strategy") - if type_name == "AgentExecutorResponse" or ("executor_id" in data and "agent_run_response" in data): + if type_name == "AgentExecutorResponse" or ("executor_id" in data and "agent_response" in data): try: return reconstruct_agent_executor_response(data) except Exception: @@ -418,9 +418,9 @@ def reconstruct_agent_executor_request(data: dict[str, Any]) -> AgentExecutorReq def reconstruct_agent_executor_response(data: dict[str, Any]) -> AgentExecutorResponse: """Helper to reconstruct AgentExecutorResponse from dict.""" - # Reconstruct AgentRunResponse - arr_data = data.get("agent_run_response", {}) - agent_run_response = AgentRunResponse.from_dict(arr_data) if isinstance(arr_data, dict) else arr_data + # Reconstruct AgentResponse + arr_data = data.get("agent_response", {}) + agent_response = AgentResponse.from_dict(arr_data) if isinstance(arr_data, dict) else arr_data # Reconstruct full_conversation fc_data = data.get("full_conversation", []) @@ -429,7 +429,7 @@ def reconstruct_agent_executor_response(data: dict[str, Any]) -> AgentExecutorRe full_conversation = [ChatMessage.from_dict(m) if isinstance(m, dict) else m for m in fc_data] return AgentExecutorResponse( - executor_id=data["executor_id"], agent_run_response=agent_run_response, full_conversation=full_conversation + executor_id=data["executor_id"], agent_response=agent_response, full_conversation=full_conversation ) @@ -484,7 +484,7 @@ def reconstruct_message_for_handler(data: Any, input_types: list[type[Any]]) -> return data # Try AgentExecutorResponse first - it needs special handling for nested objects - if "executor_id" in data and "agent_run_response" in data: + if "executor_id" in data and "agent_response" in data: try: return reconstruct_agent_executor_response(data) except Exception: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 20bb12db63..87cb7d92fd 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -31,7 +31,7 @@ AgentExecutor, AgentExecutorRequest, AgentExecutorResponse, - AgentRunResponse, + AgentResponse, ChatMessage, Workflow, ) @@ -112,6 +112,40 @@ class PendingHITLRequest: # ============================================================================ +def _evaluate_edge_condition_sync(edge: Any, message: Any) -> bool: + """Evaluate an edge's condition synchronously. + + This is needed because Durable Functions orchestrators use generators, + not async/await, so we cannot call async methods like edge.should_route(). + + Args: + edge: The Edge object with a _condition attribute + message: The message to evaluate against the condition + + Returns: + True if the edge should be traversed, False otherwise + """ + # Access the internal condition directly since should_route is async + condition = getattr(edge, "_condition", None) + if condition is None: + return True + result = condition(message) + # If the condition is async, we cannot await it in a generator context + # Log a warning and assume True (or False for safety) + if hasattr(result, "__await__"): + import warnings + + warnings.warn( + f"Edge condition for {edge.source_id}->{edge.target_id} is async, " + "which is not supported in Durable Functions orchestrators. " + "The edge will be traversed unconditionally.", + RuntimeWarning, + stacklevel=2, + ) + return True + return bool(result) + + def route_message_through_edge_groups( edge_groups: list[EdgeGroup], source_id: str, @@ -147,7 +181,7 @@ def route_message_through_edge_groups( elif isinstance(group, SingleEdgeGroup): # SingleEdgeGroup has exactly one edge edge = group.edges[0] - if edge.should_route(message): + if _evaluate_edge_condition_sync(edge, message): targets.append(edge.target_id) elif isinstance(group, FanInEdgeGroup): @@ -158,7 +192,7 @@ def route_message_through_edge_groups( else: # Generic EdgeGroup: check each edge's condition for edge in group.edges: - if edge.source_id == source_id and edge.should_route(message): + if edge.source_id == source_id and _evaluate_edge_condition_sync(edge, message): targets.append(edge.target_id) return targets @@ -189,7 +223,7 @@ def build_agent_executor_response( assistant_message = ChatMessage(role="assistant", text=final_text) - agent_run_response = AgentRunResponse( + agent_response = AgentResponse( messages=[assistant_message], ) @@ -204,7 +238,7 @@ def build_agent_executor_response( return AgentExecutorResponse( executor_id=executor_id, - agent_run_response=agent_run_response, + agent_response=agent_response, full_conversation=full_conversation, ) @@ -275,7 +309,7 @@ def _prepare_activity_task( def _process_agent_response( - agent_response: AgentRunResponse, + agent_response: AgentResponse, executor_id: str, message: Any, ) -> ExecutorResult: @@ -619,7 +653,7 @@ def run_workflow_orchestrator( for executor_id, message, _source_executor_id in remaining_agent_messages: logger.debug("Processing sequential message for agent: %s", executor_id) task = _prepare_agent_task(context, executor_id, message) - agent_response: AgentRunResponse = yield task + agent_response: AgentResponse = yield task logger.debug("Agent %s sequential response completed", executor_id) result = _process_agent_response(agent_response, executor_id, message) @@ -800,11 +834,11 @@ def _prepare_all_tasks( def _extract_message_content(message: Any) -> str: """Extract text content from various message types.""" message_content = "" - if isinstance(message, AgentExecutorResponse) and message.agent_run_response: - if message.agent_run_response.text: - message_content = message.agent_run_response.text - elif message.agent_run_response.messages: - message_content = message.agent_run_response.messages[-1].text or "" + if isinstance(message, AgentExecutorResponse) and message.agent_response: + if message.agent_response.text: + message_content = message.agent_response.text + elif message.agent_response.messages: + message_content = message.agent_response.messages[-1].text or "" elif isinstance(message, AgentExecutorRequest) and message.messages: # Extract text from the last message in the request message_content = message.messages[-1].text or "" @@ -835,9 +869,9 @@ def _extract_message_content_from_dict(message: dict[str, Any]) -> str: message_content = last_msg.get("text") or last_msg.get("_text") or "" elif hasattr(last_msg, "text"): message_content = last_msg.text or "" - elif "agent_run_response" in message: + elif "agent_response" in message: # AgentExecutorResponse dict - arr = message.get("agent_run_response", {}) + arr = message.get("agent_response", {}) if isinstance(arr, dict): message_content = arr.get("text") or "" if not message_content and arr.get("messages"): diff --git a/python/packages/azurefunctions/pyproject.toml b/python/packages/azurefunctions/pyproject.toml index 6246b686a6..2272285971 100644 --- a/python/packages/azurefunctions/pyproject.toml +++ b/python/packages/azurefunctions/pyproject.toml @@ -50,7 +50,7 @@ asyncio_default_fixture_loop_scope = "function" filterwarnings = [ "ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*" ] -timeout = 120 +timeout = 300 markers = [ "integration: marks tests as integration tests (require running function app)", "orchestration: marks tests that use orchestrations (require Azurite)", diff --git a/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py b/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py new file mode 100644 index 0000000000..c9e6b16644 --- /dev/null +++ b/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft. All rights reserved. +""" +Integration Tests for Workflow Shared State Sample + +Tests the workflow shared state sample for conditional email processing +with shared state management. + +The function app is automatically started by the test fixture. + +Prerequisites: +- Azure OpenAI credentials configured (see packages/azurefunctions/tests/integration_tests/.env.example) +- Azurite running for durable orchestrations (or Azure Storage account configured) + +Usage: + # Start Azurite (if not already running) + azurite & + + # Run tests + uv run pytest packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py -v +""" + +import pytest +from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled + +# Module-level markers - applied to all tests in this file +pytestmark = [ + pytest.mark.sample("09_workflow_shared_state"), + pytest.mark.usefixtures("function_app_for_test"), + skip_if_azure_functions_integration_tests_disabled, +] + + +@pytest.mark.orchestration +class TestWorkflowSharedState: + """Tests for 09_workflow_shared_state sample.""" + + def test_workflow_with_spam_email(self, base_url: str) -> None: + """Test workflow with spam email content - should be detected and handled as spam.""" + spam_content = "URGENT! You have won $1,000,000! Click here to claim your prize now before it expires!" + + # Start orchestration with spam email + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", spam_content) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + assert "statusQueryGetUri" in data + + # Wait for completion + status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + assert status["runtimeStatus"] == "Completed" + assert "output" in status + + def test_workflow_with_legitimate_email(self, base_url: str) -> None: + """Test workflow with legitimate email content - should generate response.""" + legitimate_content = ( + "Hi team, just a reminder about the sprint planning meeting tomorrow at 10 AM. " + "Please review the agenda items in Jira before the call." + ) + + # Start orchestration with legitimate email + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", legitimate_content) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + assert "statusQueryGetUri" in data + + # Wait for completion + status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + assert status["runtimeStatus"] == "Completed" + assert "output" in status + + def test_workflow_with_phishing_email(self, base_url: str) -> None: + """Test workflow with phishing email - should be detected as spam.""" + phishing_content = ( + "Dear Customer, Your account has been compromised! " + "Click this link immediately to secure your account: http://totallylegit.suspicious.com/secure" + ) + + # Start orchestration with phishing email + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", phishing_content) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + + # Wait for completion + status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + assert status["runtimeStatus"] == "Completed" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py b/python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py new file mode 100644 index 0000000000..0ceb4c72eb --- /dev/null +++ b/python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft. All rights reserved. +""" +Integration Tests for Workflow No Shared State Sample + +Tests the workflow sample that runs without shared state, +demonstrating conditional routing with spam detection and email response. + +The function app is automatically started by the test fixture. + +Prerequisites: +- Azure OpenAI credentials configured (see packages/azurefunctions/tests/integration_tests/.env.example) +- Azurite running for durable orchestrations (or Azure Storage account configured) + +Usage: + # Start Azurite (if not already running) + azurite & + + # Run tests + uv run pytest packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py -v +""" + +import pytest +from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled + +# Module-level markers - applied to all tests in this file +pytestmark = [ + pytest.mark.sample("10_workflow_no_shared_state"), + pytest.mark.usefixtures("function_app_for_test"), + skip_if_azure_functions_integration_tests_disabled, +] + + +@pytest.mark.orchestration +class TestWorkflowNoSharedState: + """Tests for 10_workflow_no_shared_state sample.""" + + def test_workflow_with_spam_email(self, base_url: str) -> None: + """Test workflow with spam email - should detect and handle as spam.""" + payload = { + "email_id": "email-test-001", + "email_content": ( + "URGENT! You've won $1,000,000! Click here immediately to claim your prize! " + "Limited time offer - act now!" + ), + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + assert "statusQueryGetUri" in data + + # Wait for completion + status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + assert status["runtimeStatus"] == "Completed" + assert "output" in status + + def test_workflow_with_legitimate_email(self, base_url: str) -> None: + """Test workflow with legitimate email - should draft a response.""" + payload = { + "email_id": "email-test-002", + "email_content": ( + "Hi team, just a reminder about our sprint planning meeting tomorrow at 10 AM. " + "Please review the agenda in Jira." + ), + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + assert "statusQueryGetUri" in data + + # Wait for completion + status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + assert status["runtimeStatus"] == "Completed" + assert "output" in status + + def test_workflow_status_endpoint(self, base_url: str) -> None: + """Test that the status endpoint works correctly.""" + payload = { + "email_id": "email-test-003", + "email_content": "Quick question: When is the next team meeting scheduled?", + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + instance_id = data["instanceId"] + + # Check status using the workflow status endpoint + status_response = SampleTestHelper.get(f"{base_url}/api/workflow/status/{instance_id}") + assert status_response.status_code == 200 + status = status_response.json() + assert "instanceId" in status + assert status["instanceId"] == instance_id + assert "runtimeStatus" in status + + # Wait for completion to clean up + SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py new file mode 100644 index 0000000000..7430cca96f --- /dev/null +++ b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft. All rights reserved. +""" +Integration Tests for Parallel Workflow Sample + +Tests the parallel workflow execution sample demonstrating: +- Two executors running concurrently (fan-out to activities) +- Two agents running concurrently (fan-out to entities) +- Mixed agent + executor running concurrently + +The function app is automatically started by the test fixture. + +Prerequisites: +- Azure OpenAI credentials configured (see packages/azurefunctions/tests/integration_tests/.env.example) +- Azurite running for durable orchestrations (or Azure Storage account configured) + +Usage: + # Start Azurite (if not already running) + azurite & + + # Run tests + uv run pytest packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py -v +""" + +import pytest +from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled + +# Module-level markers - applied to all tests in this file +pytestmark = [ + pytest.mark.sample("11_workflow_parallel"), + pytest.mark.usefixtures("function_app_for_test"), + skip_if_azure_functions_integration_tests_disabled, +] + + +@pytest.mark.orchestration +class TestWorkflowParallel: + """Tests for 11_workflow_parallel sample.""" + + def test_parallel_workflow_document_analysis(self, base_url: str) -> None: + """Test parallel workflow with a standard document.""" + payload = { + "document_id": "doc-test-001", + "content": ( + "The quarterly earnings report shows strong growth in our cloud services division. " + "Revenue increased by 25% compared to last year, driven by enterprise adoption. " + "Customer satisfaction remains high at 92%. However, we face challenges in the " + "mobile segment where competition is intense. Overall, the outlook is positive " + "with expected continued growth in the coming quarters." + ), + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + assert "statusQueryGetUri" in data + + # Wait for completion - parallel workflows may take longer + status = SampleTestHelper.wait_for_orchestration_with_output( + data["statusQueryGetUri"], + max_wait=300, # 5 minutes for parallel execution + ) + assert status["runtimeStatus"] == "Completed" + assert "output" in status + + def test_parallel_workflow_short_document(self, base_url: str) -> None: + """Test parallel workflow with a short document.""" + payload = { + "document_id": "doc-test-002", + "content": "Quick update: Project completed successfully. Team performance exceeded expectations.", + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + assert "statusQueryGetUri" in data + + # Wait for completion + status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) + assert status["runtimeStatus"] == "Completed" + assert "output" in status + + def test_parallel_workflow_technical_document(self, base_url: str) -> None: + """Test parallel workflow with a technical document.""" + payload = { + "document_id": "doc-test-003", + "content": ( + "The new microservices architecture has been deployed to production. " + "Key improvements include: reduced latency by 40%, improved scalability " + "to handle 10x traffic spikes, and enhanced monitoring with distributed tracing. " + "The Kubernetes cluster is now running on version 1.28 with auto-scaling enabled. " + "Next steps include implementing service mesh and improving CI/CD pipelines." + ), + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + + # Wait for completion + status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) + assert status["runtimeStatus"] == "Completed" + + def test_workflow_status_endpoint(self, base_url: str) -> None: + """Test that the workflow status endpoint works correctly.""" + payload = { + "document_id": "doc-test-004", + "content": "Brief status update for testing purposes.", + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + instance_id = data["instanceId"] + + # Check status + status_response = SampleTestHelper.get(f"{base_url}/api/workflow/status/{instance_id}") + assert status_response.status_code == 200 + status = status_response.json() + assert "instanceId" in status + assert status["instanceId"] == instance_id + + # Wait for completion + SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"], max_wait=300) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py b/python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py new file mode 100644 index 0000000000..713e28d63e --- /dev/null +++ b/python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft. All rights reserved. +""" +Integration Tests for Workflow Human-in-the-Loop (HITL) Sample + +Tests the workflow HITL sample demonstrating content moderation with human approval +using the MAF request_info / @response_handler pattern. + +The function app is automatically started by the test fixture. + +Prerequisites: +- Azure OpenAI credentials configured (see packages/azurefunctions/tests/integration_tests/.env.example) +- Azurite running for durable orchestrations (or Azure Storage account configured) + +Usage: + # Start Azurite (if not already running) + azurite & + + # Run tests + uv run pytest packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py -v +""" + +import time + +import pytest +from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled + +# Module-level markers - applied to all tests in this file +pytestmark = [ + pytest.mark.sample("12_workflow_hitl"), + pytest.mark.usefixtures("function_app_for_test"), + skip_if_azure_functions_integration_tests_disabled, +] + + +@pytest.mark.orchestration +class TestWorkflowHITL: + """Tests for 12_workflow_hitl sample.""" + + @pytest.fixture(autouse=True) + def _set_base_url(self, base_url: str) -> None: + """Store the base URL for tests.""" + self.base_url = base_url + + def _wait_for_hitl_request(self, instance_id: str, timeout: int = 40) -> dict: + """Polls for a pending HITL request.""" + start_time = time.time() + while time.time() - start_time < timeout: + status_response = SampleTestHelper.get(f"{self.base_url}/api/workflow/status/{instance_id}") + if status_response.status_code == 200: + status = status_response.json() + pending_requests = status.get("pendingHumanInputRequests", []) + if pending_requests: + return status + time.sleep(2) + raise AssertionError(f"Timed out waiting for HITL request for instance {instance_id}") + + def test_hitl_workflow_approval(self) -> None: + """Test HITL workflow with human approval.""" + payload = { + "content_id": "article-test-001", + "title": "Introduction to AI in Healthcare", + "body": ( + "Artificial intelligence is revolutionizing healthcare by enabling faster diagnosis, " + "personalized treatment plans, and improved patient outcomes. Machine learning algorithms " + "can analyze medical images with remarkable accuracy." + ), + "author": "Dr. Jane Smith", + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + assert "instanceId" in data + assert "statusQueryGetUri" in data + instance_id = data["instanceId"] + + # Wait for the workflow to reach the HITL pause point + status = self._wait_for_hitl_request(instance_id) + + # Confirm status is valid + assert status["runtimeStatus"] in ["Running", "Pending"] + + # Get the request ID from pending requests + pending_requests = status.get("pendingHumanInputRequests", []) + assert len(pending_requests) > 0, "Expected pending HITL request" + request_id = pending_requests[0]["requestId"] + + # Send approval + approval_response = SampleTestHelper.post_json( + f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", + {"approved": True, "reviewer_notes": "Content is appropriate and well-written."}, + ) + assert approval_response.status_code == 200 + + # Wait for orchestration to complete + final_status = SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + assert final_status["runtimeStatus"] == "Completed" + assert "output" in final_status + + def test_hitl_workflow_rejection(self) -> None: + """Test HITL workflow with human rejection.""" + payload = { + "content_id": "article-test-002", + "title": "Get Rich Quick Scheme", + "body": ( + "Click here NOW to make $10,000 overnight! This SECRET method is GUARANTEED to work! " + "Limited time offer - act NOW before it's too late!" + ), + "author": "Definitely Not Spam", + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + instance_id = data["instanceId"] + + # Wait for the workflow to reach the HITL pause point + status = self._wait_for_hitl_request(instance_id) + + # Get the request ID from pending requests + pending_requests = status.get("pendingHumanInputRequests", []) + assert len(pending_requests) > 0, "Expected pending HITL request" + request_id = pending_requests[0]["requestId"] + + # Send rejection + rejection_response = SampleTestHelper.post_json( + f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", + {"approved": False, "reviewer_notes": "Content appears to be spam/scam material."}, + ) + assert rejection_response.status_code == 200 + + # Wait for orchestration to complete + final_status = SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + assert final_status["runtimeStatus"] == "Completed" + assert "output" in final_status + # The output should indicate rejection + output = final_status["output"] + assert "rejected" in str(output).lower() + + def test_hitl_workflow_status_endpoint(self) -> None: + """Test that the workflow status endpoint shows pending HITL requests.""" + payload = { + "content_id": "article-test-003", + "title": "Test Article", + "body": "This is a test article for checking status endpoint functionality.", + "author": "Test Author", + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + instance_id = data["instanceId"] + + # Wait for HITL pause + status = self._wait_for_hitl_request(instance_id) + + # Check status + assert "instanceId" in status + assert status["instanceId"] == instance_id + assert "runtimeStatus" in status + assert "pendingHumanInputRequests" in status + + # Clean up: approve to complete + pending_requests = status.get("pendingHumanInputRequests", []) + if pending_requests: + request_id = pending_requests[0]["requestId"] + SampleTestHelper.post_json( + f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", + {"approved": True, "reviewer_notes": ""}, + ) + + # Wait for completion + SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + + def test_hitl_workflow_with_neutral_content(self) -> None: + """Test HITL workflow with neutral content that should get medium risk.""" + payload = { + "content_id": "article-test-004", + "title": "Product Review", + "body": ( + "This product works as advertised. The build quality is average and the price " + "is reasonable. I would recommend it for basic use cases but not for professional work." + ), + "author": "Regular User", + } + + # Start orchestration + response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + assert response.status_code == 202 + data = response.json() + instance_id = data["instanceId"] + + # Wait for HITL pause + status = self._wait_for_hitl_request(instance_id) + + pending_requests = status.get("pendingHumanInputRequests", []) + assert len(pending_requests) > 0 + request_id = pending_requests[0]["requestId"] + + # Approve + SampleTestHelper.post_json( + f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", + {"approved": True, "reviewer_notes": "Approved after review."}, + ) + + # Wait for completion + final_status = SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + assert final_status["runtimeStatus"] == "Completed" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 68bcb3b9b0..2bfd833279 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1357,8 +1357,12 @@ def test_init_with_workflow_extracts_agents(self) -> None: def test_init_with_workflow_calls_setup_methods(self) -> None: """Test that workflow setup methods are called.""" + mock_executor = Mock() + mock_executor.id = "TestExecutor" + mock_workflow = Mock() - mock_workflow.executors = {} + # Include a non-AgentExecutor so _setup_executor_activity is called + mock_workflow.executors = {"TestExecutor": mock_executor} with ( patch.object(AgentFunctionApp, "_setup_executor_activity") as setup_exec, diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py index c95b663161..67b4354a9d 100644 --- a/python/packages/azurefunctions/tests/test_utils.py +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -9,7 +9,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - AgentRunResponse, + AgentResponse, ChatMessage, Message, WorkflowOutputEvent, @@ -84,7 +84,7 @@ async def test_has_messages_returns_correct_status(self, context: CapturingRunne @pytest.mark.asyncio async def test_add_event_queues_event(self, context: CapturingRunnerContext) -> None: """Test that add_event queues events correctly.""" - event = WorkflowOutputEvent(data="output", source_executor_id="exec_1") + event = WorkflowOutputEvent(data="output", executor_id="exec_1") await context.add_event(event) @@ -96,7 +96,7 @@ async def test_add_event_queues_event(self, context: CapturingRunnerContext) -> @pytest.mark.asyncio async def test_drain_events_clears_queue(self, context: CapturingRunnerContext) -> None: """Test that drain_events clears the event queue.""" - await context.add_event(WorkflowOutputEvent(data="test", source_executor_id="e")) + await context.add_event(WorkflowOutputEvent(data="test", executor_id="e")) await context.drain_events() # First drain events = await context.drain_events() # Second drain @@ -108,14 +108,14 @@ async def test_has_events_returns_correct_status(self, context: CapturingRunnerC """Test has_events returns correct boolean.""" assert await context.has_events() is False - await context.add_event(WorkflowOutputEvent(data="test", source_executor_id="e")) + await context.add_event(WorkflowOutputEvent(data="test", executor_id="e")) assert await context.has_events() is True @pytest.mark.asyncio async def test_next_event_waits_for_event(self, context: CapturingRunnerContext) -> None: """Test that next_event returns queued events.""" - event = WorkflowOutputEvent(data="waited", source_executor_id="e") + event = WorkflowOutputEvent(data="waited", executor_id="e") await context.add_event(event) result = await context.next_event() @@ -147,7 +147,7 @@ def test_set_workflow_id(self, context: CapturingRunnerContext) -> None: async def test_reset_for_new_run_clears_state(self, context: CapturingRunnerContext) -> None: """Test that reset_for_new_run clears all state.""" await context.send_message(Message(data="test", target_id="t", source_id="s")) - await context.add_event(WorkflowOutputEvent(data="event", source_executor_id="e")) + await context.add_event(WorkflowOutputEvent(data="event", executor_id="e")) context.set_streaming(True) context.reset_for_new_run() @@ -294,8 +294,8 @@ def test_deserialize_agent_executor_response(self) -> None: """Test deserializing AgentExecutorResponse.""" data = { "executor_id": "test_exec", - "agent_run_response": { - "type": "agent_run_response", + "agent_response": { + "type": "agent_response", "messages": [ {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Hi there"}]} ], @@ -352,12 +352,12 @@ def test_reconstruct_defaults_should_respond_to_true(self) -> None: class TestReconstructAgentExecutorResponse: """Test suite for reconstruct_agent_executor_response function.""" - def test_reconstruct_with_agent_run_response(self) -> None: - """Test reconstructing response with agent_run_response.""" + def test_reconstruct_with_agent_response(self) -> None: + """Test reconstructing response with agent_response.""" data = { "executor_id": "my_executor", - "agent_run_response": { - "type": "agent_run_response", + "agent_response": { + "type": "agent_response", "messages": [ {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Response"}]} ], @@ -369,13 +369,13 @@ def test_reconstruct_with_agent_run_response(self) -> None: assert isinstance(result, AgentExecutorResponse) assert result.executor_id == "my_executor" - assert isinstance(result.agent_run_response, AgentRunResponse) + assert isinstance(result.agent_response, AgentResponse) def test_reconstruct_with_full_conversation(self) -> None: """Test reconstructing response with full_conversation.""" data = { "executor_id": "exec", - "agent_run_response": {"type": "agent_run_response", "messages": []}, + "agent_response": {"type": "agent_response", "messages": []}, "full_conversation": [ {"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Q"}]}, {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "A"}]}, @@ -400,7 +400,7 @@ def test_reconstruct_agent_executor_response(self) -> None: """Test reconstructing AgentExecutorResponse.""" data = { "executor_id": "exec", - "agent_run_response": {"type": "agent_run_response", "messages": []}, + "agent_response": {"type": "agent_response", "messages": []}, } result = reconstruct_message_for_handler(data, [AgentExecutorResponse]) diff --git a/python/packages/azurefunctions/tests/test_workflow.py b/python/packages/azurefunctions/tests/test_workflow.py index f401fc96c2..f35e0f13c0 100644 --- a/python/packages/azurefunctions/tests/test_workflow.py +++ b/python/packages/azurefunctions/tests/test_workflow.py @@ -9,7 +9,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - AgentRunResponse, + AgentResponse, ChatMessage, ) from agent_framework._workflows._edge import ( @@ -142,7 +142,7 @@ def test_builds_response_with_text(self) -> None: ) assert response.executor_id == "my_executor" - assert response.agent_run_response.text == "Hello, world!" + assert response.agent_response.text == "Hello, world!" assert len(response.full_conversation) == 2 # User + Assistant def test_builds_response_with_structured_response(self) -> None: @@ -157,7 +157,7 @@ def test_builds_response_with_structured_response(self) -> None: ) # Structured response overrides text - assert response.agent_run_response.text == json.dumps(structured) + assert response.agent_response.text == json.dumps(structured) def test_conversation_includes_previous_string_message(self) -> None: """Test that string previous_message is included in conversation.""" @@ -178,7 +178,7 @@ def test_conversation_extends_previous_agent_executor_response(self) -> None: # Create a previous response with conversation history previous = AgentExecutorResponse( executor_id="prev", - agent_run_response=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Previous")]), + agent_response=AgentResponse(messages=[ChatMessage(role="assistant", text="Previous")]), full_conversation=[ ChatMessage(role="user", text="First"), ChatMessage(role="assistant", text="Previous"), @@ -212,7 +212,7 @@ def test_extract_from_agent_executor_response_with_text(self) -> None: """Test extracting from AgentExecutorResponse with text.""" response = AgentExecutorResponse( executor_id="exec", - agent_run_response=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response text")]), + agent_response=AgentResponse(messages=[ChatMessage(role="assistant", text="Response text")]), ) result = _extract_message_content(response) @@ -223,7 +223,7 @@ def test_extract_from_agent_executor_response_with_messages(self) -> None: """Test extracting from AgentExecutorResponse with messages.""" response = AgentExecutorResponse( executor_id="exec", - agent_run_response=AgentRunResponse( + agent_response=AgentResponse( messages=[ ChatMessage(role="user", text="First"), ChatMessage(role="assistant", text="Last message"), @@ -233,7 +233,7 @@ def test_extract_from_agent_executor_response_with_messages(self) -> None: result = _extract_message_content(response) - # AgentRunResponse.text concatenates all message texts + # AgentResponse.text concatenates all message texts assert result == "FirstLast message" def test_extract_from_agent_executor_request(self) -> None: @@ -290,17 +290,17 @@ def test_extract_from_messages_with_direct_text(self) -> None: assert result == "Direct text" - def test_extract_from_agent_run_response(self) -> None: - """Test extracting from agent_run_response dict.""" - msg_dict = {"agent_run_response": {"text": "Response text"}} + def test_extract_from_agent_response(self) -> None: + """Test extracting from agent_response dict.""" + msg_dict = {"agent_response": {"text": "Response text"}} result = _extract_message_content_from_dict(msg_dict) assert result == "Response text" - def test_extract_from_agent_run_response_with_messages(self) -> None: - """Test extracting from agent_run_response with messages.""" - msg_dict = {"agent_run_response": {"messages": [{"contents": [{"type": "text", "text": "Nested content"}]}]}} + def test_extract_from_agent_response_with_messages(self) -> None: + """Test extracting from agent_response with messages.""" + msg_dict = {"agent_response": {"messages": [{"contents": [{"type": "text", "text": "Nested content"}]}]}} result = _extract_message_content_from_dict(msg_dict) diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py index bf38dfc72b..520647b0c3 100644 --- a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py @@ -36,7 +36,7 @@ ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import Never from agent_framework_azurefunctions import AgentFunctionApp @@ -125,7 +125,12 @@ async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowCont 2) Retrieve the current email_id from shared state. 3) Send a typed DetectionResult for conditional routing. """ - parsed = DetectionResultAgent.model_validate_json(response.agent_run_response.text) + try: + parsed = DetectionResultAgent.model_validate_json(response.agent_response.text) + except ValidationError: + # Fallback for empty or invalid response (e.g. due to content filtering) + parsed = DetectionResultAgent(is_spam=True, reason="Agent execution failed or yielded invalid JSON.") + email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) await ctx.send_message(DetectionResult(is_spam=parsed.is_spam, reason=parsed.reason, email_id=email_id)) @@ -150,7 +155,7 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon @executor(id="finalize_and_send") async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: """Validate the drafted reply and yield the final output.""" - parsed = EmailResponse.model_validate_json(response.agent_run_response.text) + parsed = EmailResponse.model_validate_json(response.agent_response.text) await ctx.yield_output(f"Email sent: {parsed.response}") @@ -197,21 +202,21 @@ def _create_workflow() -> Workflow: client_kwargs = _build_client_kwargs() chat_client = AzureOpenAIChatClient(**client_kwargs) - spam_detection_agent = chat_client.create_agent( + spam_detection_agent = chat_client.as_agent( instructions=( "You are a spam detection assistant that identifies spam emails. " "Always return JSON with fields is_spam (bool) and reason (string)." ), - response_format=DetectionResultAgent, + default_options={"response_format": DetectionResultAgent}, name="spam_detection_agent", ) - email_assistant_agent = chat_client.create_agent( + email_assistant_agent = chat_client.as_agent( instructions=( "You are an email assistant that helps users draft responses to emails with professionalism. " "Return JSON with a single field 'response' containing the drafted reply." ), - response_format=EmailResponse, + default_options={"response_format": EmailResponse}, name="email_assistant_agent", ) diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json b/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json index 292562af8e..9e7fd873dd 100644 --- a/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json @@ -6,11 +6,7 @@ }, "extensions": { "durableTask": { - "hubName": "%TASKHUB_NAME%", - "storageProvider": { - "type": "AzureManaged", - "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" - } + "hubName": "%TASKHUB_NAME%" } } } diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py index b55fef58b8..768d5b138d 100644 --- a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py @@ -33,7 +33,7 @@ ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from agent_framework_azurefunctions import AgentFunctionApp from typing_extensions import Never @@ -120,8 +120,12 @@ async def handle_spam_result( ctx: WorkflowContext[Never, str], ) -> None: """Mark email as spam and log the reason.""" - text = agent_response.agent_run_response.text - spam_result = SpamDetectionResult.model_validate_json(text) + text = agent_response.agent_response.text + try: + spam_result = SpamDetectionResult.model_validate_json(text) + except ValidationError: + spam_result = SpamDetectionResult(is_spam=True, reason="Invalid JSON from agent") + message = f"Email marked as spam: {spam_result.reason}" await ctx.yield_output(message) @@ -136,8 +140,12 @@ async def handle_email_response( ctx: WorkflowContext[Never, str], ) -> None: """Send the drafted email response.""" - text = agent_response.agent_run_response.text - email_response = EmailResponse.model_validate_json(text) + text = agent_response.agent_response.text + try: + email_response = EmailResponse.model_validate_json(text) + except ValidationError: + email_response = EmailResponse(response="Error generating response.") + message = f"Email sent: {email_response.response}" await ctx.yield_output(message) @@ -148,7 +156,7 @@ def is_spam_detected(message: Any) -> bool: if not isinstance(message, AgentExecutorResponse): return False try: - result = SpamDetectionResult.model_validate_json(message.agent_run_response.text) + result = SpamDetectionResult.model_validate_json(message.agent_response.text) return result.is_spam except Exception: return False @@ -159,16 +167,16 @@ def _create_workflow() -> Workflow: client_kwargs = _build_client_kwargs() chat_client = AzureOpenAIChatClient(**client_kwargs) - spam_agent = chat_client.create_agent( + spam_agent = chat_client.as_agent( name=SPAM_AGENT_NAME, instructions=SPAM_DETECTION_INSTRUCTIONS, - response_format=SpamDetectionResult, + default_options={"response_format": SpamDetectionResult}, ) - email_agent = chat_client.create_agent( + email_agent = chat_client.as_agent( name=EMAIL_AGENT_NAME, instructions=EMAIL_ASSISTANT_INSTRUCTIONS, - response_format=EmailResponse, + default_options={"response_format": EmailResponse}, ) # Executors diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json index 292562af8e..9e7fd873dd 100644 --- a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json @@ -6,11 +6,7 @@ }, "extensions": { "durableTask": { - "hubName": "%TASKHUB_NAME%", - "storageProvider": { - "type": "AzureManaged", - "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" - } + "hubName": "%TASKHUB_NAME%" } } } diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py index a51a1b6a04..1c307d4ab4 100644 --- a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py @@ -250,7 +250,7 @@ async def prepare_for_mixed( for analysis in analyses: executor_id = analysis.executor_id - text = analysis.agent_run_response.text if analysis.agent_run_response else "" + text = analysis.agent_response.text if analysis.agent_response else "" if executor_id == SENTIMENT_AGENT_NAME: sentiment_text = text @@ -303,7 +303,7 @@ async def compile_report( for analysis in analyses: if isinstance(analysis, AgentExecutorResponse): agent_name = analysis.executor_id - text = analysis.agent_run_response.text if analysis.agent_run_response else "No response" + text = analysis.agent_response.text if analysis.agent_response else "No response" elif isinstance(analysis, ProcessorResult): agent_name = f"Processor: {analysis.processor_name}" text = f"Words: {analysis.word_count}, Chars: {analysis.char_count}" @@ -334,7 +334,7 @@ async def collect_mixed_results( for result in results: if isinstance(result, AgentExecutorResponse): output_parts.append(f"[Agent: {result.executor_id}]") - output_parts.append(result.agent_run_response.text if result.agent_run_response else "No response") + output_parts.append(result.agent_response.text if result.agent_response else "No response") elif isinstance(result, ProcessorResult): output_parts.append(f"[Processor: {result.processor_name}]") output_parts.append(f" Words: {result.word_count}, Chars: {result.char_count}") @@ -398,35 +398,35 @@ def _create_workflow() -> Workflow: chat_client = AzureOpenAIChatClient(**client_kwargs) # Create agents for parallel analysis - sentiment_agent = chat_client.create_agent( + sentiment_agent = chat_client.as_agent( name=SENTIMENT_AGENT_NAME, instructions=( "You are a sentiment analysis expert. Analyze the sentiment of the given text. " "Return JSON with fields: sentiment (positive/negative/neutral), " "confidence (0.0-1.0), and explanation (brief reasoning)." ), - response_format=SentimentResult, + default_options={"response_format": SentimentResult}, ) - keyword_agent = chat_client.create_agent( + keyword_agent = chat_client.as_agent( name=KEYWORD_AGENT_NAME, instructions=( "You are a keyword extraction expert. Extract important keywords and categories " "from the given text. Return JSON with fields: keywords (list of strings), " "and categories (list of topic categories)." ), - response_format=KeywordResult, + default_options={"response_format": KeywordResult}, ) # Create summary agent for Pattern 3 (mixed parallel) - summary_agent = chat_client.create_agent( + summary_agent = chat_client.as_agent( name=SUMMARY_AGENT_NAME, instructions=( "You are a summarization expert. Given analysis results (sentiment and keywords), " "provide a concise summary. Return JSON with fields: summary (brief text), " "and key_points (list of main takeaways)." ), - response_format=SummaryResult, + default_options={"response_format": SummaryResult}, ) # Create executor instances diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json b/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json index 292562af8e..9e7fd873dd 100644 --- a/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json @@ -6,11 +6,7 @@ }, "extensions": { "durableTask": { - "hubName": "%TASKHUB_NAME%", - "storageProvider": { - "type": "AzureManaged", - "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" - } + "hubName": "%TASKHUB_NAME%" } } } diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py index bb36832b17..11fb48c7db 100644 --- a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py @@ -43,7 +43,7 @@ ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import Never from agent_framework_azurefunctions import AgentFunctionApp @@ -166,7 +166,15 @@ async def handle_analysis( ctx: WorkflowContext[AnalysisWithSubmission], ) -> None: """Parse the AI analysis and forward with submission context.""" - analysis = ContentAnalysisResult.model_validate_json(response.agent_run_response.text) + try: + analysis = ContentAnalysisResult.model_validate_json(response.agent_response.text) + except ValidationError: + analysis = ContentAnalysisResult( + is_appropriate=False, + risk_level="high", + concerns=["Agent execution failed or yielded invalid JSON (possible content filter)."], + recommendation="Manual review required", + ) # Retrieve the original submission from shared state submission: ContentSubmission = await ctx.get_shared_state("current_submission") @@ -376,10 +384,10 @@ def _create_workflow() -> Workflow: chat_client = AzureOpenAIChatClient(**client_kwargs) # Create the content analysis agent - content_analyzer_agent = chat_client.create_agent( + content_analyzer_agent = chat_client.as_agent( name=CONTENT_ANALYZER_AGENT_NAME, instructions=CONTENT_ANALYZER_INSTRUCTIONS, - response_format=ContentAnalysisResult, + default_options={"response_format": ContentAnalysisResult}, ) # Create executors diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json b/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json index 292562af8e..9e7fd873dd 100644 --- a/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json @@ -6,11 +6,7 @@ }, "extensions": { "durableTask": { - "hubName": "%TASKHUB_NAME%", - "storageProvider": { - "type": "AzureManaged", - "connectionStringName": "DURABLE_TASK_SCHEDULER_CONNECTION_STRING" - } + "hubName": "%TASKHUB_NAME%" } } } From 8e6c0d2311e1973861f4d1b99cbe2408837d3471 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 3 Feb 2026 10:41:01 -0600 Subject: [PATCH 03/29] refactor code --- .../agent_framework_azurefunctions/_app.py | 9 +- .../_context.py | 173 +++++++++++ .../{_utils.py => _serialization.py} | 293 +----------------- .../_workflow.py | 120 ++++++- .../azurefunctions/tests/test_utils.py | 4 +- 5 files changed, 314 insertions(+), 285 deletions(-) create mode 100644 python/packages/azurefunctions/agent_framework_azurefunctions/_context.py rename python/packages/azurefunctions/agent_framework_azurefunctions/{_utils.py => _serialization.py} (56%) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 7b3863eeda..88bed775fc 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -40,17 +40,16 @@ RunRequest, ) +from ._context import CapturingRunnerContext from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor -from ._utils import ( - CapturingRunnerContext, - _execute_hitl_response_handler, +from ._serialization import ( deserialize_value, reconstruct_message_for_handler, serialize_message, ) -from ._workflow import run_workflow_orchestrator +from ._workflow import execute_hitl_response_handler, run_workflow_orchestrator logger = logging.getLogger("agent_framework.azurefunctions") @@ -306,7 +305,7 @@ async def run() -> dict[str, Any]: if is_hitl_response: # Handle HITL response by calling the executor's @response_handler - await _execute_hitl_response_handler( + await execute_hitl_response_handler( executor=executor, hitl_message=message_data, shared_state=shared_state, diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py new file mode 100644 index 0000000000..6ceaa88fb8 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Runner context for Azure Functions activity execution. + +This module provides the CapturingRunnerContext class that captures messages +and events produced during executor execution within Azure Functions activities. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from agent_framework import ( + CheckpointStorage, + Message, + RequestInfoEvent, + RunnerContext, + SharedState, + WorkflowCheckpoint, + WorkflowEvent, +) + + +class CapturingRunnerContext(RunnerContext): + """A RunnerContext implementation that captures messages and events for Azure Functions activities. + + This context is designed for executing standard Executors within Azure Functions activities. + It captures all messages and events produced during execution without requiring durable + entity storage, allowing the results to be returned to the orchestrator. + + Unlike the full InProcRunnerContext, this implementation: + - Does NOT support checkpointing (always returns False for has_checkpointing) + - Does NOT support streaming (always returns False for is_streaming) + - Captures messages and events in memory for later retrieval + + The orchestrator manages state coordination; this context just captures execution output. + """ + + def __init__(self) -> None: + """Initialize the capturing runner context.""" + self._messages: dict[str, list[Message]] = {} + self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() + self._pending_request_info_events: dict[str, RequestInfoEvent] = {} + self._workflow_id: str | None = None + self._streaming: bool = False + + # region Messaging + + async def send_message(self, message: Message) -> None: + """Capture a message sent by an executor.""" + self._messages.setdefault(message.source_id, []) + self._messages[message.source_id].append(message) + + async def drain_messages(self) -> dict[str, list[Message]]: + """Drain and return all captured messages.""" + from copy import copy + + messages = copy(self._messages) + self._messages.clear() + return messages + + async def has_messages(self) -> bool: + """Check if there are any captured messages.""" + return bool(self._messages) + + # endregion Messaging + + # region Events + + async def add_event(self, event: WorkflowEvent) -> None: + """Capture an event produced during execution.""" + await self._event_queue.put(event) + + async def drain_events(self) -> list[WorkflowEvent]: + """Drain all currently queued events without blocking.""" + events: list[WorkflowEvent] = [] + while True: + try: + events.append(self._event_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return events + + async def has_events(self) -> bool: + """Check if there are any queued events.""" + return not self._event_queue.empty() + + async def next_event(self) -> WorkflowEvent: + """Wait for and return the next event.""" + return await self._event_queue.get() + + # endregion Events + + # region Checkpointing (not supported in activity context) + + def has_checkpointing(self) -> bool: + """Checkpointing is not supported in activity context.""" + return False + + def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None: + """No-op: checkpointing not supported in activity context.""" + pass + + def clear_runtime_checkpoint_storage(self) -> None: + """No-op: checkpointing not supported in activity context.""" + pass + + async def create_checkpoint( + self, + shared_state: SharedState, + iteration_count: int, + metadata: dict[str, Any] | None = None, + ) -> str: + """Checkpointing not supported in activity context.""" + raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") + + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: + """Checkpointing not supported in activity context.""" + raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") + + async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: + """Checkpointing not supported in activity context.""" + raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") + + # endregion Checkpointing + + # region Workflow Configuration + + def set_workflow_id(self, workflow_id: str) -> None: + """Set the workflow ID.""" + self._workflow_id = workflow_id + + def reset_for_new_run(self) -> None: + """Reset the context for a new run.""" + self._messages.clear() + self._event_queue = asyncio.Queue() + self._pending_request_info_events.clear() + self._streaming = False + + def set_streaming(self, streaming: bool) -> None: + """Set streaming mode (not used in activity context).""" + self._streaming = streaming + + def is_streaming(self) -> bool: + """Check if streaming mode is enabled (always False in activity context).""" + return self._streaming + + # endregion Workflow Configuration + + # region Request Info Events + + async def add_request_info_event(self, event: RequestInfoEvent) -> None: + """Add a RequestInfoEvent and track it for correlation.""" + self._pending_request_info_events[event.request_id] = event + await self.add_event(event) + + async def send_request_info_response(self, request_id: str, response: Any) -> None: + """Send a response correlated to a pending request. + + Note: This is not supported in activity context since human-in-the-loop + scenarios require orchestrator-level coordination. + """ + raise NotImplementedError( + "send_request_info_response is not supported in Azure Functions activity context. " + "Human-in-the-loop scenarios should be handled at the orchestrator level." + ) + + async def get_pending_request_info_events(self) -> dict[str, RequestInfoEvent]: + """Get the mapping of request IDs to their corresponding RequestInfoEvent.""" + return dict(self._pending_request_info_events) + + # endregion Request Info Events diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py similarity index 56% rename from python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py rename to python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index 3cb10a1ae0..11034558b2 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_utils.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -1,14 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. -"""Utility functions for workflow execution. +"""Serialization and deserialization utilities for workflow execution. -This module provides helper functions for serialization, deserialization, and -context management used by the workflow orchestrator and executors. +This module provides helper functions for serializing and deserializing messages, +dataclasses, and Pydantic models for cross-activity communication in Azure Functions. """ from __future__ import annotations -import asyncio import logging import types from dataclasses import asdict, fields, is_dataclass @@ -19,168 +18,15 @@ AgentExecutorResponse, AgentResponse, ChatMessage, - CheckpointStorage, - Message, - RequestInfoEvent, - RunnerContext, - SharedState, - WorkflowCheckpoint, - WorkflowEvent, ) from pydantic import BaseModel logger = logging.getLogger(__name__) -class CapturingRunnerContext(RunnerContext): - """A RunnerContext implementation that captures messages and events for Azure Functions activities. - - This context is designed for executing standard Executors within Azure Functions activities. - It captures all messages and events produced during execution without requiring durable - entity storage, allowing the results to be returned to the orchestrator. - - Unlike the full InProcRunnerContext, this implementation: - - Does NOT support checkpointing (always returns False for has_checkpointing) - - Does NOT support streaming (always returns False for is_streaming) - - Captures messages and events in memory for later retrieval - - The orchestrator manages state coordination; this context just captures execution output. - """ - - def __init__(self) -> None: - """Initialize the capturing runner context.""" - self._messages: dict[str, list[Message]] = {} - self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() - self._pending_request_info_events: dict[str, RequestInfoEvent] = {} - self._workflow_id: str | None = None - self._streaming: bool = False - - # region Messaging - - async def send_message(self, message: Message) -> None: - """Capture a message sent by an executor.""" - self._messages.setdefault(message.source_id, []) - self._messages[message.source_id].append(message) - - async def drain_messages(self) -> dict[str, list[Message]]: - """Drain and return all captured messages.""" - from copy import copy - - messages = copy(self._messages) - self._messages.clear() - return messages - - async def has_messages(self) -> bool: - """Check if there are any captured messages.""" - return bool(self._messages) - - # endregion Messaging - - # region Events - - async def add_event(self, event: WorkflowEvent) -> None: - """Capture an event produced during execution.""" - await self._event_queue.put(event) - - async def drain_events(self) -> list[WorkflowEvent]: - """Drain all currently queued events without blocking.""" - events: list[WorkflowEvent] = [] - while True: - try: - events.append(self._event_queue.get_nowait()) - except asyncio.QueueEmpty: - break - return events - - async def has_events(self) -> bool: - """Check if there are any queued events.""" - return not self._event_queue.empty() - - async def next_event(self) -> WorkflowEvent: - """Wait for and return the next event.""" - return await self._event_queue.get() - - # endregion Events - - # region Checkpointing (not supported in activity context) - - def has_checkpointing(self) -> bool: - """Checkpointing is not supported in activity context.""" - return False - - def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None: - """No-op: checkpointing not supported in activity context.""" - pass - - def clear_runtime_checkpoint_storage(self) -> None: - """No-op: checkpointing not supported in activity context.""" - pass - - async def create_checkpoint( - self, - shared_state: SharedState, - iteration_count: int, - metadata: dict[str, Any] | None = None, - ) -> str: - """Checkpointing not supported in activity context.""" - raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") - - async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: - """Checkpointing not supported in activity context.""" - raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") - - async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: - """Checkpointing not supported in activity context.""" - raise NotImplementedError("Checkpointing is not supported in Azure Functions activity context") - - # endregion Checkpointing - - # region Workflow Configuration - - def set_workflow_id(self, workflow_id: str) -> None: - """Set the workflow ID.""" - self._workflow_id = workflow_id - - def reset_for_new_run(self) -> None: - """Reset the context for a new run.""" - self._messages.clear() - self._event_queue = asyncio.Queue() - self._pending_request_info_events.clear() - self._streaming = False - - def set_streaming(self, streaming: bool) -> None: - """Set streaming mode (not used in activity context).""" - self._streaming = streaming - - def is_streaming(self) -> bool: - """Check if streaming mode is enabled (always False in activity context).""" - return self._streaming - - # endregion Workflow Configuration - - # region Request Info Events - - async def add_request_info_event(self, event: RequestInfoEvent) -> None: - """Add a RequestInfoEvent and track it for correlation.""" - self._pending_request_info_events[event.request_id] = event - await self.add_event(event) - - async def send_request_info_response(self, request_id: str, response: Any) -> None: - """Send a response correlated to a pending request. - - Note: This is not supported in activity context since human-in-the-loop - scenarios require orchestrator-level coordination. - """ - raise NotImplementedError( - "send_request_info_response is not supported in Azure Functions activity context. " - "Human-in-the-loop scenarios should be handled at the orchestrator level." - ) - - async def get_pending_request_info_events(self) -> dict[str, RequestInfoEvent]: - """Get the mapping of request IDs to their corresponding RequestInfoEvent.""" - return dict(self._pending_request_info_events) - - # endregion Request Info Events +# ============================================================================ +# Serialization +# ============================================================================ def _serialize_value(value: Any) -> Any: @@ -231,6 +77,11 @@ def serialize_message(message: Any) -> Any: return _serialize_value(message) +# ============================================================================ +# Deserialization +# ============================================================================ + + def deserialize_value(data: Any, type_registry: dict[str, type] | None = None) -> Any: """Attempt to deserialize a value using embedded type metadata. @@ -407,6 +258,11 @@ def _reconstruct_typed_value(value: Any, target_type: type) -> Any: return value +# ============================================================================ +# MAF Type Reconstruction +# ============================================================================ + + def reconstruct_agent_executor_request(data: dict[str, Any]) -> AgentExecutorRequest: """Helper to reconstruct AgentExecutorRequest from dict.""" # Reconstruct ChatMessage objects in messages @@ -521,120 +377,3 @@ def reconstruct_message_for_handler(data: Any, input_types: list[type[Any]]) -> logger.debug("Could not construct %s from matching fields", msg_type.__name__) return data - - -# ============================================================================ -# HITL Response Handler Execution -# ============================================================================ - - -async def _execute_hitl_response_handler( - executor: Any, - hitl_message: dict[str, Any], - shared_state: SharedState, - runner_context: CapturingRunnerContext, -) -> None: - """Execute a HITL response handler on an executor. - - This function handles the delivery of a HITL response to the executor's - @response_handler method. It: - 1. Deserializes the original request and response - 2. Finds the matching response handler based on types - 3. Creates a WorkflowContext and invokes the handler - - Args: - executor: The executor instance that has a @response_handler - hitl_message: The HITL response message containing original_request and response - shared_state: The shared state for the workflow context - runner_context: The runner context for capturing outputs - """ - from agent_framework._workflows._workflow_context import WorkflowContext - - # Extract the response data - original_request_data = hitl_message.get("original_request") - response_data = hitl_message.get("response") - response_type_str = hitl_message.get("response_type") - - # Deserialize the original request - original_request = deserialize_value(original_request_data) - - # Deserialize the response - try to match expected type - response = _deserialize_hitl_response(response_data, response_type_str) - - # Find the matching response handler - handler = executor._find_response_handler(original_request, response) - - if handler is None: - logger.warning( - "No response handler found for HITL response in executor %s. Request type: %s, Response type: %s", - executor.id, - type(original_request).__name__, - type(response).__name__, - ) - return - - # Create a WorkflowContext for the handler - # Use a special source ID to indicate this is a HITL response - ctx = WorkflowContext( - executor=executor, - source_executor_ids=["__hitl_response__"], - runner_context=runner_context, - shared_state=shared_state, - ) - - # Call the response handler - # Note: handler is already a partial with original_request bound - logger.debug( - "Invoking response handler for HITL request in executor %s", - executor.id, - ) - await handler(response, ctx) - - -def _deserialize_hitl_response(response_data: Any, response_type_str: str | None) -> Any: - """Deserialize a HITL response to its expected type. - - Args: - response_data: The raw response data (typically a dict from JSON) - response_type_str: The fully qualified type name (module:classname) - - Returns: - The deserialized response, or the original data if deserialization fails - """ - logger.debug( - "Deserializing HITL response. response_type_str=%s, response_data type=%s", - response_type_str, - type(response_data).__name__, - ) - - if response_data is None: - return None - - # If already a primitive, return as-is - if not isinstance(response_data, dict): - logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) - return response_data - - # Try to deserialize using the type hint - if response_type_str: - try: - module_name, class_name = response_type_str.rsplit(":", 1) - import importlib - - module = importlib.import_module(module_name) - response_type = getattr(module, class_name, None) - - if response_type: - logger.debug("Found response type %s, attempting reconstruction", response_type) - # Use the shared reconstruction logic which handles nested objects - result = _reconstruct_typed_value(response_data, response_type) - logger.debug("Reconstructed response type: %s", type(result).__name__) - return result - logger.warning("Could not find class %s in module %s", class_name, module_name) - - except Exception as e: - logger.warning("Could not deserialize HITL response to %s: %s", response_type_str, e) - - # Fall back to generic deserialization - logger.debug("Falling back to generic deserialization") - return deserialize_value(response_data) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 87cb7d92fd..d778eeeffa 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -45,8 +45,9 @@ from agent_framework_durabletask import AgentSessionId, DurableAgentThread, DurableAIAgent from azure.durable_functions import DurableOrchestrationContext +from ._context import CapturingRunnerContext from ._orchestration import AzureFunctionsAgentExecutor -from ._utils import deserialize_value, serialize_message +from ._serialization import _reconstruct_typed_value, deserialize_value, serialize_message logger = logging.getLogger(__name__) @@ -886,3 +887,120 @@ def _extract_message_content_from_dict(message: dict[str, Any]) -> str: message_content = last_msg.get("text") or last_msg.get("_text") or "" return message_content + + +# ============================================================================ +# HITL Response Handler Execution +# ============================================================================ + + +async def execute_hitl_response_handler( + executor: Any, + hitl_message: dict[str, Any], + shared_state: Any, + runner_context: CapturingRunnerContext, +) -> None: + """Execute a HITL response handler on an executor. + + This function handles the delivery of a HITL response to the executor's + @response_handler method. It: + 1. Deserializes the original request and response + 2. Finds the matching response handler based on types + 3. Creates a WorkflowContext and invokes the handler + + Args: + executor: The executor instance that has a @response_handler + hitl_message: The HITL response message containing original_request and response + shared_state: The shared state for the workflow context + runner_context: The runner context for capturing outputs + """ + from agent_framework._workflows._workflow_context import WorkflowContext + + # Extract the response data + original_request_data = hitl_message.get("original_request") + response_data = hitl_message.get("response") + response_type_str = hitl_message.get("response_type") + + # Deserialize the original request + original_request = deserialize_value(original_request_data) + + # Deserialize the response - try to match expected type + response = _deserialize_hitl_response(response_data, response_type_str) + + # Find the matching response handler + handler = executor._find_response_handler(original_request, response) + + if handler is None: + logger.warning( + "No response handler found for HITL response in executor %s. Request type: %s, Response type: %s", + executor.id, + type(original_request).__name__, + type(response).__name__, + ) + return + + # Create a WorkflowContext for the handler + # Use a special source ID to indicate this is a HITL response + ctx = WorkflowContext( + executor=executor, + source_executor_ids=["__hitl_response__"], + runner_context=runner_context, + shared_state=shared_state, + ) + + # Call the response handler + # Note: handler is already a partial with original_request bound + logger.debug( + "Invoking response handler for HITL request in executor %s", + executor.id, + ) + await handler(response, ctx) + + +def _deserialize_hitl_response(response_data: Any, response_type_str: str | None) -> Any: + """Deserialize a HITL response to its expected type. + + Args: + response_data: The raw response data (typically a dict from JSON) + response_type_str: The fully qualified type name (module:classname) + + Returns: + The deserialized response, or the original data if deserialization fails + """ + logger.debug( + "Deserializing HITL response. response_type_str=%s, response_data type=%s", + response_type_str, + type(response_data).__name__, + ) + + if response_data is None: + return None + + # If already a primitive, return as-is + if not isinstance(response_data, dict): + logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) + return response_data + + # Try to deserialize using the type hint + if response_type_str: + try: + module_name, class_name = response_type_str.rsplit(":", 1) + import importlib + + module = importlib.import_module(module_name) + response_type = getattr(module, class_name, None) + + if response_type: + logger.debug("Found response type %s, attempting reconstruction", response_type) + # Use the shared reconstruction logic which handles nested objects + result = _reconstruct_typed_value(response_data, response_type) + logger.debug("Reconstructed response type: %s", type(result).__name__) + return result + logger.warning("Could not find class %s in module %s", class_name, module_name) + + except Exception as e: + logger.warning("Could not deserialize HITL response to %s: %s", response_type_str, e) + + # Fall back to generic deserialization + logger.debug("Falling back to generic deserialization") + return deserialize_value(response_data) diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py index 67b4354a9d..68a0c9dd99 100644 --- a/python/packages/azurefunctions/tests/test_utils.py +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -16,8 +16,8 @@ ) from pydantic import BaseModel -from agent_framework_azurefunctions._utils import ( - CapturingRunnerContext, +from agent_framework_azurefunctions._context import CapturingRunnerContext +from agent_framework_azurefunctions._serialization import ( deserialize_value, reconstruct_agent_executor_request, reconstruct_agent_executor_response, From d4337da748cdcaa5ab04d4d91a306dcc1e0bee5b Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin <36454324+ahmedmuhsin@users.noreply.github.com> Date: Wed, 4 Feb 2026 12:46:34 -0600 Subject: [PATCH 04/29] remove white space Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../10_workflow_no_shared_state/function_app.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py index 768d5b138d..4f8d031e3b 100644 --- a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py @@ -206,11 +206,7 @@ def launch(durable: bool = True) -> AgentFunctionApp | None: if durable: # Initialize app workflow = _create_workflow() - - app = AgentFunctionApp(workflow=workflow) - - return app else: # Launch the spam detection workflow in DevUI From 9bbb9d6bf43afd067fc076d4ca851e81250323b2 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin <36454324+ahmedmuhsin@users.noreply.github.com> Date: Wed, 4 Feb 2026 12:47:01 -0600 Subject: [PATCH 05/29] align help text with actual port used Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../azure_functions/10_workflow_no_shared_state/function_app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py index 4f8d031e3b..227d6d9ca7 100644 --- a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py @@ -244,5 +244,5 @@ def launch(durable: bool = True) -> AgentFunctionApp | None: launch(durable=False) else: print("Usage: python function_app.py --maf") - print(" --maf Run in pure MAF mode with DevUI (http://localhost:8096)") + print(" --maf Run in pure MAF mode with DevUI (http://localhost:8094)") print("\nFor Azure Functions mode, use: func start") From 7357cbf5c09610ae997774339ba00d26a7d40c33 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin <36454324+ahmedmuhsin@users.noreply.github.com> Date: Wed, 4 Feb 2026 12:48:40 -0600 Subject: [PATCH 06/29] replace instance id with a place holder Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../07_single_agent_orchestration_hitl/demo.http | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http index 28231a08a8..42f93b8543 100644 --- a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http +++ b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/demo.http @@ -20,7 +20,7 @@ Content-Type: application/json ### Replace INSTANCE_ID_GOES_HERE below with the value returned from the POST call -@instanceId=ccf3950407b5496893df93d1357a5afa +@instanceId= ### Check the status of the orchestration GET http://localhost:7071/api/hitl/status/{{instanceId}} From e6a5035f92530f828deec71420d8c66556172ad4 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin <36454324+ahmedmuhsin@users.noreply.github.com> Date: Wed, 4 Feb 2026 12:49:10 -0600 Subject: [PATCH 07/29] remove unused import Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../azure_functions/10_workflow_no_shared_state/function_app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py index 227d6d9ca7..eacdaad5f3 100644 --- a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py @@ -18,7 +18,7 @@ import logging import os -from typing import Any, Dict +from typing import Any from pathlib import Path from agent_framework import ( From d9702de43779a83bd3644ca8d1e5136890d5b92d Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 4 Feb 2026 13:25:55 -0600 Subject: [PATCH 08/29] remove redundant typing import and fix SIM115 --- .../agent_framework_azurefunctions/_serialization.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index 11034558b2..8b9c31f0fa 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -11,7 +11,7 @@ import logging import types from dataclasses import asdict, fields, is_dataclass -from typing import Any, Union, get_args, get_origin +from typing import Any, Union, get_args, get_origin, get_type_hints from agent_framework import ( AgentExecutorRequest, @@ -169,9 +169,7 @@ def _reconstruct_dataclass_fields(dataclass_type: type, data: dict[str, Any]) -> # Get type hints for the dataclass try: - import typing - - type_hints = typing.get_type_hints(dataclass_type) + type_hints = get_type_hints(dataclass_type) except Exception: # Fall back to field annotations if get_type_hints fails for f in fields(dataclass_type): From d5d1af822cb311e5006022c6bbf57c05c4a301ff Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 4 Feb 2026 14:00:24 -0600 Subject: [PATCH 09/29] fix latest breaking changes --- .../azurefunctions/agent_framework_azurefunctions/_app.py | 4 ++-- .../agent_framework_azurefunctions/_context.py | 4 ++-- python/packages/azurefunctions/tests/test_workflow.py | 4 ++-- python/packages/core/agent_framework/_workflows/__init__.py | 2 ++ .../azure_functions/09_workflow_shared_state/function_app.py | 5 ++--- .../azure_functions/12_workflow_hitl/function_app.py | 3 +-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 88bed775fc..74ab766b32 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -273,7 +273,7 @@ def executor_activity(inputData: str) -> str: """ import json as json_module - from agent_framework import SharedState + from agent_framework import State data = json_module.loads(inputData) message_data = data["message"] @@ -296,7 +296,7 @@ def executor_activity(inputData: str) -> str: async def run() -> dict[str, Any]: # Create runner context and shared state runner_context = CapturingRunnerContext() - shared_state = SharedState() + shared_state = State() # Deserialize shared state values to reconstruct dataclasses/Pydantic models deserialized_state = {k: deserialize_value(v) for k, v in (shared_state_snapshot or {}).items()} diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index 6ceaa88fb8..af8ca077d9 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -16,7 +16,7 @@ Message, RequestInfoEvent, RunnerContext, - SharedState, + State, WorkflowCheckpoint, WorkflowEvent, ) @@ -108,7 +108,7 @@ def clear_runtime_checkpoint_storage(self) -> None: async def create_checkpoint( self, - shared_state: SharedState, + shared_state: State, iteration_count: int, metadata: dict[str, Any] | None = None, ) -> str: diff --git a/python/packages/azurefunctions/tests/test_workflow.py b/python/packages/azurefunctions/tests/test_workflow.py index f35e0f13c0..bbcd00e849 100644 --- a/python/packages/azurefunctions/tests/test_workflow.py +++ b/python/packages/azurefunctions/tests/test_workflow.py @@ -169,9 +169,9 @@ def test_conversation_includes_previous_string_message(self) -> None: ) assert len(response.full_conversation) == 2 - assert response.full_conversation[0].role.value == "user" + assert response.full_conversation[0].role == "user" assert response.full_conversation[0].text == "User said this" - assert response.full_conversation[1].role.value == "assistant" + assert response.full_conversation[1].role == "assistant" def test_conversation_extends_previous_agent_executor_response(self) -> None: """Test that previous AgentExecutorResponse's conversation is extended.""" diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 2dfa029840..803ff73ba7 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -67,6 +67,7 @@ RunnerContext, WorkflowMessage, ) +from ._state import State from ._validation import ( EdgeDuplicationError, GraphConnectivityError, @@ -106,6 +107,7 @@ "InProcRunnerContext", "Runner", "RunnerContext", + "State", "SingleEdgeGroup", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py index 520647b0c3..c2176e2832 100644 --- a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py @@ -28,7 +28,6 @@ AgentExecutorRequest, AgentExecutorResponse, ChatMessage, - Role, Workflow, WorkflowBuilder, WorkflowContext, @@ -112,7 +111,7 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) await ctx.send_message( - AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=new_email.email_content)], should_respond=True) + AgentExecutorRequest(messages=[ChatMessage(role="user", text=new_email.email_content)], should_respond=True) ) @@ -148,7 +147,7 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon # Load the original content by id from shared state and forward it to the assistant. email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") await ctx.send_message( - AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=email.email_content)], should_respond=True) + AgentExecutorRequest(messages=[ChatMessage(role="user", text=email.email_content)], should_respond=True) ) diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py index 11fb48c7db..1f856c7ea2 100644 --- a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py @@ -34,7 +34,6 @@ AgentExecutorResponse, ChatMessage, Executor, - Role, Workflow, WorkflowBuilder, WorkflowContext, @@ -367,7 +366,7 @@ async def route_input( await ctx.send_message( AgentExecutorRequest( - messages=[ChatMessage(Role.USER, text=message)], + messages=[ChatMessage(role="user", text=message)], should_respond=True, ) ) From 27bef116665d98f39a8c9966a78b36015ed27825 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 4 Feb 2026 14:26:20 -0600 Subject: [PATCH 10/29] fix mypy issues --- .../azurefunctions/agent_framework_azurefunctions/_app.py | 5 ++++- .../agent_framework_azurefunctions/_serialization.py | 3 ++- .../agent_framework_azurefunctions/_workflow.py | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 74ab766b32..f1cd6bc290 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -389,8 +389,11 @@ def _setup_workflow_orchestration(self) -> None: """Register the workflow orchestration and related HTTP endpoints.""" @self.orchestration_trigger(context_name="context") - def workflow_orchestrator(context: df.DurableOrchestrationContext): # type: ignore[type-arg] + def workflow_orchestrator(context: df.DurableOrchestrationContext) -> Any: # type: ignore[type-arg] """Generic orchestrator for running the configured workflow.""" + if self.workflow is None: + raise RuntimeError("Workflow not initialized in AgentFunctionApp") + input_data = context.get_input() # Ensure input is a string for the agent diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index 8b9c31f0fa..b48b1af987 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -10,6 +10,7 @@ import logging import types +from collections.abc import Sequence from dataclasses import asdict, fields, is_dataclass from typing import Any, Union, get_args, get_origin, get_type_hints @@ -287,7 +288,7 @@ def reconstruct_agent_executor_response(data: dict[str, Any]) -> AgentExecutorRe ) -def reconstruct_message_for_handler(data: Any, input_types: list[type[Any]]) -> Any: +def reconstruct_message_for_handler(data: Any, input_types: Sequence[type[Any] | types.UnionType]) -> Any: """Attempt to reconstruct a message to match one of the handler's expected types. Handles: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index d778eeeffa..71d4bb56c4 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -22,6 +22,7 @@ import json import logging from collections import defaultdict +from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta from enum import Enum @@ -581,7 +582,7 @@ def run_workflow_orchestrator( initial_message: Any, shared_state: dict[str, Any] | None = None, hitl_timeout_hours: float = DEFAULT_HITL_TIMEOUT_HOURS, -): +) -> Generator[Any, Any, list[Any]]: """Traverse and execute the workflow graph using Durable Functions. This orchestrator reuses MAF's edge group routing logic while adapting From 6a0adb67a74312c6e881320d30b2883024eaba8f Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 10:57:12 -0600 Subject: [PATCH 11/29] clean up imports --- .../azurefunctions/agent_framework_azurefunctions/_app.py | 4 +--- .../agent_framework_azurefunctions/_context.py | 3 +-- .../agent_framework_azurefunctions/_workflow.py | 5 +---- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index f1cd6bc290..a84b5782cf 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -271,11 +271,9 @@ def executor_activity(inputData: str) -> str: Note: We use str type annotations instead of dict to work around Azure Functions worker type validation issues with dict[str, Any]. """ - import json as json_module - from agent_framework import State - data = json_module.loads(inputData) + data = json.loads(inputData) message_data = data["message"] shared_state_snapshot = data.get("shared_state_snapshot", {}) source_executor_ids = data.get("source_executor_ids", ["__orchestrator__"]) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index af8ca077d9..2a7e901faa 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -9,6 +9,7 @@ from __future__ import annotations import asyncio +from copy import copy from typing import Any from agent_framework import ( @@ -54,8 +55,6 @@ async def send_message(self, message: Message) -> None: async def drain_messages(self) -> dict[str, list[Message]]: """Drain and return all captured messages.""" - from copy import copy - messages = copy(self._messages) self._messages.clear() return messages diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 71d4bb56c4..6f37af01c3 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -19,6 +19,7 @@ from __future__ import annotations +import importlib import json import logging from collections import defaultdict @@ -720,8 +721,6 @@ def run_workflow_orchestrator( # Durable Functions may return a JSON string; parse it if so if isinstance(raw_response, str): try: - import json - raw_response = json.loads(raw_response) logger.debug("Parsed JSON string response to: %s", type(raw_response).__name__) except (json.JSONDecodeError, TypeError): @@ -986,8 +985,6 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None if response_type_str: try: module_name, class_name = response_type_str.rsplit(":", 1) - import importlib - module = importlib.import_module(module_name) response_type = getattr(module, class_name, None) From a5b2a836baf2a8be27dca31b627c7e3970d73d49 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 11:18:16 -0600 Subject: [PATCH 12/29] define source marker strings as constants --- .../agent_framework_azurefunctions/_app.py | 13 ++++++--- .../_workflow.py | 27 ++++++++++++++++--- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index a84b5782cf..485cef2588 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -49,7 +49,12 @@ reconstruct_message_for_handler, serialize_message, ) -from ._workflow import execute_hitl_response_handler, run_workflow_orchestrator +from ._workflow import ( + SOURCE_HITL_RESPONSE, + SOURCE_ORCHESTRATOR, + execute_hitl_response_handler, + run_workflow_orchestrator, +) logger = logging.getLogger("agent_framework.azurefunctions") @@ -276,7 +281,7 @@ def executor_activity(inputData: str) -> str: data = json.loads(inputData) message_data = data["message"] shared_state_snapshot = data.get("shared_state_snapshot", {}) - source_executor_ids = data.get("source_executor_ids", ["__orchestrator__"]) + source_executor_ids = data.get("source_executor_ids", [SOURCE_ORCHESTRATOR]) if not self.workflow: raise RuntimeError("Workflow not initialized in AgentFunctionApp") @@ -288,8 +293,8 @@ def executor_activity(inputData: str) -> str: # Reconstruct message - try to match handler's expected types using public input_types message = reconstruct_message_for_handler(message_data, executor.input_types) - # Check if this is a HITL response message - is_hitl_response = isinstance(message_data, dict) and message_data.get("__hitl_response__") + # Check if this is a HITL response message by examining source_executor_ids + is_hitl_response = any(s.startswith(SOURCE_HITL_RESPONSE) for s in source_executor_ids) async def run() -> dict[str, Any]: # Create runner context and shared state diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 6f37af01c3..f4e5f67077 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -54,6 +54,25 @@ logger = logging.getLogger(__name__) +# ============================================================================ +# Source Marker Constants +# ============================================================================ +# These markers identify the origin of messages in the workflow orchestration. +# They are used to track message provenance and handle special cases like HITL. + +# Marker indicating the message originated from the workflow start (initial user input) +SOURCE_WORKFLOW_START = "__workflow_start__" + +# Marker indicating the message originated from the orchestrator itself +# (used as default when executor is called directly by orchestrator, not via another executor) +SOURCE_ORCHESTRATOR = "__orchestrator__" + +# Marker indicating the message is a human-in-the-loop response. +# Used as a source ID prefix. To detect HITL responses, check if any source_executor_id +# starts with this prefix. +SOURCE_HITL_RESPONSE = "__hitl_response__" + + # ============================================================================ # Task Types and Data Structures # ============================================================================ @@ -549,8 +568,8 @@ def _route_hitl_response( """ # Create a message structure that the executor can recognize # This mimics what the InProcRunnerContext does for request_info responses + # Note: HITL origin is identified via source_executor_ids (starting with SOURCE_HITL_RESPONSE) response_message = { - "__hitl_response__": True, "request_id": hitl_request.request_id, "original_request": hitl_request.request_data, "response": raw_response, @@ -562,7 +581,7 @@ def _route_hitl_response( pending_messages[target_id] = [] # Use a special source ID to indicate this is a HITL response - source_id = f"__hitl_response__{hitl_request.request_id}" + source_id = f"{SOURCE_HITL_RESPONSE}_{hitl_request.request_id}" pending_messages[target_id].append((response_message, source_id)) logger.debug( @@ -614,7 +633,7 @@ def run_workflow_orchestrator( List of workflow outputs collected from executor activities """ pending_messages: dict[str, list[tuple[Any, str]]] = { - workflow.start_executor_id: [(initial_message, "__workflow_start__")] + workflow.start_executor_id: [(initial_message, SOURCE_WORKFLOW_START)] } workflow_outputs: list[Any] = [] iteration = 0 @@ -943,7 +962,7 @@ async def execute_hitl_response_handler( # Use a special source ID to indicate this is a HITL response ctx = WorkflowContext( executor=executor, - source_executor_ids=["__hitl_response__"], + source_executor_ids=[SOURCE_HITL_RESPONSE], runner_context=runner_context, shared_state=shared_state, ) From 38f6ff6a9193b7fc771ca464d296baa30051c4b9 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 14:33:37 -0600 Subject: [PATCH 13/29] fix json module name --- .../azurefunctions/agent_framework_azurefunctions/_app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 485cef2588..f8aad6566e 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -383,7 +383,7 @@ async def run() -> dict[str, Any]: } result = asyncio.run(run()) - return json_module.dumps(result) + return json.dumps(result) # Ensure the function is registered (prevents garbage collection) _ = executor_activity From 8b66a8b47ce812a2e8fb8d1cd8fec14cde38f8c2 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 14:35:47 -0600 Subject: [PATCH 14/29] refactor _extract_message_content_from_dict --- .../_workflow.py | 63 +++++++++---------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index f4e5f67077..809a9982df 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -871,41 +871,38 @@ def _extract_message_content(message: Any) -> str: def _extract_message_content_from_dict(message: dict[str, Any]) -> str: - """Extract text content from serialized message dictionaries.""" - message_content = "" + """Extract text content from serialized message dictionaries. - if message.get("messages"): - # AgentExecutorRequest dict - messages is a list of ChatMessage dicts - last_msg = message["messages"][-1] - if isinstance(last_msg, dict): - # ChatMessage serialized via to_dict() has structure: - # {"type": "chat_message", "contents": [{"type": "text", "text": "..."}], ...} - if last_msg.get("contents"): - first_content = last_msg["contents"][0] - if isinstance(first_content, dict): - message_content = first_content.get("text") or "" - # Fallback to direct text field if not in contents structure - if not message_content: - message_content = last_msg.get("text") or last_msg.get("_text") or "" - elif hasattr(last_msg, "text"): - message_content = last_msg.text or "" - elif "agent_response" in message: - # AgentExecutorResponse dict - arr = message.get("agent_response", {}) - if isinstance(arr, dict): - message_content = arr.get("text") or "" - if not message_content and arr.get("messages"): - last_msg = arr["messages"][-1] - if isinstance(last_msg, dict): - # Check for contents structure first - if last_msg.get("contents"): - first_content = last_msg["contents"][0] - if isinstance(first_content, dict): - message_content = first_content.get("text") or "" - if not message_content: - message_content = last_msg.get("text") or last_msg.get("_text") or "" + Uses MAF's from_dict() methods to reconstruct objects before extracting text. + Returns empty string if the message format is not recognized. + """ + # Try to reconstruct as AgentExecutorResponse + if "executor_id" in message and "agent_response" in message: + try: + reconstructed = AgentExecutorResponse.from_dict(message) + return _extract_message_content(reconstructed) + except Exception: + logger.debug("Could not reconstruct AgentExecutorResponse") - return message_content + # Try to reconstruct as AgentExecutorRequest + if "messages" in message and "should_respond" in message: + try: + reconstructed = AgentExecutorRequest.from_dict(message) + return _extract_message_content(reconstructed) + except Exception: + logger.debug("Could not reconstruct AgentExecutorRequest") + + # Try to reconstruct as ChatMessage + if message.get("type") == "chat_message" or "contents" in message: + try: + reconstructed = ChatMessage.from_dict(message) + return reconstructed.text or "" + except Exception: + logger.debug("Could not reconstruct ChatMessage") + + # Unrecognized format - return empty string + logger.debug("Unrecognized message format, returning empty string. Keys: %s", list(message.keys())) + return "" # ============================================================================ From 1c2f50bdd16691c526c16968e2873c37782d67fb Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 17:06:55 -0600 Subject: [PATCH 15/29] refactor serialization --- .../agent_framework_azurefunctions/_app.py | 19 +- .../_serialization.py | 455 ++++++------------ .../_workflow.py | 27 +- .../azurefunctions/tests/test_utils.py | 448 +++++++---------- .../agent_framework/_workflows/__init__.py | 7 + 5 files changed, 350 insertions(+), 606 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index f8aad6566e..228b53852c 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -44,11 +44,7 @@ from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor -from ._serialization import ( - deserialize_value, - reconstruct_message_for_handler, - serialize_message, -) +from ._serialization import deserialize_value, serialize_value from ._workflow import ( SOURCE_HITL_RESPONSE, SOURCE_ORCHESTRATOR, @@ -290,8 +286,9 @@ def executor_activity(inputData: str) -> str: if not executor: raise ValueError(f"Unknown executor: {captured_executor_id}") - # Reconstruct message - try to match handler's expected types using public input_types - message = reconstruct_message_for_handler(message_data, executor.input_types) + # Reconstruct message - deserialize_value restores the original typed objects + # from the encoded data (with type markers) + message = deserialize_value(message_data) # Check if this is a HITL response message by examining source_executor_ids is_hitl_response = any(s.startswith(SOURCE_HITL_RESPONSE) for s in source_executor_ids) @@ -344,7 +341,7 @@ async def run() -> dict[str, Any]: outputs: list[Any] = [] for event in events: if isinstance(event, WorkflowOutputEvent): - outputs.append(serialize_message(event.data)) + outputs.append(serialize_value(event.data)) # Get pending request info events for HITL pending_request_info_events = await runner_context.get_pending_request_info_events() @@ -355,7 +352,7 @@ async def run() -> dict[str, Any]: serialized_pending_requests.append({ "request_id": event.request_id, "source_executor_id": event.source_executor_id, - "data": serialize_message(event.data), + "data": serialize_value(event.data), "request_type": f"{type(event.data).__module__}:{type(event.data).__name__}", "response_type": f"{event.response_type.__module__}:{event.response_type.__name__}" if event.response_type @@ -367,12 +364,12 @@ async def run() -> dict[str, Any]: for _source_id, msg_list in sent_messages.items(): for msg in msg_list: serialized_sent_messages.append({ - "message": serialize_message(msg.data), + "message": serialize_value(msg.data), "target_id": msg.target_id, "source_id": msg.source_id, }) - serialized_updates = {k: serialize_message(v) for k, v in updates.items()} + serialized_updates = {k: serialize_value(v) for k, v in updates.items()} return { "sent_messages": serialized_sent_messages, diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index b48b1af987..dd72437547 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -1,378 +1,217 @@ # Copyright (c) Microsoft. All rights reserved. -"""Serialization and deserialization utilities for workflow execution. - -This module provides helper functions for serializing and deserializing messages, -dataclasses, and Pydantic models for cross-activity communication in Azure Functions. +"""Serialization utilities for workflow execution. + +This module provides thin wrappers around the core checkpoint encoding system +(encode_checkpoint_value / decode_checkpoint_value) from agent_framework._workflows, +adding Pydantic model support. + +The core checkpoint encoding handles type-safe roundtripping of: +- Objects with to_dict/from_dict (ChatMessage, AgentResponse, etc.) +- Dataclasses (AgentExecutorRequest/Response, custom dataclasses) +- Objects with to_json/from_json +- Primitives, lists, dicts + +This module adds: +- serialize_value / deserialize_value: wrappers that also handle Pydantic BaseModel instances +- reconstruct_to_type: for HITL responses where external data (without type markers) + needs to be reconstructed to a known type """ from __future__ import annotations +import importlib import logging -import types -from collections.abc import Sequence -from dataclasses import asdict, fields, is_dataclass -from typing import Any, Union, get_args, get_origin, get_type_hints - -from agent_framework import ( - AgentExecutorRequest, - AgentExecutorResponse, - AgentResponse, - ChatMessage, -) +from dataclasses import fields as dc_fields +from dataclasses import is_dataclass +from typing import Any + +from agent_framework._workflows import decode_checkpoint_value, encode_checkpoint_value +from agent_framework._workflows._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER from pydantic import BaseModel logger = logging.getLogger(__name__) +# Marker for Pydantic models serialized by this module. +# Core checkpoint encoding only supports to_dict/from_dict protocol; Pydantic v2 +# uses model_dump/model_validate, so we handle it here with a compatible marker format. +PYDANTIC_MARKER = "__af_pydantic__" -# ============================================================================ -# Serialization -# ============================================================================ - - -def _serialize_value(value: Any) -> Any: - """Recursively serialize a value for JSON compatibility.""" - # Handle None - if value is None: - return None - - # Handle objects with to_dict() method (like ChatMessage) - if hasattr(value, "to_dict") and callable(value.to_dict): - return value.to_dict() - - # Handle dataclasses - if is_dataclass(value) and not isinstance(value, type): - d: dict[str, Any] = {} - for k, v in asdict(value).items(): - d[k] = _serialize_value(v) - d["__type__"] = type(value).__name__ - d["__module__"] = type(value).__module__ - return d - - # Handle Pydantic models - if isinstance(value, BaseModel): - d = value.model_dump() - d["__type__"] = type(value).__name__ - d["__module__"] = type(value).__module__ - return d - - # Handle lists - if isinstance(value, list): - return [_serialize_value(item) for item in value] - - # Handle dicts - if isinstance(value, dict): - return {k: _serialize_value(v) for k, v in value.items()} - - # Handle primitives and other types - return value +def _resolve_type(type_key: str) -> type | None: + """Resolve a 'module:class' type key to its Python type. -def serialize_message(message: Any) -> Any: - """Helper to serialize messages for activity input. + Args: + type_key: Fully qualified type reference in 'module_name:class_name' format. - Adds type metadata (__type__, __module__) to dataclasses and Pydantic models - to enable reconstruction on the receiving end. Handles nested ChatMessage - and other objects with to_dict() methods. + Returns: + The resolved type, or None if resolution fails. """ - return _serialize_value(message) + try: + module_name, class_name = type_key.split(":", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name, None) + except Exception: + logger.debug("Could not resolve type %s", type_key) + return None # ============================================================================ -# Deserialization +# Serialize / Deserialize # ============================================================================ -def deserialize_value(data: Any, type_registry: dict[str, type] | None = None) -> Any: - """Attempt to deserialize a value using embedded type metadata. +def serialize_value(value: Any) -> Any: + """Serialize a value for JSON-compatible cross-activity communication. + + Extends core checkpoint encoding with Pydantic BaseModel support. + The output is JSON-serializable and can be deserialized with deserialize_value(). + + Dataclasses are handled here (rather than delegating to encode_checkpoint_value) + because their fields may contain nested Pydantic models that core encoding + does not recognise. Args: - data: The serialized data (could be dict with __type__ metadata) - type_registry: Optional dict mapping type names to types for reconstruction + value: Any Python value (primitive, dataclass, Pydantic model, ChatMessage, etc.) Returns: - Reconstructed object if type metadata found and type available, otherwise original data + A JSON-serializable representation with embedded type metadata for reconstruction. """ - if not isinstance(data, dict): - return data - - type_name = data.get("__type__") - module_name = data.get("__module__") + if isinstance(value, BaseModel): + cls = type(value) + return { + PYDANTIC_MARKER: f"{cls.__module__}:{cls.__name__}", + "value": encode_checkpoint_value(value.model_dump()), + } + + # Handle dataclasses ourselves so that nested Pydantic models get the + # PYDANTIC_MARKER treatment instead of being str()'d by core encoding. + if is_dataclass(value) and not isinstance(value, type): + cls = type(value) + return { + DATACLASS_MARKER: f"{cls.__module__}:{cls.__name__}", + **{field.name: serialize_value(getattr(value, field.name)) for field in dc_fields(value)}, + } - # Special handling for MAF types with nested objects - if type_name == "AgentExecutorRequest" or ("messages" in data and "should_respond" in data): - try: - return reconstruct_agent_executor_request(data) - except Exception: - logger.debug("Could not reconstruct as AgentExecutorRequest, trying next strategy") + # Handle lists and dicts recursively to catch nested Pydantic models + if isinstance(value, list): + return [serialize_value(item) for item in value] + if isinstance(value, dict): + return {k: serialize_value(v) for k, v in value.items()} - if type_name == "AgentExecutorResponse" or ("executor_id" in data and "agent_response" in data): - try: - return reconstruct_agent_executor_response(data) - except Exception: - logger.debug("Could not reconstruct as AgentExecutorResponse, trying next strategy") - - if not type_name: - return data - - # Try to find the type - target_type = None - - # First check the registry - if type_registry and type_name in type_registry: - target_type = type_registry[type_name] - else: - # Try to import from module - if module_name: - try: - import importlib - - module = importlib.import_module(module_name) - target_type = getattr(module, type_name, None) - except Exception: - logger.debug("Could not import module %s for type %s", module_name, type_name) - - if target_type: - # Remove metadata before reconstruction - clean_data = {k: v for k, v in data.items() if not k.startswith("__")} - try: - if is_dataclass(target_type): - # Recursively reconstruct nested fields for dataclasses - reconstructed_data = _reconstruct_dataclass_fields(target_type, clean_data) - return target_type(**reconstructed_data) - if issubclass(target_type, BaseModel): - # Pydantic handles nested model validation automatically - return target_type.model_validate(clean_data) - except Exception: - logger.debug("Could not reconstruct type %s from data", type_name) + return encode_checkpoint_value(value) - return data +def deserialize_value(value: Any) -> Any: + """Deserialize a value previously serialized with serialize_value(). -def _reconstruct_dataclass_fields(dataclass_type: type, data: dict[str, Any]) -> dict[str, Any]: - """Recursively reconstruct nested dataclass and Pydantic fields. + Handles core checkpoint markers (__af_model__, __af_dataclass__) and + Pydantic markers (__af_pydantic__) to reconstruct the original typed objects. - This function processes each field of a dataclass, looking up the expected type - from type hints and reconstructing nested objects (dataclasses, Pydantic models, lists). + Dataclasses are reconstructed here (rather than delegating to + decode_checkpoint_value) so that fields containing PYDANTIC_MARKER dicts + are properly deserialized. Args: - dataclass_type: The dataclass type being constructed - data: The dict of field values + value: The serialized data (dict with type markers, list, or primitive) Returns: - Dict with nested objects properly reconstructed + Reconstructed typed object if type metadata found, otherwise original value. """ - if not is_dataclass(dataclass_type): - return data + if isinstance(value, dict): + # Handle Pydantic marker + if PYDANTIC_MARKER in value and "value" in value: + type_key: str = value[PYDANTIC_MARKER] + payload = decode_checkpoint_value(value["value"]) + cls = _resolve_type(type_key) + if cls is not None and hasattr(cls, "model_validate"): + try: + return cls.model_validate(payload) + except Exception: + logger.debug("Could not reconstruct Pydantic model %s", type_key) + return payload + + # Handle dataclass marker — deserialize fields ourselves so that nested + # PYDANTIC_MARKER dicts are properly handled. + if DATACLASS_MARKER in value: + type_key = value[DATACLASS_MARKER] + cls = _resolve_type(type_key) + if cls is not None and is_dataclass(cls): + try: + field_data = {k: deserialize_value(v) for k, v in value.items() if k != DATACLASS_MARKER} + return cls(**field_data) + except Exception: + logger.debug("Could not reconstruct dataclass %s, falling back to core decode", type_key) + return decode_checkpoint_value(value) - result = {} - type_hints = {} + # Handle model marker (to_dict/from_dict objects like ChatMessage) — core + # handles these fully since the object's own serialisation manages nesting. + if MODEL_MARKER in value: + return decode_checkpoint_value(value) - # Get type hints for the dataclass - try: - type_hints = get_type_hints(dataclass_type) - except Exception: - # Fall back to field annotations if get_type_hints fails - for f in fields(dataclass_type): - type_hints[f.name] = f.type + # Recurse into plain dicts to catch nested markers + return {k: deserialize_value(v) for k, v in value.items()} - for key, value in data.items(): - if key not in type_hints: - result[key] = value - continue + if isinstance(value, list): + return [deserialize_value(item) for item in value] - field_type = type_hints[key] + return decode_checkpoint_value(value) - # Handle Optional types (Union with None) - origin = get_origin(field_type) - if origin is Union or isinstance(field_type, types.UnionType): - args = get_args(field_type) - # Filter out NoneType to get the actual type - non_none_types = [t for t in args if t is not type(None)] - if len(non_none_types) == 1: - field_type = non_none_types[0] - # Recursively reconstruct the value - result[key] = _reconstruct_typed_value(value, field_type) +# ============================================================================ +# HITL Type Reconstruction +# ============================================================================ - return result +def reconstruct_to_type(value: Any, target_type: type) -> Any: + """Reconstruct a value to a known target type. -def _reconstruct_typed_value(value: Any, target_type: type) -> Any: - """Reconstruct a single value to the target type. + Used for HITL responses where external data (without checkpoint type markers) + needs to be reconstructed to a specific type determined by the response_type hint. - Handles dataclasses, Pydantic models, and lists with typed elements. + Tries strategies in order: + 1. Return as-is if already the correct type + 2. deserialize_value (for data with any type markers) + 3. Pydantic model_validate (for Pydantic models) + 4. Dataclass constructor (for dataclasses) Args: - value: The value to reconstruct - target_type: The expected type + value: The value to reconstruct (typically a dict from JSON) + target_type: The expected type to reconstruct to Returns: - The reconstructed value + Reconstructed value if possible, otherwise the original value """ if value is None: return None - # If already the correct type, return as-is try: if isinstance(value, target_type): return value except TypeError: - # target_type might not be a valid type for isinstance pass - # Handle dict values that need reconstruction - if isinstance(value, dict): - # First try deserialize_value which uses embedded type metadata - if "__type__" in value: - deserialized = deserialize_value(value) - if deserialized is not value: - return deserialized - - # Handle Pydantic models - if hasattr(target_type, "model_validate"): - try: - return target_type.model_validate(value) - except Exception: - logger.debug("Could not validate Pydantic model %s", target_type) - - # Handle dataclasses - if is_dataclass(target_type) and isinstance(target_type, type): - try: - # Recursively reconstruct nested fields - reconstructed = _reconstruct_dataclass_fields(target_type, value) - return target_type(**reconstructed) - except Exception: - logger.debug("Could not construct dataclass %s", target_type) - - # Handle list values - if isinstance(value, list): - origin = get_origin(target_type) - if origin is list: - args = get_args(target_type) - if args: - element_type = args[0] - return [_reconstruct_typed_value(item, element_type) for item in value] + if not isinstance(value, dict): + return value - return value - - -# ============================================================================ -# MAF Type Reconstruction -# ============================================================================ + # Try marker-based decoding if data has type markers + if MODEL_MARKER in value or DATACLASS_MARKER in value or PYDANTIC_MARKER in value: + decoded = deserialize_value(value) + if not isinstance(decoded, dict): + return decoded - -def reconstruct_agent_executor_request(data: dict[str, Any]) -> AgentExecutorRequest: - """Helper to reconstruct AgentExecutorRequest from dict.""" - # Reconstruct ChatMessage objects in messages - messages_data = data.get("messages", []) - messages = [ChatMessage.from_dict(m) if isinstance(m, dict) else m for m in messages_data] - - return AgentExecutorRequest(messages=messages, should_respond=data.get("should_respond", True)) - - -def reconstruct_agent_executor_response(data: dict[str, Any]) -> AgentExecutorResponse: - """Helper to reconstruct AgentExecutorResponse from dict.""" - # Reconstruct AgentResponse - arr_data = data.get("agent_response", {}) - agent_response = AgentResponse.from_dict(arr_data) if isinstance(arr_data, dict) else arr_data - - # Reconstruct full_conversation - fc_data = data.get("full_conversation", []) - full_conversation = None - if fc_data: - full_conversation = [ChatMessage.from_dict(m) if isinstance(m, dict) else m for m in fc_data] - - return AgentExecutorResponse( - executor_id=data["executor_id"], agent_response=agent_response, full_conversation=full_conversation - ) - - -def reconstruct_message_for_handler(data: Any, input_types: Sequence[type[Any] | types.UnionType]) -> Any: - """Attempt to reconstruct a message to match one of the handler's expected types. - - Handles: - - Dicts with __type__ metadata -> reconstructs to original dataclass/Pydantic model - - Lists (from fan-in) -> recursively reconstructs each item - - Union types (T | U) -> tries each type in the union - - AgentExecutorRequest/Response -> special handling for nested ChatMessage objects - - Args: - data: The serialized message data (could be dict, str, list, etc.) - input_types: List of message types the executor can accept - - Returns: - Reconstructed message if possible, otherwise the original data - """ - # Flatten union types in input_types (e.g., T | U becomes [T, U]) - flattened_types: list[type[Any]] = [] - for input_type in input_types: - origin = get_origin(input_type) - # Handle both typing.Union and types.UnionType (Python 3.10+ | syntax) - if origin is Union or isinstance(input_type, types.UnionType): - # This is a Union type (T | U), extract the component types - flattened_types.extend(get_args(input_type)) - else: - flattened_types.append(input_type) - - # Handle lists (fan-in aggregation) - recursively reconstruct each item - if isinstance(data, list): - # Extract element types from list[T] annotations in input_types if possible - element_types: list[type[Any]] = [] - for input_type in input_types: - origin = get_origin(input_type) - if origin is list: - args = get_args(input_type) - if args: - # Handle union types inside list[T | U] - for arg in args: - arg_origin = get_origin(arg) - if arg_origin is Union or isinstance(arg, types.UnionType): - element_types.extend(get_args(arg)) - else: - element_types.append(arg) - - # Recursively reconstruct each item in the list - return [reconstruct_message_for_handler(item, element_types or flattened_types) for item in data] - - if not isinstance(data, dict): - return data - - # Try AgentExecutorResponse first - it needs special handling for nested objects - if "executor_id" in data and "agent_response" in data: + # Try Pydantic model validation (for unmarked dicts, e.g., external HITL data) + if hasattr(target_type, "model_validate"): try: - return reconstruct_agent_executor_response(data) + return target_type.model_validate(value) except Exception: - logger.debug("Could not reconstruct as AgentExecutorResponse in handler context") + logger.debug("Could not validate Pydantic model %s", target_type) - # Try AgentExecutorRequest - also needs special handling for nested ChatMessage objects - if "messages" in data and "should_respond" in data: + # Try dataclass construction (for unmarked dicts, e.g., external HITL data) + if is_dataclass(target_type) and isinstance(target_type, type): try: - return reconstruct_agent_executor_request(data) + return target_type(**value) except Exception: - logger.debug("Could not reconstruct as AgentExecutorRequest in handler context") - - # Try deserialize_value which uses embedded type metadata (__type__, __module__) - if "__type__" in data: - deserialized = deserialize_value(data) - if deserialized is not data: - return deserialized - - # Try to match against input types by checking dict keys vs dataclass fields - # Filter out metadata keys when comparing - data_keys = {k for k in data if not k.startswith("__")} - for msg_type in flattened_types: - if is_dataclass(msg_type): - # Check if the dict keys match the dataclass fields - field_names = {f.name for f in fields(msg_type)} - if field_names == data_keys or field_names.issubset(data_keys): - try: - # Remove metadata before constructing - clean_data = {k: v for k, v in data.items() if not k.startswith("__")} - # Recursively reconstruct nested objects based on field types - reconstructed_data = _reconstruct_dataclass_fields(msg_type, clean_data) - return msg_type(**reconstructed_data) - except Exception: - logger.debug("Could not construct %s from matching fields", msg_type.__name__) + logger.debug("Could not construct dataclass %s", target_type) - return data + return value diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 809a9982df..8ebd53d0ad 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -19,7 +19,6 @@ from __future__ import annotations -import importlib import json import logging from collections import defaultdict @@ -49,7 +48,7 @@ from ._context import CapturingRunnerContext from ._orchestration import AzureFunctionsAgentExecutor -from ._serialization import _reconstruct_typed_value, deserialize_value, serialize_message +from ._serialization import _resolve_type, deserialize_value, reconstruct_to_type, serialize_value logger = logging.getLogger(__name__) @@ -315,7 +314,7 @@ def _prepare_activity_task( """ activity_input = { "executor_id": executor_id, - "message": serialize_message(message), + "message": serialize_value(message), "shared_state_snapshot": shared_state_snapshot, "source_executor_ids": [source_executor_id], } @@ -999,21 +998,13 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None # Try to deserialize using the type hint if response_type_str: - try: - module_name, class_name = response_type_str.rsplit(":", 1) - module = importlib.import_module(module_name) - response_type = getattr(module, class_name, None) - - if response_type: - logger.debug("Found response type %s, attempting reconstruction", response_type) - # Use the shared reconstruction logic which handles nested objects - result = _reconstruct_typed_value(response_data, response_type) - logger.debug("Reconstructed response type: %s", type(result).__name__) - return result - logger.warning("Could not find class %s in module %s", class_name, module_name) - - except Exception as e: - logger.warning("Could not deserialize HITL response to %s: %s", response_type_str, e) + response_type = _resolve_type(response_type_str) + if response_type: + logger.debug("Found response type %s, attempting reconstruction", response_type) + result = reconstruct_to_type(response_data, response_type) + logger.debug("Reconstructed response type: %s", type(result).__name__) + return result + logger.warning("Could not resolve response type: %s", response_type_str) # Fall back to generic deserialization logger.debug("Falling back to generic deserialization") diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py index 68a0c9dd99..506710db98 100644 --- a/python/packages/azurefunctions/tests/test_utils.py +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -19,13 +19,35 @@ from agent_framework_azurefunctions._context import CapturingRunnerContext from agent_framework_azurefunctions._serialization import ( deserialize_value, - reconstruct_agent_executor_request, - reconstruct_agent_executor_response, - reconstruct_message_for_handler, - serialize_message, + reconstruct_to_type, + serialize_value, ) +# Module-level test types (must be importable for checkpoint encoding roundtrip) +@dataclass +class SampleData: + """Sample dataclass for testing checkpoint encoding roundtrip.""" + + name: str + value: int + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing checkpoint encoding roundtrip.""" + + title: str + count: int + + +@dataclass +class DataclassWithPydanticField: + """Dataclass containing a Pydantic model field for testing nested serialization.""" + + label: str + model: SampleModel + + class TestCapturingRunnerContext: """Test suite for CapturingRunnerContext.""" @@ -177,287 +199,175 @@ async def test_apply_checkpoint_raises_not_implemented(self, context: CapturingR await context.apply_checkpoint(Mock()) -class TestSerializeMessage: - """Test suite for serialize_message function.""" - - def test_serialize_none(self) -> None: - """Test serializing None.""" - assert serialize_message(None) is None - - def test_serialize_primitive_types(self) -> None: - """Test serializing primitive types.""" - assert serialize_message("hello") == "hello" - assert serialize_message(42) == 42 - assert serialize_message(3.14) == 3.14 - assert serialize_message(True) is True - - def test_serialize_list(self) -> None: - """Test serializing lists.""" - result = serialize_message([1, 2, 3]) - assert result == [1, 2, 3] - - def test_serialize_dict(self) -> None: - """Test serializing dicts.""" - result = serialize_message({"key": "value", "num": 42}) - assert result == {"key": "value", "num": 42} - - def test_serialize_dataclass(self) -> None: - """Test serializing dataclasses with type metadata.""" - - @dataclass - class TestData: - name: str - value: int - - data = TestData(name="test", value=123) - result = serialize_message(data) - - assert result["name"] == "test" - assert result["value"] == 123 - assert result["__type__"] == "TestData" - assert "__module__" in result - - def test_serialize_pydantic_model(self) -> None: - """Test serializing Pydantic models with type metadata.""" - - class TestModel(BaseModel): - title: str - count: int - - model = TestModel(title="Hello", count=5) - result = serialize_message(model) - - assert result["title"] == "Hello" - assert result["count"] == 5 - assert result["__type__"] == "TestModel" - assert "__module__" in result - - def test_serialize_nested_structures(self) -> None: - """Test serializing nested structures.""" - - @dataclass - class Inner: - x: int - - @dataclass - class Outer: - inner: Inner - items: list[int] - - outer = Outer(inner=Inner(x=10), items=[1, 2, 3]) - result = serialize_message(outer) - - assert result["__type__"] == "Outer" - # Nested dataclass is serialized via asdict, which doesn't add __type__ recursively - assert result["inner"]["x"] == 10 - assert result["items"] == [1, 2, 3] - - def test_serialize_object_with_to_dict(self) -> None: - """Test serializing objects with to_dict method.""" - message = ChatMessage(role="user", text="Hello") - result = serialize_message(message) - - # ChatMessage has to_dict() method which returns a specific structure - assert isinstance(result, dict) - assert "contents" in result # ChatMessage uses contents structure - - -class TestDeserializeValue: - """Test suite for deserialize_value function.""" - - def test_deserialize_non_dict_returns_original(self) -> None: +class TestSerializationRoundtrip: + """Test that serialization roundtrips correctly for types used in Azure Functions workflows.""" + + def test_roundtrip_chat_message(self) -> None: + """Test ChatMessage survives encode → decode roundtrip.""" + original = ChatMessage(role="user", text="Hello") + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert isinstance(decoded, ChatMessage) + assert decoded.role == "user" + + def test_roundtrip_agent_executor_request(self) -> None: + """Test AgentExecutorRequest with nested ChatMessages roundtrips.""" + original = AgentExecutorRequest( + messages=[ChatMessage(role="user", text="Hi")], + should_respond=True, + ) + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert isinstance(decoded, AgentExecutorRequest) + assert len(decoded.messages) == 1 + assert isinstance(decoded.messages[0], ChatMessage) + assert decoded.should_respond is True + + def test_roundtrip_agent_executor_response(self) -> None: + """Test AgentExecutorResponse with nested AgentResponse roundtrips.""" + original = AgentExecutorResponse( + executor_id="test_exec", + agent_response=AgentResponse(messages=[ChatMessage(role="assistant", text="Reply")]), + ) + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert isinstance(decoded, AgentExecutorResponse) + assert decoded.executor_id == "test_exec" + assert isinstance(decoded.agent_response, AgentResponse) + + def test_roundtrip_dataclass(self) -> None: + """Test custom dataclass roundtrips.""" + original = SampleData(name="test", value=42) + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert isinstance(decoded, SampleData) + assert decoded.name == "test" + assert decoded.value == 42 + + def test_roundtrip_pydantic_model(self) -> None: + """Test Pydantic model roundtrips.""" + original = SampleModel(title="Hello", count=5) + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert isinstance(decoded, SampleModel) + assert decoded.title == "Hello" + assert decoded.count == 5 + + def test_roundtrip_primitives(self) -> None: + """Test primitives pass through unchanged.""" + assert serialize_value(None) is None + assert serialize_value("hello") == "hello" + assert serialize_value(42) == 42 + assert serialize_value(3.14) == 3.14 + assert serialize_value(True) is True + + def test_roundtrip_list_of_objects(self) -> None: + """Test list of typed objects roundtrips.""" + original = [ + ChatMessage(role="user", text="Q"), + ChatMessage(role="assistant", text="A"), + ] + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert isinstance(decoded, list) + assert len(decoded) == 2 + assert all(isinstance(m, ChatMessage) for m in decoded) + + def test_roundtrip_dict_of_objects(self) -> None: + """Test dict with typed values roundtrips (used for shared state).""" + original = {"count": 42, "msg": ChatMessage(role="user", text="Hi")} + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert decoded["count"] == 42 + assert isinstance(decoded["msg"], ChatMessage) + + def test_roundtrip_dataclass_with_nested_pydantic(self) -> None: + """Test dataclass containing a Pydantic model field roundtrips correctly. + + This covers the HITL pattern where AnalysisWithSubmission (dataclass) + contains a ContentAnalysisResult (Pydantic BaseModel) field. + """ + original = DataclassWithPydanticField(label="test", model=SampleModel(title="Nested", count=99)) + encoded = serialize_value(original) + decoded = deserialize_value(encoded) + + assert isinstance(decoded, DataclassWithPydanticField) + assert decoded.label == "test" + assert isinstance(decoded.model, SampleModel) + assert decoded.model.title == "Nested" + assert decoded.model.count == 99 + + +class TestReconstructToType: + """Test suite for reconstruct_to_type function (used for HITL responses).""" + + def test_none_returns_none(self) -> None: + """Test that None input returns None.""" + assert reconstruct_to_type(None, str) is None + + def test_already_correct_type(self) -> None: + """Test that values already of the correct type are returned as-is.""" + assert reconstruct_to_type("hello", str) == "hello" + assert reconstruct_to_type(42, int) == 42 + + def test_non_dict_returns_original(self) -> None: """Test that non-dict values are returned as-is.""" - assert deserialize_value("string") == "string" - assert deserialize_value(42) == 42 - assert deserialize_value([1, 2, 3]) == [1, 2, 3] - - def test_deserialize_dict_without_type_returns_original(self) -> None: - """Test that dicts without type metadata are returned as-is.""" - data = {"key": "value", "num": 42} - result = deserialize_value(data) - assert result == data - - def test_deserialize_agent_executor_request(self) -> None: - """Test deserializing AgentExecutorRequest.""" - data = { - "messages": [{"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Hello"}]}], - "should_respond": True, - } - - result = deserialize_value(data) + assert reconstruct_to_type("hello", int) == "hello" + assert reconstruct_to_type([1, 2], dict) == [1, 2] - assert isinstance(result, AgentExecutorRequest) - assert len(result.messages) == 1 - assert result.should_respond is True + def test_reconstruct_pydantic_model(self) -> None: + """Test reconstruction of Pydantic model from plain dict.""" - def test_deserialize_agent_executor_response(self) -> None: - """Test deserializing AgentExecutorResponse.""" - data = { - "executor_id": "test_exec", - "agent_response": { - "type": "agent_response", - "messages": [ - {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Hi there"}]} - ], - }, - } + class ApprovalResponse(BaseModel): + approved: bool + reason: str - result = deserialize_value(data) + data = {"approved": True, "reason": "Looks good"} + result = reconstruct_to_type(data, ApprovalResponse) - assert isinstance(result, AgentExecutorResponse) - assert result.executor_id == "test_exec" + assert isinstance(result, ApprovalResponse) + assert result.approved is True + assert result.reason == "Looks good" - def test_deserialize_with_type_registry(self) -> None: - """Test deserializing with type registry.""" + def test_reconstruct_dataclass(self) -> None: + """Test reconstruction of dataclass from plain dict.""" @dataclass - class CustomType: - name: str + class Feedback: + score: int + comment: str - data = {"name": "test", "__type__": "CustomType"} - result = deserialize_value(data, type_registry={"CustomType": CustomType}) + data = {"score": 5, "comment": "Great"} + result = reconstruct_to_type(data, Feedback) - assert isinstance(result, CustomType) - assert result.name == "test" + assert isinstance(result, Feedback) + assert result.score == 5 + assert result.comment == "Great" + def test_reconstruct_from_checkpoint_markers(self) -> None: + """Test that data with checkpoint markers is decoded via deserialize_value.""" + original = SampleData(value=99, name="marker-test") + encoded = serialize_value(original) -class TestReconstructAgentExecutorRequest: - """Test suite for reconstruct_agent_executor_request function.""" + result = reconstruct_to_type(encoded, SampleData) + assert isinstance(result, SampleData) + assert result.value == 99 - def test_reconstruct_with_chat_messages(self) -> None: - """Test reconstructing request with ChatMessage dicts.""" - data = { - "messages": [ - {"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Hello"}]}, - {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Hi"}]}, - ], - "should_respond": True, - } - - result = reconstruct_agent_executor_request(data) - - assert isinstance(result, AgentExecutorRequest) - assert len(result.messages) == 2 - assert result.should_respond is True - - def test_reconstruct_defaults_should_respond_to_true(self) -> None: - """Test that should_respond defaults to True.""" - data = {"messages": []} - - result = reconstruct_agent_executor_request(data) - - assert result.should_respond is True - - -class TestReconstructAgentExecutorResponse: - """Test suite for reconstruct_agent_executor_response function.""" - - def test_reconstruct_with_agent_response(self) -> None: - """Test reconstructing response with agent_response.""" - data = { - "executor_id": "my_executor", - "agent_response": { - "type": "agent_response", - "messages": [ - {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "Response"}]} - ], - }, - "full_conversation": [], - } - - result = reconstruct_agent_executor_response(data) - - assert isinstance(result, AgentExecutorResponse) - assert result.executor_id == "my_executor" - assert isinstance(result.agent_response, AgentResponse) - - def test_reconstruct_with_full_conversation(self) -> None: - """Test reconstructing response with full_conversation.""" - data = { - "executor_id": "exec", - "agent_response": {"type": "agent_response", "messages": []}, - "full_conversation": [ - {"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Q"}]}, - {"type": "chat_message", "role": "assistant", "contents": [{"type": "text", "text": "A"}]}, - ], - } - - result = reconstruct_agent_executor_response(data) - - assert result.full_conversation is not None - assert len(result.full_conversation) == 2 - - -class TestReconstructMessageForHandler: - """Test suite for reconstruct_message_for_handler function.""" - - def test_reconstruct_non_dict_returns_original(self) -> None: - """Test that non-dict messages are returned as-is.""" - assert reconstruct_message_for_handler("string", []) == "string" - assert reconstruct_message_for_handler(42, []) == 42 - - def test_reconstruct_agent_executor_response(self) -> None: - """Test reconstructing AgentExecutorResponse.""" - data = { - "executor_id": "exec", - "agent_response": {"type": "agent_response", "messages": []}, - } - - result = reconstruct_message_for_handler(data, [AgentExecutorResponse]) - - assert isinstance(result, AgentExecutorResponse) - - def test_reconstruct_agent_executor_request(self) -> None: - """Test reconstructing AgentExecutorRequest.""" - data = { - "messages": [{"type": "chat_message", "role": "user", "contents": [{"type": "text", "text": "Hi"}]}], - "should_respond": True, - } - - result = reconstruct_message_for_handler(data, [AgentExecutorRequest]) - - assert isinstance(result, AgentExecutorRequest) - - def test_reconstruct_with_type_metadata(self) -> None: - """Test reconstructing using __type__ metadata.""" - - @dataclass - class CustomMsg: - content: str - - # Serialize includes type metadata - serialized = serialize_message(CustomMsg(content="test")) - - result = reconstruct_message_for_handler(serialized, [CustomMsg]) - - assert isinstance(result, CustomMsg) - assert result.content == "test" - - def test_reconstruct_matches_dataclass_fields(self) -> None: - """Test reconstruction by matching dataclass field names.""" + def test_unrecognized_dict_returns_original(self) -> None: + """Test that unrecognized dicts are returned as-is.""" @dataclass - class MyData: - field_a: str - field_b: int - - data = {"field_a": "hello", "field_b": 42} - - result = reconstruct_message_for_handler(data, [MyData]) - - assert isinstance(result, MyData) - assert result.field_a == "hello" - assert result.field_b == 42 - - def test_reconstruct_returns_original_if_no_match(self) -> None: - """Test that original dict is returned if no type matches.""" - - @dataclass - class UnrelatedType: - completely_different_field: str + class Unrelated: + completely_different: str data = {"some_key": "some_value"} - - result = reconstruct_message_for_handler(data, [UnrelatedType]) + result = reconstruct_to_type(data, Unrelated) assert result == data diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 803ff73ba7..ea31d5a718 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -26,6 +26,11 @@ InMemoryCheckpointStorage, WorkflowCheckpoint, ) +from ._checkpoint_encoding import ( + decode_checkpoint_value, + encode_checkpoint_value, +) +from ._checkpoint_summary import WorkflowCheckpointSummary, get_checkpoint_summary from ._const import ( DEFAULT_MAX_ITERATIONS, ) @@ -94,9 +99,11 @@ "Case", "CheckpointStorage", "Default", + "decode_checkpoint_value", "Edge", "EdgeCondition", "EdgeDuplicationError", + "encode_checkpoint_value", "Executor", "FanInEdgeGroup", "FanOutEdgeGroup", From 7d9554749d6d64150fc7bef96adddb537756432f Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 17:37:47 -0600 Subject: [PATCH 16/29] add helper method for error response construction and remove _extract_message_content_from_dict since it is not needed --- .../agent_framework_azurefunctions/_app.py | 33 ++++------ .../_workflow.py | 37 +---------- .../azurefunctions/tests/test_workflow.py | 64 ++----------------- .../agent_framework/_workflows/__init__.py | 4 +- 4 files changed, 20 insertions(+), 118 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 228b53852c..f2935df8f3 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -415,11 +415,7 @@ async def start_workflow_orchestration( try: req_body = req.get_json() except ValueError: - return func.HttpResponse( - json.dumps({"error": "Invalid JSON body"}), - status_code=400, - mimetype="application/json", - ) + return self._build_error_response("Invalid JSON body") instance_id = await client.start_new("workflow_orchestrator", client_input=req_body) @@ -447,11 +443,7 @@ async def get_workflow_status( status = await client.get_status(instance_id) if not status: - return func.HttpResponse( - json.dumps({"error": "Instance not found"}), - status_code=404, - mimetype="application/json", - ) + return self._build_error_response("Instance not found", status_code=404) response = { "instanceId": status.instance_id, @@ -497,20 +489,12 @@ async def send_hitl_response(req: func.HttpRequest, client: df.DurableOrchestrat request_id = req.route_params.get("requestId") if not instance_id or not request_id: - return func.HttpResponse( - json.dumps({"error": "Instance ID and Request ID are required."}), - status_code=400, - mimetype="application/json", - ) + return self._build_error_response("Instance ID and Request ID are required.") try: response_data = req.get_json() except ValueError: - return func.HttpResponse( - json.dumps({"error": "Request body must be valid JSON."}), - status_code=400, - mimetype="application/json", - ) + return self._build_error_response("Request body must be valid JSON.") # Send the response as an external event # The request_id is used as the event name for correlation @@ -1229,6 +1213,15 @@ def _build_json_response(self, payload: dict[str, Any] | str, status_code: int) body_json = payload if isinstance(payload, str) else json.dumps(payload) return func.HttpResponse(body_json, status_code=status_code, mimetype=MIMETYPE_APPLICATION_JSON) + @staticmethod + def _build_error_response(message: str, status_code: int = 400) -> func.HttpResponse: + """Return a JSON error response with the given message and status code.""" + return func.HttpResponse( + json.dumps({"error": message}), + status_code=status_code, + mimetype=MIMETYPE_APPLICATION_JSON, + ) + def _convert_payload_to_text(self, payload: dict[str, Any]) -> str: """Convert a structured payload into a human-readable text response.""" for key in ("response", "error", "message"): diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 8ebd53d0ad..9fe967483f 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -862,48 +862,13 @@ def _extract_message_content(message: Any) -> str: # Extract text from the last message in the request message_content = message.messages[-1].text or "" elif isinstance(message, dict): - message_content = _extract_message_content_from_dict(message) + logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", list(message.keys())) elif isinstance(message, str): message_content = message return message_content -def _extract_message_content_from_dict(message: dict[str, Any]) -> str: - """Extract text content from serialized message dictionaries. - - Uses MAF's from_dict() methods to reconstruct objects before extracting text. - Returns empty string if the message format is not recognized. - """ - # Try to reconstruct as AgentExecutorResponse - if "executor_id" in message and "agent_response" in message: - try: - reconstructed = AgentExecutorResponse.from_dict(message) - return _extract_message_content(reconstructed) - except Exception: - logger.debug("Could not reconstruct AgentExecutorResponse") - - # Try to reconstruct as AgentExecutorRequest - if "messages" in message and "should_respond" in message: - try: - reconstructed = AgentExecutorRequest.from_dict(message) - return _extract_message_content(reconstructed) - except Exception: - logger.debug("Could not reconstruct AgentExecutorRequest") - - # Try to reconstruct as ChatMessage - if message.get("type") == "chat_message" or "contents" in message: - try: - reconstructed = ChatMessage.from_dict(message) - return reconstructed.text or "" - except Exception: - logger.debug("Could not reconstruct ChatMessage") - - # Unrecognized format - return empty string - logger.debug("Unrecognized message format, returning empty string. Keys: %s", list(message.keys())) - return "" - - # ============================================================================ # HITL Response Handler Execution # ============================================================================ diff --git a/python/packages/azurefunctions/tests/test_workflow.py b/python/packages/azurefunctions/tests/test_workflow.py index bbcd00e849..b8e246d85c 100644 --- a/python/packages/azurefunctions/tests/test_workflow.py +++ b/python/packages/azurefunctions/tests/test_workflow.py @@ -23,7 +23,6 @@ from agent_framework_azurefunctions._workflow import ( _extract_message_content, - _extract_message_content_from_dict, build_agent_executor_response, route_message_through_edge_groups, ) @@ -249,20 +248,13 @@ def test_extract_from_agent_executor_request(self) -> None: assert result == "Last request" - def test_extract_from_dict_agent_executor_request(self) -> None: - """Test extracting from serialized AgentExecutorRequest dict.""" - msg_dict = { - "messages": [ - { - "type": "chat_message", - "contents": [{"type": "text", "text": "Hello from dict"}], - } - ] - } + def test_extract_from_dict_returns_empty(self) -> None: + """Test that dict messages return empty string (unexpected input).""" + msg_dict = {"messages": [{"text": "Hello"}]} result = _extract_message_content(msg_dict) - assert result == "Hello from dict" + assert result == "" def test_extract_returns_empty_for_unknown_type(self) -> None: """Test that unknown types return empty string.""" @@ -271,54 +263,6 @@ def test_extract_returns_empty_for_unknown_type(self) -> None: assert result == "" -class TestExtractMessageContentFromDict: - """Test suite for _extract_message_content_from_dict function.""" - - def test_extract_from_messages_with_contents(self) -> None: - """Test extracting from messages with contents structure.""" - msg_dict = {"messages": [{"contents": [{"type": "text", "text": "Content text"}]}]} - - result = _extract_message_content_from_dict(msg_dict) - - assert result == "Content text" - - def test_extract_from_messages_with_direct_text(self) -> None: - """Test extracting from messages with direct text field.""" - msg_dict = {"messages": [{"text": "Direct text"}]} - - result = _extract_message_content_from_dict(msg_dict) - - assert result == "Direct text" - - def test_extract_from_agent_response(self) -> None: - """Test extracting from agent_response dict.""" - msg_dict = {"agent_response": {"text": "Response text"}} - - result = _extract_message_content_from_dict(msg_dict) - - assert result == "Response text" - - def test_extract_from_agent_response_with_messages(self) -> None: - """Test extracting from agent_response with messages.""" - msg_dict = {"agent_response": {"messages": [{"contents": [{"type": "text", "text": "Nested content"}]}]}} - - result = _extract_message_content_from_dict(msg_dict) - - assert result == "Nested content" - - def test_extract_returns_empty_for_empty_dict(self) -> None: - """Test that empty dict returns empty string.""" - result = _extract_message_content_from_dict({}) - - assert result == "" - - def test_extract_returns_empty_for_empty_messages(self) -> None: - """Test that empty messages list returns empty string.""" - result = _extract_message_content_from_dict({"messages": []}) - - assert result == "" - - class TestEdgeGroupIntegration: """Integration tests for edge group routing with realistic scenarios.""" diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index ea31d5a718..c89c707284 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -99,11 +99,9 @@ "Case", "CheckpointStorage", "Default", - "decode_checkpoint_value", "Edge", "EdgeCondition", "EdgeDuplicationError", - "encode_checkpoint_value", "Executor", "FanInEdgeGroup", "FanOutEdgeGroup", @@ -143,6 +141,8 @@ "WorkflowValidationError", "WorkflowViz", "create_edge_runner", + "decode_checkpoint_value", + "encode_checkpoint_value", "executor", "handler", "resolve_agent_id", From e861b2b260ffd618fb8c60ba2b0deb2dfb006d9b Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 17:53:30 -0600 Subject: [PATCH 17/29] use strict tpe checking for edges --- .../agent_framework_azurefunctions/_workflow.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 9fe967483f..f5b99c2f68 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -37,6 +37,7 @@ Workflow, ) from agent_framework._workflows._edge import ( + Edge, EdgeGroup, FanInEdgeGroup, FanOutEdgeGroup, @@ -133,21 +134,21 @@ class PendingHITLRequest: # ============================================================================ -def _evaluate_edge_condition_sync(edge: Any, message: Any) -> bool: +def _evaluate_edge_condition_sync(edge: Edge, message: Any) -> bool: """Evaluate an edge's condition synchronously. This is needed because Durable Functions orchestrators use generators, not async/await, so we cannot call async methods like edge.should_route(). Args: - edge: The Edge object with a _condition attribute + edge: The Edge with an optional _condition callable message: The message to evaluate against the condition Returns: True if the edge should be traversed, False otherwise """ # Access the internal condition directly since should_route is async - condition = getattr(edge, "_condition", None) + condition = edge._condition if condition is None: return True result = condition(message) From b75fa5af4b27ae8c7ec23d4d488f9d020eaf20b4 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 18:02:22 -0600 Subject: [PATCH 18/29] change how duplicate agent registrations are handled --- .../agent_framework_azurefunctions/_app.py | 6 ++--- .../packages/azurefunctions/tests/test_app.py | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index f2935df8f3..7970cfde4e 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -554,8 +554,7 @@ def add_agent( The app level enable_mcp_tool_trigger setting will override this setting. Raises: - ValueError: If the agent doesn't have a 'name' attribute or if an agent - with the same name is already registered + ValueError: If the agent doesn't have a 'name' attribute. """ # Get agent name from the agent's name attribute name = getattr(agent, "name", None) @@ -563,7 +562,8 @@ def add_agent( raise ValueError("Agent does not have a 'name' attribute. All agents must have a 'name' attribute.") if name in self._agent_metadata: - raise ValueError(f"Agent with name '{name}' is already registered. Each agent must have a unique name.") + logger.warning("[AgentFunctionApp] Agent '%s' is already registered, skipping duplicate.", name) + return effective_enable_http_endpoint = ( self.enable_http_endpoints if enable_http_endpoint is None else self._coerce_to_bool(enable_http_endpoint) diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 2bfd833279..f4b86ba2d7 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1387,6 +1387,29 @@ def test_init_without_workflow_does_not_call_workflow_setup(self) -> None: setup_exec.assert_not_called() setup_orch.assert_not_called() + def test_init_with_workflow_deduplicates_agents(self) -> None: + """Test that agents in both 'agents' and workflow are not double-registered.""" + from agent_framework import AgentExecutor + + mock_agent = Mock() + mock_agent.name = "SharedAgent" + + mock_executor = Mock(spec=AgentExecutor) + mock_executor.agent = mock_agent + + mock_workflow = Mock() + mock_workflow.executors = {"SharedAgent": mock_executor} + + with ( + patch.object(AgentFunctionApp, "_setup_executor_activity"), + patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), + patch.object(AgentFunctionApp, "_setup_agent_functions"), + ): + # Same agent passed explicitly AND present in workflow — should not raise + app = AgentFunctionApp(agents=[mock_agent], workflow=mock_workflow) + + assert "SharedAgent" in app.agents + def test_build_status_url(self) -> None: """Test _build_status_url constructs correct URL.""" mock_workflow = Mock() From f824d56ddf13a7eb6b23761d8d3352ea7a69ac06 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 18:06:25 -0600 Subject: [PATCH 19/29] cancel approval_task on HITL timeout --- .../azurefunctions/agent_framework_azurefunctions/_workflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index f5b99c2f68..259e52e0af 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -755,7 +755,8 @@ def run_workflow_orchestrator( pending_messages, ) else: - # Timeout occurred + # Timeout occurred — cancel the dangling external event listener + approval_task.cancel() logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) raise TimeoutError( f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." From 537e869a92c4805a5ec7b409c17956c36a616cb8 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Fri, 6 Feb 2026 18:12:07 -0600 Subject: [PATCH 20/29] update docstring --- .../agent_framework_azurefunctions/_context.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index 2a7e901faa..e50e159dae 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -30,12 +30,9 @@ class CapturingRunnerContext(RunnerContext): It captures all messages and events produced during execution without requiring durable entity storage, allowing the results to be returned to the orchestrator. - Unlike the full InProcRunnerContext, this implementation: - - Does NOT support checkpointing (always returns False for has_checkpointing) - - Does NOT support streaming (always returns False for is_streaming) - - Captures messages and events in memory for later retrieval - - The orchestrator manages state coordination; this context just captures execution output. + Unlike InProcRunnerContext, this implementation does NOT support checkpointing + (always returns False for has_checkpointing). The orchestrator manages state + coordination; this context just captures execution output. """ def __init__(self) -> None: From b5d1e26994456cdfdff4938011d4fb78d19e1011 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Mon, 9 Feb 2026 17:29:15 -0600 Subject: [PATCH 21/29] fix: align azurefunctions package with core API changes after rebase - State.import_state/export_state are now sync (removed await) - Add State.commit() before export_state() in activity execution - Rename executor parameter shared_state -> state - Rename ctx.set_shared_state/get_shared_state -> set_state/get_state (sync) - WorkflowBuilder now takes start_executor as constructor kwarg - Update WorkflowOutputEvent -> WorkflowEvent with type='output' - Update RequestInfoEvent -> WorkflowEvent[Any] - Update SharedState -> State in test imports - Update duplicate agent name tests to match new warning behavior - Update sample README API references --- .../agent_framework_azurefunctions/_app.py | 15 ++++++------- .../_context.py | 11 +++++----- .../_workflow.py | 2 +- .../azurefunctions/tests/test_multi_agent.py | 21 ++++++++++++------- .../azurefunctions/tests/test_utils.py | 19 +++++++++-------- .../09_workflow_shared_state/README.md | 8 +++---- .../09_workflow_shared_state/function_app.py | 11 +++++----- .../function_app.py | 3 +-- .../11_workflow_parallel/function_app.py | 4 +--- .../12_workflow_hitl/function_app.py | 9 ++++---- 10 files changed, 52 insertions(+), 51 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 7970cfde4e..dd95007a8c 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -20,7 +20,7 @@ import azure.durable_functions as df import azure.functions as func -from agent_framework import AgentExecutor, SupportsAgentRun, Workflow, WorkflowOutputEvent, get_logger +from agent_framework import AgentExecutor, SupportsAgentRun, Workflow, WorkflowEvent, get_logger from agent_framework_durabletask import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, @@ -301,7 +301,7 @@ async def run() -> dict[str, Any]: # Deserialize shared state values to reconstruct dataclasses/Pydantic models deserialized_state = {k: deserialize_value(v) for k, v in (shared_state_snapshot or {}).items()} original_snapshot = dict(deserialized_state) - await shared_state.import_state(deserialized_state) + shared_state.import_state(deserialized_state) if is_hitl_response: # Handle HITL response by calling the executor's @response_handler @@ -316,12 +316,13 @@ async def run() -> dict[str, Any]: await executor.execute( message=message, source_executor_ids=source_executor_ids, - shared_state=shared_state, + state=shared_state, runner_context=runner_context, ) - # Export current state and compute changes - current_state = await shared_state.export_state() + # Commit pending state changes and export + shared_state.commit() + current_state = shared_state.export_state() original_keys = set(original_snapshot.keys()) current_keys = set(current_state.keys()) @@ -337,10 +338,10 @@ async def run() -> dict[str, Any]: sent_messages = await runner_context.drain_messages() events = await runner_context.drain_events() - # Extract outputs from WorkflowOutputEvent instances + # Extract outputs from WorkflowEvent instances with type='output' outputs: list[Any] = [] for event in events: - if isinstance(event, WorkflowOutputEvent): + if isinstance(event, WorkflowEvent) and event.type == "output": outputs.append(serialize_value(event.data)) # Get pending request info events for HITL diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index e50e159dae..2c13e3d2c5 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -15,7 +15,6 @@ from agent_framework import ( CheckpointStorage, Message, - RequestInfoEvent, RunnerContext, State, WorkflowCheckpoint, @@ -39,7 +38,7 @@ def __init__(self) -> None: """Initialize the capturing runner context.""" self._messages: dict[str, list[Message]] = {} self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() - self._pending_request_info_events: dict[str, RequestInfoEvent] = {} + self._pending_request_info_events: dict[str, WorkflowEvent[Any]] = {} self._workflow_id: str | None = None self._streaming: bool = False @@ -146,8 +145,8 @@ def is_streaming(self) -> bool: # region Request Info Events - async def add_request_info_event(self, event: RequestInfoEvent) -> None: - """Add a RequestInfoEvent and track it for correlation.""" + async def add_request_info_event(self, event: WorkflowEvent[Any]) -> None: + """Add a request_info WorkflowEvent and track it for correlation.""" self._pending_request_info_events[event.request_id] = event await self.add_event(event) @@ -162,8 +161,8 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No "Human-in-the-loop scenarios should be handled at the orchestrator level." ) - async def get_pending_request_info_events(self) -> dict[str, RequestInfoEvent]: - """Get the mapping of request IDs to their corresponding RequestInfoEvent.""" + async def get_pending_request_info_events(self) -> dict[str, WorkflowEvent[Any]]: + """Get the mapping of request IDs to their corresponding request_info events.""" return dict(self._pending_request_info_events) # endregion Request Info Events diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 259e52e0af..fc11ad807b 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -927,7 +927,7 @@ async def execute_hitl_response_handler( executor=executor, source_executor_ids=[SOURCE_HITL_RESPONSE], runner_context=runner_context, - shared_state=shared_state, + state=shared_state, ) # Call the response handler diff --git a/python/packages/azurefunctions/tests/test_multi_agent.py b/python/packages/azurefunctions/tests/test_multi_agent.py index 0c0be7f35d..c03e00dd3e 100644 --- a/python/packages/azurefunctions/tests/test_multi_agent.py +++ b/python/packages/azurefunctions/tests/test_multi_agent.py @@ -40,14 +40,17 @@ def test_init_with_no_agents(self) -> None: assert len(app.agents) == 0 def test_init_with_duplicate_agent_names(self) -> None: - """Test initialization with agents having the same name raises error.""" + """Test initialization with duplicate agent names deduplicates with warning.""" agent1 = Mock() agent1.name = "TestAgent" agent2 = Mock() agent2.name = "TestAgent" - with pytest.raises(ValueError, match="already registered"): - AgentFunctionApp(agents=[agent1, agent2]) + app = AgentFunctionApp(agents=[agent1, agent2]) + + # Duplicate is skipped, only the first agent is registered + assert len(app.agents) == 1 + assert "TestAgent" in app.agents def test_init_with_agent_without_name(self) -> None: """Test initialization with agent missing name attribute raises error.""" @@ -91,8 +94,8 @@ def test_add_multiple_agents(self) -> None: assert "Agent1" in app.agents assert "Agent2" in app.agents - def test_add_agent_with_duplicate_name_raises_error(self) -> None: - """Test that adding agent with duplicate name raises ValueError.""" + def test_add_agent_with_duplicate_name_skips(self) -> None: + """Test that adding agent with duplicate name logs warning and skips.""" agent1 = Mock() agent1.name = "MyAgent" agent2 = Mock() @@ -100,9 +103,11 @@ def test_add_agent_with_duplicate_name_raises_error(self) -> None: app = AgentFunctionApp(agents=[agent1]) - # Try to add another agent with the same name - with pytest.raises(ValueError, match="already registered"): - app.add_agent(agent2) + # Duplicate is silently skipped with a warning + app.add_agent(agent2) + + # Only the original agent remains + assert len(app.agents) == 1 def test_add_agent_to_app_with_existing_agents(self) -> None: """Test adding agent to app that already has agents.""" diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py index 506710db98..4ecae82063 100644 --- a/python/packages/azurefunctions/tests/test_utils.py +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -12,7 +12,7 @@ AgentResponse, ChatMessage, Message, - WorkflowOutputEvent, + WorkflowEvent, ) from pydantic import BaseModel @@ -106,19 +106,20 @@ async def test_has_messages_returns_correct_status(self, context: CapturingRunne @pytest.mark.asyncio async def test_add_event_queues_event(self, context: CapturingRunnerContext) -> None: """Test that add_event queues events correctly.""" - event = WorkflowOutputEvent(data="output", executor_id="exec_1") + event = WorkflowEvent.output(executor_id="exec_1", data="output") await context.add_event(event) events = await context.drain_events() assert len(events) == 1 - assert isinstance(events[0], WorkflowOutputEvent) + assert isinstance(events[0], WorkflowEvent) + assert events[0].type == "output" assert events[0].data == "output" @pytest.mark.asyncio async def test_drain_events_clears_queue(self, context: CapturingRunnerContext) -> None: """Test that drain_events clears the event queue.""" - await context.add_event(WorkflowOutputEvent(data="test", executor_id="e")) + await context.add_event(WorkflowEvent.output(executor_id="e", data="test")) await context.drain_events() # First drain events = await context.drain_events() # Second drain @@ -130,14 +131,14 @@ async def test_has_events_returns_correct_status(self, context: CapturingRunnerC """Test has_events returns correct boolean.""" assert await context.has_events() is False - await context.add_event(WorkflowOutputEvent(data="test", executor_id="e")) + await context.add_event(WorkflowEvent.output(executor_id="e", data="test")) assert await context.has_events() is True @pytest.mark.asyncio async def test_next_event_waits_for_event(self, context: CapturingRunnerContext) -> None: """Test that next_event returns queued events.""" - event = WorkflowOutputEvent(data="waited", executor_id="e") + event = WorkflowEvent.output(executor_id="e", data="waited") await context.add_event(event) result = await context.next_event() @@ -169,7 +170,7 @@ def test_set_workflow_id(self, context: CapturingRunnerContext) -> None: async def test_reset_for_new_run_clears_state(self, context: CapturingRunnerContext) -> None: """Test that reset_for_new_run clears all state.""" await context.send_message(Message(data="test", target_id="t", source_id="s")) - await context.add_event(WorkflowOutputEvent(data="event", executor_id="e")) + await context.add_event(WorkflowEvent.output(executor_id="e", data="event")) context.set_streaming(True) context.reset_for_new_run() @@ -181,10 +182,10 @@ async def test_reset_for_new_run_clears_state(self, context: CapturingRunnerCont @pytest.mark.asyncio async def test_create_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None: """Test that checkpointing methods raise NotImplementedError.""" - from agent_framework import SharedState + from agent_framework import State with pytest.raises(NotImplementedError): - await context.create_checkpoint(SharedState(), 1) + await context.create_checkpoint(State(), 1) @pytest.mark.asyncio async def test_load_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None: diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md b/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md index bd6e33c916..8e3593b6d0 100644 --- a/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md @@ -9,7 +9,7 @@ This sample shows how to use `AgentFunctionApp` to execute a `WorkflowBuilder` w ## What This Sample Demonstrates 1. **Workflow Execution** - Running `WorkflowBuilder` workflows in Azure Durable Functions -2. **SharedState APIs** - Using `ctx.set_shared_state()` and `ctx.get_shared_state()` to share data +2. **State APIs** - Using `ctx.set_state()` and `ctx.get_state()` to share data 3. **Conditional Routing** - Routing messages based on spam detection results 4. **Agent + Executor Composition** - Combining AI agents with non-AI function executors @@ -25,9 +25,9 @@ store_email → spam_detector (agent) → to_detection_result → [branch]: | Executor | SharedState Operations | |----------|----------------------| -| `store_email` | `set_shared_state("email:{id}", email)`, `set_shared_state("current_email_id", id)` | -| `to_detection_result` | `get_shared_state("current_email_id")` | -| `submit_to_email_assistant` | `get_shared_state("email:{id}")` | +| `store_email` | `set_state("email:{id}", email)`, `set_state("current_email_id", id)` | +| `to_detection_result` | `get_state("current_email_id")` | +| `submit_to_email_assistant` | `get_state("email:{id}")` | SharedState allows executors to pass large payloads (like email content) by reference rather than through message routing. diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py index c2176e2832..2ad0211818 100644 --- a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py @@ -107,8 +107,8 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest - Emit an AgentExecutorRequest asking the detector to respond. """ new_email = Email(email_id=str(uuid4()), email_content=email_text) - await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) - await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) + ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email) + ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(role="user", text=new_email.email_content)], should_respond=True) @@ -130,7 +130,7 @@ async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowCont # Fallback for empty or invalid response (e.g. due to content filtering) parsed = DetectionResultAgent(is_spam=True, reason="Agent execution failed or yielded invalid JSON.") - email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) + email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY) await ctx.send_message(DetectionResult(is_spam=parsed.is_spam, reason=parsed.reason, email_id=email_id)) @@ -145,7 +145,7 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon raise RuntimeError("This executor should only handle non-spam messages.") # Load the original content by id from shared state and forward it to the assistant. - email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") + email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") await ctx.send_message( AgentExecutorRequest(messages=[ChatMessage(role="user", text=email.email_content)], should_respond=True) ) @@ -225,8 +225,7 @@ def _create_workflow() -> Workflow: # False -> submit_to_email_assistant -> email_assistant_agent -> finalize_and_send # True -> handle_spam workflow = ( - WorkflowBuilder() - .set_start_executor(store_email) + WorkflowBuilder(start_executor=store_email) .add_edge(store_email, spam_detection_agent) .add_edge(spam_detection_agent, to_detection_result) .add_edge(to_detection_result, submit_to_email_assistant, condition=get_condition(False)) diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py index eacdaad5f3..9fccf5bc1c 100644 --- a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py @@ -185,8 +185,7 @@ def _create_workflow() -> Workflow: # Build workflow workflow = ( - WorkflowBuilder() - .set_start_executor(spam_agent) + WorkflowBuilder(start_executor=spam_agent) .add_switch_case_edge_group( spam_agent, [ diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py index 1c307d4ab4..0535536951 100644 --- a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py @@ -434,9 +434,7 @@ def _create_workflow() -> Workflow: # Build workflow with parallel patterns workflow = ( - WorkflowBuilder() - # Start: Route input to parallel processors - .set_start_executor(input_router) + WorkflowBuilder(start_executor=input_router) # Pattern 1: Fan-out to two executors (run in parallel) .add_fan_out_edges( diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py index 1f856c7ea2..e38ed5a0e5 100644 --- a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py @@ -176,7 +176,7 @@ async def handle_analysis( ) # Retrieve the original submission from shared state - submission: ContentSubmission = await ctx.get_shared_state("current_submission") + submission: ContentSubmission = ctx.get_state("current_submission") await ctx.send_message(AnalysisWithSubmission(submission=submission, analysis=analysis)) @@ -233,7 +233,7 @@ async def request_review( ) # Store analysis in shared state for the response handler - await ctx.set_shared_state("pending_analysis", data) + ctx.set_state("pending_analysis", data) # Request human input - workflow will pause here # The response_type specifies what we expect back @@ -354,7 +354,7 @@ async def route_input( ) # Store submission in shared state for later retrieval - await ctx.set_shared_state("current_submission", submission) + ctx.set_state("current_submission", submission) # Create the agent request message = ( @@ -400,8 +400,7 @@ def _create_workflow() -> Workflow: # input_router -> content_analyzer_agent -> content_analyzer_executor # -> human_review_executor (HITL pause here) -> publish_executor workflow = ( - WorkflowBuilder() - .set_start_executor(input_router) + WorkflowBuilder(start_executor=input_router) .add_edge(input_router, content_analyzer_agent) .add_edge(content_analyzer_agent, content_analyzer_executor) .add_edge(content_analyzer_executor, human_review_executor) From 02315c116baae705adc75efaf12828c5c3ea663d Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Mon, 9 Feb 2026 17:47:32 -0600 Subject: [PATCH 22/29] fix sample check errors --- .../agent_framework/_workflows/__init__.py | 2 +- .../09_workflow_shared_state/function_app.py | 43 +++---- .../function_app.py | 55 ++++---- .../11_workflow_parallel/function_app.py | 118 ++++++++---------- .../12_workflow_hitl/function_app.py | 55 ++++---- 5 files changed, 124 insertions(+), 149 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index c89c707284..42e70a8c96 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -112,8 +112,8 @@ "InProcRunnerContext", "Runner", "RunnerContext", - "State", "SingleEdgeGroup", + "State", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", "SwitchCaseEdgeGroup", diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py index 2ad0211818..82673dd329 100644 --- a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py @@ -34,12 +34,11 @@ executor, ) from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_azurefunctions import AgentFunctionApp from azure.identity import AzureCliCredential from pydantic import BaseModel, ValidationError from typing_extensions import Never -from agent_framework_azurefunctions import AgentFunctionApp - logger = logging.getLogger(__name__) # Environment variable names @@ -224,7 +223,7 @@ def _create_workflow() -> Workflow: # store_email -> spam_detection_agent -> to_detection_result -> branch: # False -> submit_to_email_assistant -> email_assistant_agent -> finalize_and_send # True -> handle_spam - workflow = ( + return ( WorkflowBuilder(start_executor=store_email) .add_edge(store_email, spam_detection_agent) .add_edge(spam_detection_agent, to_detection_result) @@ -235,8 +234,6 @@ def _create_workflow() -> Workflow: .build() ) - return workflow - # ============================================================================ # Application Entry Point @@ -254,30 +251,28 @@ def launch(durable: bool = True) -> AgentFunctionApp | None: # Azure Functions mode with Durable Functions # SharedState is enabled by default, which this sample requires for storing emails workflow = _create_workflow() - app = AgentFunctionApp(workflow=workflow, enable_health_check=True) - return app - else: - # Pure MAF mode with DevUI for local development - from pathlib import Path + return AgentFunctionApp(workflow=workflow, enable_health_check=True) + # Pure MAF mode with DevUI for local development + from pathlib import Path - from agent_framework.devui import serve - from dotenv import load_dotenv + from agent_framework.devui import serve + from dotenv import load_dotenv - env_path = Path(__file__).parent / ".env" - load_dotenv(dotenv_path=env_path) + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) - logger.info("Starting Workflow Shared State Sample in MAF mode") - logger.info("Available at: http://localhost:8096") - logger.info("\nThis workflow demonstrates:") - logger.info("- Shared state to decouple large payloads from messages") - logger.info("- Structured agent outputs with Pydantic models") - logger.info("- Conditional routing based on detection results") - logger.info("\nFlow: store_email -> spam_detection -> branch (spam/not spam)") + logger.info("Starting Workflow Shared State Sample in MAF mode") + logger.info("Available at: http://localhost:8096") + logger.info("\nThis workflow demonstrates:") + logger.info("- Shared state to decouple large payloads from messages") + logger.info("- Structured agent outputs with Pydantic models") + logger.info("- Conditional routing based on detection results") + logger.info("\nFlow: store_email -> spam_detection -> branch (spam/not spam)") - workflow = _create_workflow() - serve(entities=[workflow], port=8096, auto_open=True) + workflow = _create_workflow() + serve(entities=[workflow], port=8096, auto_open=True) - return None + return None # Default: Azure Functions mode diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py index 9fccf5bc1c..831d860806 100644 --- a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py +++ b/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py @@ -18,9 +18,9 @@ import logging import os +from pathlib import Path from typing import Any -from pathlib import Path from agent_framework import ( AgentExecutorResponse, Case, @@ -32,9 +32,9 @@ handler, ) from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_azurefunctions import AgentFunctionApp from azure.identity import AzureCliCredential from pydantic import BaseModel, ValidationError -from agent_framework_azurefunctions import AgentFunctionApp from typing_extensions import Never logger = logging.getLogger(__name__) @@ -125,7 +125,7 @@ async def handle_spam_result( spam_result = SpamDetectionResult.model_validate_json(text) except ValidationError: spam_result = SpamDetectionResult(is_spam=True, reason="Invalid JSON from agent") - + message = f"Email marked as spam: {spam_result.reason}" await ctx.yield_output(message) @@ -145,7 +145,7 @@ async def handle_email_response( email_response = EmailResponse.model_validate_json(text) except ValidationError: email_response = EmailResponse(response="Error generating response.") - + message = f"Email sent: {email_response.response}" await ctx.yield_output(message) @@ -184,7 +184,7 @@ def _create_workflow() -> Workflow: email_sender = EmailSenderExecutor(id="email_sender") # Build workflow - workflow = ( + return ( WorkflowBuilder(start_executor=spam_agent) .add_switch_case_edge_group( spam_agent, @@ -196,7 +196,6 @@ def _create_workflow() -> Workflow: .add_edge(email_agent, email_sender) .build() ) - return workflow def launch(durable: bool = True) -> AgentFunctionApp | None: @@ -205,31 +204,29 @@ def launch(durable: bool = True) -> AgentFunctionApp | None: if durable: # Initialize app workflow = _create_workflow() - app = AgentFunctionApp(workflow=workflow) - return app - else: - # Launch the spam detection workflow in DevUI - from agent_framework.devui import serve - from dotenv import load_dotenv - - # Load environment variables from .env file - env_path = Path(__file__).parent / ".env" - load_dotenv(dotenv_path=env_path) - - logger.info("Starting Multi-Agent Spam Detection Workflow") - logger.info("Available at: http://localhost:8094") - logger.info("\nThis workflow demonstrates:") - logger.info("- Conditional routing based on spam detection") - logger.info("- Mixing AI agents with non-AI executors (like activity functions)") - logger.info("- Path 1 (spam): SpamDetector Agent → SpamHandler Executor") - logger.info("- Path 2 (legitimate): SpamDetector Agent → EmailAssistant Agent → EmailSender Executor") + return AgentFunctionApp(workflow=workflow) + # Launch the spam detection workflow in DevUI + from agent_framework.devui import serve + from dotenv import load_dotenv + + # Load environment variables from .env file + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) + + logger.info("Starting Multi-Agent Spam Detection Workflow") + logger.info("Available at: http://localhost:8094") + logger.info("\nThis workflow demonstrates:") + logger.info("- Conditional routing based on spam detection") + logger.info("- Mixing AI agents with non-AI executors (like activity functions)") + logger.info("- Path 1 (spam): SpamDetector Agent → SpamHandler Executor") + logger.info("- Path 2 (legitimate): SpamDetector Agent → EmailAssistant Agent → EmailSender Executor") + + workflow = _create_workflow() + serve(entities=[workflow], port=8094, auto_open=True) + + return None - workflow = _create_workflow() - serve(entities=[workflow], port=8094, auto_open=True) - return None - - # Default: Azure Functions mode # Run with `python function_app.py --maf` for pure MAF mode with DevUI app = launch(durable=True) diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py index 0535536951..38b3f481a9 100644 --- a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py @@ -37,12 +37,11 @@ handler, ) from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_azurefunctions import AgentFunctionApp from azure.identity import AzureCliCredential from pydantic import BaseModel from typing_extensions import Never -from agent_framework_azurefunctions import AgentFunctionApp - logger = logging.getLogger(__name__) AZURE_OPENAI_ENDPOINT_ENV = "AZURE_OPENAI_ENDPOINT" @@ -104,7 +103,7 @@ class ProcessorResult: has_numbers: bool -@dataclass +@dataclass class AggregatedResults: """Aggregated results from parallel processors.""" document_id: str @@ -137,7 +136,7 @@ async def input_router( ctx: WorkflowContext[DocumentInput] ) -> None: """Route input document to parallel processors. - + Accepts a JSON string from the HTTP request and converts to DocumentInput. """ # Parse the JSON string input @@ -152,16 +151,16 @@ async def input_router( @executor(id="word_count_processor") async def word_count_processor( - doc: DocumentInput, + doc: DocumentInput, ctx: WorkflowContext[ProcessorResult] ) -> None: """Process document and count words - runs as an activity.""" logger.info("[word_count_processor] Processing document: %s", doc.document_id) - + word_count = len(doc.content.split()) char_count = len(doc.content) has_numbers = any(c.isdigit() for c in doc.content) - + result = ProcessorResult( processor_name="word_count", document_id=doc.document_id, @@ -170,7 +169,7 @@ async def word_count_processor( char_count=char_count, has_numbers=has_numbers, ) - + await ctx.send_message(result) @@ -181,13 +180,13 @@ async def format_analyzer_processor( ) -> None: """Analyze document format - runs as an activity in parallel with word_count.""" logger.info("[format_analyzer_processor] Processing document: %s", doc.document_id) - + # Simple format analysis - lines = doc.content.split('\n') + lines = doc.content.split("\n") word_count = len(lines) # Using line count as "word count" for this processor char_count = sum(len(line) for line in lines) - has_numbers = doc.content.count('.') > 0 # Check for sentences - + has_numbers = doc.content.count(".") > 0 # Check for sentences + result = ProcessorResult( processor_name="format_analyzer", document_id=doc.document_id, @@ -196,7 +195,7 @@ async def format_analyzer_processor( char_count=char_count, has_numbers=has_numbers, ) - + await ctx.send_message(result) @@ -207,17 +206,17 @@ async def aggregator( ) -> None: """Aggregate results from parallel processors - receives fan-in input.""" logger.info("[aggregator] Aggregating %d results", len(results)) - + # Extract document info from the first result (all have the same content) document_id = results[0].document_id if results else "unknown" content = results[0].content if results else "" - + aggregated = AggregatedResults( document_id=document_id, content=content, processor_results=results, ) - + await ctx.send_message(aggregated) @@ -228,7 +227,7 @@ async def prepare_for_agents( ) -> None: """Prepare content for agent analysis - broadcasts to multiple agents.""" logger.info("[prepare_for_agents] Preparing content for agents") - + # Send the original content to agents for analysis await ctx.send_message(aggregated.content) @@ -239,24 +238,24 @@ async def prepare_for_mixed( ctx: WorkflowContext[str] ) -> None: """Prepare results for mixed agent+executor parallel processing. - + Combines agent analysis results into a string that can be consumed by both the SummaryAgent and the statistics_processor in parallel. """ logger.info("[prepare_for_mixed] Preparing for mixed parallel pattern") - + sentiment_text = "" keyword_text = "" - + for analysis in analyses: executor_id = analysis.executor_id text = analysis.agent_response.text if analysis.agent_response else "" - + if executor_id == SENTIMENT_AGENT_NAME: sentiment_text = text elif executor_id == KEYWORD_AGENT_NAME: keyword_text = text - + # Combine into a string that both agent and executor can process combined = f"Sentiment Analysis: {sentiment_text}\n\nKeyword Extraction: {keyword_text}" await ctx.send_message(combined) @@ -269,12 +268,12 @@ async def statistics_processor( ) -> None: """Calculate statistics from the analysis - runs in parallel with SummaryAgent.""" logger.info("[statistics_processor] Calculating statistics") - + # Calculate some statistics from the combined analysis word_count = len(analysis_text.split()) char_count = len(analysis_text) has_numbers = any(c.isdigit() for c in analysis_text) - + result = ProcessorResult( processor_name="statistics", document_id="analysis", @@ -297,9 +296,9 @@ async def compile_report( ) -> None: """Compile final report from mixed agent + processor results.""" logger.info("[final_report] Compiling report from %d analyses", len(analyses)) - + report_parts = ["=== Document Analysis Report ===\n"] - + for analysis in analyses: if isinstance(analysis, AgentExecutorResponse): agent_name = analysis.executor_id @@ -309,10 +308,10 @@ async def compile_report( text = f"Words: {analysis.word_count}, Chars: {analysis.char_count}" else: continue - + report_parts.append(f"\n--- {agent_name} ---") report_parts.append(text) - + final_report = "\n".join(report_parts) await ctx.yield_output(final_report) @@ -328,9 +327,9 @@ async def collect_mixed_results( ) -> None: """Collect and format results from mixed parallel execution.""" logger.info("[mixed_collector] Collecting %d mixed results", len(results)) - + output_parts = ["=== Mixed Parallel Execution Results ===\n"] - + for result in results: if isinstance(result, AgentExecutorResponse): output_parts.append(f"[Agent: {result.executor_id}]") @@ -338,7 +337,7 @@ async def collect_mixed_results( elif isinstance(result, ProcessorResult): output_parts.append(f"[Processor: {result.processor_name}]") output_parts.append(f" Words: {result.word_count}, Chars: {result.char_count}") - + await ctx.yield_output("\n".join(output_parts)) @@ -373,21 +372,21 @@ def _build_client_kwargs() -> dict[str, Any]: def _create_workflow() -> Workflow: """Create the parallel workflow definition. - + Workflow structure demonstrating three parallel patterns: - + Pattern 1: Two Executors in Parallel (Fan-out/Fan-in to activities) ──────────────────────────────────────────────────────────────────── ┌─> word_count_processor ─────┐ input_router ──┤ ├──> aggregator └─> format_analyzer_processor ─┘ - + Pattern 2: Two Agents in Parallel (Fan-out to entities) ──────────────────────────────────────────────────────── prepare_for_agents ─┬─> SentimentAgent ──┐ └─> KeywordAgent ────┤ └──> prepare_for_mixed - + Pattern 3: Mixed Agent + Executor in Parallel ────────────────────────────────────────────── prepare_for_mixed ─┬─> SummaryAgent ────────┐ @@ -433,52 +432,42 @@ def _create_workflow() -> Workflow: final_report_executor = FinalReportExecutor(id="final_report") # Build workflow with parallel patterns - workflow = ( + return ( WorkflowBuilder(start_executor=input_router) - # Pattern 1: Fan-out to two executors (run in parallel) .add_fan_out_edges( source=input_router, targets=[word_count_processor, format_analyzer_processor], ) - # Fan-in: Both processors send results to aggregator .add_fan_in_edges( sources=[word_count_processor, format_analyzer_processor], target=aggregator, ) - # Prepare content for agent analysis .add_edge(aggregator, prepare_for_agents) - # Pattern 2: Fan-out to two agents (run in parallel) .add_fan_out_edges( source=prepare_for_agents, targets=[sentiment_agent, keyword_agent], ) - # Fan-in: Collect agent results into prepare_for_mixed .add_fan_in_edges( sources=[sentiment_agent, keyword_agent], target=prepare_for_mixed, ) - # Pattern 3: Fan-out to one agent + one executor (mixed parallel) .add_fan_out_edges( source=prepare_for_mixed, targets=[summary_agent, statistics_processor], ) - # Final fan-in: Collect mixed results .add_fan_in_edges( sources=[summary_agent, statistics_processor], target=final_report_executor, ) - .build() ) - - return workflow # ============================================================================ @@ -492,31 +481,30 @@ def launch(durable: bool = True) -> AgentFunctionApp | None: if durable: workflow = _create_workflow() - app = AgentFunctionApp( - workflow=workflow, + return AgentFunctionApp( + workflow=workflow, enable_health_check=True, ) - return app - else: - from pathlib import Path - from agent_framework.devui import serve - from dotenv import load_dotenv + from pathlib import Path - env_path = Path(__file__).parent / ".env" - load_dotenv(dotenv_path=env_path) + from agent_framework.devui import serve + from dotenv import load_dotenv - logger.info("Starting Parallel Workflow Sample") - logger.info("Available at: http://localhost:8095") - logger.info("\nThis workflow demonstrates:") - logger.info("- Pattern 1: Two executors running in parallel") - logger.info("- Pattern 2: Two agents running in parallel") - logger.info("- Pattern 3: Mixed agent + executor running in parallel") - logger.info("- Fan-in aggregation of parallel results") + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) - workflow = _create_workflow() - serve(entities=[workflow], port=8095, auto_open=True) + logger.info("Starting Parallel Workflow Sample") + logger.info("Available at: http://localhost:8095") + logger.info("\nThis workflow demonstrates:") + logger.info("- Pattern 1: Two executors running in parallel") + logger.info("- Pattern 2: Two agents running in parallel") + logger.info("- Pattern 3: Mixed agent + executor running in parallel") + logger.info("- Fan-in aggregation of parallel results") + + workflow = _create_workflow() + serve(entities=[workflow], port=8095, auto_open=True) - return None + return None # Default: Azure Functions mode diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py index e38ed5a0e5..c18b55163d 100644 --- a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py @@ -41,12 +41,11 @@ response_handler, ) from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_azurefunctions import AgentFunctionApp from azure.identity import AzureCliCredential from pydantic import BaseModel, ValidationError from typing_extensions import Never -from agent_framework_azurefunctions import AgentFunctionApp - logger = logging.getLogger(__name__) # Environment variable names @@ -85,7 +84,7 @@ class ContentSubmission: @dataclass class HumanApprovalRequest: """Request sent to human reviewer for approval. - + This is the payload passed to ctx.request_info() and will be exposed via the orchestration status for external systems to retrieve. """ @@ -100,7 +99,7 @@ class HumanApprovalRequest: class HumanApprovalResponse(BaseModel): """Response from human reviewer. - + This is what the external system must send back via the HITL response endpoint. """ @@ -183,7 +182,7 @@ async def handle_analysis( class HumanReviewExecutor(Executor): """Requests human approval using MAF's request_info pattern. - + This executor demonstrates the core HITL pattern: 1. Receives the AI analysis result 2. Calls ctx.request_info() to pause and request human input @@ -200,7 +199,7 @@ async def request_review( ctx: WorkflowContext, ) -> None: """Request human review for the content. - + This method: 1. Constructs the approval request with all context 2. Calls request_info to pause the workflow @@ -250,7 +249,7 @@ async def handle_approval_response( ctx: WorkflowContext[ModerationResult], ) -> None: """Process the human reviewer's decision. - + This method is called automatically when a response to request_info is received. The original_request contains the HumanApprovalRequest we sent. The response contains the HumanApprovalResponse from the reviewer. @@ -399,7 +398,7 @@ def _create_workflow() -> Workflow: # Flow: # input_router -> content_analyzer_agent -> content_analyzer_executor # -> human_review_executor (HITL pause here) -> publish_executor - workflow = ( + return ( WorkflowBuilder(start_executor=input_router) .add_edge(input_router, content_analyzer_agent) .add_edge(content_analyzer_agent, content_analyzer_executor) @@ -408,8 +407,6 @@ def _create_workflow() -> Workflow: .build() ) - return workflow - # ============================================================================ # Application Entry Point @@ -418,7 +415,7 @@ def _create_workflow() -> Workflow: def launch(durable: bool = True) -> AgentFunctionApp | None: """Launch the function app or DevUI. - + Args: durable: If True, returns AgentFunctionApp for Azure Functions. If False, launches DevUI for local MAF development. @@ -431,30 +428,28 @@ def launch(durable: bool = True) -> AgentFunctionApp | None: # - POST /api/workflow/respond/{instanceId}/{requestId} - Send HITL response # - GET /api/health - Health check workflow = _create_workflow() - app = AgentFunctionApp(workflow=workflow, enable_health_check=True) - return app - else: - # Pure MAF mode with DevUI for local development - from pathlib import Path + return AgentFunctionApp(workflow=workflow, enable_health_check=True) + # Pure MAF mode with DevUI for local development + from pathlib import Path - from agent_framework.devui import serve - from dotenv import load_dotenv + from agent_framework.devui import serve + from dotenv import load_dotenv - env_path = Path(__file__).parent / ".env" - load_dotenv(dotenv_path=env_path) + env_path = Path(__file__).parent / ".env" + load_dotenv(dotenv_path=env_path) - logger.info("Starting Workflow HITL Sample in MAF mode") - logger.info("Available at: http://localhost:8096") - logger.info("\nThis workflow demonstrates:") - logger.info("- Human-in-the-loop using request_info / @response_handler pattern") - logger.info("- AI content analysis with structured output") - logger.info("- Human approval workflow integration") - logger.info("\nFlow: InputRouter -> ContentAnalyzer Agent -> HumanReview -> Publish") + logger.info("Starting Workflow HITL Sample in MAF mode") + logger.info("Available at: http://localhost:8096") + logger.info("\nThis workflow demonstrates:") + logger.info("- Human-in-the-loop using request_info / @response_handler pattern") + logger.info("- AI content analysis with structured output") + logger.info("- Human approval workflow integration") + logger.info("\nFlow: InputRouter -> ContentAnalyzer Agent -> HumanReview -> Publish") - workflow = _create_workflow() - serve(entities=[workflow], port=8096, auto_open=True) + workflow = _create_workflow() + serve(entities=[workflow], port=8096, auto_open=True) - return None + return None # Default: Azure Functions mode From 440ffa8d73d526f5c86d440057403c479baba00d Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Mon, 9 Feb 2026 17:51:41 -0600 Subject: [PATCH 23/29] fix mypy issues --- .../agent_framework_azurefunctions/_serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index dd72437547..86bcddd357 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -87,9 +87,9 @@ def serialize_value(value: Any) -> Any: # Handle dataclasses ourselves so that nested Pydantic models get the # PYDANTIC_MARKER treatment instead of being str()'d by core encoding. if is_dataclass(value) and not isinstance(value, type): - cls = type(value) + dc_cls = type(value) return { - DATACLASS_MARKER: f"{cls.__module__}:{cls.__name__}", + DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}", **{field.name: serialize_value(getattr(value, field.name)) for field in dc_fields(value)}, } From 0f6b00779bf8266a5a18115e94f35a17fcb6a445 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Mon, 9 Feb 2026 18:09:37 -0600 Subject: [PATCH 24/29] fix trailing white spaces --- .../azure_functions/11_workflow_parallel/README.md | 2 +- .../getting_started/azure_functions/12_workflow_hitl/demo.http | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md b/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md index 07c48b73e6..9d0c8a1878 100644 --- a/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md +++ b/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md @@ -7,7 +7,7 @@ This sample demonstrates **parallel execution** of executors and agents in Azure This sample showcases three different parallel execution patterns: 1. **Two Executors in Parallel** - Fan-out to multiple activities -2. **Two Agents in Parallel** - Fan-out to multiple entities +2. **Two Agents in Parallel** - Fan-out to multiple entities 3. **Mixed Execution** - Agents and executors can run concurrently ## Workflow Architecture diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http b/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http index b59ae8b61c..9ed4c368c9 100644 --- a/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http +++ b/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http @@ -92,7 +92,7 @@ Content-Type: application/json ### ============================================================================ ### Example Workflow - Complete Happy Path ### ============================================================================ -### +### ### Step 1: Start workflow with content ### POST http://localhost:7071/api/workflow/run ### -> Returns instanceId: "abc123..." From a44af9c83397833558c9d11546de8c645a744ae9 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 10 Feb 2026 11:01:50 -0600 Subject: [PATCH 25/29] fix test imports --- .../test_09_workflow_shared_state.py | 26 ++++++++------ .../test_10_workflow_no_shared_state.py | 28 ++++++++------- .../test_11_workflow_parallel.py | 34 +++++++++++-------- .../test_12_workflow_hitl.py | 33 +++++++++--------- 4 files changed, 66 insertions(+), 55 deletions(-) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py b/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py index c9e6b16644..26bb20e5b4 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py @@ -20,13 +20,11 @@ """ import pytest -from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ pytest.mark.sample("09_workflow_shared_state"), pytest.mark.usefixtures("function_app_for_test"), - skip_if_azure_functions_integration_tests_disabled, ] @@ -34,23 +32,29 @@ class TestWorkflowSharedState: """Tests for 09_workflow_shared_state sample.""" - def test_workflow_with_spam_email(self, base_url: str) -> None: + @pytest.fixture(autouse=True) + def _setup(self, base_url: str, sample_helper) -> None: + """Provide the helper and base URL for each test.""" + self.base_url = base_url + self.helper = sample_helper + + def test_workflow_with_spam_email(self) -> None: """Test workflow with spam email content - should be detected and handled as spam.""" spam_content = "URGENT! You have won $1,000,000! Click here to claim your prize now before it expires!" # Start orchestration with spam email - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", spam_content) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", spam_content) assert response.status_code == 202 data = response.json() assert "instanceId" in data assert "statusQueryGetUri" in data # Wait for completion - status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) assert status["runtimeStatus"] == "Completed" assert "output" in status - def test_workflow_with_legitimate_email(self, base_url: str) -> None: + def test_workflow_with_legitimate_email(self) -> None: """Test workflow with legitimate email content - should generate response.""" legitimate_content = ( "Hi team, just a reminder about the sprint planning meeting tomorrow at 10 AM. " @@ -58,18 +62,18 @@ def test_workflow_with_legitimate_email(self, base_url: str) -> None: ) # Start orchestration with legitimate email - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", legitimate_content) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", legitimate_content) assert response.status_code == 202 data = response.json() assert "instanceId" in data assert "statusQueryGetUri" in data # Wait for completion - status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) assert status["runtimeStatus"] == "Completed" assert "output" in status - def test_workflow_with_phishing_email(self, base_url: str) -> None: + def test_workflow_with_phishing_email(self) -> None: """Test workflow with phishing email - should be detected as spam.""" phishing_content = ( "Dear Customer, Your account has been compromised! " @@ -77,13 +81,13 @@ def test_workflow_with_phishing_email(self, base_url: str) -> None: ) # Start orchestration with phishing email - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", phishing_content) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", phishing_content) assert response.status_code == 202 data = response.json() assert "instanceId" in data # Wait for completion - status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) assert status["runtimeStatus"] == "Completed" diff --git a/python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py b/python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py index 0ceb4c72eb..88b610ac70 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_10_workflow_no_shared_state.py @@ -20,13 +20,11 @@ """ import pytest -from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ pytest.mark.sample("10_workflow_no_shared_state"), pytest.mark.usefixtures("function_app_for_test"), - skip_if_azure_functions_integration_tests_disabled, ] @@ -34,7 +32,13 @@ class TestWorkflowNoSharedState: """Tests for 10_workflow_no_shared_state sample.""" - def test_workflow_with_spam_email(self, base_url: str) -> None: + @pytest.fixture(autouse=True) + def _setup(self, base_url: str, sample_helper) -> None: + """Provide the helper and base URL for each test.""" + self.base_url = base_url + self.helper = sample_helper + + def test_workflow_with_spam_email(self) -> None: """Test workflow with spam email - should detect and handle as spam.""" payload = { "email_id": "email-test-001", @@ -45,18 +49,18 @@ def test_workflow_with_spam_email(self, base_url: str) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() assert "instanceId" in data assert "statusQueryGetUri" in data # Wait for completion - status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) assert status["runtimeStatus"] == "Completed" assert "output" in status - def test_workflow_with_legitimate_email(self, base_url: str) -> None: + def test_workflow_with_legitimate_email(self) -> None: """Test workflow with legitimate email - should draft a response.""" payload = { "email_id": "email-test-002", @@ -67,18 +71,18 @@ def test_workflow_with_legitimate_email(self, base_url: str) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() assert "instanceId" in data assert "statusQueryGetUri" in data # Wait for completion - status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) + status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"]) assert status["runtimeStatus"] == "Completed" assert "output" in status - def test_workflow_status_endpoint(self, base_url: str) -> None: + def test_workflow_status_endpoint(self) -> None: """Test that the status endpoint works correctly.""" payload = { "email_id": "email-test-003", @@ -86,13 +90,13 @@ def test_workflow_status_endpoint(self, base_url: str) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() instance_id = data["instanceId"] # Check status using the workflow status endpoint - status_response = SampleTestHelper.get(f"{base_url}/api/workflow/status/{instance_id}") + status_response = self.helper.get(f"{self.base_url}/api/workflow/status/{instance_id}") assert status_response.status_code == 200 status = status_response.json() assert "instanceId" in status @@ -100,7 +104,7 @@ def test_workflow_status_endpoint(self, base_url: str) -> None: assert "runtimeStatus" in status # Wait for completion to clean up - SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + self.helper.wait_for_orchestration(data["statusQueryGetUri"]) if __name__ == "__main__": diff --git a/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py index 7430cca96f..81f7466e5d 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py @@ -22,13 +22,11 @@ """ import pytest -from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ pytest.mark.sample("11_workflow_parallel"), pytest.mark.usefixtures("function_app_for_test"), - skip_if_azure_functions_integration_tests_disabled, ] @@ -36,7 +34,13 @@ class TestWorkflowParallel: """Tests for 11_workflow_parallel sample.""" - def test_parallel_workflow_document_analysis(self, base_url: str) -> None: + @pytest.fixture(autouse=True) + def _setup(self, base_url: str, sample_helper) -> None: + """Provide the helper and base URL for each test.""" + self.base_url = base_url + self.helper = sample_helper + + def test_parallel_workflow_document_analysis(self) -> None: """Test parallel workflow with a standard document.""" payload = { "document_id": "doc-test-001", @@ -50,21 +54,21 @@ def test_parallel_workflow_document_analysis(self, base_url: str) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() assert "instanceId" in data assert "statusQueryGetUri" in data # Wait for completion - parallel workflows may take longer - status = SampleTestHelper.wait_for_orchestration_with_output( + status = self.helper.wait_for_orchestration_with_output( data["statusQueryGetUri"], max_wait=300, # 5 minutes for parallel execution ) assert status["runtimeStatus"] == "Completed" assert "output" in status - def test_parallel_workflow_short_document(self, base_url: str) -> None: + def test_parallel_workflow_short_document(self) -> None: """Test parallel workflow with a short document.""" payload = { "document_id": "doc-test-002", @@ -72,18 +76,18 @@ def test_parallel_workflow_short_document(self, base_url: str) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() assert "instanceId" in data assert "statusQueryGetUri" in data # Wait for completion - status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) + status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) assert status["runtimeStatus"] == "Completed" assert "output" in status - def test_parallel_workflow_technical_document(self, base_url: str) -> None: + def test_parallel_workflow_technical_document(self) -> None: """Test parallel workflow with a technical document.""" payload = { "document_id": "doc-test-003", @@ -97,16 +101,16 @@ def test_parallel_workflow_technical_document(self, base_url: str) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() assert "instanceId" in data # Wait for completion - status = SampleTestHelper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) + status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) assert status["runtimeStatus"] == "Completed" - def test_workflow_status_endpoint(self, base_url: str) -> None: + def test_workflow_status_endpoint(self) -> None: """Test that the workflow status endpoint works correctly.""" payload = { "document_id": "doc-test-004", @@ -114,20 +118,20 @@ def test_workflow_status_endpoint(self, base_url: str) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() instance_id = data["instanceId"] # Check status - status_response = SampleTestHelper.get(f"{base_url}/api/workflow/status/{instance_id}") + status_response = self.helper.get(f"{self.base_url}/api/workflow/status/{instance_id}") assert status_response.status_code == 200 status = status_response.json() assert "instanceId" in status assert status["instanceId"] == instance_id # Wait for completion - SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"], max_wait=300) + self.helper.wait_for_orchestration(data["statusQueryGetUri"], max_wait=300) if __name__ == "__main__": diff --git a/python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py b/python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py index 713e28d63e..8f3c87e339 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_12_workflow_hitl.py @@ -22,13 +22,11 @@ import time import pytest -from conftest import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ pytest.mark.sample("12_workflow_hitl"), pytest.mark.usefixtures("function_app_for_test"), - skip_if_azure_functions_integration_tests_disabled, ] @@ -37,15 +35,16 @@ class TestWorkflowHITL: """Tests for 12_workflow_hitl sample.""" @pytest.fixture(autouse=True) - def _set_base_url(self, base_url: str) -> None: - """Store the base URL for tests.""" + def _setup(self, base_url: str, sample_helper) -> None: + """Provide the helper and base URL for each test.""" self.base_url = base_url + self.helper = sample_helper def _wait_for_hitl_request(self, instance_id: str, timeout: int = 40) -> dict: """Polls for a pending HITL request.""" start_time = time.time() while time.time() - start_time < timeout: - status_response = SampleTestHelper.get(f"{self.base_url}/api/workflow/status/{instance_id}") + status_response = self.helper.get(f"{self.base_url}/api/workflow/status/{instance_id}") if status_response.status_code == 200: status = status_response.json() pending_requests = status.get("pendingHumanInputRequests", []) @@ -68,7 +67,7 @@ def test_hitl_workflow_approval(self) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() assert "instanceId" in data @@ -87,14 +86,14 @@ def test_hitl_workflow_approval(self) -> None: request_id = pending_requests[0]["requestId"] # Send approval - approval_response = SampleTestHelper.post_json( + approval_response = self.helper.post_json( f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", {"approved": True, "reviewer_notes": "Content is appropriate and well-written."}, ) assert approval_response.status_code == 200 # Wait for orchestration to complete - final_status = SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + final_status = self.helper.wait_for_orchestration(data["statusQueryGetUri"]) assert final_status["runtimeStatus"] == "Completed" assert "output" in final_status @@ -111,7 +110,7 @@ def test_hitl_workflow_rejection(self) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() instance_id = data["instanceId"] @@ -125,14 +124,14 @@ def test_hitl_workflow_rejection(self) -> None: request_id = pending_requests[0]["requestId"] # Send rejection - rejection_response = SampleTestHelper.post_json( + rejection_response = self.helper.post_json( f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", {"approved": False, "reviewer_notes": "Content appears to be spam/scam material."}, ) assert rejection_response.status_code == 200 # Wait for orchestration to complete - final_status = SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + final_status = self.helper.wait_for_orchestration(data["statusQueryGetUri"]) assert final_status["runtimeStatus"] == "Completed" assert "output" in final_status # The output should indicate rejection @@ -149,7 +148,7 @@ def test_hitl_workflow_status_endpoint(self) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() instance_id = data["instanceId"] @@ -167,13 +166,13 @@ def test_hitl_workflow_status_endpoint(self) -> None: pending_requests = status.get("pendingHumanInputRequests", []) if pending_requests: request_id = pending_requests[0]["requestId"] - SampleTestHelper.post_json( + self.helper.post_json( f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", {"approved": True, "reviewer_notes": ""}, ) # Wait for completion - SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + self.helper.wait_for_orchestration(data["statusQueryGetUri"]) def test_hitl_workflow_with_neutral_content(self) -> None: """Test HITL workflow with neutral content that should get medium risk.""" @@ -188,7 +187,7 @@ def test_hitl_workflow_with_neutral_content(self) -> None: } # Start orchestration - response = SampleTestHelper.post_json(f"{self.base_url}/api/workflow/run", payload) + response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() instance_id = data["instanceId"] @@ -201,13 +200,13 @@ def test_hitl_workflow_with_neutral_content(self) -> None: request_id = pending_requests[0]["requestId"] # Approve - SampleTestHelper.post_json( + self.helper.post_json( f"{self.base_url}/api/workflow/respond/{instance_id}/{request_id}", {"approved": True, "reviewer_notes": "Approved after review."}, ) # Wait for completion - final_status = SampleTestHelper.wait_for_orchestration(data["statusQueryGetUri"]) + final_status = self.helper.wait_for_orchestration(data["statusQueryGetUri"]) assert final_status["runtimeStatus"] == "Completed" From 6894bfb45c3c113685fab8a20d69a7d8b8910416 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Thu, 12 Feb 2026 14:00:34 -0600 Subject: [PATCH 26/29] feat: add durable workflow samples and adapt to main branch changes - Add workflow samples 09-12 to 04-hosting/azure_functions/ - Adapt to ChatMessage -> Message rename from main - Adapt to pickle-based checkpoint encoding from main - Simplify _serialization.py to delegate to core encode/decode - Fix Message -> WorkflowMessage disambiguation in _context.py - Remove non-existent _checkpoint_summary import --- .../_context.py | 8 +- .../_serialization.py | 110 +++--------------- .../_workflow.py | 8 +- .../azurefunctions/tests/test_utils.py | 40 +++---- .../azurefunctions/tests/test_workflow.py | 18 +-- .../agent_framework/_workflows/__init__.py | 1 - .../09_workflow_shared_state/.gitignore | 0 .../09_workflow_shared_state/README.md | 0 .../09_workflow_shared_state/demo.http | 0 .../09_workflow_shared_state/function_app.py | 6 +- .../09_workflow_shared_state/host.json | 0 .../local.settings.json.sample | 0 .../09_workflow_shared_state/requirements.txt | 0 .../10_workflow_no_shared_state/.env.sample | 0 .../10_workflow_no_shared_state/.gitignore | 0 .../10_workflow_no_shared_state/README.md | 0 .../10_workflow_no_shared_state/demo.http | 0 .../function_app.py | 0 .../10_workflow_no_shared_state/host.json | 0 .../local.settings.json.sample | 0 .../requirements.txt | 0 .../11_workflow_parallel/.env.template | 0 .../11_workflow_parallel/.gitignore | 0 .../11_workflow_parallel/README.md | 0 .../11_workflow_parallel/demo.http | 0 .../11_workflow_parallel/function_app.py | 0 .../11_workflow_parallel/host.json | 0 .../local.settings.json.sample | 0 .../11_workflow_parallel/requirements.txt | 0 .../12_workflow_hitl/.gitignore | 0 .../12_workflow_hitl/README.md | 0 .../12_workflow_hitl/demo.http | 0 .../12_workflow_hitl/function_app.py | 4 +- .../12_workflow_hitl/host.json | 0 .../local.settings.json.sample | 0 .../12_workflow_hitl/requirements.txt | 0 36 files changed, 58 insertions(+), 137 deletions(-) rename python/samples/{getting_started => 04-hosting}/azure_functions/09_workflow_shared_state/.gitignore (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/09_workflow_shared_state/README.md (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/09_workflow_shared_state/demo.http (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/09_workflow_shared_state/function_app.py (97%) rename python/samples/{getting_started => 04-hosting}/azure_functions/09_workflow_shared_state/host.json (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/09_workflow_shared_state/local.settings.json.sample (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/09_workflow_shared_state/requirements.txt (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/.env.sample (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/.gitignore (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/README.md (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/demo.http (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/function_app.py (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/host.json (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/local.settings.json.sample (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/10_workflow_no_shared_state/requirements.txt (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/.env.template (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/.gitignore (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/README.md (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/demo.http (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/function_app.py (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/host.json (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/local.settings.json.sample (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/11_workflow_parallel/requirements.txt (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/12_workflow_hitl/.gitignore (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/12_workflow_hitl/README.md (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/12_workflow_hitl/demo.http (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/12_workflow_hitl/function_app.py (99%) rename python/samples/{getting_started => 04-hosting}/azure_functions/12_workflow_hitl/host.json (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/12_workflow_hitl/local.settings.json.sample (100%) rename python/samples/{getting_started => 04-hosting}/azure_functions/12_workflow_hitl/requirements.txt (100%) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index 2c13e3d2c5..d3642f9ce1 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -14,11 +14,11 @@ from agent_framework import ( CheckpointStorage, - Message, RunnerContext, State, WorkflowCheckpoint, WorkflowEvent, + WorkflowMessage, ) @@ -36,7 +36,7 @@ class CapturingRunnerContext(RunnerContext): def __init__(self) -> None: """Initialize the capturing runner context.""" - self._messages: dict[str, list[Message]] = {} + self._messages: dict[str, list[WorkflowMessage]] = {} self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() self._pending_request_info_events: dict[str, WorkflowEvent[Any]] = {} self._workflow_id: str | None = None @@ -44,12 +44,12 @@ def __init__(self) -> None: # region Messaging - async def send_message(self, message: Message) -> None: + async def send_message(self, message: WorkflowMessage) -> None: """Capture a message sent by an executor.""" self._messages.setdefault(message.source_id, []) self._messages[message.source_id].append(message) - async def drain_messages(self) -> dict[str, list[Message]]: + async def drain_messages(self) -> dict[str, list[WorkflowMessage]]: """Drain and return all captured messages.""" messages = copy(self._messages) self._messages.clear() diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index 86bcddd357..f479ff2250 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -3,40 +3,30 @@ """Serialization utilities for workflow execution. This module provides thin wrappers around the core checkpoint encoding system -(encode_checkpoint_value / decode_checkpoint_value) from agent_framework._workflows, -adding Pydantic model support. +(encode_checkpoint_value / decode_checkpoint_value) from agent_framework._workflows. -The core checkpoint encoding handles type-safe roundtripping of: -- Objects with to_dict/from_dict (ChatMessage, AgentResponse, etc.) -- Dataclasses (AgentExecutorRequest/Response, custom dataclasses) -- Objects with to_json/from_json -- Primitives, lists, dicts +The core checkpoint encoding uses pickle + base64 for type-safe roundtripping of +arbitrary Python objects (dataclasses, Pydantic models, Message, etc.) while +keeping JSON-native types (str, int, float, bool, None) as-is. This module adds: -- serialize_value / deserialize_value: wrappers that also handle Pydantic BaseModel instances +- serialize_value / deserialize_value: convenience aliases for encode/decode - reconstruct_to_type: for HITL responses where external data (without type markers) needs to be reconstructed to a known type +- _resolve_type: resolves 'module:class' type keys to Python types """ from __future__ import annotations import importlib import logging -from dataclasses import fields as dc_fields from dataclasses import is_dataclass from typing import Any from agent_framework._workflows import decode_checkpoint_value, encode_checkpoint_value -from agent_framework._workflows._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER -from pydantic import BaseModel logger = logging.getLogger(__name__) -# Marker for Pydantic models serialized by this module. -# Core checkpoint encoding only supports to_dict/from_dict protocol; Pydantic v2 -# uses model_dump/model_validate, so we handle it here with a compatible marker format. -PYDANTIC_MARKER = "__af_pydantic__" - def _resolve_type(type_key: str) -> type | None: """Resolve a 'module:class' type key to its Python type. @@ -64,97 +54,30 @@ def _resolve_type(type_key: str) -> type | None: def serialize_value(value: Any) -> Any: """Serialize a value for JSON-compatible cross-activity communication. - Extends core checkpoint encoding with Pydantic BaseModel support. - The output is JSON-serializable and can be deserialized with deserialize_value(). - - Dataclasses are handled here (rather than delegating to encode_checkpoint_value) - because their fields may contain nested Pydantic models that core encoding - does not recognise. + Delegates to core checkpoint encoding which uses pickle + base64 for + non-JSON-native types (dataclasses, Pydantic models, Message, etc.). Args: - value: Any Python value (primitive, dataclass, Pydantic model, ChatMessage, etc.) + value: Any Python value (primitive, dataclass, Pydantic model, Message, etc.) Returns: A JSON-serializable representation with embedded type metadata for reconstruction. """ - if isinstance(value, BaseModel): - cls = type(value) - return { - PYDANTIC_MARKER: f"{cls.__module__}:{cls.__name__}", - "value": encode_checkpoint_value(value.model_dump()), - } - - # Handle dataclasses ourselves so that nested Pydantic models get the - # PYDANTIC_MARKER treatment instead of being str()'d by core encoding. - if is_dataclass(value) and not isinstance(value, type): - dc_cls = type(value) - return { - DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}", - **{field.name: serialize_value(getattr(value, field.name)) for field in dc_fields(value)}, - } - - # Handle lists and dicts recursively to catch nested Pydantic models - if isinstance(value, list): - return [serialize_value(item) for item in value] - if isinstance(value, dict): - return {k: serialize_value(v) for k, v in value.items()} - return encode_checkpoint_value(value) def deserialize_value(value: Any) -> Any: """Deserialize a value previously serialized with serialize_value(). - Handles core checkpoint markers (__af_model__, __af_dataclass__) and - Pydantic markers (__af_pydantic__) to reconstruct the original typed objects. - - Dataclasses are reconstructed here (rather than delegating to - decode_checkpoint_value) so that fields containing PYDANTIC_MARKER dicts - are properly deserialized. + Delegates to core checkpoint decoding which unpickles base64-encoded values + and verifies type integrity. Args: - value: The serialized data (dict with type markers, list, or primitive) + value: The serialized data (dict with pickle markers, list, or primitive) Returns: Reconstructed typed object if type metadata found, otherwise original value. """ - if isinstance(value, dict): - # Handle Pydantic marker - if PYDANTIC_MARKER in value and "value" in value: - type_key: str = value[PYDANTIC_MARKER] - payload = decode_checkpoint_value(value["value"]) - cls = _resolve_type(type_key) - if cls is not None and hasattr(cls, "model_validate"): - try: - return cls.model_validate(payload) - except Exception: - logger.debug("Could not reconstruct Pydantic model %s", type_key) - return payload - - # Handle dataclass marker — deserialize fields ourselves so that nested - # PYDANTIC_MARKER dicts are properly handled. - if DATACLASS_MARKER in value: - type_key = value[DATACLASS_MARKER] - cls = _resolve_type(type_key) - if cls is not None and is_dataclass(cls): - try: - field_data = {k: deserialize_value(v) for k, v in value.items() if k != DATACLASS_MARKER} - return cls(**field_data) - except Exception: - logger.debug("Could not reconstruct dataclass %s, falling back to core decode", type_key) - return decode_checkpoint_value(value) - - # Handle model marker (to_dict/from_dict objects like ChatMessage) — core - # handles these fully since the object's own serialisation manages nesting. - if MODEL_MARKER in value: - return decode_checkpoint_value(value) - - # Recurse into plain dicts to catch nested markers - return {k: deserialize_value(v) for k, v in value.items()} - - if isinstance(value, list): - return [deserialize_value(item) for item in value] - return decode_checkpoint_value(value) @@ -194,11 +117,10 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any: if not isinstance(value, dict): return value - # Try marker-based decoding if data has type markers - if MODEL_MARKER in value or DATACLASS_MARKER in value or PYDANTIC_MARKER in value: - decoded = deserialize_value(value) - if not isinstance(decoded, dict): - return decoded + # Try decoding if data has pickle markers (from checkpoint encoding) + decoded = deserialize_value(value) + if not isinstance(decoded, dict): + return decoded # Try Pydantic model validation (for unmarked dicts, e.g., external HITL data) if hasattr(target_type, "model_validate"): diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index fc11ad807b..9a5ab2a3d0 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -33,7 +33,7 @@ AgentExecutorRequest, AgentExecutorResponse, AgentResponse, - ChatMessage, + Message, Workflow, ) from agent_framework._workflows._edge import ( @@ -243,18 +243,18 @@ def build_agent_executor_response( if structured_response: final_text = json.dumps(structured_response) - assistant_message = ChatMessage(role="assistant", text=final_text) + assistant_message = Message(role="assistant", text=final_text) agent_response = AgentResponse( messages=[assistant_message], ) # Build conversation history - full_conversation: list[ChatMessage] = [] + full_conversation: list[Message] = [] if isinstance(previous_message, AgentExecutorResponse) and previous_message.full_conversation: full_conversation.extend(previous_message.full_conversation) elif isinstance(previous_message, str): - full_conversation.append(ChatMessage(role="user", text=previous_message)) + full_conversation.append(Message(role="user", text=previous_message)) full_conversation.append(assistant_message) diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py index 4ecae82063..be9af18e04 100644 --- a/python/packages/azurefunctions/tests/test_utils.py +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -10,9 +10,9 @@ AgentExecutorRequest, AgentExecutorResponse, AgentResponse, - ChatMessage, Message, WorkflowEvent, + WorkflowMessage, ) from pydantic import BaseModel @@ -59,7 +59,7 @@ def context(self) -> CapturingRunnerContext: @pytest.mark.asyncio async def test_send_message_captures_message(self, context: CapturingRunnerContext) -> None: """Test that send_message captures messages correctly.""" - message = Message(data="test data", target_id="target_1", source_id="source_1") + message = WorkflowMessage(data="test data", target_id="target_1", source_id="source_1") await context.send_message(message) @@ -71,9 +71,9 @@ async def test_send_message_captures_message(self, context: CapturingRunnerConte @pytest.mark.asyncio async def test_send_multiple_messages_groups_by_source(self, context: CapturingRunnerContext) -> None: """Test that messages are grouped by source_id.""" - msg1 = Message(data="msg1", target_id="target", source_id="source_a") - msg2 = Message(data="msg2", target_id="target", source_id="source_a") - msg3 = Message(data="msg3", target_id="target", source_id="source_b") + msg1 = WorkflowMessage(data="msg1", target_id="target", source_id="source_a") + msg2 = WorkflowMessage(data="msg2", target_id="target", source_id="source_a") + msg3 = WorkflowMessage(data="msg3", target_id="target", source_id="source_b") await context.send_message(msg1) await context.send_message(msg2) @@ -86,7 +86,7 @@ async def test_send_multiple_messages_groups_by_source(self, context: CapturingR @pytest.mark.asyncio async def test_drain_messages_clears_messages(self, context: CapturingRunnerContext) -> None: """Test that drain_messages clears the message store.""" - message = Message(data="test", target_id="t", source_id="s") + message = WorkflowMessage(data="test", target_id="t", source_id="s") await context.send_message(message) await context.drain_messages() # First drain @@ -99,7 +99,7 @@ async def test_has_messages_returns_correct_status(self, context: CapturingRunne """Test has_messages returns correct boolean.""" assert await context.has_messages() is False - await context.send_message(Message(data="test", target_id="t", source_id="s")) + await context.send_message(WorkflowMessage(data="test", target_id="t", source_id="s")) assert await context.has_messages() is True @@ -169,7 +169,7 @@ def test_set_workflow_id(self, context: CapturingRunnerContext) -> None: @pytest.mark.asyncio async def test_reset_for_new_run_clears_state(self, context: CapturingRunnerContext) -> None: """Test that reset_for_new_run clears all state.""" - await context.send_message(Message(data="test", target_id="t", source_id="s")) + await context.send_message(WorkflowMessage(data="test", target_id="t", source_id="s")) await context.add_event(WorkflowEvent.output(executor_id="e", data="event")) context.set_streaming(True) @@ -204,18 +204,18 @@ class TestSerializationRoundtrip: """Test that serialization roundtrips correctly for types used in Azure Functions workflows.""" def test_roundtrip_chat_message(self) -> None: - """Test ChatMessage survives encode → decode roundtrip.""" - original = ChatMessage(role="user", text="Hello") + """Test Message survives encode → decode roundtrip.""" + original = Message(role="user", text="Hello") encoded = serialize_value(original) decoded = deserialize_value(encoded) - assert isinstance(decoded, ChatMessage) + assert isinstance(decoded, Message) assert decoded.role == "user" def test_roundtrip_agent_executor_request(self) -> None: - """Test AgentExecutorRequest with nested ChatMessages roundtrips.""" + """Test AgentExecutorRequest with nested Messages roundtrips.""" original = AgentExecutorRequest( - messages=[ChatMessage(role="user", text="Hi")], + messages=[Message(role="user", text="Hi")], should_respond=True, ) encoded = serialize_value(original) @@ -223,14 +223,14 @@ def test_roundtrip_agent_executor_request(self) -> None: assert isinstance(decoded, AgentExecutorRequest) assert len(decoded.messages) == 1 - assert isinstance(decoded.messages[0], ChatMessage) + assert isinstance(decoded.messages[0], Message) assert decoded.should_respond is True def test_roundtrip_agent_executor_response(self) -> None: """Test AgentExecutorResponse with nested AgentResponse roundtrips.""" original = AgentExecutorResponse( executor_id="test_exec", - agent_response=AgentResponse(messages=[ChatMessage(role="assistant", text="Reply")]), + agent_response=AgentResponse(messages=[Message(role="assistant", text="Reply")]), ) encoded = serialize_value(original) decoded = deserialize_value(encoded) @@ -270,24 +270,24 @@ def test_roundtrip_primitives(self) -> None: def test_roundtrip_list_of_objects(self) -> None: """Test list of typed objects roundtrips.""" original = [ - ChatMessage(role="user", text="Q"), - ChatMessage(role="assistant", text="A"), + Message(role="user", text="Q"), + Message(role="assistant", text="A"), ] encoded = serialize_value(original) decoded = deserialize_value(encoded) assert isinstance(decoded, list) assert len(decoded) == 2 - assert all(isinstance(m, ChatMessage) for m in decoded) + assert all(isinstance(m, Message) for m in decoded) def test_roundtrip_dict_of_objects(self) -> None: """Test dict with typed values roundtrips (used for shared state).""" - original = {"count": 42, "msg": ChatMessage(role="user", text="Hi")} + original = {"count": 42, "msg": Message(role="user", text="Hi")} encoded = serialize_value(original) decoded = deserialize_value(encoded) assert decoded["count"] == 42 - assert isinstance(decoded["msg"], ChatMessage) + assert isinstance(decoded["msg"], Message) def test_roundtrip_dataclass_with_nested_pydantic(self) -> None: """Test dataclass containing a Pydantic model field roundtrips correctly. diff --git a/python/packages/azurefunctions/tests/test_workflow.py b/python/packages/azurefunctions/tests/test_workflow.py index b8e246d85c..4c26c980b2 100644 --- a/python/packages/azurefunctions/tests/test_workflow.py +++ b/python/packages/azurefunctions/tests/test_workflow.py @@ -10,7 +10,7 @@ AgentExecutorRequest, AgentExecutorResponse, AgentResponse, - ChatMessage, + Message, ) from agent_framework._workflows._edge import ( FanInEdgeGroup, @@ -177,10 +177,10 @@ def test_conversation_extends_previous_agent_executor_response(self) -> None: # Create a previous response with conversation history previous = AgentExecutorResponse( executor_id="prev", - agent_response=AgentResponse(messages=[ChatMessage(role="assistant", text="Previous")]), + agent_response=AgentResponse(messages=[Message(role="assistant", text="Previous")]), full_conversation=[ - ChatMessage(role="user", text="First"), - ChatMessage(role="assistant", text="Previous"), + Message(role="user", text="First"), + Message(role="assistant", text="Previous"), ], ) @@ -211,7 +211,7 @@ def test_extract_from_agent_executor_response_with_text(self) -> None: """Test extracting from AgentExecutorResponse with text.""" response = AgentExecutorResponse( executor_id="exec", - agent_response=AgentResponse(messages=[ChatMessage(role="assistant", text="Response text")]), + agent_response=AgentResponse(messages=[Message(role="assistant", text="Response text")]), ) result = _extract_message_content(response) @@ -224,8 +224,8 @@ def test_extract_from_agent_executor_response_with_messages(self) -> None: executor_id="exec", agent_response=AgentResponse( messages=[ - ChatMessage(role="user", text="First"), - ChatMessage(role="assistant", text="Last message"), + Message(role="user", text="First"), + Message(role="assistant", text="Last message"), ] ), ) @@ -239,8 +239,8 @@ def test_extract_from_agent_executor_request(self) -> None: """Test extracting from AgentExecutorRequest.""" request = AgentExecutorRequest( messages=[ - ChatMessage(role="user", text="First"), - ChatMessage(role="user", text="Last request"), + Message(role="user", text="First"), + Message(role="user", text="Last request"), ] ) diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 42e70a8c96..e7f2c7d966 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -30,7 +30,6 @@ decode_checkpoint_value, encode_checkpoint_value, ) -from ._checkpoint_summary import WorkflowCheckpointSummary, get_checkpoint_summary from ._const import ( DEFAULT_MAX_ITERATIONS, ) diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/.gitignore b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/.gitignore similarity index 100% rename from python/samples/getting_started/azure_functions/09_workflow_shared_state/.gitignore rename to python/samples/04-hosting/azure_functions/09_workflow_shared_state/.gitignore diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/README.md similarity index 100% rename from python/samples/getting_started/azure_functions/09_workflow_shared_state/README.md rename to python/samples/04-hosting/azure_functions/09_workflow_shared_state/README.md diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/demo.http b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/demo.http similarity index 100% rename from python/samples/getting_started/azure_functions/09_workflow_shared_state/demo.http rename to python/samples/04-hosting/azure_functions/09_workflow_shared_state/demo.http diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/function_app.py similarity index 97% rename from python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py rename to python/samples/04-hosting/azure_functions/09_workflow_shared_state/function_app.py index 82673dd329..7b2317c58e 100644 --- a/python/samples/getting_started/azure_functions/09_workflow_shared_state/function_app.py +++ b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/function_app.py @@ -27,7 +27,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - ChatMessage, + Message, Workflow, WorkflowBuilder, WorkflowContext, @@ -110,7 +110,7 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id) await ctx.send_message( - AgentExecutorRequest(messages=[ChatMessage(role="user", text=new_email.email_content)], should_respond=True) + AgentExecutorRequest(messages=[Message(role="user", text=new_email.email_content)], should_respond=True) ) @@ -146,7 +146,7 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon # Load the original content by id from shared state and forward it to the assistant. email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}") await ctx.send_message( - AgentExecutorRequest(messages=[ChatMessage(role="user", text=email.email_content)], should_respond=True) + AgentExecutorRequest(messages=[Message(role="user", text=email.email_content)], should_respond=True) ) diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/host.json similarity index 100% rename from python/samples/getting_started/azure_functions/09_workflow_shared_state/host.json rename to python/samples/04-hosting/azure_functions/09_workflow_shared_state/host.json diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/local.settings.json.sample b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/local.settings.json.sample similarity index 100% rename from python/samples/getting_started/azure_functions/09_workflow_shared_state/local.settings.json.sample rename to python/samples/04-hosting/azure_functions/09_workflow_shared_state/local.settings.json.sample diff --git a/python/samples/getting_started/azure_functions/09_workflow_shared_state/requirements.txt b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/requirements.txt similarity index 100% rename from python/samples/getting_started/azure_functions/09_workflow_shared_state/requirements.txt rename to python/samples/04-hosting/azure_functions/09_workflow_shared_state/requirements.txt diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.env.sample b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/.env.sample similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.env.sample rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/.env.sample diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.gitignore b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/.gitignore similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/.gitignore rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/.gitignore diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/README.md b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/README.md similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/README.md rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/README.md diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/demo.http b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/demo.http similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/demo.http rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/demo.http diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/function_app.py similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/function_app.py rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/function_app.py diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/host.json similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/host.json rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/host.json diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/local.settings.json.sample b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/local.settings.json.sample similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/local.settings.json.sample rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/local.settings.json.sample diff --git a/python/samples/getting_started/azure_functions/10_workflow_no_shared_state/requirements.txt b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/requirements.txt similarity index 100% rename from python/samples/getting_started/azure_functions/10_workflow_no_shared_state/requirements.txt rename to python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/requirements.txt diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/.env.template b/python/samples/04-hosting/azure_functions/11_workflow_parallel/.env.template similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/.env.template rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/.env.template diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/.gitignore b/python/samples/04-hosting/azure_functions/11_workflow_parallel/.gitignore similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/.gitignore rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/.gitignore diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/README.md b/python/samples/04-hosting/azure_functions/11_workflow_parallel/README.md similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/README.md rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/README.md diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/demo.http b/python/samples/04-hosting/azure_functions/11_workflow_parallel/demo.http similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/demo.http rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/demo.http diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py b/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/function_app.py rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/host.json b/python/samples/04-hosting/azure_functions/11_workflow_parallel/host.json similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/host.json rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/host.json diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/local.settings.json.sample b/python/samples/04-hosting/azure_functions/11_workflow_parallel/local.settings.json.sample similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/local.settings.json.sample rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/local.settings.json.sample diff --git a/python/samples/getting_started/azure_functions/11_workflow_parallel/requirements.txt b/python/samples/04-hosting/azure_functions/11_workflow_parallel/requirements.txt similarity index 100% rename from python/samples/getting_started/azure_functions/11_workflow_parallel/requirements.txt rename to python/samples/04-hosting/azure_functions/11_workflow_parallel/requirements.txt diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/.gitignore b/python/samples/04-hosting/azure_functions/12_workflow_hitl/.gitignore similarity index 100% rename from python/samples/getting_started/azure_functions/12_workflow_hitl/.gitignore rename to python/samples/04-hosting/azure_functions/12_workflow_hitl/.gitignore diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/README.md b/python/samples/04-hosting/azure_functions/12_workflow_hitl/README.md similarity index 100% rename from python/samples/getting_started/azure_functions/12_workflow_hitl/README.md rename to python/samples/04-hosting/azure_functions/12_workflow_hitl/README.md diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http b/python/samples/04-hosting/azure_functions/12_workflow_hitl/demo.http similarity index 100% rename from python/samples/getting_started/azure_functions/12_workflow_hitl/demo.http rename to python/samples/04-hosting/azure_functions/12_workflow_hitl/demo.http diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py similarity index 99% rename from python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py rename to python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py index c18b55163d..f747d60919 100644 --- a/python/samples/getting_started/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py @@ -32,7 +32,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - ChatMessage, + Message, Executor, Workflow, WorkflowBuilder, @@ -365,7 +365,7 @@ async def route_input( await ctx.send_message( AgentExecutorRequest( - messages=[ChatMessage(role="user", text=message)], + messages=[Message(role="user", text=message)], should_respond=True, ) ) diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/host.json b/python/samples/04-hosting/azure_functions/12_workflow_hitl/host.json similarity index 100% rename from python/samples/getting_started/azure_functions/12_workflow_hitl/host.json rename to python/samples/04-hosting/azure_functions/12_workflow_hitl/host.json diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/local.settings.json.sample b/python/samples/04-hosting/azure_functions/12_workflow_hitl/local.settings.json.sample similarity index 100% rename from python/samples/getting_started/azure_functions/12_workflow_hitl/local.settings.json.sample rename to python/samples/04-hosting/azure_functions/12_workflow_hitl/local.settings.json.sample diff --git a/python/samples/getting_started/azure_functions/12_workflow_hitl/requirements.txt b/python/samples/04-hosting/azure_functions/12_workflow_hitl/requirements.txt similarity index 100% rename from python/samples/getting_started/azure_functions/12_workflow_hitl/requirements.txt rename to python/samples/04-hosting/azure_functions/12_workflow_hitl/requirements.txt From ad23d32444c77c9496cecd38d3e5680f0ad5a9f4 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Thu, 12 Feb 2026 14:25:42 -0600 Subject: [PATCH 27/29] fix: update create_checkpoint signature to match superclass --- .../agent_framework_azurefunctions/_context.py | 5 ++++- python/packages/azurefunctions/tests/test_utils.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index d3642f9ce1..14634f098d 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -103,7 +103,10 @@ def clear_runtime_checkpoint_storage(self) -> None: async def create_checkpoint( self, - shared_state: State, + workflow_name: str, + graph_signature_hash: str, + state: State, + previous_checkpoint_id: str | None, iteration_count: int, metadata: dict[str, Any] | None = None, ) -> str: diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py index be9af18e04..98218c1082 100644 --- a/python/packages/azurefunctions/tests/test_utils.py +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -185,7 +185,7 @@ async def test_create_checkpoint_raises_not_implemented(self, context: Capturing from agent_framework import State with pytest.raises(NotImplementedError): - await context.create_checkpoint(State(), 1) + await context.create_checkpoint("test_workflow", "abc123", State(), None, 1) @pytest.mark.asyncio async def test_load_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None: From 3d0e91fc45965f39603faeef8a9fc35e938bc614 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Thu, 12 Feb 2026 14:26:02 -0600 Subject: [PATCH 28/29] fix: correct relative link in HITL sample README --- .../04-hosting/azure_functions/12_workflow_hitl/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/samples/04-hosting/azure_functions/12_workflow_hitl/README.md b/python/samples/04-hosting/azure_functions/12_workflow_hitl/README.md index 2bb84f16dc..1b9f9eff87 100644 --- a/python/samples/04-hosting/azure_functions/12_workflow_hitl/README.md +++ b/python/samples/04-hosting/azure_functions/12_workflow_hitl/README.md @@ -138,4 +138,4 @@ Use the `demo.http` file with the VS Code REST Client extension: - [07_single_agent_orchestration_hitl](../07_single_agent_orchestration_hitl/) - HITL at orchestrator level (not using MAF pattern) - [09_workflow_shared_state](../09_workflow_shared_state/) - Workflow with shared state -- [guessing_game_with_human_input](../../workflows/human-in-the-loop/guessing_game_with_human_input.py) - MAF HITL pattern (non-durable) +- [guessing_game_with_human_input](../../../03-workflows/human-in-the-loop/guessing_game_with_human_input.py) - MAF HITL pattern (non-durable) From ea704ce5c237194d7167089302026c94463a10ba Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 17 Feb 2026 13:57:53 -0600 Subject: [PATCH 29/29] fix: resolve import breakage after rebase (State, DurableAgentThread, get_logger) --- .../azurefunctions/agent_framework_azurefunctions/_app.py | 4 ++-- .../agent_framework_azurefunctions/_context.py | 2 +- .../agent_framework_azurefunctions/_workflow.py | 6 +++--- python/packages/azurefunctions/tests/test_utils.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index dd95007a8c..631975351d 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -20,7 +20,7 @@ import azure.durable_functions as df import azure.functions as func -from agent_framework import AgentExecutor, SupportsAgentRun, Workflow, WorkflowEvent, get_logger +from agent_framework import AgentExecutor, SupportsAgentRun, Workflow, WorkflowEvent from agent_framework_durabletask import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, @@ -272,7 +272,7 @@ def executor_activity(inputData: str) -> str: Note: We use str type annotations instead of dict to work around Azure Functions worker type validation issues with dict[str, Any]. """ - from agent_framework import State + from agent_framework._workflows import State data = json.loads(inputData) message_data = data["message"] diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index 14634f098d..1346c8c498 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -15,11 +15,11 @@ from agent_framework import ( CheckpointStorage, RunnerContext, - State, WorkflowCheckpoint, WorkflowEvent, WorkflowMessage, ) +from agent_framework._workflows import State class CapturingRunnerContext(RunnerContext): diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 9a5ab2a3d0..a0e0f04185 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -44,7 +44,7 @@ SingleEdgeGroup, SwitchCaseEdgeGroup, ) -from agent_framework_durabletask import AgentSessionId, DurableAgentThread, DurableAIAgent +from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent from azure.durable_functions import DurableOrchestrationContext from ._context import CapturingRunnerContext @@ -287,11 +287,11 @@ def _prepare_agent_task( """ message_content = _extract_message_content(message) session_id = AgentSessionId(name=executor_id, key=context.instance_id) - thread = DurableAgentThread(session_id=session_id) + session = DurableAgentSession(durable_session_id=session_id) az_executor = AzureFunctionsAgentExecutor(context) agent = DurableAIAgent(az_executor, executor_id) - return agent.run(message_content, thread=thread) + return agent.run(message_content, session=session) def _prepare_activity_task( diff --git a/python/packages/azurefunctions/tests/test_utils.py b/python/packages/azurefunctions/tests/test_utils.py index 98218c1082..35bed9057e 100644 --- a/python/packages/azurefunctions/tests/test_utils.py +++ b/python/packages/azurefunctions/tests/test_utils.py @@ -182,7 +182,7 @@ async def test_reset_for_new_run_clears_state(self, context: CapturingRunnerCont @pytest.mark.asyncio async def test_create_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None: """Test that checkpointing methods raise NotImplementedError.""" - from agent_framework import State + from agent_framework._workflows import State with pytest.raises(NotImplementedError): await context.create_checkpoint("test_workflow", "abc123", State(), None, 1)