diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 3958957596..63f8373244 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -138,3 +138,7 @@ class MyClient(BaseChatClient): async def _inner_get_streaming_response(self, *, messages, options, **kwargs): yield ChatResponseUpdate(...) ``` + +### Running tests + +Always use the virtual environment defined, and run with `uv run pytest ...` \ No newline at end of file diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 88281597a2..04aea23e5c 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -4,8 +4,8 @@ import contextlib import logging from collections import defaultdict -from collections.abc import AsyncGenerator, Sequence -from typing import Any +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping, Sequence +from typing import Any, cast from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint from ._const import EXECUTOR_STATE_KEY @@ -26,6 +26,9 @@ logger = logging.getLogger(__name__) +RequestHandler = Callable[[Any], Awaitable[Any]] +RequestHandlerMap = Mapping[type[Any], RequestHandler] + class Runner: """A class to run a workflow in Pregel supersteps.""" @@ -75,11 +78,88 @@ def reset_iteration_count(self) -> None: """Reset the iteration count to zero.""" self._iteration = 0 - async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: - """Run the workflow until no more messages are sent.""" + @staticmethod + def _type_name(t: type[Any]) -> str: + """Return a human-readable type name, handling union/generic aliases.""" + return getattr(t, "__name__", str(t)) + + async def _invoke_request_handler( + self, + handler_fn: RequestHandler, + request_id: str, + request_data: Any, + ) -> None: + """Invoke an async request handler and submit the response to the context.""" + try: + response = await handler_fn(request_data) + except Exception: + logger.exception(f"Response handler failed for request {request_id} (type={type(request_data).__name__})") + return + + try: + await self._ctx.send_request_info_response(request_id, response) + except Exception: + logger.exception( + f"Failed to submit response for request {request_id} " + f"(type={type(request_data).__name__}). " + f"Handler succeeded but response could not be delivered." + ) + + def _dispatch_request_handler( + self, + event: WorkflowEvent[Any], + request_handlers: RequestHandlerMap | None, + outstanding_handler_tasks: set[asyncio.Task[None]] | None, + ) -> None: + """Schedule a request handler task for request_info events when configured.""" + if event.type != "request_info" or request_handlers is None or outstanding_handler_tasks is None: + return + + request_data: Any = event.data + request_data_type: type[Any] = cast(type[Any], type(request_data)) + matched_handler = request_handlers.get(request_data_type) + if matched_handler is None: + registered_types = [self._type_name(t) for t in request_handlers] + logger.warning( + f"No response handler registered for request type " + f"{self._type_name(request_data_type)} (request_id={event.request_id}). " + f"Registered types: {registered_types}" + ) + return + + task = asyncio.create_task(self._invoke_request_handler(matched_handler, event.request_id, request_data)) + outstanding_handler_tasks.add(task) + task.add_done_callback(outstanding_handler_tasks.discard) + + async def _cancel_outstanding_handler_tasks( + self, + outstanding_handler_tasks: set[asyncio.Task[None]] | None, + ) -> None: + """Cancel and await outstanding handler tasks.""" + if not outstanding_handler_tasks: + return + + for task in outstanding_handler_tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*outstanding_handler_tasks, return_exceptions=True) + + async def run_until_convergence( + self, + request_handlers: RequestHandlerMap | None = None, + ) -> AsyncGenerator[WorkflowEvent, None]: + """Run the workflow until no more messages are sent. + + Args: + request_handlers: Optional mapping of request data types to async handler + functions. When provided, request_info events are dispatched inline and + handler responses are submitted back through the runner context. + """ if self._running: raise WorkflowRunnerException("Runner is already running.") + outstanding_handler_tasks: set[asyncio.Task[None]] | None = set() if request_handlers else None + self._running = True previous_checkpoint_id: CheckpointID | None = None try: @@ -87,6 +167,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: if await self._ctx.has_events(): logger.info("Yielding pre-loop events") for event in await self._ctx.drain_events(): + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event # Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration, @@ -108,6 +189,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: try: # Wait briefly for any new event; timeout allows progress checks event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event except asyncio.TimeoutError: # Periodically continue to let iteration advance @@ -126,6 +208,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: # Make sure failure-related events (like ExecutorFailedEvent) are surfaced if await self._ctx.has_events(): for event in await self._ctx.drain_events(): + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event raise self._iteration += 1 @@ -133,6 +216,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: # Drain any straggler events emitted at tail end if await self._ctx.has_events(): for event in await self._ctx.drain_events(): + self._dispatch_request_handler(event, request_handlers, outstanding_handler_tasks) yield event logger.info(f"Completed superstep {self._iteration}") @@ -147,6 +231,23 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: # Check for convergence: no more messages to process if not await self._ctx.has_messages(): + if outstanding_handler_tasks is not None: + # Wait for in-flight request handlers before declaring idle. + # A completed handler may enqueue a response message that should + # be processed in another superstep within this same run. + still_running: set[asyncio.Task[None]] = { + task for task in outstanding_handler_tasks if not task.done() + } + while still_running: + wait_result: tuple[set[asyncio.Task[None]], set[asyncio.Task[None]]] = await asyncio.wait( + still_running, + return_when=asyncio.FIRST_COMPLETED, + ) + still_running = wait_result[1] + if await self._ctx.has_messages(): + break + if await self._ctx.has_messages(): + continue break if self._iteration >= self._max_iterations and await self._ctx.has_messages(): @@ -155,6 +256,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: logger.info(f"Workflow completed after {self._iteration} supersteps") self._resumed_from_checkpoint = False # Reset resume flag for next run finally: + await self._cancel_outstanding_handler_tasks(outstanding_handler_tasks) self._running = False async def _run_iteration(self) -> None: diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index d52e135e91..9492f9ec12 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -453,25 +453,32 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No if not event: raise ValueError(f"No pending request found for request_id: {request_id}") - # Validate response type if specified - if event.response_type and not is_instance_of(response, event.response_type): - raise TypeError( - f"Response type mismatch for request_id {request_id}: " - f"expected {event.response_type.__name__}, got {type(response).__name__}" + try: + # Validate response type if specified + if event.response_type and not is_instance_of(response, event.response_type): + expected_type_name = getattr(event.response_type, "__name__", str(event.response_type)) + raise TypeError( + f"Response type mismatch for request_id {request_id}: " + f"expected {expected_type_name}, got {type(response).__name__}" + ) + + source_executor_id = event.source_executor_id + + # Create ResponseMessage instance + response_msg = WorkflowMessage( + data=response, + source_id=INTERNAL_SOURCE_ID(source_executor_id), + target_id=source_executor_id, + type=MessageType.RESPONSE, + original_request_info_event=event, ) - source_executor_id = event.source_executor_id - - # Create ResponseMessage instance - response_msg = WorkflowMessage( - data=response, - source_id=INTERNAL_SOURCE_ID(source_executor_id), - target_id=source_executor_id, - type=MessageType.RESPONSE, - original_request_info_event=event, - ) - - await self.send_message(response_msg) + await self.send_message(response_msg) + except BaseException: + # Restore pending request so callers can retry (for example, type mismatch + # from an async request handler or cancellation while submitting). + self._pending_request_info_events[request_id] = event + raise async def get_pending_request_info_events(self) -> dict[str, WorkflowEvent[Any]]: """Get the mapping of request IDs to their corresponding request_info events. diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index cd7dbb4a68..34ae6a69cc 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -6,10 +6,9 @@ import functools import hashlib import json -import logging import types import uuid -from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from typing import Any, Literal, overload from .._types import ResponseStream @@ -34,8 +33,6 @@ from ._state import State from ._typing_utils import is_instance_of, try_coerce_to_type -logger = logging.getLogger(__name__) - class WorkflowRunResult(list[WorkflowEvent]): """Container for events generated during non-streaming workflow execution. @@ -179,6 +176,7 @@ def __init__( description: str | None = None, max_iterations: int = DEFAULT_MAX_ITERATIONS, output_executors: list[str] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, **kwargs: Any, ): """Initialize the workflow with a list of edges. @@ -197,6 +195,8 @@ def __init__( WorkflowBuilder, this will be the description of the builder. output_executors: Optional list of executor IDs whose outputs will be considered workflow outputs. If None or empty, all executor outputs are treated as workflow outputs. + request_handlers: Optional default response handlers for automatic HITL request handling. + Can be overridden per-run via workflow.run(request_handlers=...). kwargs: Additional keyword arguments. Unused in this implementation. """ self.edge_groups = list(edge_groups) @@ -233,6 +233,9 @@ def __init__( # Flag to prevent concurrent workflow executions self._is_running = False + # Default response handlers (can be overridden per-run) + self._request_handlers = request_handlers + def _ensure_not_running(self) -> None: """Ensure the workflow is not already running.""" if self._is_running: @@ -300,6 +303,7 @@ async def _run_workflow_with_tracing( reset_context: bool = True, streaming: bool = False, run_kwargs: dict[str, Any] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -311,6 +315,7 @@ async def _run_workflow_with_tracing( reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents run_kwargs: Optional kwargs to store in State for agent invocations + request_handlers: Optional request handlers for automatic request_info responses. Yields: WorkflowEvent: The events generated during the workflow execution. @@ -326,7 +331,6 @@ async def _run_workflow_with_tracing( OtelAttr.WORKFLOW_RUN_SPAN, attributes, ) as span: - saw_request = False emitted_in_progress_pending = False try: # Add workflow started event (telemetry + surface state to consumers) @@ -358,10 +362,9 @@ async def _run_workflow_with_tracing( await initial_executor_fn() # All executor executions happen within workflow span - async for event in self._runner.run_until_convergence(): - # Track request events for final status determination - if event.type == "request_info": - saw_request = True + async for event in self._runner.run_until_convergence( + request_handlers=request_handlers, + ): yield event if event.type == "request_info" and not emitted_in_progress_pending: @@ -371,7 +374,8 @@ async def _run_workflow_with_tracing( yield pending_status # Workflow runs until idle - emit final status based on whether requests are pending - if saw_request: + pending = await self._runner_context.get_pending_request_info_events() + if pending: with _framework_event_origin(): terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS) yield terminal_status @@ -455,6 +459,7 @@ def run( *, stream: Literal[True], responses: dict[str, Any] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -467,6 +472,7 @@ def run( *, stream: Literal[False] = ..., responses: dict[str, Any] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, @@ -479,6 +485,7 @@ def run( *, stream: bool = False, responses: dict[str, Any] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, @@ -499,6 +506,13 @@ def run( request IDs and values are the corresponding response data. Mutually exclusive with message. Can be combined with checkpoint_id to restore a checkpoint and send responses in a single call. + request_handlers: Dict mapping request data types to async handler functions. + Mutually exclusive with ``responses``. Handlers are dispatched inline + as asyncio tasks when request_info events are emitted during execution. + The runner waits for outstanding handler tasks before declaring convergence, + so handler responses are processed in subsequent supersteps within the + same workflow run. If None, falls back to handlers set at build time via + WorkflowBuilder. Pass an empty dict to explicitly disable handlers for a run. checkpoint_id: ID of checkpoint to restore from. Can be used alone (resume from checkpoint), with message (not allowed), or with responses (restore then send responses). @@ -514,14 +528,24 @@ def run( Raises: ValueError: If parameter combination is invalid. """ - # Validate parameters and set running flag eagerly (before any async work) - self._validate_run_params(message, responses, checkpoint_id) + # Fall back to builder-level request_handlers if none provided at runtime. + # Skip fallback when responses are provided — manual submission bypasses automatic handling. + if request_handlers is not None: + effective_handlers = request_handlers + elif responses is not None: + effective_handlers = None + else: + effective_handlers = self._request_handlers + + # Validate parameters eagerly (before any async work or setting running flag) + self._validate_run_params(message, responses, checkpoint_id, effective_handlers) self._ensure_not_running() response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( self._run_core( message=message, responses=responses, + request_handlers=effective_handlers, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, streaming=stream, @@ -542,6 +566,7 @@ async def _run_core( message: Any | None = None, *, responses: dict[str, Any] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, streaming: bool = False, @@ -549,6 +574,12 @@ async def _run_core( ) -> AsyncIterable[WorkflowEvent]: """Single core execution path for both streaming and non-streaming modes. + When request_handlers are provided, the runner dispatches handlers inline + as asyncio tasks when request_info events are emitted. The runner's + convergence check waits for outstanding handler tasks before deciding the + workflow is idle, allowing handler responses to be processed in subsequent + supersteps. + Yields: WorkflowEvent: The events generated during the workflow execution. """ @@ -565,6 +596,7 @@ async def _run_core( reset_context=reset_context, streaming=streaming, run_kwargs=kwargs if kwargs else None, + request_handlers=request_handlers, ): if event.type == "output" and not self._should_yield_output_event(event): continue @@ -617,6 +649,7 @@ def _validate_run_params( message: Any | None, responses: dict[str, Any] | None, checkpoint_id: str | None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, ) -> None: """Validate parameter combinations for run(). @@ -625,6 +658,7 @@ def _validate_run_params( - message and checkpoint_id are mutually exclusive - At least one of message, responses, or checkpoint_id must be provided - responses + checkpoint_id is allowed (restore then send) + - request_handlers and responses are mutually exclusive """ if message is not None and responses is not None: raise ValueError("Cannot provide both 'message' and 'responses'. Use one or the other.") @@ -638,6 +672,12 @@ def _validate_run_params( "or 'checkpoint_id' (resume from checkpoint)." ) + if request_handlers is not None and responses is not None: + raise ValueError( + "Cannot provide both 'request_handlers' and 'responses'. " + "Use 'request_handlers' for automatic handling or 'responses' for manual submission." + ) + def _resolve_execution_mode( self, message: Any | None, diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 1a71b3a49b..ccbc0a0f51 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -3,7 +3,7 @@ import logging import sys import uuid -from collections.abc import Callable, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from typing import Any from .._agents import SupportsAgentRun @@ -84,6 +84,7 @@ def __init__( start_executor: Executor | SupportsAgentRun, checkpoint_storage: CheckpointStorage | None = None, output_executors: list[Executor | SupportsAgentRun] | None = None, + request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = None, ): """Initialize the WorkflowBuilder. @@ -100,6 +101,9 @@ def __init__( checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence. output_executors: Optional list of executors whose outputs should be collected. If not provided, outputs from all executors are collected. + request_handlers: Optional dict mapping request data types to async handler functions. + Keys must match request types declared in executor @response_handler decorators. + Validated at build time. Can be overridden per-run via workflow.run(request_handlers=...). """ self._edge_groups: list[EdgeGroup] = [] self._executors: dict[str, Executor] = {} @@ -116,6 +120,9 @@ def __init__( # Output executors filter; if set, only outputs from these executors are yielded self._output_executors: list[Executor | SupportsAgentRun] = output_executors if output_executors else [] + # Response handlers for automatic HITL request handling + self._request_handlers: Mapping[type, Callable[[Any], Awaitable[Any]]] | None = request_handlers + # Set the start executor self._set_start_executor(start_executor) @@ -584,6 +591,47 @@ def _set_start_executor(self, executor: Executor | SupportsAgentRun) -> None: if existing is not wrapped: self._add_executor(wrapped) + @staticmethod + def _type_name(t: type) -> str: + """Return a human-readable name for a type, handling UnionType safely.""" + return getattr(t, "__name__", str(t)) + + def _validate_request_handlers(self, executors: dict[str, Executor]) -> None: + """Validate that request_handlers keys match executor @response_handler request types. + + Collects all request types declared via @response_handler decorators across all + executors, then checks that each key in self._request_handlers maps to a known + request type. Union types in executor annotations are expanded into their + constituent types. Raises ValueError if an unrecognized type is found. + """ + if self._request_handlers is None: + return + + # Collect all request types from executor @response_handler annotations. + # Expand union types (e.g. str | int) into their constituent types so that + # request_handlers keys can match individual members. + known_request_types: set[type] = set() + for executor in executors.values(): + if hasattr(executor, "_response_handlers"): + for request_type, _ in executor._response_handlers: # type: ignore + if hasattr(request_type, "__args__"): + # UnionType — expand into constituent types + known_request_types.update(request_type.__args__) + else: + known_request_types.add(request_type) + + # Check each handler key against known request types + for handler_type in self._request_handlers: + if handler_type not in known_request_types: + known_names = ( + sorted(self._type_name(t) for t in known_request_types) if known_request_types else ["(none)"] + ) + raise ValueError( + f"request_handlers key {self._type_name(handler_type)} does not match any " + f"@response_handler request type declared in workflow executors. " + f"Known request types: {known_names}" + ) + def build(self) -> Workflow: """Build and return the constructed workflow. @@ -652,6 +700,10 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: output_executors, ) + # Validate request_handlers against executor @response_handler annotations + if self._request_handlers is not None: + self._validate_request_handlers(executors) + # Add validation completed event span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED) @@ -667,6 +719,7 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: description=self._description, max_iterations=self._max_iterations, output_executors=output_executors, + request_handlers=self._request_handlers, ) build_attributes: dict[str, Any] = { OtelAttr.WORKFLOW_BUILDER_NAME: self._name, diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 039c61b07d..bb444f0902 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -19,9 +19,10 @@ WorkflowRunnerException, WorkflowRunState, handler, + response_handler, ) from agent_framework._workflows._const import EXECUTOR_STATE_KEY -from agent_framework._workflows._edge import SingleEdgeGroup +from agent_framework._workflows._edge import InternalEdgeGroup, SingleEdgeGroup from agent_framework._workflows._runner import Runner from agent_framework._workflows._runner_context import ( InProcRunnerContext, @@ -38,6 +39,13 @@ class MockMessage: data: int +@dataclass +class MockRequest: + """A mock request_info payload.""" + + value: str + + class MockExecutor(Executor): """A mock executor for testing purposes.""" @@ -223,6 +231,28 @@ async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, i await ctx.yield_output(message.data) +class RequestInfoExecutor(Executor): + """Executor that requests external input and yields it once received.""" + + def __init__(self, id: str = "request_info_executor"): + super().__init__(id=id) + self.received_response: str | None = None + + @handler + async def handle(self, message: MockMessage, ctx: WorkflowContext[int]) -> None: + await ctx.request_info(MockRequest(value=f"request:{message.data}"), str) + + @response_handler + async def on_response( + self, + original_request: MockRequest, + response: str, + ctx: WorkflowContext[int, int], + ) -> None: + self.received_response = response + await ctx.yield_output(len(response)) + + async def test_runner_cancellation_stops_active_executor(): """Test that cancelling a workflow properly cancels the active executor.""" executor_a = SlowExecutor(id="executor_a", work_duration=0.3) @@ -273,6 +303,40 @@ async def run_workflow(): assert executor_b.completed_count == 0 # Should NOT have completed due to cancellation +async def test_runner_dispatches_request_handlers_inline(): + """Runner should dispatch request_handlers and process responses in the same run.""" + executor = RequestInfoExecutor() + state = State() + ctx = InProcRunnerContext() + runner = Runner( + [InternalEdgeGroup(executor.id)], + {executor.id: executor}, + state, + ctx, + "test_name", + graph_signature_hash="test_hash", + ) + + await executor.execute( + MockMessage(data=1), + ["START"], + state, + ctx, + ) + + async def handle_request(request: MockRequest) -> str: + await asyncio.sleep(0.05) + return "approved" + + result: int | None = None + async for event in runner.run_until_convergence(request_handlers={MockRequest: handle_request}): + if event.type == "output": + result = event.data + + assert executor.received_response == "approved" + assert result == len("approved") + + class FailingExecutor(Executor): """An executor that fails during execution.""" diff --git a/python/packages/core/tests/workflow/test_workflow_response_handlers.py b/python/packages/core/tests/workflow/test_workflow_response_handlers.py new file mode 100644 index 0000000000..90250fe5f1 --- /dev/null +++ b/python/packages/core/tests/workflow/test_workflow_response_handlers.py @@ -0,0 +1,707 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Tests for workflow request_handlers parameter. + +Verifies automatic HITL request handling: type-based dispatch, concurrent execution, +error handling, parameter validation, and inline handler dispatch. +""" + +import asyncio +import logging +import time +from dataclasses import dataclass +from typing import Any + +import pytest + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowContext, + WorkflowRunState, + handler, + response_handler, +) + +# --------------------------------------------------------------------------- +# Shared request / data types +# --------------------------------------------------------------------------- + + +@dataclass +class ReviewRequest: + """Request type for review.""" + + id: str + data: str + + +@dataclass +class ApprovalRequest: + """Request type for approval.""" + + id: str + task: str + + +@dataclass +class UnknownRequest: + """Request type with no registered handler.""" + + id: str + + +# --------------------------------------------------------------------------- +# Executors +# --------------------------------------------------------------------------- + + +class ReviewerExecutor(Executor): + """Requests a review and records whether the response was processed.""" + + def __init__(self): + super().__init__("reviewer") + self.feedback_received: bool = False + self.feedback_value: str | None = None + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = ReviewRequest(id="r1", data=f"Review: {message}") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: ReviewRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + self.feedback_received = True + self.feedback_value = response + await ctx.send_message(f"review_done:{response}") + + +class ApproverExecutor(Executor): + """Requests approval and records whether the response was processed.""" + + def __init__(self): + super().__init__("approver") + self.approval_received: bool = False + self.approval_value: str | None = None + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = ApprovalRequest(id="a1", task=f"Approve: {message}") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: ApprovalRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + self.approval_received = True + self.approval_value = response + await ctx.send_message(f"approval_done:{response}") + + +class NoneResponseExecutor(Executor): + """Requests info expecting a None response (valid value, not 'no response').""" + + def __init__(self): + super().__init__("none_requester") + self.handler_invoked: bool = False + self.received_value: Any = "NOT_SET" + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = ReviewRequest(id="n1", data="needs-none") + # Allow None as a valid response type + await ctx.request_info(request_data=request, response_type=str | None) + + @response_handler + async def on_response( + self, + original_request: ReviewRequest, + response: str | None, + ctx: WorkflowContext[str], + ) -> None: + self.handler_invoked = True + self.received_value = response + + +class UnknownRequestExecutor(Executor): + """Emits a request type with no handler registered.""" + + def __init__(self): + super().__init__("unknown_requester") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + request = UnknownRequest(id="u1") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: UnknownRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + pass # Should never be reached + + +class SimpleExecutor(Executor): + """Executor that just passes messages through without requesting info.""" + + def __init__(self): + super().__init__("simple") + + @handler + async def start(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(f"echo:{message}") + + +class CollectorExecutor(Executor): + """Collects messages for verification.""" + + def __init__(self): + super().__init__("collector") + self.collected: list[str] = [] + + @handler + async def collect(self, message: str, ctx: WorkflowContext) -> None: + self.collected.append(message) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestResponseHandlers: + async def test_single_handler_submits_response(self): + """Handler response is submitted back and the executor's response_handler runs.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() + + async def handle_review(request: ReviewRequest) -> str: + return "lgtm" + + result = await workflow.run( + "test_msg", + request_handlers={ReviewRequest: handle_review}, + ) + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "lgtm" + assert result.get_final_state() == WorkflowRunState.IDLE + # Collector should have received the message from the response handler + assert any("review_done:lgtm" in msg for msg in collector.collected) + + async def test_multiple_handlers_dispatched_by_type(self): + """Two different request types are dispatched to their respective handlers.""" + reviewer = ReviewerExecutor() + approver = ApproverExecutor() + collector = CollectorExecutor() + + # Use a dispatcher to fan out to both executors + class Dispatcher(Executor): + def __init__(self): + super().__init__("dispatcher") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + await ctx.send_message(message) + + dispatcher = Dispatcher() + workflow = ( + WorkflowBuilder(start_executor=dispatcher) + .add_edge(dispatcher, reviewer) + .add_edge(dispatcher, approver) + .add_edge(reviewer, collector) + .add_edge(approver, collector) + .build() + ) + + async def handle_review(request: ReviewRequest) -> str: + return "review_ok" + + async def handle_approval(request: ApprovalRequest) -> str: + return "approved" + + result = await workflow.run( + "multi_test", + request_handlers={ + ReviewRequest: handle_review, + ApprovalRequest: handle_approval, + }, + ) + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "review_ok" + assert approver.approval_received is True + assert approver.approval_value == "approved" + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_handler_returning_none_is_valid_response(self): + """A handler that returns None should still submit the response (sentinel fix).""" + none_exec = NoneResponseExecutor() + workflow = WorkflowBuilder(start_executor=none_exec).build() + + async def handle_returning_none(request: ReviewRequest) -> None: + return None + + result = await workflow.run( + "test_none", + request_handlers={ReviewRequest: handle_returning_none}, + ) + + # The handler matched and returned None — that None should be submitted + assert none_exec.handler_invoked is True + assert none_exec.received_value is None + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_handler_exception_leaves_request_pending(self, caplog): + """A handler that raises keeps the request pending and logs an error.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + async def failing_handler(request: ReviewRequest) -> str: + raise RuntimeError("handler boom") + + with caplog.at_level(logging.ERROR): + result = await workflow.run( + "fail_test", + request_handlers={ReviewRequest: failing_handler}, + ) + + assert reviewer.feedback_received is False + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert "handler boom" in caplog.text + assert "ReviewRequest" in caplog.text + + async def test_handler_type_mismatch_leaves_request_pending(self, caplog): + """A handler returning wrong type should not drop the pending request.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + async def wrong_type_handler(request: ReviewRequest) -> int: + return 123 + + with caplog.at_level(logging.ERROR): + result1 = await workflow.run( + "bad_type_test", + request_handlers={ReviewRequest: wrong_type_handler}, + ) + + assert reviewer.feedback_received is False + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert "Response type mismatch" in caplog.text + + # Pending request should still be recoverable via manual response submission. + request_events = result1.get_request_info_events() + assert len(request_events) == 1 + + result2 = await workflow.run( + responses={request_events[0].request_id: "manual_retry"}, + ) + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "manual_retry" + assert result2.get_final_state() == WorkflowRunState.IDLE + + async def test_unmatched_type_logs_warning(self, caplog): + """A request type with no matching handler logs a warning and stays pending.""" + unknown_exec = UnknownRequestExecutor() + workflow = WorkflowBuilder(start_executor=unknown_exec).build() + + # Register a handler for a *different* type + async def handle_review(request: ReviewRequest) -> str: + return "unused" + + with caplog.at_level(logging.WARNING): + result = await workflow.run( + "unknown_test", + request_handlers={ReviewRequest: handle_review}, + ) + + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert "UnknownRequest" in caplog.text + assert "ReviewRequest" in caplog.text + + async def test_stream_true_with_handlers_works(self): + """stream=True + request_handlers works with inline dispatch.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() + + async def handle_review(request: ReviewRequest) -> str: + return "stream_lgtm" + + events = [] + stream = workflow.run( + "stream_test", + stream=True, + request_handlers={ReviewRequest: handle_review}, + ) + async for event in stream: + events.append(event) + result = await stream.get_final_response() + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "stream_lgtm" + assert result.get_final_state() == WorkflowRunState.IDLE + assert any("review_done:stream_lgtm" in msg for msg in collector.collected) + + async def test_responses_with_handlers_raises(self): + """request_handlers + responses raises ValueError immediately.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + async def handle_review(request: ReviewRequest) -> str: + return "x" + + with pytest.raises(ValueError, match="request_handlers.*responses"): + await workflow.run( + responses={"some_id": "value"}, + request_handlers={ReviewRequest: handle_review}, + ) + + async def test_no_request_info_handlers_unused(self): + """Handlers registered but no request_info events — workflow completes normally.""" + simple = SimpleExecutor() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=simple).add_edge(simple, collector).build() + + async def handle_review(request: ReviewRequest) -> str: + return "unused" + + result = await workflow.run( + "hello", + request_handlers={ReviewRequest: handle_review}, + ) + + assert result.get_final_state() == WorkflowRunState.IDLE + assert len(result.get_request_info_events()) == 0 + + async def test_none_handlers_preserves_behavior(self): + """request_handlers=None preserves original IDLE_WITH_PENDING_REQUESTS state.""" + reviewer = ReviewerExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).build() + + result = await workflow.run("test_default", request_handlers=None) + + assert reviewer.feedback_received is False + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + async def test_handlers_run_concurrently(self): + """Two 200ms handlers complete in under 350ms (proving concurrency).""" + reviewer = ReviewerExecutor() + approver = ApproverExecutor() + + class Dispatcher(Executor): + def __init__(self): + super().__init__("dispatcher") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + await ctx.send_message(message) + + dispatcher = Dispatcher() + workflow = ( + WorkflowBuilder(start_executor=dispatcher) + .add_edge(dispatcher, reviewer) + .add_edge(dispatcher, approver) + .build() + ) + + async def slow_review(request: ReviewRequest) -> str: + await asyncio.sleep(0.2) + return "reviewed" + + async def slow_approval(request: ApprovalRequest) -> str: + await asyncio.sleep(0.2) + return "approved" + + start = time.monotonic() + await workflow.run( + "concurrent_test", + request_handlers={ + ReviewRequest: slow_review, + ApprovalRequest: slow_approval, + }, + ) + elapsed = time.monotonic() - start + + # Two 200ms handlers running concurrently should be well under 500ms + # (generous margin for CI/test overhead; sequential would be ~400ms+) + assert elapsed < 0.5, f"Expected < 500ms, got {elapsed * 1000:.0f}ms" + + async def test_events_include_request_and_response_handling(self): + """Result contains request_info events and executor events from response processing.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + workflow = WorkflowBuilder(start_executor=reviewer).add_edge(reviewer, collector).build() + + async def handle_review(request: ReviewRequest) -> str: + return "feedback" + + result = await workflow.run( + "merge_test", + request_handlers={ReviewRequest: handle_review}, + ) + + # Should have produced request_info events + request_events = result.get_request_info_events() + assert len(request_events) >= 1 + assert isinstance(request_events[0].data, ReviewRequest) + + # Response handler should have triggered executor events (reviewer response handler ran) + executor_invoked_events = [e for e in result if e.type == "executor_invoked"] + assert len(executor_invoked_events) >= 2 # Initial + response handling + + # Final state should be IDLE (response was processed) + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_handler_with_looping_branch(self): + """Fan-out with self-looping branch + HITL handler: handler dispatches while loop runs.""" + + @dataclass + class LoopMessage: + iteration: int + content: str + + class LoopingProcessor(Executor): + """Loops a fixed number of times via self-send. Stops by not sending.""" + + def __init__(self): + super().__init__("looper") + self.iteration_count = 0 + + @handler + async def process(self, message: LoopMessage, ctx: WorkflowContext[LoopMessage]) -> None: + self.iteration_count += 1 + if self.iteration_count < 3: + await ctx.send_message(LoopMessage(iteration=self.iteration_count, content="processing")) + + class HITLExecutor(Executor): + """Requests info from external handler.""" + + def __init__(self): + super().__init__("hitl") + self.response_received = False + self.response_value: str | None = None + + @handler + async def start(self, message: LoopMessage, ctx: WorkflowContext) -> None: + request = ReviewRequest(id="hitl_req", data=f"Review: {message.content}") + await ctx.request_info(request_data=request, response_type=str) + + @response_handler + async def on_response( + self, + original_request: ReviewRequest, + response: str, + ctx: WorkflowContext[str], + ) -> None: + self.response_received = True + self.response_value = response + + class FanOutDispatcher(Executor): + def __init__(self): + super().__init__("fanout") + + @handler + async def start(self, message: str, ctx: WorkflowContext[LoopMessage]) -> None: + await ctx.send_message(LoopMessage(iteration=0, content=message)) + + fanout = FanOutDispatcher() + looper = LoopingProcessor() + hitl = HITLExecutor() + + workflow = ( + WorkflowBuilder(start_executor=fanout) + .add_edge(fanout, looper) + .add_edge(fanout, hitl) + .add_edge(looper, looper) # self-loop (stops when no message sent) + .build() + ) + + async def handle_review(request: ReviewRequest) -> str: + await asyncio.sleep(0.1) # Simulate async work + return "approved_by_handler" + + result = await workflow.run( + "fan_out_test", + request_handlers={ReviewRequest: handle_review}, + ) + + # Looping branch completed its iterations + assert looper.iteration_count == 3 + + # HITL handler was dispatched and response processed + assert hitl.response_received is True + assert hitl.response_value == "approved_by_handler" + + # Final state should be IDLE (all branches complete, all requests handled) + assert result.get_final_state() == WorkflowRunState.IDLE + + +class TestBuilderLevelResponseHandlers: + """Tests for request_handlers passed via WorkflowBuilder constructor.""" + + async def test_builder_handlers_used_as_default(self): + """Builder-level request_handlers are used when run() has none.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + + async def handle_review(request: ReviewRequest) -> str: + return "builder_lgtm" + + workflow = ( + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ReviewRequest: handle_review}, + ) + .add_edge(reviewer, collector) + .build() + ) + + result = await workflow.run("test_builder") + + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "builder_lgtm" + assert result.get_final_state() == WorkflowRunState.IDLE + assert any("review_done:builder_lgtm" in msg for msg in collector.collected) + + async def test_run_level_overrides_builder(self): + """run(request_handlers=...) overrides builder-level handlers.""" + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + + async def builder_handler(request: ReviewRequest) -> str: + return "from_builder" + + async def run_handler(request: ReviewRequest) -> str: + return "from_run" + + workflow = ( + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ReviewRequest: builder_handler}, + ) + .add_edge(reviewer, collector) + .build() + ) + + result = await workflow.run( + "test_override", + request_handlers={ReviewRequest: run_handler}, + ) + + assert reviewer.feedback_value == "from_run" + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_validation_rejects_unknown_type(self): + """Builder rejects request_handlers with types not in any @response_handler.""" + reviewer = ReviewerExecutor() + + async def handle_unknown(request: UnknownRequest) -> str: + return "x" + + with pytest.raises(ValueError, match="UnknownRequest.*does not match"): + WorkflowBuilder( + start_executor=reviewer, + request_handlers={UnknownRequest: handle_unknown}, + ).build() + + async def test_validation_passes_for_known_types(self): + """Builder accepts request_handlers matching executor @response_handler types.""" + reviewer = ReviewerExecutor() + approver = ApproverExecutor() + + async def handle_review(request: ReviewRequest) -> str: + return "ok" + + async def handle_approval(request: ApprovalRequest) -> str: + return "approved" + + # Should not raise — both types are declared via @response_handler on executors + class Dispatcher(Executor): + def __init__(self): + super().__init__("dispatcher") + + @handler + async def start(self, message: str, ctx: WorkflowContext) -> None: + await ctx.send_message(message) + + dispatcher = Dispatcher() + workflow = ( + WorkflowBuilder( + start_executor=dispatcher, + request_handlers={ + ReviewRequest: handle_review, + ApprovalRequest: handle_approval, + }, + ) + .add_edge(dispatcher, reviewer) + .add_edge(dispatcher, approver) + .build() + ) + assert workflow is not None + + async def test_validation_partial_mismatch_rejected(self): + """Builder rejects if any key is unknown, even if others are valid.""" + reviewer = ReviewerExecutor() + + async def handle_review(request: ReviewRequest) -> str: + return "ok" + + async def handle_unknown(request: UnknownRequest) -> str: + return "x" + + with pytest.raises(ValueError, match="UnknownRequest.*does not match"): + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ + ReviewRequest: handle_review, + UnknownRequest: handle_unknown, + }, + ).build() + + async def test_responses_bypass_builder_handlers(self): + """run(responses=...) works even when builder has default request_handlers. + + Regression: without the fix, builder-level handlers would be treated as + effective_handlers, causing _validate_run_params to raise ValueError + ('Cannot provide both request_handlers and responses'). + """ + reviewer = ReviewerExecutor() + collector = CollectorExecutor() + + async def failing_handler(request: ReviewRequest) -> str: + raise RuntimeError("intentional failure") + + # Build with handlers that always fail — first run ends with pending request + workflow = ( + WorkflowBuilder( + start_executor=reviewer, + request_handlers={ReviewRequest: failing_handler}, + ) + .add_edge(reviewer, collector) + .build() + ) + + result1 = await workflow.run("first") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + request_events = result1.get_request_info_events() + assert len(request_events) == 1 + + # Manual response submission should NOT conflict with builder-level handlers + result2 = await workflow.run(responses={request_events[0].request_id: "manual_response"}) + assert reviewer.feedback_received is True + assert reviewer.feedback_value == "manual_response" + assert result2.get_final_state() == WorkflowRunState.IDLE diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index ce4aee4172..db54e30269 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -88,6 +88,7 @@ Once comfortable with these, explore the rest of the samples below. | SequentialBuilder Request Info | [human-in-the-loop/sequential_request_info.py](./human-in-the-loop/sequential_request_info.py) | Request info for agent responses mid-workflow using `.with_request_info()` on SequentialBuilder | | ConcurrentBuilder Request Info | [human-in-the-loop/concurrent_request_info.py](./human-in-the-loop/concurrent_request_info.py) | Review concurrent agent outputs before aggregation using `.with_request_info()` on ConcurrentBuilder | | GroupChatBuilder Request Info | [human-in-the-loop/group_chat_request_info.py](./human-in-the-loop/group_chat_request_info.py) | Steer group discussions with periodic guidance using `.with_request_info()` on GroupChatBuilder | +| Fan-Out Async with Request Handlers | [human-in-the-loop/fan_out_async_with_request_handlers.py](./human-in-the-loop/fan_out_async_with_request_handlers.py) | Type-based request handler dispatch for HITL requests in fan-out workflows; automatic async handler execution and response submission | ### tool-approval diff --git a/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_request_handlers.py b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_request_handlers.py new file mode 100644 index 0000000000..bbade2d453 --- /dev/null +++ b/python/samples/getting_started/workflows/human-in-the-loop/fan_out_async_with_request_handlers.py @@ -0,0 +1,296 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Sample: Fan-out Async + HITL with request_handlers parameter + +Demonstrates automatic HITL request handling in fan-out workflows using the +request_handlers parameter. Handlers are dispatched inline as asyncio tasks +when request_info events are emitted during execution. The runner waits for +outstanding handler tasks before declaring convergence, so handler responses +are processed in subsequent supersteps within the same workflow run. + +Usage: + request_handlers = { + ReviewRequest: handle_review, + ApprovalRequest: handle_approval, + } + + result = await workflow.run( + initial_data, + request_handlers=request_handlers, + ) +""" + +import asyncio +import time +from dataclasses import dataclass + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowContext, + handler, + response_handler, +) + +_start_time: float = 0.0 + + +def _ts() -> str: + """Return elapsed seconds since workflow start.""" + return f"{time.monotonic() - _start_time:.1f}s" + + +@dataclass +class AnalysisData: + """Initial data for analysis.""" + + task: str + data: str + + +@dataclass +class DataPacket: + """Data flowing through the async processor.""" + + iteration: int + content: str + analysis: str = "" + + +@dataclass +class ReviewRequest: + """Request sent to human reviewer.""" + + packet: DataPacket + analysis: str + prompt: str = "Please review the analysis and provide feedback" + + +class Analyzer(Executor): + def __init__(self, id: str = "analyzer"): + super().__init__(id) + + @handler + async def start_analysis( + self, + data: AnalysisData, + ctx: WorkflowContext[DataPacket, DataPacket], + ) -> None: + packet = DataPacket( + iteration=0, + content=data.data, + analysis=f"Starting {data.task}...", + ) + await ctx.send_message(packet) + + +class AsyncProcessor(Executor): + """Async branch: loops and processes data.""" + + def __init__(self, id: str = "processor"): + super().__init__(id) + self.iteration_count = 0 + + @handler + async def process_data( + self, + packet: DataPacket, + ctx: WorkflowContext[DataPacket, DataPacket], + ) -> None: + self.iteration_count += 1 + await asyncio.sleep(0.5) + analysis = f"Iteration {self.iteration_count}: {packet.content[:50]}... processed" + print(f" [{_ts()}] Processor: iteration {self.iteration_count}/10 complete") + + if self.iteration_count < 10: + updated_packet = DataPacket( + iteration=self.iteration_count, + content=packet.content, + analysis=analysis, + ) + await ctx.send_message(updated_packet) + else: + # Stop sending to break the self-loop. Use yield_output for final result. + print(f" [{_ts()}] Processor: DONE after {self.iteration_count} iterations") + await ctx.yield_output( + DataPacket(iteration=self.iteration_count, content=packet.content, analysis=analysis) + ) + + +class Reviewer(Executor): + """HITL branch: requests external feedback.""" + + def __init__(self, id: str = "reviewer"): + super().__init__(id) + + @handler + async def review_data( + self, + packet: DataPacket, + ctx: WorkflowContext[str, str], + ) -> None: + review_request = ReviewRequest( + packet=packet, + analysis=packet.analysis or "No analysis yet", + prompt=f"Please review iteration {packet.iteration}", + ) + + await ctx.request_info( + request_data=review_request, + response_type=str, + ) + + @response_handler + async def handle_review_feedback( + self, + original_request: ReviewRequest, + feedback: str, + ctx: WorkflowContext[str], + ) -> None: + print(f" [{_ts()}] Reviewer: response_handler invoked with feedback") + result = f"Review feedback processed: {feedback}" + await ctx.send_message(result) + + +class FinalAggregator(Executor): + def __init__(self, id: str = "aggregator"): + super().__init__(id) + self.results: list[str] = [] + + @handler + async def aggregate( + self, + message: DataPacket | str, + ctx: WorkflowContext[str], + ) -> None: + self.results.append(str(message)) + + if len(self.results) >= 2: + summary = "=== WORKFLOW COMPLETE ===\n" + summary += "Results from async path: " + str(self.results[0]) + "\n" + summary += "Results from HITL path: " + str(self.results[1]) + await ctx.send_message(summary) + + +# ============================================================================ +# Response Handlers (External) +# ============================================================================ +# These are registered via request_handlers dict, not as executor methods. +# They are dispatched inline as asyncio tasks when request_info events are emitted. + + +async def handle_review(request: ReviewRequest) -> str: + """Handle external review request. + + Dispatched as an asyncio task when the Reviewer emits a request_info + with ReviewRequest type. The response is submitted back to the workflow + and processed in a subsequent superstep. + """ + print(f" [{_ts()}] Handler: STARTED - reviewing iteration {request.packet.iteration}") + + # Simulate slow external API call (e.g., LLM inference, human approval system) + await asyncio.sleep(3.0) + + feedback = "Analysis looks good. Continue processing." + print(f" [{_ts()}] Handler: DONE - returning feedback") + + return feedback + + +# ============================================================================ +# Main +# ============================================================================ + + +async def main() -> None: + print("=" * 80) + print("HITL: request_handlers parameter in workflow.run()") + print("=" * 80) + + # Create executors + analyzer = Analyzer() + processor = AsyncProcessor() + reviewer = Reviewer() + aggregator = FinalAggregator() + + # Build workflow: processor self-loops, both branches feed into aggregator + workflow = ( + WorkflowBuilder(start_executor=analyzer) + .add_edge(analyzer, processor) + .add_edge(analyzer, reviewer) + .add_edge(processor, processor) # self-loop for async processing + .add_edge(processor, aggregator) + .add_edge(reviewer, aggregator) + .build() + ) + + initial_data = AnalysisData( + task="document analysis", + data="Sample document to process", + ) + + # ======================================================================== + # THE CLEAN API + # ======================================================================== + # Define external response handlers (type-based dispatch) + request_handlers = { + ReviewRequest: handle_review, + # Can add more: ApprovalRequest: handle_approval, etc. + } + + global _start_time + _start_time = time.monotonic() + + print(f"\n[{_ts()}] Start: {initial_data.task}") + + # Run workflow with automatic response handling (inline dispatch) + # Handlers are dispatched as asyncio tasks when request_info events are emitted + # Responses are injected back and processed in subsequent supersteps + result = await workflow.run( + initial_data, + request_handlers=request_handlers, + ) + + elapsed = time.monotonic() - _start_time + + # Display results + print(f"\n[{_ts()}] Final state: {result.get_final_state()}") + print(f" Total time: {elapsed:.1f}s") + print(f" Total events: {len(result)}") + outputs = result.get_outputs() + if outputs: + print(f"Outputs ({len(outputs)}):") + for output in outputs: + print(f" - {output}") + + """ + Sample Output: + + [0.0s] Start: document analysis + + [0.0s] Handler: STARTED - reviewing iteration 0 + [0.5s] Processor: iteration 1/10 complete + [1.1s] Processor: iteration 2/10 complete + [1.6s] Processor: iteration 3/10 complete + [2.2s] Processor: iteration 4/10 complete + [2.7s] Processor: iteration 5/10 complete + [3.0s] Handler: DONE - returning feedback + [3.3s] Processor: iteration 6/10 complete + [3.3s] Reviewer: response_handler invoked with feedback + [3.8s] Processor: iteration 7/10 complete + [4.4s] Processor: iteration 8/10 complete + [4.9s] Processor: iteration 9/10 complete + [5.5s] Processor: iteration 10/10 complete + [5.5s] Processor: DONE after 10 iterations + + [5.6s] Final state: WorkflowRunState.IDLE + Total time: 5.6s + Total events: 70 + Outputs (1): + - DataPacket(iteration=10, content='Sample document to process', analysis='Iteration 10: Sample document to process... processed') + """ # noqa: E501 + + +if __name__ == "__main__": + asyncio.run(main())