From 44f6f34d9808c38264ee5ef92c88e85a99e2c4fa Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 11 Feb 2026 07:37:12 +0900 Subject: [PATCH 1/7] wip --- python/packages/core/AGENTS.md | 4 + .../agent_framework/_workflows/_workflow.py | 159 ++++++- .../test_workflow_response_handlers.py | 429 ++++++++++++++++++ .../getting_started/workflows/README.md | 1 + .../fan_out_async_with_response_handlers.py | 256 +++++++++++ 5 files changed, 846 insertions(+), 3 deletions(-) create mode 100644 python/packages/core/tests/workflow/test_workflow_response_handlers.py create mode 100644 python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 2a308c245d..fd0e0b089b 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -138,3 +138,7 @@ class MyClient(BaseChatClient): async def _inner_get_streaming_response(self, *, messages, options, **kwargs): yield ChatResponseUpdate(...) ``` + +### Running tests + +Always use the virtual environment defined, and run with `uv run pytest ...` \ No newline at end of file diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 5f93644035..5704d6cf8d 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -9,7 +9,7 @@ import logging import types import uuid -from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from typing import Any, Literal, overload from .._types import ResponseStream @@ -37,6 +37,15 @@ logger = logging.getLogger(__name__) +class _NoResponseType: + """Sentinel type to distinguish 'handler failed/unmatched' from 'handler returned None'.""" + + pass + + +_NO_RESPONSE = _NoResponseType() + + class WorkflowRunResult(list[WorkflowEvent]): """Container for events generated during non-streaming workflow execution. @@ -222,6 +231,9 @@ def __init__( # Flag to prevent concurrent workflow executions self._is_running = False + # Store responses collected from response_handlers for submission after stream completes + self._pending_handler_responses: dict[str, Any] | None = None + # Capture a canonical fingerprint of the workflow graph so checkpoints # can assert they are resumed with an equivalent topology. self._graph_signature = self._compute_graph_signature() @@ -452,6 +464,7 @@ def run( *, stream: Literal[True], responses: dict[str, Any] | None = None, + response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -464,6 +477,7 @@ def run( *, stream: Literal[False] = ..., responses: dict[str, Any] | None = None, + response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, @@ -476,6 +490,7 @@ def run( *, stream: bool = False, responses: dict[str, Any] | None = None, + response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, @@ -496,6 +511,14 @@ def run( request IDs and values are the corresponding response data. Mutually exclusive with message. Can be combined with checkpoint_id to restore a checkpoint and send responses in a single call. + response_handlers: Dict mapping request data types to async handler functions. + Requires stream=False (default). Mutually exclusive with ``responses``. + Execution is sequential in two phases: + 1. Phase 1: Workflow runs to idle, collecting request_info events. + 2. All matching handlers are called concurrently via asyncio.gather. + 3. Phase 2: Collected responses are submitted back and the workflow + runs again to process them. + The combined WorkflowRunResult contains events from both phases. checkpoint_id: ID of checkpoint to restore from. Can be used alone (resume from checkpoint), with message (not allowed), or with responses (restore then send responses). @@ -511,14 +534,21 @@ def run( Raises: ValueError: If parameter combination is invalid. """ - # Validate parameters and set running flag eagerly (before any async work) - self._validate_run_params(message, responses, checkpoint_id) + # Validate parameters eagerly (before any async work or setting running flag) + self._validate_run_params(message, responses, checkpoint_id, response_handlers) + if stream and response_handlers is not None: + raise ValueError( + "Cannot use response_handlers with stream=True. " + "Response handlers require non-streaming mode (stream=False) " + "to automatically submit responses after execution completes." + ) self._ensure_not_running() response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( self._run_core( message=message, responses=responses, + response_handlers=response_handlers, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, streaming=stream, @@ -532,13 +562,75 @@ def run( if stream: return response_stream + if response_handlers is not None: + # Wrap to handle automatic response submission after stream completes + return self._run_with_response_submission( + response_stream, + checkpoint_storage=checkpoint_storage, + include_status_events=include_status_events, + **kwargs, + ) return response_stream.get_final_response() + async def _run_with_response_submission( + self, + response_stream: ResponseStream[WorkflowEvent, WorkflowRunResult], + checkpoint_storage: CheckpointStorage | None = None, + include_status_events: bool = False, + **kwargs: Any, + ) -> WorkflowRunResult: + """Execute response_stream and submit collected handler responses as Phase 2. + + After the main stream completes, if response_handlers collected any responses, + this method submits them back to the workflow via ``run(responses=...)``. + + Args: + response_stream: The ResponseStream from Phase 1 to execute. + checkpoint_storage: Checkpoint storage (forwarded to Phase 2). + include_status_events: Whether to include status events. + **kwargs: Additional kwargs forwarded to Phase 2. + + Returns: + WorkflowRunResult with merged events from Phase 1 and Phase 2. + """ + # Clear stale state from any prior run + self._pending_handler_responses = None + + # Phase 1: execute the main stream (handlers populate _pending_handler_responses) + result = await response_stream.get_final_response() + + # Phase 2: submit collected responses into the current workflow state + if self._pending_handler_responses: + logger.debug( + f"Response handlers collected {len(self._pending_handler_responses)} " + f"responses. Submitting back to workflow for continued processing." + ) + collected_responses = self._pending_handler_responses + self._pending_handler_responses = None + + continuation_result = await self.run( + responses=collected_responses, + checkpoint_storage=checkpoint_storage, + stream=False, + include_status_events=include_status_events, + **kwargs, + ) + + merged_events = list(result) + list(continuation_result) + merged_status_events = result._status_events + continuation_result._status_events + return WorkflowRunResult( + events=merged_events, + status_events=merged_status_events, + ) + + return result + async def _run_core( self, message: Any | None = None, *, responses: dict[str, Any] | None = None, + response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, streaming: bool = False, @@ -549,6 +641,9 @@ async def _run_core( Yields: WorkflowEvent: The events generated during the workflow execution. """ + # Clear stale handler responses from any prior run + self._pending_handler_responses = None + # Enable runtime checkpointing if storage provided if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) @@ -557,6 +652,9 @@ async def _run_core( message, responses, checkpoint_id, checkpoint_storage ) + # Collect pending requests if handlers provided + pending_requests: dict[str, Any] = {} + async for event in self._run_workflow_with_tracing( initial_executor_fn=initial_executor_fn, reset_context=reset_context, @@ -565,8 +663,55 @@ async def _run_core( ): if event.type == "output" and not self._should_yield_output_event(event): continue + + # Collect request_info events for external handling + if event.type == "request_info" and response_handlers: + pending_requests[event.request_id] = event.data + yield event + # After main execution, dispatch response handlers for pending requests + if response_handlers and pending_requests: + + async def get_response(request_id: str, request_data: Any) -> tuple[str, Any]: + request_data_type = type(request_data) + matched_handler = response_handlers.get(request_data_type) + if matched_handler: + try: + response = await matched_handler(request_data) + return (request_id, response) + except Exception: + logger.exception( + f"Error in response handler for request {request_id} (type={request_data_type.__name__})" + ) + return (request_id, _NO_RESPONSE) + else: + registered_types = [t.__name__ for t in response_handlers] + logger.warning( + f"No response handler registered for request type " + f"{request_data_type.__name__} (request_id={request_id}). " + f"Registered types: {registered_types}" + ) + return (request_id, _NO_RESPONSE) + + # Run all handlers concurrently, capturing exceptions as return values + response_tasks = [get_response(req_id, req_data) for req_id, req_data in pending_requests.items()] + responses_list = await asyncio.gather(*response_tasks, return_exceptions=True) + + # Filter out failed/unmatched results + collected_responses: dict[str, Any] = {} + for item in responses_list: + if isinstance(item, BaseException): + logger.error(f"Unexpected exception in response handler gather: {item}") + continue + req_id, resp = item + if not isinstance(resp, _NoResponseType): + collected_responses[req_id] = resp + + # Store for use in _run_with_response_submission + if collected_responses: + self._pending_handler_responses = collected_responses + async def _run_cleanup(self, checkpoint_storage: CheckpointStorage | None) -> None: """Cleanup hook called after stream consumption.""" if checkpoint_storage is not None: @@ -605,6 +750,7 @@ def _validate_run_params( message: Any | None, responses: dict[str, Any] | None, checkpoint_id: str | None, + response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, ) -> None: """Validate parameter combinations for run(). @@ -613,6 +759,7 @@ def _validate_run_params( - message and checkpoint_id are mutually exclusive - At least one of message, responses, or checkpoint_id must be provided - responses + checkpoint_id is allowed (restore then send) + - response_handlers and responses are mutually exclusive """ if message is not None and responses is not None: raise ValueError("Cannot provide both 'message' and 'responses'. Use one or the other.") @@ -626,6 +773,12 @@ def _validate_run_params( "or 'checkpoint_id' (resume from checkpoint)." ) + if response_handlers is not None and responses is not None: + raise ValueError( + "Cannot provide both 'response_handlers' and 'responses'. " + "Use 'response_handlers' for automatic handling or 'responses' for manual submission." + ) + def _resolve_execution_mode( self, message: Any | None, diff --git a/python/packages/core/tests/workflow/test_workflow_response_handlers.py b/python/packages/core/tests/workflow/test_workflow_response_handlers.py new file mode 100644 index 0000000000..6ad63e711e --- /dev/null +++ b/python/packages/core/tests/workflow/test_workflow_response_handlers.py @@ -0,0 +1,429 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Tests for workflow response_handlers parameter. + +Verifies automatic HITL request handling: type-based dispatch, concurrent execution, +error handling, parameter validation, and two-phase event merging. +""" + +import asyncio +import logging +import time +from dataclasses import dataclass +from typing import Any + +import pytest + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowContext, + WorkflowRunState, + handler, + response_handler, +) + +# --------------------------------------------------------------------------- +# Shared request / data types +# --------------------------------------------------------------------------- + + +@dataclass +class ReviewRequest: + """Request type for review.""" + + id: str + data: str + + +@dataclass +class ApprovalRequest: + """Request type for approval.""" + + id: str + task: str + + +@dataclass +class UnknownRequest: + """Request type with no registered handler.""" + + id: str + + +# --------------------------------------------------------------------------- +# Executors +# --------------------------------------------------------------------------- + + +class ReviewerExecutor(Executor): + """Requests a review and records whether the response was processed.""" + + def __init__(self): + super().__init__("reviewer") + self.feedback_received: bool = False + self.feedback_value: str | None = None + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = ReviewRequest(id="r1", data=f"Review: {message}") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: ReviewRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + self.feedback_received = True + self.feedback_value = response + await ctx.send_message(f"review_done:{response}") + + +class ApproverExecutor(Executor): + """Requests approval and records whether the response was processed.""" + + def __init__(self): + super().__init__("approver") + self.approval_received: bool = False + self.approval_value: str | None = None + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = ApprovalRequest(id="a1", task=f"Approve: {message}") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: ApprovalRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + self.approval_received = True + self.approval_value = response + await ctx.send_message(f"approval_done:{response}") + + +class NoneResponseExecutor(Executor): + """Requests info expecting a None response (valid value, not 'no response').""" + + def __init__(self): + super().__init__("none_requester") + self.handler_invoked: bool = False + self.received_value: Any = "NOT_SET" + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = ReviewRequest(id="n1", data="needs-none") + # Allow None as a valid response type + await ctx.request_info(request_data=request, response_type=str | None) + + @response_handler + async def on_response( + self, + original_request: ReviewRequest, + response: str | None, + ctx: WorkflowContext[str], + ) -> None: + self.handler_invoked = True + self.received_value = response + + +class UnknownRequestExecutor(Executor): + """Emits a request type with no handler registered.""" + + def __init__(self): + super().__init__("unknown_requester") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = UnknownRequest(id="u1") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: UnknownRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + pass # Should never be reached + + +class SimpleExecutor(Executor): + """Executor that just passes messages through without requesting info.""" + + def __init__(self): + super().__init__("simple") + + @handler + async def start(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(f"echo:{message}") + + +class CollectorExecutor(Executor): + """Collects messages for verification.""" + + def __init__(self): + super().__init__("collector") + self.collected: list[str] = [] + + @handler + async def collect(self, message: str, ctx: WorkflowContext) -> None: + self.collected.append(message) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestResponseHandlers: + async def test_single_handler_submits_and_triggers_phase2(self): + """Handler response is submitted back and the executor's response_handler runs.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() + + async def handle_review(request: ReviewRequest) -> str: + return "lgtm" + + result = await workflow.run( + "test_msg", + response_handlers={ReviewRequest: handle_review}, + ) + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "lgtm" + assert result.get_final_state() == WorkflowRunState.IDLE + # Collector should have received the message from Phase 2 + assert any("review_done:lgtm" in msg for msg in collector.collected) + + async def test_multiple_handlers_dispatched_by_type(self): + """Two different request types are dispatched to their respective handlers.""" + reviewer = ReviewerExecutor() + approver = ApproverExecutor() + collector = CollectorExecutor() + + # Use a dispatcher to fan out to both executors + class Dispatcher(Executor): + def __init__(self): + super().__init__("dispatcher") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + await ctx.send_message(message) + + dispatcher = Dispatcher() + workflow = ( + WorkflowBuilder(start_executor=dispatcher) + .add_edge(dispatcher, reviewer) + .add_edge(dispatcher, approver) + .add_edge(reviewer, collector) + .add_edge(approver, collector) + .build() + ) + + async def handle_review(request: ReviewRequest) -> str: + return "review_ok" + + async def handle_approval(request: ApprovalRequest) -> str: + return "approved" + + result = await workflow.run( + "multi_test", + response_handlers={ + ReviewRequest: handle_review, + ApprovalRequest: handle_approval, + }, + ) + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "review_ok" + assert approver.approval_received is True + assert approver.approval_value == "approved" + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_handler_returning_none_is_valid_response(self): + """A handler that returns None should still submit the response (sentinel fix).""" + none_exec = NoneResponseExecutor() + workflow = WorkflowBuilder(start_executor=none_exec).build() + + async def handle_returning_none(request: ReviewRequest) -> None: + return None + + result = await workflow.run( + "test_none", + response_handlers={ReviewRequest: handle_returning_none}, + ) + + # The handler matched and returned None — that None should be submitted + assert none_exec.handler_invoked is True + assert none_exec.received_value is None + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_handler_exception_leaves_request_pending(self, caplog): + """A handler that raises keeps the request pending and logs an error.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + async def failing_handler(request: ReviewRequest) -> str: + raise RuntimeError("handler boom") + + with caplog.at_level(logging.ERROR): + result = await workflow.run( + "fail_test", + response_handlers={ReviewRequest: failing_handler}, + ) + + assert reviewer.feedback_received is False + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert "handler boom" in caplog.text + assert "ReviewRequest" in caplog.text + + async def test_unmatched_type_logs_warning(self, caplog): + """A request type with no matching handler logs a warning and stays pending.""" + unknown_exec = UnknownRequestExecutor() + workflow = WorkflowBuilder(start_executor=unknown_exec).build() + + # Register a handler for a *different* type + async def handle_review(request: ReviewRequest) -> str: + return "unused" + + with caplog.at_level(logging.WARNING): + result = await workflow.run( + "unknown_test", + response_handlers={ReviewRequest: handle_review}, + ) + + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert "UnknownRequest" in caplog.text + assert "ReviewRequest" in caplog.text + + async def test_stream_true_with_handlers_raises(self): + """stream=True + response_handlers raises ValueError immediately.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + async def handle_review(request: ReviewRequest) -> str: + return "x" + + with pytest.raises(ValueError, match="stream=True"): + workflow.run( + "test", + stream=True, + response_handlers={ReviewRequest: handle_review}, + ) + + async def test_responses_with_handlers_raises(self): + """response_handlers + responses raises ValueError immediately.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + async def handle_review(request: ReviewRequest) -> str: + return "x" + + with pytest.raises(ValueError, match="response_handlers.*responses"): + await workflow.run( + responses={"some_id": "value"}, + response_handlers={ReviewRequest: handle_review}, + ) + + async def test_no_request_info_handlers_unused(self): + """Handlers registered but no request_info events — workflow completes normally.""" + simple = SimpleExecutor() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=simple).add_edge(simple, collector).build() + + async def handle_review(request: ReviewRequest) -> str: + return "unused" + + result = await workflow.run( + "hello", + response_handlers={ReviewRequest: handle_review}, + ) + + assert result.get_final_state() == WorkflowRunState.IDLE + assert len(result.get_request_info_events()) == 0 + + async def test_none_handlers_preserves_behavior(self): + """response_handlers=None preserves original IDLE_WITH_PENDING_REQUESTS state.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + result = await workflow.run("test_default", response_handlers=None) + + assert reviewer.feedback_received is False + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + async def test_handlers_run_concurrently(self): + """Two 200ms handlers complete in under 350ms (proving concurrency).""" + reviewer = ReviewerExecutor() + approver = ApproverExecutor() + + class Dispatcher(Executor): + def __init__(self): + super().__init__("dispatcher") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + await ctx.send_message(message) + + dispatcher = Dispatcher() + workflow = ( + WorkflowBuilder(start_executor=dispatcher) + .add_edge(dispatcher, reviewer) + .add_edge(dispatcher, approver) + .build() + ) + + async def slow_review(request: ReviewRequest) -> str: + await asyncio.sleep(0.2) + return "reviewed" + + async def slow_approval(request: ApprovalRequest) -> str: + await asyncio.sleep(0.2) + return "approved" + + start = time.monotonic() + await workflow.run( + "concurrent_test", + response_handlers={ + ReviewRequest: slow_review, + ApprovalRequest: slow_approval, + }, + ) + elapsed = time.monotonic() - start + + # Two 200ms handlers running concurrently should be well under 500ms + # (generous margin for CI/test overhead; sequential would be ~400ms+) + assert elapsed < 0.5, f"Expected < 500ms, got {elapsed * 1000:.0f}ms" + + async def test_events_merged_from_both_phases(self): + """Result contains request_info events from Phase 1 AND executor events from Phase 2.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() + + async def handle_review(request: ReviewRequest) -> str: + return "feedback" + + result = await workflow.run( + "merge_test", + response_handlers={ReviewRequest: handle_review}, + ) + + # Phase 1 should have produced request_info events + request_events = result.get_request_info_events() + assert len(request_events) >= 1 + assert isinstance(request_events[0].data, ReviewRequest) + + # Phase 2 should have produced executor events (reviewer response handler ran) + executor_invoked_events = [e for e in result if e.type == "executor_invoked"] + assert len(executor_invoked_events) >= 2 # Phase 1 + Phase 2 + + # Final state should be IDLE (Phase 2 completed) + assert result.get_final_state() == WorkflowRunState.IDLE diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index 7b368335a3..f27c058189 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -88,6 +88,7 @@ Once comfortable with these, explore the rest of the samples below. | SequentialBuilder Request Info | [human-in-the-loop/sequential_request_info.py](./human-in-the-loop/sequential_request_info.py) | Request info for agent responses mid-workflow using `.with_request_info()` on SequentialBuilder | | ConcurrentBuilder Request Info | [human-in-the-loop/concurrent_request_info.py](./human-in-the-loop/concurrent_request_info.py) | Review concurrent agent outputs before aggregation using `.with_request_info()` on ConcurrentBuilder | | GroupChatBuilder Request Info | [human-in-the-loop/group_chat_request_info.py](./human-in-the-loop/group_chat_request_info.py) | Steer group discussions with periodic guidance using `.with_request_info()` on GroupChatBuilder | +| Fan-Out Async with Response Handlers | [human-in-the-loop/fan_out_async_with_response_handlers.py](./human-in-the-loop/fan_out_async_with_response_handlers.py) | Type-based response handler dispatch for HITL requests in fan-out workflows; automatic async handler execution and response submission | ### tool-approval diff --git a/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py new file mode 100644 index 0000000000..7bd548552c --- /dev/null +++ b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py @@ -0,0 +1,256 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Sample: Fan-out Async + HITL with response_handlers parameter + +Demonstrates automatic HITL request handling in fan-out workflows using the +response_handlers parameter. Execution is sequential in two phases: + +1. Phase 1: The workflow runs to idle, collecting any request_info events. +2. Matching handlers are called concurrently via asyncio.gather. +3. Phase 2: Collected responses are submitted back and the workflow runs + again to process them. + +Usage: + response_handlers = { + ReviewRequest: handle_review, + ApprovalRequest: handle_approval, + } + + result = await workflow.run( + initial_data, + response_handlers=response_handlers, + ) +""" + +import asyncio +from dataclasses import dataclass + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowContext, + handler, + response_handler, +) + + +@dataclass +class AnalysisData: + """Initial data for analysis.""" + + task: str + data: str + + +@dataclass +class DataPacket: + """Data flowing through the async processor.""" + + iteration: int + content: str + analysis: str = "" + + +@dataclass +class ReviewRequest: + """Request sent to human reviewer.""" + + packet: DataPacket + analysis: str + prompt: str = "Please review the analysis and provide feedback" + + +class Analyzer(Executor): + def __init__(self, id: str = "analyzer"): + super().__init__(id) + + @handler + async def start_analysis( + self, + data: AnalysisData, + ctx: WorkflowContext[DataPacket, DataPacket], + ) -> None: + packet = DataPacket( + iteration=0, + content=data.data, + analysis=f"Starting {data.task}...", + ) + await ctx.send_message(packet) + + +class AsyncProcessor(Executor): + """Async branch: loops and processes data.""" + + def __init__(self, id: str = "processor"): + super().__init__(id) + self.iteration_count = 0 + + @handler + async def process_data( + self, + packet: DataPacket, + ctx: WorkflowContext[DataPacket, DataPacket], + ) -> None: + self.iteration_count += 1 + await asyncio.sleep(0.5) + analysis = f"Iteration {self.iteration_count}: {packet.content[:50]}... processed" + + if self.iteration_count < 3: + updated_packet = DataPacket( + iteration=self.iteration_count, + content=packet.content, + analysis=analysis, + ) + await ctx.send_message(updated_packet) + else: + await ctx.send_message( + DataPacket( + iteration=self.iteration_count, + content=packet.content, + analysis=analysis, + ) + ) + + +class Reviewer(Executor): + """HITL branch: requests external feedback.""" + + def __init__(self, id: str = "reviewer"): + super().__init__(id) + + @handler + async def review_data( + self, + packet: DataPacket, + ctx: WorkflowContext[str, str], + ) -> None: + review_request = ReviewRequest( + packet=packet, + analysis=packet.analysis or "No analysis yet", + prompt=f"Please review iteration {packet.iteration}", + ) + + await ctx.request_info( + request_data=review_request, + response_type=str, + ) + + @response_handler + async def handle_review_feedback( + self, + original_request: ReviewRequest, + feedback: str, + ctx: WorkflowContext[str], + ) -> None: + result = f"Review feedback processed: {feedback}" + await ctx.send_message(result) + + +class FinalAggregator(Executor): + def __init__(self, id: str = "aggregator"): + super().__init__(id) + self.results: list[str] = [] + + @handler + async def aggregate( + self, + message: DataPacket | str, + ctx: WorkflowContext[str], + ) -> None: + self.results.append(str(message)) + + if len(self.results) >= 2: + summary = "=== WORKFLOW COMPLETE ===\n" + summary += "Results from async path: " + str(self.results[0]) + "\n" + summary += "Results from HITL path: " + str(self.results[1]) + await ctx.send_message(summary) + + +# ============================================================================ +# Response Handlers (External) +# ============================================================================ +# These are registered via response_handlers dict, not as executor methods. +# They run after Phase 1 completes and their responses are submitted in Phase 2. + + +async def handle_review(request: ReviewRequest) -> str: + """Handle external review request. + + Called after Phase 1 when the Reviewer emits a request_info with + ReviewRequest type. The response is submitted back to the workflow + for Phase 2 processing. + """ + print(f"\n[Handler] Processing review request for iteration {request.packet.iteration}") + + # Simulate external API call + await asyncio.sleep(2.0) + + feedback = "Analysis looks good. Continue processing." + print(f"[Handler] Review feedback: {feedback}\n") + + return feedback + + +# ============================================================================ +# Main +# ============================================================================ + + +async def main() -> None: + print("=" * 80) + print("HITL: response_handlers parameter in workflow.run()") + print("=" * 80) + + # Create executors + analyzer = Analyzer() + processor = AsyncProcessor() + reviewer = Reviewer() + aggregator = FinalAggregator() + + # Build workflow + workflow = ( + WorkflowBuilder(start_executor=analyzer) + .add_edge(analyzer, processor) + .add_edge(analyzer, reviewer) + .add_edge(processor, aggregator) + .add_edge(reviewer, aggregator) + .build() + ) + + initial_data = AnalysisData( + task="document analysis", + data="Sample document to process", + ) + + # ======================================================================== + # THE CLEAN API + # ======================================================================== + # Define external response handlers (type-based dispatch) + response_handlers = { + ReviewRequest: handle_review, + # Can add more: ApprovalRequest: handle_approval, etc. + } + + print(f"\n[Start] Task: {initial_data.task}\n") + + # Run workflow with automatic response handling (two-phase execution) + # Phase 1: workflow runs to idle, collecting request_info events + # Handlers run concurrently, then responses are submitted for Phase 2 + result = await workflow.run( + initial_data, + response_handlers=response_handlers, + ) + + # Display results + print(f"\nFinal state: {result.get_final_state()}") + print(f"Total events: {len(result)}") + outputs = result.get_outputs() + if outputs: + print(f"Outputs ({len(outputs)}):") + for output in outputs: + print(f" - {output}") + + +if __name__ == "__main__": + asyncio.run(main()) From 2632a7acebaba2e4f93a04c533a8af803d502025 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 11 Feb 2026 15:56:45 +0900 Subject: [PATCH 2/7] Refactor response_handlers to inline dispatch for fan-out HITL workflows --- .../agent_framework/_workflows/_runner.py | 19 +- .../agent_framework/_workflows/_workflow.py | 220 ++++++------------ .../test_workflow_response_handlers.py | 130 +++++++++-- .../fan_out_async_with_response_handlers.py | 96 +++++--- 4 files changed, 272 insertions(+), 193 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index f3a475e034..7e935c17a3 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -79,7 +79,10 @@ def reset_iteration_count(self) -> None: """Reset the iteration count to zero.""" self._iteration = 0 - async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: + async def run_until_convergence( + self, + outstanding_handler_tasks: set[asyncio.Task[None]] | None = None, + ) -> AsyncGenerator[WorkflowEvent, None]: """Run the workflow until no more messages are sent.""" if self._running: raise WorkflowRunnerException("Runner is already running.") @@ -121,6 +124,12 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: iteration_task.cancel() with contextlib.suppress(asyncio.CancelledError): await iteration_task + if outstanding_handler_tasks: + for t in outstanding_handler_tasks: + if not t.done(): + t.cancel() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*outstanding_handler_tasks, return_exceptions=True) raise # Propagate errors from iteration, but first surface any pending events @@ -151,6 +160,14 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: # Check for convergence: no more messages to process if not await self._ctx.has_messages(): + if outstanding_handler_tasks: + still_running = {t for t in outstanding_handler_tasks if not t.done()} + while still_running: + _, still_running = await asyncio.wait(still_running, return_when=asyncio.FIRST_COMPLETED) + if await self._ctx.has_messages(): + break + if await self._ctx.has_messages(): + continue break if self._iteration >= self._max_iterations and await self._ctx.has_messages(): diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 5704d6cf8d..c453ebadb6 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -37,15 +37,6 @@ logger = logging.getLogger(__name__) -class _NoResponseType: - """Sentinel type to distinguish 'handler failed/unmatched' from 'handler returned None'.""" - - pass - - -_NO_RESPONSE = _NoResponseType() - - class WorkflowRunResult(list[WorkflowEvent]): """Container for events generated during non-streaming workflow execution. @@ -231,9 +222,6 @@ def __init__( # Flag to prevent concurrent workflow executions self._is_running = False - # Store responses collected from response_handlers for submission after stream completes - self._pending_handler_responses: dict[str, Any] | None = None - # Capture a canonical fingerprint of the workflow graph so checkpoints # can assert they are resumed with an equivalent topology. self._graph_signature = self._compute_graph_signature() @@ -309,6 +297,7 @@ async def _run_workflow_with_tracing( reset_context: bool = True, streaming: bool = False, run_kwargs: dict[str, Any] | None = None, + outstanding_handler_tasks: set[asyncio.Task[None]] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -320,6 +309,7 @@ async def _run_workflow_with_tracing( reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents run_kwargs: Optional kwargs to store in State for agent invocations + outstanding_handler_tasks: Optional set to track outstanding response handler tasks Yields: WorkflowEvent: The events generated during the workflow execution. @@ -335,7 +325,6 @@ async def _run_workflow_with_tracing( OtelAttr.WORKFLOW_RUN_SPAN, attributes, ) as span: - saw_request = False emitted_in_progress_pending = False try: # Add workflow started event (telemetry + surface state to consumers) @@ -367,10 +356,9 @@ async def _run_workflow_with_tracing( await initial_executor_fn() # All executor executions happen within workflow span - async for event in self._runner.run_until_convergence(): - # Track request events for final status determination - if event.type == "request_info": - saw_request = True + async for event in self._runner.run_until_convergence( + outstanding_handler_tasks=outstanding_handler_tasks, + ): yield event if event.type == "request_info" and not emitted_in_progress_pending: @@ -380,7 +368,8 @@ async def _run_workflow_with_tracing( yield pending_status # Workflow runs until idle - emit final status based on whether requests are pending - if saw_request: + pending = await self._runner_context.get_pending_request_info_events() + if pending: with _framework_event_origin(): terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS) yield terminal_status @@ -512,13 +501,11 @@ def run( exclusive with message. Can be combined with checkpoint_id to restore a checkpoint and send responses in a single call. response_handlers: Dict mapping request data types to async handler functions. - Requires stream=False (default). Mutually exclusive with ``responses``. - Execution is sequential in two phases: - 1. Phase 1: Workflow runs to idle, collecting request_info events. - 2. All matching handlers are called concurrently via asyncio.gather. - 3. Phase 2: Collected responses are submitted back and the workflow - runs again to process them. - The combined WorkflowRunResult contains events from both phases. + Mutually exclusive with ``responses``. Handlers are dispatched inline + as asyncio tasks when request_info events are emitted during execution. + The runner waits for outstanding handler tasks before declaring convergence, + so handler responses are processed in subsequent supersteps within the + same workflow run. checkpoint_id: ID of checkpoint to restore from. Can be used alone (resume from checkpoint), with message (not allowed), or with responses (restore then send responses). @@ -536,12 +523,6 @@ def run( """ # Validate parameters eagerly (before any async work or setting running flag) self._validate_run_params(message, responses, checkpoint_id, response_handlers) - if stream and response_handlers is not None: - raise ValueError( - "Cannot use response_handlers with stream=True. " - "Response handlers require non-streaming mode (stream=False) " - "to automatically submit responses after execution completes." - ) self._ensure_not_running() response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( @@ -562,69 +543,36 @@ def run( if stream: return response_stream - if response_handlers is not None: - # Wrap to handle automatic response submission after stream completes - return self._run_with_response_submission( - response_stream, - checkpoint_storage=checkpoint_storage, - include_status_events=include_status_events, - **kwargs, - ) return response_stream.get_final_response() - async def _run_with_response_submission( + async def _invoke_response_handler( self, - response_stream: ResponseStream[WorkflowEvent, WorkflowRunResult], - checkpoint_storage: CheckpointStorage | None = None, - include_status_events: bool = False, - **kwargs: Any, - ) -> WorkflowRunResult: - """Execute response_stream and submit collected handler responses as Phase 2. - - After the main stream completes, if response_handlers collected any responses, - this method submits them back to the workflow via ``run(responses=...)``. - - Args: - response_stream: The ResponseStream from Phase 1 to execute. - checkpoint_storage: Checkpoint storage (forwarded to Phase 2). - include_status_events: Whether to include status events. - **kwargs: Additional kwargs forwarded to Phase 2. + handler_fn: Callable[[Any], Awaitable[Any]], + request_id: str, + request_data: Any, + ) -> None: + """Invoke a response handler and submit the result back to the workflow. - Returns: - WorkflowRunResult with merged events from Phase 1 and Phase 2. + On success, the response is submitted via send_request_info_response which + pops the pending request and injects a response message for the next superstep. + On failure, the request stays pending and the workflow converges as + IDLE_WITH_PENDING_REQUESTS. """ - # Clear stale state from any prior run - self._pending_handler_responses = None - - # Phase 1: execute the main stream (handlers populate _pending_handler_responses) - result = await response_stream.get_final_response() - - # Phase 2: submit collected responses into the current workflow state - if self._pending_handler_responses: - logger.debug( - f"Response handlers collected {len(self._pending_handler_responses)} " - f"responses. Submitting back to workflow for continued processing." - ) - collected_responses = self._pending_handler_responses - self._pending_handler_responses = None - - continuation_result = await self.run( - responses=collected_responses, - checkpoint_storage=checkpoint_storage, - stream=False, - include_status_events=include_status_events, - **kwargs, - ) - - merged_events = list(result) + list(continuation_result) - merged_status_events = result._status_events + continuation_result._status_events - return WorkflowRunResult( - events=merged_events, - status_events=merged_status_events, + try: + response = await handler_fn(request_data) + except Exception: + logger.exception(f"Response handler failed for request {request_id} (type={type(request_data).__name__})") + return + + try: + await self._runner_context.send_request_info_response(request_id, response) + except Exception: + logger.exception( + f"Failed to submit response for request {request_id} " + f"(type={type(request_data).__name__}). " + f"Handler succeeded but response could not be delivered." ) - return result - async def _run_core( self, message: Any | None = None, @@ -638,12 +586,14 @@ async def _run_core( ) -> AsyncIterable[WorkflowEvent]: """Single core execution path for both streaming and non-streaming modes. + When response_handlers are provided, handlers are dispatched inline as + asyncio tasks when request_info events are emitted. The runner's convergence + check waits for outstanding handler tasks before deciding the workflow is idle, + allowing handler responses to be processed in subsequent supersteps. + Yields: WorkflowEvent: The events generated during the workflow execution. """ - # Clear stale handler responses from any prior run - self._pending_handler_responses = None - # Enable runtime checkpointing if storage provided if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) @@ -652,65 +602,45 @@ async def _run_core( message, responses, checkpoint_id, checkpoint_storage ) - # Collect pending requests if handlers provided - pending_requests: dict[str, Any] = {} + outstanding_tasks: set[asyncio.Task[None]] | None = set() if response_handlers else None - async for event in self._run_workflow_with_tracing( - initial_executor_fn=initial_executor_fn, - reset_context=reset_context, - streaming=streaming, - run_kwargs=kwargs if kwargs else None, - ): - if event.type == "output" and not self._should_yield_output_event(event): - continue + try: + async for event in self._run_workflow_with_tracing( + initial_executor_fn=initial_executor_fn, + reset_context=reset_context, + streaming=streaming, + run_kwargs=kwargs if kwargs else None, + outstanding_handler_tasks=outstanding_tasks, + ): + if event.type == "output" and not self._should_yield_output_event(event): + continue - # Collect request_info events for external handling - if event.type == "request_info" and response_handlers: - pending_requests[event.request_id] = event.data - - yield event - - # After main execution, dispatch response handlers for pending requests - if response_handlers and pending_requests: - - async def get_response(request_id: str, request_data: Any) -> tuple[str, Any]: - request_data_type = type(request_data) - matched_handler = response_handlers.get(request_data_type) - if matched_handler: - try: - response = await matched_handler(request_data) - return (request_id, response) - except Exception: - logger.exception( - f"Error in response handler for request {request_id} (type={request_data_type.__name__})" + # Dispatch response handlers inline as tasks + if event.type == "request_info" and response_handlers is not None and outstanding_tasks is not None: + request_data_type = type(event.data) + matched_handler = response_handlers.get(request_data_type) + if matched_handler: + task = asyncio.create_task( + self._invoke_response_handler(matched_handler, event.request_id, event.data) + ) + outstanding_tasks.add(task) + task.add_done_callback(outstanding_tasks.discard) + else: + registered_types = [t.__name__ for t in response_handlers] + logger.warning( + f"No response handler registered for request type " + f"{request_data_type.__name__} (request_id={event.request_id}). " + f"Registered types: {registered_types}" ) - return (request_id, _NO_RESPONSE) - else: - registered_types = [t.__name__ for t in response_handlers] - logger.warning( - f"No response handler registered for request type " - f"{request_data_type.__name__} (request_id={request_id}). " - f"Registered types: {registered_types}" - ) - return (request_id, _NO_RESPONSE) - - # Run all handlers concurrently, capturing exceptions as return values - response_tasks = [get_response(req_id, req_data) for req_id, req_data in pending_requests.items()] - responses_list = await asyncio.gather(*response_tasks, return_exceptions=True) - - # Filter out failed/unmatched results - collected_responses: dict[str, Any] = {} - for item in responses_list: - if isinstance(item, BaseException): - logger.error(f"Unexpected exception in response handler gather: {item}") - continue - req_id, resp = item - if not isinstance(resp, _NoResponseType): - collected_responses[req_id] = resp - # Store for use in _run_with_response_submission - if collected_responses: - self._pending_handler_responses = collected_responses + yield event + finally: + # Cancel any outstanding handler tasks on error/cancellation + if outstanding_tasks: + for t in outstanding_tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*outstanding_tasks, return_exceptions=True) async def _run_cleanup(self, checkpoint_storage: CheckpointStorage | None) -> None: """Cleanup hook called after stream consumption.""" diff --git a/python/packages/core/tests/workflow/test_workflow_response_handlers.py b/python/packages/core/tests/workflow/test_workflow_response_handlers.py index 6ad63e711e..1da3dd6b8a 100644 --- a/python/packages/core/tests/workflow/test_workflow_response_handlers.py +++ b/python/packages/core/tests/workflow/test_workflow_response_handlers.py @@ -4,7 +4,7 @@ Tests for workflow response_handlers parameter. Verifies automatic HITL request handling: type-based dispatch, concurrent execution, -error handling, parameter validation, and two-phase event merging. +error handling, parameter validation, and inline handler dispatch. """ import asyncio @@ -182,7 +182,7 @@ async def collect(self, message: str, ctx: WorkflowContext) -> None: class TestResponseHandlers: - async def test_single_handler_submits_and_triggers_phase2(self): + async def test_single_handler_submits_response(self): """Handler response is submitted back and the executor's response_handler runs.""" reviewer = ReviewerExecutor() collector = CollectorExecutor() @@ -199,7 +199,7 @@ async def handle_review(request: ReviewRequest) -> str: assert reviewer.feedback_received is True assert reviewer.feedback_value == "lgtm" assert result.get_final_state() == WorkflowRunState.IDLE - # Collector should have received the message from Phase 2 + # Collector should have received the message from the response handler assert any("review_done:lgtm" in msg for msg in collector.collected) async def test_multiple_handlers_dispatched_by_type(self): @@ -303,20 +303,29 @@ async def handle_review(request: ReviewRequest) -> str: assert "UnknownRequest" in caplog.text assert "ReviewRequest" in caplog.text - async def test_stream_true_with_handlers_raises(self): - """stream=True + response_handlers raises ValueError immediately.""" + async def test_stream_true_with_handlers_works(self): + """stream=True + response_handlers works with inline dispatch.""" reviewer = ReviewerExecutor() - workflow = WorkflowBuilder(start_executor=reviewer).build() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() async def handle_review(request: ReviewRequest) -> str: - return "x" + return "stream_lgtm" - with pytest.raises(ValueError, match="stream=True"): - workflow.run( - "test", - stream=True, - response_handlers={ReviewRequest: handle_review}, - ) + events = [] + stream = workflow.run( + "stream_test", + stream=True, + response_handlers={ReviewRequest: handle_review}, + ) + async for event in stream: + events.append(event) + result = await stream.get_final_response() + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "stream_lgtm" + assert result.get_final_state() == WorkflowRunState.IDLE + assert any("review_done:stream_lgtm" in msg for msg in collector.collected) async def test_responses_with_handlers_raises(self): """response_handlers + responses raises ValueError immediately.""" @@ -402,8 +411,8 @@ async def slow_approval(request: ApprovalRequest) -> str: # (generous margin for CI/test overhead; sequential would be ~400ms+) assert elapsed < 0.5, f"Expected < 500ms, got {elapsed * 1000:.0f}ms" - async def test_events_merged_from_both_phases(self): - """Result contains request_info events from Phase 1 AND executor events from Phase 2.""" + async def test_events_include_request_and_response_handling(self): + """Result contains request_info events and executor events from response processing.""" reviewer = ReviewerExecutor() collector = CollectorExecutor() workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() @@ -416,14 +425,97 @@ async def handle_review(request: ReviewRequest) -> str: response_handlers={ReviewRequest: handle_review}, ) - # Phase 1 should have produced request_info events + # Should have produced request_info events request_events = result.get_request_info_events() assert len(request_events) >= 1 assert isinstance(request_events[0].data, ReviewRequest) - # Phase 2 should have produced executor events (reviewer response handler ran) + # Response handler should have triggered executor events (reviewer response handler ran) executor_invoked_events = [e for e in result if e.type == "executor_invoked"] - assert len(executor_invoked_events) >= 2 # Phase 1 + Phase 2 + assert len(executor_invoked_events) >= 2 # Initial + response handling + + # Final state should be IDLE (response was processed) + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_handler_with_looping_branch(self): + """Fan-out with self-looping branch + HITL handler: handler dispatches while loop runs.""" + + @dataclass + class LoopMessage: + iteration: int + content: str + + class LoopingProcessor(Executor): + """Loops a fixed number of times via self-send. Stops by not sending.""" + + def __init__(self): + super().__init__("looper") + self.iteration_count = 0 + + @handler + async def process(self, message: LoopMessage, ctx: WorkflowContext[LoopMessage]) -> None: + self.iteration_count += 1 + if self.iteration_count < 3: + await ctx.send_message(LoopMessage(iteration=self.iteration_count, content="processing")) + + class HITLExecutor(Executor): + """Requests info from external handler.""" + + def __init__(self): + super().__init__("hitl") + self.response_received = False + self.response_value: str | None = None + + @handler + async def start(self, message: LoopMessage, ctx: WorkflowContext) -> None: + request = ReviewRequest(id="hitl_req", data=f"Review: {message.content}") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: ReviewRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + self.response_received = True + self.response_value = response + + class FanOutDispatcher(Executor): + def __init__(self): + super().__init__("fanout") + + @handler + async def start(self, message: str, ctx: WorkflowContext[LoopMessage]) -> None: + await ctx.send_message(LoopMessage(iteration=0, content=message)) + + fanout = FanOutDispatcher() + looper = LoopingProcessor() + hitl = HITLExecutor() + + workflow = ( + WorkflowBuilder(start_executor=fanout) + .add_edge(fanout, looper) + .add_edge(fanout, hitl) + .add_edge(looper, looper) # self-loop (stops when no message sent) + .build() + ) + + async def handle_review(request: ReviewRequest) -> str: + await asyncio.sleep(0.1) # Simulate async work + return "approved_by_handler" + + result = await workflow.run( + "fan_out_test", + response_handlers={ReviewRequest: handle_review}, + ) + + # Looping branch completed its iterations + assert looper.iteration_count == 3 + + # HITL handler was dispatched and response processed + assert hitl.response_received is True + assert hitl.response_value == "approved_by_handler" - # Final state should be IDLE (Phase 2 completed) + # Final state should be IDLE (all branches complete, all requests handled) assert result.get_final_state() == WorkflowRunState.IDLE diff --git a/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py index 7bd548552c..babb2d749a 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py @@ -4,12 +4,10 @@ Sample: Fan-out Async + HITL with response_handlers parameter Demonstrates automatic HITL request handling in fan-out workflows using the -response_handlers parameter. Execution is sequential in two phases: - -1. Phase 1: The workflow runs to idle, collecting any request_info events. -2. Matching handlers are called concurrently via asyncio.gather. -3. Phase 2: Collected responses are submitted back and the workflow runs - again to process them. +response_handlers parameter. Handlers are dispatched inline as asyncio tasks +when request_info events are emitted during execution. The runner waits for +outstanding handler tasks before declaring convergence, so handler responses +are processed in subsequent supersteps within the same workflow run. Usage: response_handlers = { @@ -24,6 +22,7 @@ """ import asyncio +import time from dataclasses import dataclass from agent_framework import ( @@ -34,6 +33,13 @@ response_handler, ) +_start_time: float = 0.0 + + +def _ts() -> str: + """Return elapsed seconds since workflow start.""" + return f"{time.monotonic() - _start_time:.1f}s" + @dataclass class AnalysisData: @@ -95,8 +101,9 @@ async def process_data( self.iteration_count += 1 await asyncio.sleep(0.5) analysis = f"Iteration {self.iteration_count}: {packet.content[:50]}... processed" + print(f" [{_ts()}] Processor: iteration {self.iteration_count}/10 complete") - if self.iteration_count < 3: + if self.iteration_count < 10: updated_packet = DataPacket( iteration=self.iteration_count, content=packet.content, @@ -104,12 +111,10 @@ async def process_data( ) await ctx.send_message(updated_packet) else: - await ctx.send_message( - DataPacket( - iteration=self.iteration_count, - content=packet.content, - analysis=analysis, - ) + # Stop sending to break the self-loop. Use yield_output for final result. + print(f" [{_ts()}] Processor: DONE after {self.iteration_count} iterations") + await ctx.yield_output( + DataPacket(iteration=self.iteration_count, content=packet.content, analysis=analysis) ) @@ -143,6 +148,7 @@ async def handle_review_feedback( feedback: str, ctx: WorkflowContext[str], ) -> None: + print(f" [{_ts()}] Reviewer: response_handler invoked with feedback") result = f"Review feedback processed: {feedback}" await ctx.send_message(result) @@ -171,23 +177,23 @@ async def aggregate( # Response Handlers (External) # ============================================================================ # These are registered via response_handlers dict, not as executor methods. -# They run after Phase 1 completes and their responses are submitted in Phase 2. +# They are dispatched inline as asyncio tasks when request_info events are emitted. async def handle_review(request: ReviewRequest) -> str: """Handle external review request. - Called after Phase 1 when the Reviewer emits a request_info with - ReviewRequest type. The response is submitted back to the workflow - for Phase 2 processing. + Dispatched as an asyncio task when the Reviewer emits a request_info + with ReviewRequest type. The response is submitted back to the workflow + and processed in a subsequent superstep. """ - print(f"\n[Handler] Processing review request for iteration {request.packet.iteration}") + print(f" [{_ts()}] Handler: STARTED - reviewing iteration {request.packet.iteration}") - # Simulate external API call - await asyncio.sleep(2.0) + # Simulate slow external API call (e.g., LLM inference, human approval system) + await asyncio.sleep(3.0) feedback = "Analysis looks good. Continue processing." - print(f"[Handler] Review feedback: {feedback}\n") + print(f" [{_ts()}] Handler: DONE - returning feedback") return feedback @@ -208,11 +214,12 @@ async def main() -> None: reviewer = Reviewer() aggregator = FinalAggregator() - # Build workflow + # Build workflow: processor self-loops, both branches feed into aggregator workflow = ( WorkflowBuilder(start_executor=analyzer) .add_edge(analyzer, processor) .add_edge(analyzer, reviewer) + .add_edge(processor, processor) # self-loop for async processing .add_edge(processor, aggregator) .add_edge(reviewer, aggregator) .build() @@ -232,25 +239,58 @@ async def main() -> None: # Can add more: ApprovalRequest: handle_approval, etc. } - print(f"\n[Start] Task: {initial_data.task}\n") + global _start_time + _start_time = time.monotonic() - # Run workflow with automatic response handling (two-phase execution) - # Phase 1: workflow runs to idle, collecting request_info events - # Handlers run concurrently, then responses are submitted for Phase 2 + print(f"\n[{_ts()}] Start: {initial_data.task}") + + # Run workflow with automatic response handling (inline dispatch) + # Handlers are dispatched as asyncio tasks when request_info events are emitted + # Responses are injected back and processed in subsequent supersteps result = await workflow.run( initial_data, response_handlers=response_handlers, ) + elapsed = time.monotonic() - _start_time + # Display results - print(f"\nFinal state: {result.get_final_state()}") - print(f"Total events: {len(result)}") + print(f"\n[{_ts()}] Final state: {result.get_final_state()}") + print(f" Total time: {elapsed:.1f}s") + print(f" Total events: {len(result)}") outputs = result.get_outputs() if outputs: print(f"Outputs ({len(outputs)}):") for output in outputs: print(f" - {output}") + """ + Sample Output: + + [0.0s] Start: document analysis + + [0.0s] Handler: STARTED - reviewing iteration 0 + [0.5s] Processor: iteration 1/10 complete + [1.1s] Processor: iteration 2/10 complete + [1.6s] Processor: iteration 3/10 complete + [2.2s] Processor: iteration 4/10 complete + [2.7s] Processor: iteration 5/10 complete + [3.0s] Handler: DONE - returning feedback + [3.3s] Processor: iteration 6/10 complete + [3.3s] Reviewer: response_handler invoked with feedback + [3.8s] Processor: iteration 7/10 complete + [4.4s] Processor: iteration 8/10 complete + [4.9s] Processor: iteration 9/10 complete + [5.5s] Processor: iteration 10/10 complete + [5.5s] Processor: DONE after 10 iterations + + [5.6s] Final state: WorkflowRunState.IDLE + Total time: 5.6s + Total events: 70 + Outputs (1): + - DataPacket(iteration=10, content='Sample document to process', analysis='Iteration 10: Sample document to process... processed') + """ # noqa: E501 + if __name__ == "__main__": asyncio.run(main()) From 7791c87613b5200737ef34774ce8710dc7044450 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 11 Feb 2026 18:11:33 +0900 Subject: [PATCH 3/7] Rename to request handlers --- .../agent_framework/_workflows/_workflow.py | 50 +++--- .../_workflows/_workflow_builder.py | 41 ++++- .../test_workflow_response_handlers.py | 157 ++++++++++++++++-- .../fan_out_async_with_response_handlers.py | 16 +- 4 files changed, 218 insertions(+), 46 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index c453ebadb6..7e26dc28d8 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -179,6 +179,7 @@ def __init__( name: str | None = None, description: str | None = None, output_executors: list[str] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, **kwargs: Any, ): """Initialize the workflow with a list of edges. @@ -193,6 +194,8 @@ def __init__( description: Optional description of what the workflow does. output_executors: Optional list of executor IDs whose outputs will be considered workflow outputs. If None or empty, all executor outputs are treated as workflow outputs. + request_handlers: Optional default response handlers for automatic HITL request handling. + Can be overridden per-run via workflow.run(request_handlers=...). kwargs: Additional keyword arguments. Unused in this implementation. """ self.edge_groups = list(edge_groups) @@ -222,6 +225,9 @@ def __init__( # Flag to prevent concurrent workflow executions self._is_running = False + # Default response handlers (can be overridden per-run) + self._request_handlers = request_handlers + # Capture a canonical fingerprint of the workflow graph so checkpoints # can assert they are resumed with an equivalent topology. self._graph_signature = self._compute_graph_signature() @@ -453,7 +459,7 @@ def run( *, stream: Literal[True], responses: dict[str, Any] | None = None, - response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -466,7 +472,7 @@ def run( *, stream: Literal[False] = ..., responses: dict[str, Any] | None = None, - response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, @@ -479,7 +485,7 @@ def run( *, stream: bool = False, responses: dict[str, Any] | None = None, - response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, @@ -500,12 +506,13 @@ def run( request IDs and values are the corresponding response data. Mutually exclusive with message. Can be combined with checkpoint_id to restore a checkpoint and send responses in a single call. - response_handlers: Dict mapping request data types to async handler functions. + request_handlers: Dict mapping request data types to async handler functions. Mutually exclusive with ``responses``. Handlers are dispatched inline as asyncio tasks when request_info events are emitted during execution. The runner waits for outstanding handler tasks before declaring convergence, so handler responses are processed in subsequent supersteps within the - same workflow run. + same workflow run. If None, falls back to handlers set at build time via + WorkflowBuilder. Pass an empty dict to explicitly disable handlers for a run. checkpoint_id: ID of checkpoint to restore from. Can be used alone (resume from checkpoint), with message (not allowed), or with responses (restore then send responses). @@ -521,15 +528,18 @@ def run( Raises: ValueError: If parameter combination is invalid. """ + # Fall back to builder-level request_handlers if none provided at runtime + effective_handlers = request_handlers if request_handlers is not None else self._request_handlers + # Validate parameters eagerly (before any async work or setting running flag) - self._validate_run_params(message, responses, checkpoint_id, response_handlers) + self._validate_run_params(message, responses, checkpoint_id, effective_handlers) self._ensure_not_running() response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( self._run_core( message=message, responses=responses, - response_handlers=response_handlers, + request_handlers=effective_handlers, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, streaming=stream, @@ -545,7 +555,7 @@ def run( return response_stream return response_stream.get_final_response() - async def _invoke_response_handler( + async def _invoke_request_handler( self, handler_fn: Callable[[Any], Awaitable[Any]], request_id: str, @@ -578,7 +588,7 @@ async def _run_core( message: Any | None = None, *, responses: dict[str, Any] | None = None, - response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, streaming: bool = False, @@ -586,7 +596,7 @@ async def _run_core( ) -> AsyncIterable[WorkflowEvent]: """Single core execution path for both streaming and non-streaming modes. - When response_handlers are provided, handlers are dispatched inline as + When request_handlers are provided, handlers are dispatched inline as asyncio tasks when request_info events are emitted. The runner's convergence check waits for outstanding handler tasks before deciding the workflow is idle, allowing handler responses to be processed in subsequent supersteps. @@ -602,7 +612,7 @@ async def _run_core( message, responses, checkpoint_id, checkpoint_storage ) - outstanding_tasks: set[asyncio.Task[None]] | None = set() if response_handlers else None + outstanding_tasks: set[asyncio.Task[None]] | None = set() if request_handlers else None try: async for event in self._run_workflow_with_tracing( @@ -616,17 +626,17 @@ async def _run_core( continue # Dispatch response handlers inline as tasks - if event.type == "request_info" and response_handlers is not None and outstanding_tasks is not None: + if event.type == "request_info" and request_handlers is not None and outstanding_tasks is not None: request_data_type = type(event.data) - matched_handler = response_handlers.get(request_data_type) + matched_handler = request_handlers.get(request_data_type) if matched_handler: task = asyncio.create_task( - self._invoke_response_handler(matched_handler, event.request_id, event.data) + self._invoke_request_handler(matched_handler, event.request_id, event.data) ) outstanding_tasks.add(task) task.add_done_callback(outstanding_tasks.discard) else: - registered_types = [t.__name__ for t in response_handlers] + registered_types = [t.__name__ for t in request_handlers] logger.warning( f"No response handler registered for request type " f"{request_data_type.__name__} (request_id={event.request_id}). " @@ -680,7 +690,7 @@ def _validate_run_params( message: Any | None, responses: dict[str, Any] | None, checkpoint_id: str | None, - response_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, ) -> None: """Validate parameter combinations for run(). @@ -689,7 +699,7 @@ def _validate_run_params( - message and checkpoint_id are mutually exclusive - At least one of message, responses, or checkpoint_id must be provided - responses + checkpoint_id is allowed (restore then send) - - response_handlers and responses are mutually exclusive + - request_handlers and responses are mutually exclusive """ if message is not None and responses is not None: raise ValueError("Cannot provide both 'message' and 'responses'. Use one or the other.") @@ -703,10 +713,10 @@ def _validate_run_params( "or 'checkpoint_id' (resume from checkpoint)." ) - if response_handlers is not None and responses is not None: + if request_handlers is not None and responses is not None: raise ValueError( - "Cannot provide both 'response_handlers' and 'responses'. " - "Use 'response_handlers' for automatic handling or 'responses' for manual submission." + "Cannot provide both 'request_handlers' and 'responses'. " + "Use 'request_handlers' for automatic handling or 'responses' for manual submission." ) def _resolve_execution_mode( diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 14fd512e17..5afe0ed3bb 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -2,7 +2,7 @@ import logging import sys -from collections.abc import Callable, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass from typing import Any @@ -159,6 +159,7 @@ def __init__( start_executor: Executor | SupportsAgentRun | str, checkpoint_storage: CheckpointStorage | None = None, output_executors: list[Executor | SupportsAgentRun | str] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, ): """Initialize the WorkflowBuilder. @@ -171,6 +172,9 @@ def __init__( checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence. output_executors: Optional list of executors whose outputs should be collected. If not provided, outputs from all executors are collected. + request_handlers: Optional dict mapping request data types to async handler functions. + Keys must match request types declared in executor @response_handler decorators. + Validated at build time. Can be overridden per-run via workflow.run(request_handlers=...). """ self._edge_groups: list[EdgeGroup] = [] self._executors: dict[str, Executor] = {} @@ -197,6 +201,9 @@ def __init__( # Output executors filter; if set, only outputs from these executors are yielded self._output_executors: list[Executor | SupportsAgentRun | str] = output_executors if output_executors else [] + # Response handlers for automatic HITL request handling + self._request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = request_handlers + # Set the start executor self._set_start_executor(start_executor) @@ -1023,6 +1030,33 @@ def _get_executor(name: str) -> Executor: return (start_executor, factory_name_to_instance, deferred_edge_groups) + def _validate_request_handlers(self, executors: dict[str, Executor]) -> None: + """Validate that request_handlers keys match executor @response_handler request types. + + Collects all request types declared via @response_handler decorators across all + executors, then checks that each key in self._request_handlers maps to a known + request type. Raises ValueError if an unrecognized type is found. + """ + if self._request_handlers is None: + return + + # Collect all request types from executor @response_handler annotations + known_request_types: set[type] = set() + for executor in executors.values(): + if hasattr(executor, "_response_handlers"): + for request_type, _ in executor._response_handlers: # type: ignore + known_request_types.add(request_type) + + # Check each handler key against known request types + for handler_type in self._request_handlers: + if handler_type not in known_request_types: + known_names = sorted(t.__name__ for t in known_request_types) if known_request_types else ["(none)"] + raise ValueError( + f"request_handlers key {handler_type.__name__} does not match any " + f"@response_handler request type declared in workflow executors. " + f"Known request types: {known_names}" + ) + def build(self) -> Workflow: """Build and return the constructed workflow. @@ -1100,6 +1134,10 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: output_executors, ) + # Validate request_handlers against executor @response_handler annotations + if self._request_handlers is not None: + self._validate_request_handlers(executors) + # Add validation completed event span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED) @@ -1115,6 +1153,7 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: name=self._name, description=self._description, output_executors=output_executors, + request_handlers=self._request_handlers, ) build_attributes: dict[str, Any] = { OtelAttr.WORKFLOW_ID: workflow.id, diff --git a/python/packages/core/tests/workflow/test_workflow_response_handlers.py b/python/packages/core/tests/workflow/test_workflow_response_handlers.py index 1da3dd6b8a..6babffdca2 100644 --- a/python/packages/core/tests/workflow/test_workflow_response_handlers.py +++ b/python/packages/core/tests/workflow/test_workflow_response_handlers.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. """ -Tests for workflow response_handlers parameter. +Tests for workflow request_handlers parameter. Verifies automatic HITL request handling: type-based dispatch, concurrent execution, error handling, parameter validation, and inline handler dispatch. @@ -193,7 +193,7 @@ async def handle_review(request: ReviewRequest) -> str: result = await workflow.run( "test_msg", - response_handlers={ReviewRequest: handle_review}, + request_handlers={ReviewRequest: handle_review}, ) assert reviewer.feedback_received is True @@ -235,7 +235,7 @@ async def handle_approval(request: ApprovalRequest) -> str: result = await workflow.run( "multi_test", - response_handlers={ + request_handlers={ ReviewRequest: handle_review, ApprovalRequest: handle_approval, }, @@ -257,7 +257,7 @@ async def handle_returning_none(request: ReviewRequest) -> None: result = await workflow.run( "test_none", - response_handlers={ReviewRequest: handle_returning_none}, + request_handlers={ReviewRequest: handle_returning_none}, ) # The handler matched and returned None — that None should be submitted @@ -276,7 +276,7 @@ async def failing_handler(request: ReviewRequest) -> str: with caplog.at_level(logging.ERROR): result = await workflow.run( "fail_test", - response_handlers={ReviewRequest: failing_handler}, + request_handlers={ReviewRequest: failing_handler}, ) assert reviewer.feedback_received is False @@ -296,7 +296,7 @@ async def handle_review(request: ReviewRequest) -> str: with caplog.at_level(logging.WARNING): result = await workflow.run( "unknown_test", - response_handlers={ReviewRequest: handle_review}, + request_handlers={ReviewRequest: handle_review}, ) assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS @@ -304,7 +304,7 @@ async def handle_review(request: ReviewRequest) -> str: assert "ReviewRequest" in caplog.text async def test_stream_true_with_handlers_works(self): - """stream=True + response_handlers works with inline dispatch.""" + """stream=True + request_handlers works with inline dispatch.""" reviewer = ReviewerExecutor() collector = CollectorExecutor() workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() @@ -316,7 +316,7 @@ async def handle_review(request: ReviewRequest) -> str: stream = workflow.run( "stream_test", stream=True, - response_handlers={ReviewRequest: handle_review}, + request_handlers={ReviewRequest: handle_review}, ) async for event in stream: events.append(event) @@ -328,17 +328,17 @@ async def handle_review(request: ReviewRequest) -> str: assert any("review_done:stream_lgtm" in msg for msg in collector.collected) async def test_responses_with_handlers_raises(self): - """response_handlers + responses raises ValueError immediately.""" + """request_handlers + responses raises ValueError immediately.""" reviewer = ReviewerExecutor() workflow = WorkflowBuilder(start_executor=reviewer).build() async def handle_review(request: ReviewRequest) -> str: return "x" - with pytest.raises(ValueError, match="response_handlers.*responses"): + with pytest.raises(ValueError, match="request_handlers.*responses"): await workflow.run( responses={"some_id": "value"}, - response_handlers={ReviewRequest: handle_review}, + request_handlers={ReviewRequest: handle_review}, ) async def test_no_request_info_handlers_unused(self): @@ -352,18 +352,18 @@ async def handle_review(request: ReviewRequest) -> str: result = await workflow.run( "hello", - response_handlers={ReviewRequest: handle_review}, + request_handlers={ReviewRequest: handle_review}, ) assert result.get_final_state() == WorkflowRunState.IDLE assert len(result.get_request_info_events()) == 0 async def test_none_handlers_preserves_behavior(self): - """response_handlers=None preserves original IDLE_WITH_PENDING_REQUESTS state.""" + """request_handlers=None preserves original IDLE_WITH_PENDING_REQUESTS state.""" reviewer = ReviewerExecutor() workflow = WorkflowBuilder(start_executor=reviewer).build() - result = await workflow.run("test_default", response_handlers=None) + result = await workflow.run("test_default", request_handlers=None) assert reviewer.feedback_received is False assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS @@ -400,7 +400,7 @@ async def slow_approval(request: ApprovalRequest) -> str: start = time.monotonic() await workflow.run( "concurrent_test", - response_handlers={ + request_handlers={ ReviewRequest: slow_review, ApprovalRequest: slow_approval, }, @@ -422,7 +422,7 @@ async def handle_review(request: ReviewRequest) -> str: result = await workflow.run( "merge_test", - response_handlers={ReviewRequest: handle_review}, + request_handlers={ReviewRequest: handle_review}, ) # Should have produced request_info events @@ -507,7 +507,7 @@ async def handle_review(request: ReviewRequest) -> str: result = await workflow.run( "fan_out_test", - response_handlers={ReviewRequest: handle_review}, + request_handlers={ReviewRequest: handle_review}, ) # Looping branch completed its iterations @@ -519,3 +519,126 @@ async def handle_review(request: ReviewRequest) -> str: # Final state should be IDLE (all branches complete, all requests handled) assert result.get_final_state() == WorkflowRunState.IDLE + + +class TestBuilderLevelResponseHandlers: + """Tests for request_handlers passed via WorkflowBuilder constructor.""" + + async def test_builder_handlers_used_as_default(self): + """Builder-level request_handlers are used when run() has none.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + + async def handle_review(request: ReviewRequest) -> str: + return "builder_lgtm" + + workflow = ( + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ReviewRequest: handle_review}, + ) + .add_edge(reviewer, collector) + .build() + ) + + result = await workflow.run("test_builder") + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "builder_lgtm" + assert result.get_final_state() == WorkflowRunState.IDLE + assert any("review_done:builder_lgtm" in msg for msg in collector.collected) + + async def test_run_level_overrides_builder(self): + """run(request_handlers=...) overrides builder-level handlers.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + + async def builder_handler(request: ReviewRequest) -> str: + return "from_builder" + + async def run_handler(request: ReviewRequest) -> str: + return "from_run" + + workflow = ( + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ReviewRequest: builder_handler}, + ) + .add_edge(reviewer, collector) + .build() + ) + + result = await workflow.run( + "test_override", + request_handlers={ReviewRequest: run_handler}, + ) + + assert reviewer.feedback_value == "from_run" + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_validation_rejects_unknown_type(self): + """Builder rejects request_handlers with types not in any @response_handler.""" + reviewer = ReviewerExecutor() + + async def handle_unknown(request: UnknownRequest) -> str: + return "x" + + with pytest.raises(ValueError, match="UnknownRequest.*does not match"): + WorkflowBuilder( + start_executor=reviewer, + request_handlers={UnknownRequest: handle_unknown}, + ).build() + + async def test_validation_passes_for_known_types(self): + """Builder accepts request_handlers matching executor @response_handler types.""" + reviewer = ReviewerExecutor() + approver = ApproverExecutor() + + async def handle_review(request: ReviewRequest) -> str: + return "ok" + + async def handle_approval(request: ApprovalRequest) -> str: + return "approved" + + # Should not raise — both types are declared via @response_handler on executors + class Dispatcher(Executor): + def __init__(self): + super().__init__("dispatcher") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + await ctx.send_message(message) + + dispatcher = Dispatcher() + workflow = ( + WorkflowBuilder( + start_executor=dispatcher, + request_handlers={ + ReviewRequest: handle_review, + ApprovalRequest: handle_approval, + }, + ) + .add_edge(dispatcher, reviewer) + .add_edge(dispatcher, approver) + .build() + ) + assert workflow is not None + + async def test_validation_partial_mismatch_rejected(self): + """Builder rejects if any key is unknown, even if others are valid.""" + reviewer = ReviewerExecutor() + + async def handle_review(request: ReviewRequest) -> str: + return "ok" + + async def handle_unknown(request: UnknownRequest) -> str: + return "x" + + with pytest.raises(ValueError, match="UnknownRequest.*does not match"): + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ + ReviewRequest: handle_review, + UnknownRequest: handle_unknown, + }, + ).build() diff --git a/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py index babb2d749a..bbade2d453 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py @@ -1,23 +1,23 @@ # Copyright (c) Microsoft. All rights reserved. """ -Sample: Fan-out Async + HITL with response_handlers parameter +Sample: Fan-out Async + HITL with request_handlers parameter Demonstrates automatic HITL request handling in fan-out workflows using the -response_handlers parameter. Handlers are dispatched inline as asyncio tasks +request_handlers parameter. Handlers are dispatched inline as asyncio tasks when request_info events are emitted during execution. The runner waits for outstanding handler tasks before declaring convergence, so handler responses are processed in subsequent supersteps within the same workflow run. Usage: - response_handlers = { + request_handlers = { ReviewRequest: handle_review, ApprovalRequest: handle_approval, } result = await workflow.run( initial_data, - response_handlers=response_handlers, + request_handlers=request_handlers, ) """ @@ -176,7 +176,7 @@ async def aggregate( # ============================================================================ # Response Handlers (External) # ============================================================================ -# These are registered via response_handlers dict, not as executor methods. +# These are registered via request_handlers dict, not as executor methods. # They are dispatched inline as asyncio tasks when request_info events are emitted. @@ -205,7 +205,7 @@ async def handle_review(request: ReviewRequest) -> str: async def main() -> None: print("=" * 80) - print("HITL: response_handlers parameter in workflow.run()") + print("HITL: request_handlers parameter in workflow.run()") print("=" * 80) # Create executors @@ -234,7 +234,7 @@ async def main() -> None: # THE CLEAN API # ======================================================================== # Define external response handlers (type-based dispatch) - response_handlers = { + request_handlers = { ReviewRequest: handle_review, # Can add more: ApprovalRequest: handle_approval, etc. } @@ -249,7 +249,7 @@ async def main() -> None: # Responses are injected back and processed in subsequent supersteps result = await workflow.run( initial_data, - response_handlers=response_handlers, + request_handlers=request_handlers, ) elapsed = time.monotonic() - _start_time From fbd500bcae90e186c9c93558506db0ed7b562fd4 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 11 Feb 2026 18:12:35 +0900 Subject: [PATCH 4/7] Rename sample --- python/samples/getting_started/workflows/README.md | 2 +- ...ponse_handlers.py => fan_out_async_with_request_handlers.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename python/samples/getting_started/workflows/human-in-the-loop/{fan_out_async_with_response_handlers.py => fan_out_async_with_request_handlers.py} (100%) diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index f27c058189..6b9004c105 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -88,7 +88,7 @@ Once comfortable with these, explore the rest of the samples below. | SequentialBuilder Request Info | [human-in-the-loop/sequential_request_info.py](./human-in-the-loop/sequential_request_info.py) | Request info for agent responses mid-workflow using `.with_request_info()` on SequentialBuilder | | ConcurrentBuilder Request Info | [human-in-the-loop/concurrent_request_info.py](./human-in-the-loop/concurrent_request_info.py) | Review concurrent agent outputs before aggregation using `.with_request_info()` on ConcurrentBuilder | | GroupChatBuilder Request Info | [human-in-the-loop/group_chat_request_info.py](./human-in-the-loop/group_chat_request_info.py) | Steer group discussions with periodic guidance using `.with_request_info()` on GroupChatBuilder | -| Fan-Out Async with Response Handlers | [human-in-the-loop/fan_out_async_with_response_handlers.py](./human-in-the-loop/fan_out_async_with_response_handlers.py) | Type-based response handler dispatch for HITL requests in fan-out workflows; automatic async handler execution and response submission | +| Fan-Out Async with Request Handlers | [human-in-the-loop/fan_out_async_with_request_handlers.py](./human-in-the-loop/fan_out_async_with_request_handlers.py) | Type-based request handler dispatch for HITL requests in fan-out workflows; automatic async handler execution and response submission | ### tool-approval diff --git a/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_request_handlers.py similarity index 100% rename from python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_response_handlers.py rename to python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_request_handlers.py From 8b1c9ad062e2d66c4f50f2b7474a64a9194e6566 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 11 Feb 2026 18:57:44 +0900 Subject: [PATCH 5/7] Address copilot feedback --- .../agent_framework/_workflows/_workflow.py | 10 ++++-- .../_workflows/_workflow_builder.py | 24 ++++++++++--- .../test_workflow_response_handlers.py | 34 +++++++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index e3501799eb..304891c21a 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -528,8 +528,14 @@ def run( Raises: ValueError: If parameter combination is invalid. """ - # Fall back to builder-level request_handlers if none provided at runtime - effective_handlers = request_handlers if request_handlers is not None else self._request_handlers + # Fall back to builder-level request_handlers if none provided at runtime. + # Skip fallback when responses are provided — manual submission bypasses automatic handling. + if request_handlers is not None: + effective_handlers = request_handlers + elif responses is not None: + effective_handlers = None + else: + effective_handlers = self._request_handlers # Validate parameters eagerly (before any async work or setting running flag) self._validate_run_params(message, responses, checkpoint_id, effective_handlers) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index f9179cb3ec..eb68c8329e 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -586,29 +586,43 @@ def _set_start_executor(self, executor: Executor | SupportsAgentRun) -> None: if existing is not wrapped: self._add_executor(wrapped) + @staticmethod + def _type_name(t: type) -> str: + """Return a human-readable name for a type, handling UnionType safely.""" + return getattr(t, "__name__", str(t)) + def _validate_request_handlers(self, executors: dict[str, Executor]) -> None: """Validate that request_handlers keys match executor @response_handler request types. Collects all request types declared via @response_handler decorators across all executors, then checks that each key in self._request_handlers maps to a known - request type. Raises ValueError if an unrecognized type is found. + request type. Union types in executor annotations are expanded into their + constituent types. Raises ValueError if an unrecognized type is found. """ if self._request_handlers is None: return - # Collect all request types from executor @response_handler annotations + # Collect all request types from executor @response_handler annotations. + # Expand union types (e.g. str | int) into their constituent types so that + # request_handlers keys can match individual members. known_request_types: set[type] = set() for executor in executors.values(): if hasattr(executor, "_response_handlers"): for request_type, _ in executor._response_handlers: # type: ignore - known_request_types.add(request_type) + if hasattr(request_type, "__args__"): + # UnionType — expand into constituent types + known_request_types.update(request_type.__args__) + else: + known_request_types.add(request_type) # Check each handler key against known request types for handler_type in self._request_handlers: if handler_type not in known_request_types: - known_names = sorted(t.__name__ for t in known_request_types) if known_request_types else ["(none)"] + known_names = ( + sorted(self._type_name(t) for t in known_request_types) if known_request_types else ["(none)"] + ) raise ValueError( - f"request_handlers key {handler_type.__name__} does not match any " + f"request_handlers key {self._type_name(handler_type)} does not match any " f"@response_handler request type declared in workflow executors. " f"Known request types: {known_names}" ) diff --git a/python/packages/core/tests/workflow/test_workflow_response_handlers.py b/python/packages/core/tests/workflow/test_workflow_response_handlers.py index 6babffdca2..f6a270fafe 100644 --- a/python/packages/core/tests/workflow/test_workflow_response_handlers.py +++ b/python/packages/core/tests/workflow/test_workflow_response_handlers.py @@ -642,3 +642,37 @@ async def handle_unknown(request: UnknownRequest) -> str: UnknownRequest: handle_unknown, }, ).build() + + async def test_responses_bypass_builder_handlers(self): + """run(responses=...) works even when builder has default request_handlers. + + Regression: without the fix, builder-level handlers would be treated as + effective_handlers, causing _validate_run_params to raise ValueError + ('Cannot provide both request_handlers and responses'). + """ + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + + async def failing_handler(request: ReviewRequest) -> str: + raise RuntimeError("intentional failure") + + # Build with handlers that always fail — first run ends with pending request + workflow = ( + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ReviewRequest: failing_handler}, + ) + .add_edge(reviewer, collector) + .build() + ) + + result1 = await workflow.run("first") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + request_events = result1.get_request_info_events() + assert len(request_events) == 1 + + # Manual response submission should NOT conflict with builder-level handlers + result2 = await workflow.run(responses={request_events[0].request_id: "manual_response"}) + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "manual_response" + assert result2.get_final_state() == WorkflowRunState.IDLE From 8966dfcbdb9fcab7bb9286267af5801971968cb9 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 12 Feb 2026 07:13:09 +0900 Subject: [PATCH 6/7] restores the pending request if response submission fails --- .../_workflows/_runner_context.py | 41 +++++++++++-------- .../test_workflow_response_handlers.py | 29 +++++++++++++ 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index db6558306a..caa6c21aa2 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -494,25 +494,32 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No if not event: raise ValueError(f"No pending request found for request_id: {request_id}") - # Validate response type if specified - if event.response_type and not is_instance_of(response, event.response_type): - raise TypeError( - f"Response type mismatch for request_id {request_id}: " - f"expected {event.response_type.__name__}, got {type(response).__name__}" + try: + # Validate response type if specified + if event.response_type and not is_instance_of(response, event.response_type): + expected_type_name = getattr(event.response_type, "__name__", str(event.response_type)) + raise TypeError( + f"Response type mismatch for request_id {request_id}: " + f"expected {expected_type_name}, got {type(response).__name__}" + ) + + source_executor_id = event.source_executor_id + + # Create ResponseMessage instance + response_msg = WorkflowMessage( + data=response, + source_id=INTERNAL_SOURCE_ID(source_executor_id), + target_id=source_executor_id, + type=MessageType.RESPONSE, + original_request_info_event=event, ) - source_executor_id = event.source_executor_id - - # Create ResponseMessage instance - response_msg = WorkflowMessage( - data=response, - source_id=INTERNAL_SOURCE_ID(source_executor_id), - target_id=source_executor_id, - type=MessageType.RESPONSE, - original_request_info_event=event, - ) - - await self.send_message(response_msg) + await self.send_message(response_msg) + except BaseException: + # Restore pending request so callers can retry (for example, type mismatch + # from an async request handler or cancellation while submitting). + self._pending_request_info_events[request_id] = event + raise async def get_pending_request_info_events(self) -> dict[str, WorkflowEvent[Any]]: """Get the mapping of request IDs to their corresponding request_info events. diff --git a/python/packages/core/tests/workflow/test_workflow_response_handlers.py b/python/packages/core/tests/workflow/test_workflow_response_handlers.py index f6a270fafe..90250fe5f1 100644 --- a/python/packages/core/tests/workflow/test_workflow_response_handlers.py +++ b/python/packages/core/tests/workflow/test_workflow_response_handlers.py @@ -284,6 +284,35 @@ async def failing_handler(request: ReviewRequest) -> str: assert "handler boom" in caplog.text assert "ReviewRequest" in caplog.text + async def test_handler_type_mismatch_leaves_request_pending(self, caplog): + """A handler returning wrong type should not drop the pending request.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + async def wrong_type_handler(request: ReviewRequest) -> int: + return 123 + + with caplog.at_level(logging.ERROR): + result1 = await workflow.run( + "bad_type_test", + request_handlers={ReviewRequest: wrong_type_handler}, + ) + + assert reviewer.feedback_received is False + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert "Response type mismatch" in caplog.text + + # Pending request should still be recoverable via manual response submission. + request_events = result1.get_request_info_events() + assert len(request_events) == 1 + + result2 = await workflow.run( + responses={request_events[0].request_id: "manual_retry"}, + ) + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "manual_retry" + assert result2.get_final_state() == WorkflowRunState.IDLE + async def test_unmatched_type_logs_warning(self, caplog): """A request type with no matching handler logs a warning and stays pending.""" unknown_exec = UnknownRequestExecutor() From f016b2eb2009af3386219a3f34ab3ea55df2bb69 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 12 Feb 2026 07:55:39 +0900 Subject: [PATCH 7/7] PR feedback --- .../agent_framework/_workflows/_runner.py | 111 ++++++++++++++++-- .../agent_framework/_workflows/_workflow.py | 95 +++------------ .../core/tests/workflow/test_runner.py | 59 +++++++++- 3 files changed, 174 insertions(+), 91 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 316e7630b6..ce9353e4ac 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -4,8 +4,8 @@ import contextlib import logging from collections import defaultdict -from collections.abc import AsyncGenerator, Sequence -from typing import Any +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping, Sequence +from typing import Any, cast from ._checkpoint import CheckpointStorage, WorkflowCheckpoint from ._checkpoint_encoding import ( @@ -31,6 +31,9 @@ logger = logging.getLogger(__name__) +RequestHandler = Callable[[Any], Awaitable[Any]] +RequestHandlerMap = Mapping[type[Any], RequestHandler] + class Runner: """A class to run a workflow in Pregel supersteps.""" @@ -79,20 +82,95 @@ def reset_iteration_count(self) -> None: """Reset the iteration count to zero.""" self._iteration = 0 + @staticmethod + def _type_name(t: type[Any]) -> str: + """Return a human-readable type name, handling union/generic aliases.""" + return getattr(t, "__name__", str(t)) + + async def _invoke_request_handler( + self, + handler_fn: RequestHandler, + request_id: str, + request_data: Any, + ) -> None: + """Invoke an async request handler and submit the response to the context.""" + try: + response = await handler_fn(request_data) + except Exception: + logger.exception(f"Response handler failed for request {request_id} (type={type(request_data).__name__})") + return + + try: + await self._ctx.send_request_info_response(request_id, response) + except Exception: + logger.exception( + f"Failed to submit response for request {request_id} " + f"(type={type(request_data).__name__}). " + f"Handler succeeded but response could not be delivered." + ) + + def _dispatch_request_handler( + self, + event: WorkflowEvent[Any], + request_handlers: RequestHandlerMap | None, + outstanding_handler_tasks: set[asyncio.Task[None]] | None, + ) -> None: + """Schedule a request handler task for request_info events when configured.""" + if event.type != "request_info" or request_handlers is None or outstanding_handler_tasks is None: + return + + request_data: Any = event.data + request_data_type: type[Any] = cast(type[Any], type(request_data)) + matched_handler = request_handlers.get(request_data_type) + if matched_handler is None: + registered_types = [self._type_name(t) for t in request_handlers] + logger.warning( + f"No response handler registered for request type " + f"{self._type_name(request_data_type)} (request_id={event.request_id}). " + f"Registered types: {registered_types}" + ) + return + + task = asyncio.create_task(self._invoke_request_handler(matched_handler, event.request_id, request_data)) + outstanding_handler_tasks.add(task) + task.add_done_callback(outstanding_handler_tasks.discard) + + async def _cancel_outstanding_handler_tasks( + self, + outstanding_handler_tasks: set[asyncio.Task[None]] | None, + ) -> None: + """Cancel and await outstanding handler tasks.""" + if not outstanding_handler_tasks: + return + + for task in outstanding_handler_tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*outstanding_handler_tasks, return_exceptions=True) + async def run_until_convergence( self, - outstanding_handler_tasks: set[asyncio.Task[None]] | None = None, + request_handlers: RequestHandlerMap | None = None, ) -> AsyncGenerator[WorkflowEvent, None]: - """Run the workflow until no more messages are sent.""" + """Run the workflow until no more messages are sent. + + Args: + request_handlers: Optional mapping of request data types to async handler + functions. When provided, request_info events are dispatched inline and + handler responses are submitted back through the runner context. + """ if self._running: raise WorkflowRunnerException("Runner is already running.") + outstanding_handler_tasks: set[asyncio.Task[None]] | None = set() if request_handlers else None + self._running = True try: # Emit any events already produced prior to entering loop if await self._ctx.has_events(): logger.info("Yielding pre-loop events") for event in await self._ctx.drain_events(): + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event # Create first checkpoint if there are messages from initial execution @@ -115,6 +193,7 @@ async def run_until_convergence( try: # Wait briefly for any new event; timeout allows progress checks event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event except asyncio.TimeoutError: # Periodically continue to let iteration advance @@ -124,12 +203,6 @@ async def run_until_convergence( iteration_task.cancel() with contextlib.suppress(asyncio.CancelledError): await iteration_task - if outstanding_handler_tasks: - for t in outstanding_handler_tasks: - if not t.done(): - t.cancel() - with contextlib.suppress(asyncio.CancelledError): - await asyncio.gather(*outstanding_handler_tasks, return_exceptions=True) raise # Propagate errors from iteration, but first surface any pending events @@ -139,6 +212,7 @@ async def run_until_convergence( # Make sure failure-related events (like ExecutorFailedEvent) are surfaced if await self._ctx.has_events(): for event in await self._ctx.drain_events(): + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event raise self._iteration += 1 @@ -146,6 +220,7 @@ async def run_until_convergence( # Drain any straggler events emitted at tail end if await self._ctx.has_events(): for event in await self._ctx.drain_events(): + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event logger.info(f"Completed superstep {self._iteration}") @@ -160,10 +235,19 @@ async def run_until_convergence( # Check for convergence: no more messages to process if not await self._ctx.has_messages(): - if outstanding_handler_tasks: - still_running = {t for t in outstanding_handler_tasks if not t.done()} + if outstanding_handler_tasks is not None: + # Wait for in-flight request handlers before declaring idle. + # A completed handler may enqueue a response message that should + # be processed in another superstep within this same run. + still_running: set[asyncio.Task[None]] = { + task for task in outstanding_handler_tasks if not task.done() + } while still_running: - _, still_running = await asyncio.wait(still_running, return_when=asyncio.FIRST_COMPLETED) + wait_result: tuple[set[asyncio.Task[None]], set[asyncio.Task[None]]] = await asyncio.wait( + still_running, + return_when=asyncio.FIRST_COMPLETED, + ) + still_running = wait_result[1] if await self._ctx.has_messages(): break if await self._ctx.has_messages(): @@ -176,6 +260,7 @@ async def run_until_convergence( logger.info(f"Workflow completed after {self._iteration} supersteps") self._resumed_from_checkpoint = False # Reset resume flag for next run finally: + await self._cancel_outstanding_handler_tasks(outstanding_handler_tasks) self._running = False async def _run_iteration(self) -> None: diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 304891c21a..944f5d8548 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -6,7 +6,6 @@ import functools import hashlib import json -import logging import types import uuid from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence @@ -34,8 +33,6 @@ from ._state import State from ._typing_utils import is_instance_of, try_coerce_to_type -logger = logging.getLogger(__name__) - class WorkflowRunResult(list[WorkflowEvent]): """Container for events generated during non-streaming workflow execution. @@ -303,7 +300,7 @@ async def _run_workflow_with_tracing( reset_context: bool = True, streaming: bool = False, run_kwargs: dict[str, Any] | None = None, - outstanding_handler_tasks: set[asyncio.Task[None]] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -315,7 +312,7 @@ async def _run_workflow_with_tracing( reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents run_kwargs: Optional kwargs to store in State for agent invocations - outstanding_handler_tasks: Optional set to track outstanding response handler tasks + request_handlers: Optional request handlers for automatic request_info responses. Yields: WorkflowEvent: The events generated during the workflow execution. @@ -363,7 +360,7 @@ async def _run_workflow_with_tracing( # All executor executions happen within workflow span async for event in self._runner.run_until_convergence( - outstanding_handler_tasks=outstanding_handler_tasks, + request_handlers=request_handlers, ): yield event @@ -561,34 +558,6 @@ def run( return response_stream return response_stream.get_final_response() - async def _invoke_request_handler( - self, - handler_fn: Callable[[Any], Awaitable[Any]], - request_id: str, - request_data: Any, - ) -> None: - """Invoke a response handler and submit the result back to the workflow. - - On success, the response is submitted via send_request_info_response which - pops the pending request and injects a response message for the next superstep. - On failure, the request stays pending and the workflow converges as - IDLE_WITH_PENDING_REQUESTS. - """ - try: - response = await handler_fn(request_data) - except Exception: - logger.exception(f"Response handler failed for request {request_id} (type={type(request_data).__name__})") - return - - try: - await self._runner_context.send_request_info_response(request_id, response) - except Exception: - logger.exception( - f"Failed to submit response for request {request_id} " - f"(type={type(request_data).__name__}). " - f"Handler succeeded but response could not be delivered." - ) - async def _run_core( self, message: Any | None = None, @@ -602,10 +571,11 @@ async def _run_core( ) -> AsyncIterable[WorkflowEvent]: """Single core execution path for both streaming and non-streaming modes. - When request_handlers are provided, handlers are dispatched inline as - asyncio tasks when request_info events are emitted. The runner's convergence - check waits for outstanding handler tasks before deciding the workflow is idle, - allowing handler responses to be processed in subsequent supersteps. + When request_handlers are provided, the runner dispatches handlers inline + as asyncio tasks when request_info events are emitted. The runner's + convergence check waits for outstanding handler tasks before deciding the + workflow is idle, allowing handler responses to be processed in subsequent + supersteps. Yields: WorkflowEvent: The events generated during the workflow execution. @@ -618,45 +588,16 @@ async def _run_core( message, responses, checkpoint_id, checkpoint_storage ) - outstanding_tasks: set[asyncio.Task[None]] | None = set() if request_handlers else None - - try: - async for event in self._run_workflow_with_tracing( - initial_executor_fn=initial_executor_fn, - reset_context=reset_context, - streaming=streaming, - run_kwargs=kwargs if kwargs else None, - outstanding_handler_tasks=outstanding_tasks, - ): - if event.type == "output" and not self._should_yield_output_event(event): - continue - - # Dispatch response handlers inline as tasks - if event.type == "request_info" and request_handlers is not None and outstanding_tasks is not None: - request_data_type = type(event.data) - matched_handler = request_handlers.get(request_data_type) - if matched_handler: - task = asyncio.create_task( - self._invoke_request_handler(matched_handler, event.request_id, event.data) - ) - outstanding_tasks.add(task) - task.add_done_callback(outstanding_tasks.discard) - else: - registered_types = [t.__name__ for t in request_handlers] - logger.warning( - f"No response handler registered for request type " - f"{request_data_type.__name__} (request_id={event.request_id}). " - f"Registered types: {registered_types}" - ) - - yield event - finally: - # Cancel any outstanding handler tasks on error/cancellation - if outstanding_tasks: - for t in outstanding_tasks: - if not t.done(): - t.cancel() - await asyncio.gather(*outstanding_tasks, return_exceptions=True) + async for event in self._run_workflow_with_tracing( + initial_executor_fn=initial_executor_fn, + reset_context=reset_context, + streaming=streaming, + run_kwargs=kwargs if kwargs else None, + request_handlers=request_handlers, + ): + if event.type == "output" and not self._should_yield_output_event(event): + continue + yield event async def _run_cleanup(self, checkpoint_storage: CheckpointStorage | None) -> None: """Cleanup hook called after stream consumption.""" diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index e527ba13fa..8c5dcb6b6a 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -15,8 +15,9 @@ WorkflowRunnerException, WorkflowRunState, handler, + response_handler, ) -from agent_framework._workflows._edge import SingleEdgeGroup +from agent_framework._workflows._edge import InternalEdgeGroup, SingleEdgeGroup from agent_framework._workflows._runner import Runner from agent_framework._workflows._runner_context import ( InProcRunnerContext, @@ -33,6 +34,13 @@ class MockMessage: data: int +@dataclass +class MockRequest: + """A mock request_info payload.""" + + value: str + + class MockExecutor(Executor): """A mock executor for testing purposes.""" @@ -211,6 +219,28 @@ async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, i await ctx.yield_output(message.data) +class RequestInfoExecutor(Executor): + """Executor that requests external input and yields it once received.""" + + def __init__(self, id: str = "request_info_executor"): + super().__init__(id=id) + self.received_response: str | None = None + + @handler + async def handle(self, message: MockMessage, ctx: WorkflowContext[int]) -> None: + await ctx.request_info(MockRequest(value=f"request:{message.data}"), str) + + @response_handler + async def on_response( + self, + original_request: MockRequest, + response: str, + ctx: WorkflowContext[int, int], + ) -> None: + self.received_response = response + await ctx.yield_output(len(response)) + + async def test_runner_cancellation_stops_active_executor(): """Test that cancelling a workflow properly cancels the active executor.""" executor_a = SlowExecutor(id="executor_a", work_duration=0.3) @@ -259,3 +289,30 @@ async def run_workflow(): assert executor_a.completed_count == 1 assert executor_b.started_count == 1 assert executor_b.completed_count == 0 # Should NOT have completed due to cancellation + + +async def test_runner_dispatches_request_handlers_inline(): + """Runner should dispatch request_handlers and process responses in the same run.""" + executor = RequestInfoExecutor() + state = State() + ctx = InProcRunnerContext() + runner = Runner([InternalEdgeGroup(executor.id)], {executor.id: executor}, state, ctx) + + await executor.execute( + MockMessage(data=1), + ["START"], + state, + ctx, + ) + + async def handle_request(request: MockRequest) -> str: + await asyncio.sleep(0.05) + return "approved" + + result: int | None = None + async for event in runner.run_until_convergence(request_handlers={MockRequest: handle_request}): + if event.type == "output": + result = event.data + + assert executor.received_response == "approved" + assert result == len("approved")