From b5de0997aa0972dd483c4b0b3451a2131a0d34c7 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Tue, 20 Jan 2026 09:17:04 +0900 Subject: [PATCH] fix(core): prevent WorkflowExecutor from re-sending answered requests after checkpoint restore --- .../_workflows/_workflow_executor.py | 39 ++- .../core/tests/workflow/test_sub_workflow.py | 253 ++++++++++++++++++ 2 files changed, 290 insertions(+), 2 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 69f24bcf2c..3fcda1ccb6 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -299,6 +299,9 @@ def __init__( self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext # Map request_id to execution_id for response routing self._request_to_execution: dict[str, str] = {} # request_id -> execution_id + # Track request IDs that have been responded to, to filter out duplicates + # after checkpoint restore (when old requests may reappear in the event stream) + self._responded_request_ids: set[str] = set() self._propagate_request = propagate_request @property @@ -458,6 +461,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]: for execution_id, execution_context in self._execution_contexts.items() }, "request_to_execution": dict(self._request_to_execution), + "responded_request_ids": list(self._responded_request_ids), } @override @@ -497,6 +501,14 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: self._execution_contexts = execution_contexts self._request_to_execution = request_to_execution + # Restore responded request IDs (defaults to empty for checkpoints created before this field existed) + responded_request_ids = state.get("responded_request_ids", []) + if not isinstance(responded_request_ids, list): + raise ValueError("'responded_request_ids' must be a list.") + if not all(isinstance(rid, str) for rid in responded_request_ids): + raise ValueError("All items in 'responded_request_ids' must be strings.") + self._responded_request_ids = set(responded_request_ids) + # Add the `request_info_event`s back to the sub workflow. # This is only a temporary solution to rehydrate the sub workflow with the requests. # The proper way would be to rehydrate the workflow from a checkpoint on a Workflow @@ -548,11 +560,30 @@ async def _process_workflow_result( await asyncio.gather(*[ctx.send_message(output) for output in outputs]) # Process request info events + # Filter out requests that have already been processed to avoid re-sending + # after checkpoint restore. This can happen because apply_checkpoint() re-adds + # pending RequestInfoEvents to the event queue, which then appear in subsequent + # workflow results even though they were already handled. + new_request_count = 0 for event in request_info_events: + # Skip if this request has already been responded to (e.g., after checkpoint restore) + if event.request_id in self._responded_request_ids: + logger.debug(f"WorkflowExecutor {self.id} skipping already-responded request {event.request_id}") + continue + # Skip if this request is already being tracked (already processed) + if event.request_id in execution_context.pending_requests: + logger.debug(f"WorkflowExecutor {self.id} skipping already-tracked request {event.request_id}") + continue + if event.request_id in self._request_to_execution: + logger.debug(f"WorkflowExecutor {self.id} skipping already-mapped request {event.request_id}") + continue + # Track the pending request in execution context execution_context.pending_requests[event.request_id] = event # Map request to execution for response routing self._request_to_execution[event.request_id] = execution_context.execution_id + new_request_count += 1 + if self._propagate_request: # In a workflow where the parent workflow does not handle the request, the request # should be propagated via the `request_info` mechanism to an external source. And @@ -563,8 +594,8 @@ async def _process_workflow_result( # request and handle it directly, a message should be sent. await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) - # Update expected response count for this execution - execution_context.expected_response_count = len(request_info_events) + # Update expected response count for this execution (only count NEW requests) + execution_context.expected_response_count = new_request_count # Handle final state if workflow_run_state == WorkflowRunState.FAILED: @@ -634,6 +665,10 @@ async def _handle_response( execution_context.pending_requests.pop(request_id, None) self._request_to_execution.pop(request_id, None) + # Track that this request has been responded to, so we can filter it out + # if it reappears in the event stream after checkpoint restore + self._responded_request_ids.add(request_id) + # Accumulate the response in this execution's context execution_context.collected_responses[request_id] = response # Check if we have all expected responses for this execution diff --git a/python/packages/core/tests/workflow/test_sub_workflow.py b/python/packages/core/tests/workflow/test_sub_workflow.py index cb2733b653..67ab5ab688 100644 --- a/python/packages/core/tests/workflow/test_sub_workflow.py +++ b/python/packages/core/tests/workflow/test_sub_workflow.py @@ -1,18 +1,23 @@ # Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass, field +from typing import Any from uuid import uuid4 from typing_extensions import Never from agent_framework import ( Executor, + InMemoryCheckpointStorage, + RequestInfoEvent, SubWorkflowRequestMessage, SubWorkflowResponseMessage, Workflow, WorkflowBuilder, WorkflowContext, WorkflowExecutor, + WorkflowRunState, + WorkflowStatusEvent, handler, response_handler, ) @@ -461,3 +466,251 @@ async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) - # Verify that concurrent executions were properly isolated # (This is implicitly tested by the fact that we got correct results for all emails) + + +# ============================================================================= +# Checkpoint restore tests +# ============================================================================= + + +@dataclass +class CheckpointTestStartMessage: + """Initial message to start the checkpoint test workflow.""" + + value: str + + +@dataclass +class CheckpointTestFirstRequest: + """First request_info() call in checkpoint tests.""" + + id: str = field(default_factory=lambda: str(uuid4())) + prompt: str = "First request" + + +@dataclass +class CheckpointTestSecondRequest: + """Second request_info() call in checkpoint tests.""" + + id: str = field(default_factory=lambda: str(uuid4())) + prompt: str = "Second request" + + +@dataclass +class CheckpointTestFinalOutput: + """Final output from the checkpoint test workflow.""" + + first_response: str + second_response: str + + +class TwoStepExecutor(Executor): + """Executor that makes TWO sequential request_info() calls for checkpoint testing.""" + + def __init__(self) -> None: + super().__init__(id="two_step_executor") + self._first_response: str = "" + + @handler + async def handle_start(self, msg: CheckpointTestStartMessage, ctx: WorkflowContext) -> None: + await ctx.request_info( + request_data=CheckpointTestFirstRequest(prompt=f"First request for: {msg.value}"), + response_type=str, + ) + + @response_handler + async def handle_first_response( + self, + original_request: CheckpointTestFirstRequest, + response: str, + ctx: WorkflowContext, + ) -> None: + self._first_response = response + await ctx.request_info( + request_data=CheckpointTestSecondRequest(prompt="Second request"), + response_type=str, + ) + + @response_handler + async def handle_second_response( + self, + original_request: CheckpointTestSecondRequest, + response: str, + ctx: WorkflowContext[None, CheckpointTestFinalOutput], + ) -> None: + await ctx.yield_output( + CheckpointTestFinalOutput( + first_response=self._first_response, + second_response=response, + ) + ) + + async def on_checkpoint_save(self) -> dict[str, Any]: + return {"first_response": self._first_response} + + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + self._first_response = state.get("first_response", "") + + +class CheckpointTestCoordinator(Executor): + """Parent workflow coordinator for checkpoint tests.""" + + def __init__(self) -> None: + super().__init__(id="checkpoint_test_coordinator") + self.received_requests: list[SubWorkflowRequestMessage] = [] + self.output: CheckpointTestFinalOutput | None = None + self._pending_requests: dict[str, SubWorkflowRequestMessage] = {} + + @handler + async def start(self, value: str, ctx: WorkflowContext[CheckpointTestStartMessage]) -> None: + await ctx.send_message(CheckpointTestStartMessage(value=value)) + + @handler + async def handle_sub_workflow_request( + self, + request: SubWorkflowRequestMessage, + ctx: WorkflowContext, + ) -> None: + self.received_requests.append(request) + data = request.source_event.data + if isinstance(data, (CheckpointTestFirstRequest, CheckpointTestSecondRequest)): + self._pending_requests[data.id] = request + await ctx.request_info(data, str) + + @response_handler + async def handle_response( + self, + original_request: CheckpointTestFirstRequest | CheckpointTestSecondRequest, + response: str, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + sub_request = self._pending_requests.pop(original_request.id, None) + if sub_request is None: + raise ValueError(f"No pending request for ID: {original_request.id}") + await ctx.send_message(sub_request.create_response(response)) + + @handler + async def collect_output(self, output: CheckpointTestFinalOutput, ctx: WorkflowContext) -> None: + self.output = output + + async def on_checkpoint_save(self) -> dict[str, Any]: + return {"pending_requests": self._pending_requests} + + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + self._pending_requests = state.get("pending_requests", {}) + + +def _build_checkpoint_test_sub_workflow() -> WorkflowExecutor: + sub_workflow = ( + WorkflowBuilder().register_executor(TwoStepExecutor, name="two_step").set_start_executor("two_step").build() + ) + return WorkflowExecutor(sub_workflow, id="sub_workflow") + + +def _build_checkpoint_test_workflow(storage: InMemoryCheckpointStorage) -> tuple[Workflow, CheckpointTestCoordinator]: + coordinator = CheckpointTestCoordinator() + workflow = ( + WorkflowBuilder() + .register_executor(lambda: coordinator, name="coordinator") + .register_executor(_build_checkpoint_test_sub_workflow, name="sub_executor") + .set_start_executor("coordinator") + .add_edge("coordinator", "sub_executor") + .add_edge("sub_executor", "coordinator") + .with_checkpointing(storage) + .build() + ) + return workflow, coordinator + + +async def test_checkpoint_restore_does_not_duplicate_requests(): + """ + Regression test: WorkflowExecutor should not re-send already-answered + RequestInfoEvents after checkpoint restore. + + This test verifies the fix for a bug where after checkpoint restore, + WorkflowExecutor._process_workflow_result() would process ALL RequestInfoEvents + from the workflow result, including ones that were already answered. + """ + storage = InMemoryCheckpointStorage() + + # Step 1: Run until first request + workflow, coordinator = _build_checkpoint_test_workflow(storage) + + first_request_id: str | None = None + async for event in workflow.run_stream("test_value"): + if isinstance(event, RequestInfoEvent): + first_request_id = event.request_id + if isinstance(event, WorkflowStatusEvent) and event.state is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS: + break + + assert first_request_id is not None + assert len(coordinator.received_requests) == 1 + assert isinstance(coordinator.received_requests[0].source_event.data, CheckpointTestFirstRequest) + + # Get checkpoint + checkpoints = await storage.list_checkpoints(workflow.id) + checkpoints.sort(key=lambda cp: cp.timestamp) + checkpoint_id = checkpoints[-1].checkpoint_id + + # Step 2: Resume from checkpoint + workflow2, coordinator2 = _build_checkpoint_test_workflow(storage) + + resumed_first_request_id: str | None = None + async for event in workflow2.run_stream(checkpoint_id=checkpoint_id): + if isinstance(event, RequestInfoEvent): + resumed_first_request_id = event.request_id + + assert resumed_first_request_id is not None + + # Respond to first request - this triggers second request + second_request_id: str | None = None + async for event in workflow2.send_responses_streaming({resumed_first_request_id: "first_answer"}): + if isinstance(event, RequestInfoEvent): + second_request_id = event.request_id + if isinstance(event, WorkflowStatusEvent) and event.state is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS: + break + + assert second_request_id is not None + + # After responding to first request, coordinator2 should only receive the new SecondRequest + assert len(coordinator2.received_requests) == 1 + assert isinstance(coordinator2.received_requests[0].source_event.data, CheckpointTestSecondRequest) + + +async def test_checkpoint_restore_full_flow(): + """Test that a sub-workflow with checkpointing can complete successfully.""" + storage = InMemoryCheckpointStorage() + + # Step 1: Run until first request + workflow, _ = _build_checkpoint_test_workflow(storage) + + result = await workflow.run("test_value") + request_events = result.get_request_info_events() + assert len(request_events) == 1 + + # Get checkpoint + checkpoints = await storage.list_checkpoints(workflow.id) + checkpoints.sort(key=lambda cp: cp.timestamp) + checkpoint_id = checkpoints[-1].checkpoint_id + + # Step 2: Resume and respond to first request + workflow2, coordinator2 = _build_checkpoint_test_workflow(storage) + + result2 = await workflow2.run(checkpoint_id=checkpoint_id) + request_events2 = result2.get_request_info_events() + assert len(request_events2) == 1 + resumed_request_id = request_events2[0].request_id + + # Respond to first request - triggers second request + result3 = await workflow2.send_responses({resumed_request_id: "first_answer"}) + request_events3 = result3.get_request_info_events() + assert len(request_events3) == 1 + second_request_id = request_events3[0].request_id + + # Step 3: Respond to second request to complete workflow + await workflow2.send_responses({second_request_id: "second_answer"}) + + # Verify final output + assert coordinator2.output is not None + assert coordinator2.output.first_response == "first_answer" + assert coordinator2.output.second_response == "second_answer"