From c8285732294031a42f1179b6f1692d578cc73e90 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 4 Feb 2026 22:40:55 -0800 Subject: [PATCH] Fix subworkflow duplicate request info events --- .../_workflows/_workflow_executor.py | 18 ++ .../core/tests/workflow/test_sub_workflow.py | 159 ++++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 2453620cfd..80200b9e73 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -652,6 +652,24 @@ async def _handle_response( try: # Resume the sub-workflow with all collected responses result = await self.workflow.send_responses(responses_to_send) + # Remove handled requests from result. The result may contain the original + # RequestInfoEvents that were already handled. This is due to checkpointing + # and rehydration of the workflow that re-adds the RequestInfoEvents to the + # workflow's _runner_context thus the event queue. When the workflow is resumed, + # those events will be emitted at the very beginning of the superstep, prior to + # processing messages/responses, creating the illusion that the workflow is + # requesting the same information again. + for request_id in responses_to_send: + event_to_remove = next( + ( + event + for event in result + if isinstance(event, RequestInfoEvent) and event.request_id == request_id + ), + None, + ) + if event_to_remove: + result.remove(event_to_remove) # Process the workflow result using shared logic await self._process_workflow_result(result, execution_context, ctx) diff --git a/python/packages/core/tests/workflow/test_sub_workflow.py b/python/packages/core/tests/workflow/test_sub_workflow.py index cb2733b653..b77ddeb1b8 100644 --- a/python/packages/core/tests/workflow/test_sub_workflow.py +++ b/python/packages/core/tests/workflow/test_sub_workflow.py @@ -1,12 +1,14 @@ # Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass, field +from typing import Any from uuid import uuid4 from typing_extensions import Never from agent_framework import ( Executor, + RequestInfoEvent, SubWorkflowRequestMessage, SubWorkflowResponseMessage, Workflow, @@ -16,6 +18,7 @@ handler, response_handler, ) +from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage # Test message types @@ -461,3 +464,159 @@ async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) - # Verify that concurrent executions were properly isolated # (This is implicitly tested by the fact that we got correct results for all emails) + + +# region Checkpoint-related message types and executors for sub-workflow tests + + +@dataclass +class CheckpointRequest: + """Request in a two-step checkpoint test.""" + + prompt: str + id: str = field(default_factory=lambda: str(uuid4())) + + +class TwoStepSubWorkflowExecutor(Executor): + """Sub-workflow executor that makes two sequential requests.""" + + def __init__(self) -> None: + super().__init__(id="two_step_executor") + self._responses: list[str] = [] + + @handler + async def handle_start(self, msg: str, ctx: WorkflowContext) -> None: + await ctx.request_info( + request_data=CheckpointRequest(prompt=f"First request for: {msg}"), + response_type=str, + ) + + @response_handler + async def handle_response( + self, + original_request: CheckpointRequest, + response: str, + ctx: WorkflowContext[Never, bool], + ) -> None: + self._responses.append(response) + if len(self._responses) == 1: + # First response received, make second request + await ctx.request_info( + request_data=CheckpointRequest(prompt="Second request"), + response_type=str, + ) + else: + # Second response received, yield final output + await ctx.yield_output(True) + + async def on_checkpoint_save(self) -> dict[str, Any]: + return {"responses": self._responses} + + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + self._responses = state.get("responses", []) + + +class CheckpointTestCoordinator(Executor): + """Coordinator for checkpoint sub-workflow tests.""" + + def __init__(self) -> None: + super().__init__(id="checkpoint_coordinator") + self._pending_requests: dict[str, SubWorkflowRequestMessage] = {} + + @handler + async def start(self, value: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(value) + + @handler + async def handle_sub_workflow_request( + self, + request: SubWorkflowRequestMessage, + ctx: WorkflowContext, + ) -> None: + data = request.source_event.data + if isinstance(data, CheckpointRequest): + self._pending_requests[data.id] = request + await ctx.request_info(data, str) + + @response_handler + async def handle_response( + self, + original_request: CheckpointRequest, + response: str, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + sub_request = self._pending_requests.pop(original_request.id, None) + if sub_request is None: + raise ValueError(f"No pending request for ID: {original_request.id}") + await ctx.send_message(sub_request.create_response(response)) + + async def on_checkpoint_save(self) -> dict[str, Any]: + return {"pending_requests": self._pending_requests} + + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + self._pending_requests = state.get("pending_requests", {}) + + +def _build_checkpoint_test_workflow(storage: InMemoryCheckpointStorage) -> Workflow: + """Build the main workflow with checkpointing for testing.""" + two_step_executor = TwoStepSubWorkflowExecutor() + sub_workflow = WorkflowBuilder().set_start_executor(two_step_executor).build() + sub_workflow_executor = WorkflowExecutor(sub_workflow, id="sub_workflow_executor") + + coordinator = CheckpointTestCoordinator() + return ( + WorkflowBuilder() + .set_start_executor(coordinator) + .add_edge(coordinator, sub_workflow_executor) + .add_edge(sub_workflow_executor, coordinator) + .with_checkpointing(storage) + .build() + ) + + +async def test_sub_workflow_checkpoint_restore_no_duplicate_requests() -> None: + """Test that resuming a sub-workflow from checkpoint does not emit duplicate requests. + + This test verifies the fix for an issue where after checkpoint restore, when a response + is sent to a sub-workflow, duplicate RequestInfoEvents were emitted. The bug occurred + because checkpoint rehydration re-added RequestInfoEvents to the event queue, and when + the workflow was resumed, those events were emitted again along with any new requests. + + The fix ensures that already-handled requests are filtered out from the result when + the sub-workflow is resumed with responses. + """ + storage = InMemoryCheckpointStorage() + + # Step 1: Run workflow until first request + workflow1 = _build_checkpoint_test_workflow(storage) + + first_request_id: str | None = None + async for event in workflow1.run_stream("test_value"): + if isinstance(event, RequestInfoEvent): + first_request_id = event.request_id + + assert first_request_id is not None + + # Get checkpoint + checkpoints = await storage.list_checkpoints(workflow1.id) + checkpoint_id = max(checkpoints, key=lambda cp: cp.timestamp).checkpoint_id + + # Step 2: Resume workflow from checkpoint + workflow2 = _build_checkpoint_test_workflow(storage) + + resumed_first_request_id: str | None = None + async for event in workflow2.run_stream(checkpoint_id=checkpoint_id): + if isinstance(event, RequestInfoEvent): + resumed_first_request_id = event.request_id + + assert resumed_first_request_id is not None + assert resumed_first_request_id == first_request_id + + request_events: list[RequestInfoEvent] = [] + async for event in workflow2.send_responses_streaming({resumed_first_request_id: "first_answer"}): + if isinstance(event, RequestInfoEvent): + request_events.append(event) + + # Key assertion: Only the second request should be received, not a duplicate of the first + assert len(request_events) == 1 + assert request_events[0].data.prompt == "Second request"