Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is needed and what the root cause actually is.

When a request is fulfilled, it's taken off record immediately and won't be available in the next checkpoint. My understanding is that a workflow emits a request, the executor captures it and send it out as an event or a message, followed by a checkpoint (A), which will have all the pending requests. When a respond comes back, the request will be taken off record, and processing will begin, followed by the next checkpoint (B).

If the workflow is resumed from checkpoint A, the request will be reemitted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix is needed because on_checkpoint_restore() re-adds pending requests to the sub-workflow's event queue. When a response arrives, we remove the request from WorkflowExecutor's tracking, but it's still in the sub-workflow's event stream.

So when the sub-workflow continues and makes another request_info() call, result.get_request_info_events() returns both the old (answered) request and the new one, causing duplicate SubWorkflowRequestMessages and incorrect expected_response_count.

_responded_request_ids filters these out. The proper fix would be sub-workflow-level checkpoint tracking (Issue #1614), but this works until then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be in the sub workflow's event queue because the event has been emitted. When a checkpoint is created, the event queue should be empty. This is guaranteed by the runner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that the event queue is empty at checkpoint time. The issue is on_checkpoint_restore() (lines 519-527) which explicitly re-adds pending requests to the sub-workflow's event queue:

await asyncio.gather(*[
    self.workflow._runner_context.add_request_info_event(event)
    for event in request_info_events
])

When _handle_response() later calls send_responses() on the sub-workflow, run_until_convergence() drains these pre-loop events (lines 88-92 in _runner.py) and they end up in the WorkflowRunResult.

The rehydration is intentional (marked as "temporary solution" with TODO #1614), so we need _responded_request_ids to filter them out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an issue with the flow you're describing here. If the checkpoint contains the pending requests, they should be re-emitted when the checkpoint is loaded.

Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def __init__(
self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext
# Map request_id to execution_id for response routing
self._request_to_execution: dict[str, str] = {} # request_id -> execution_id
# Track request IDs that have been responded to, to filter out duplicates
# after checkpoint restore (when old requests may reappear in the event stream)
self._responded_request_ids: set[str] = set()
self._propagate_request = propagate_request

@property
Expand Down Expand Up @@ -458,6 +461,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]:
for execution_id, execution_context in self._execution_contexts.items()
},
"request_to_execution": dict(self._request_to_execution),
"responded_request_ids": list(self._responded_request_ids),
}

@override
Expand Down Expand Up @@ -497,6 +501,14 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self._execution_contexts = execution_contexts
self._request_to_execution = request_to_execution

# Restore responded request IDs (defaults to empty for checkpoints created before this field existed)
responded_request_ids = state.get("responded_request_ids", [])
if not isinstance(responded_request_ids, list):
raise ValueError("'responded_request_ids' must be a list.")
if not all(isinstance(rid, str) for rid in responded_request_ids):
raise ValueError("All items in 'responded_request_ids' must be strings.")
self._responded_request_ids = set(responded_request_ids)

# Add the `request_info_event`s back to the sub workflow.
# This is only a temporary solution to rehydrate the sub workflow with the requests.
# The proper way would be to rehydrate the workflow from a checkpoint on a Workflow
Expand Down Expand Up @@ -548,11 +560,30 @@ async def _process_workflow_result(
await asyncio.gather(*[ctx.send_message(output) for output in outputs])

# Process request info events
# Filter out requests that have already been processed to avoid re-sending
# after checkpoint restore. This can happen because apply_checkpoint() re-adds
# pending RequestInfoEvents to the event queue, which then appear in subsequent
# workflow results even though they were already handled.
new_request_count = 0
for event in request_info_events:
# Skip if this request has already been responded to (e.g., after checkpoint restore)
if event.request_id in self._responded_request_ids:
logger.debug(f"WorkflowExecutor {self.id} skipping already-responded request {event.request_id}")
continue
# Skip if this request is already being tracked (already processed)
if event.request_id in execution_context.pending_requests:
logger.debug(f"WorkflowExecutor {self.id} skipping already-tracked request {event.request_id}")
continue
if event.request_id in self._request_to_execution:
logger.debug(f"WorkflowExecutor {self.id} skipping already-mapped request {event.request_id}")
continue

# Track the pending request in execution context
execution_context.pending_requests[event.request_id] = event
# Map request to execution for response routing
self._request_to_execution[event.request_id] = execution_context.execution_id
new_request_count += 1

if self._propagate_request:
# In a workflow where the parent workflow does not handle the request, the request
# should be propagated via the `request_info` mechanism to an external source. And
Expand All @@ -563,8 +594,8 @@ async def _process_workflow_result(
# request and handle it directly, a message should be sent.
await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id))

# Update expected response count for this execution
execution_context.expected_response_count = len(request_info_events)
# Update expected response count for this execution (only count NEW requests)
execution_context.expected_response_count = new_request_count

# Handle final state
if workflow_run_state == WorkflowRunState.FAILED:
Expand Down Expand Up @@ -634,6 +665,10 @@ async def _handle_response(
execution_context.pending_requests.pop(request_id, None)
self._request_to_execution.pop(request_id, None)

# Track that this request has been responded to, so we can filter it out
# if it reappears in the event stream after checkpoint restore
self._responded_request_ids.add(request_id)

# Accumulate the response in this execution's context
execution_context.collected_responses[request_id] = response
# Check if we have all expected responses for this execution
Expand Down
253 changes: 253 additions & 0 deletions python/packages/core/tests/workflow/test_sub_workflow.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
# Copyright (c) Microsoft. All rights reserved.

from dataclasses import dataclass, field
from typing import Any
from uuid import uuid4

from typing_extensions import Never

from agent_framework import (
Executor,
InMemoryCheckpointStorage,
RequestInfoEvent,
SubWorkflowRequestMessage,
SubWorkflowResponseMessage,
Workflow,
WorkflowBuilder,
WorkflowContext,
WorkflowExecutor,
WorkflowRunState,
WorkflowStatusEvent,
handler,
response_handler,
)
Expand Down Expand Up @@ -461,3 +466,251 @@ async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) -

# Verify that concurrent executions were properly isolated
# (This is implicitly tested by the fact that we got correct results for all emails)


# =============================================================================
# Checkpoint restore tests
# =============================================================================


@dataclass
class CheckpointTestStartMessage:
"""Initial message to start the checkpoint test workflow."""

value: str


@dataclass
class CheckpointTestFirstRequest:
"""First request_info() call in checkpoint tests."""

id: str = field(default_factory=lambda: str(uuid4()))
prompt: str = "First request"


@dataclass
class CheckpointTestSecondRequest:
"""Second request_info() call in checkpoint tests."""

id: str = field(default_factory=lambda: str(uuid4()))
prompt: str = "Second request"


@dataclass
class CheckpointTestFinalOutput:
"""Final output from the checkpoint test workflow."""

first_response: str
second_response: str


class TwoStepExecutor(Executor):
"""Executor that makes TWO sequential request_info() calls for checkpoint testing."""

def __init__(self) -> None:
super().__init__(id="two_step_executor")
self._first_response: str = ""

@handler
async def handle_start(self, msg: CheckpointTestStartMessage, ctx: WorkflowContext) -> None:
await ctx.request_info(
request_data=CheckpointTestFirstRequest(prompt=f"First request for: {msg.value}"),
response_type=str,
)

@response_handler
async def handle_first_response(
self,
original_request: CheckpointTestFirstRequest,
response: str,
ctx: WorkflowContext,
) -> None:
self._first_response = response
await ctx.request_info(
request_data=CheckpointTestSecondRequest(prompt="Second request"),
response_type=str,
)

@response_handler
async def handle_second_response(
self,
original_request: CheckpointTestSecondRequest,
response: str,
ctx: WorkflowContext[None, CheckpointTestFinalOutput],
) -> None:
await ctx.yield_output(
CheckpointTestFinalOutput(
first_response=self._first_response,
second_response=response,
)
)

async def on_checkpoint_save(self) -> dict[str, Any]:
return {"first_response": self._first_response}

async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self._first_response = state.get("first_response", "")


class CheckpointTestCoordinator(Executor):
"""Parent workflow coordinator for checkpoint tests."""

def __init__(self) -> None:
super().__init__(id="checkpoint_test_coordinator")
self.received_requests: list[SubWorkflowRequestMessage] = []
self.output: CheckpointTestFinalOutput | None = None
self._pending_requests: dict[str, SubWorkflowRequestMessage] = {}

@handler
async def start(self, value: str, ctx: WorkflowContext[CheckpointTestStartMessage]) -> None:
await ctx.send_message(CheckpointTestStartMessage(value=value))

@handler
async def handle_sub_workflow_request(
self,
request: SubWorkflowRequestMessage,
ctx: WorkflowContext,
) -> None:
self.received_requests.append(request)
data = request.source_event.data
if isinstance(data, (CheckpointTestFirstRequest, CheckpointTestSecondRequest)):
self._pending_requests[data.id] = request
await ctx.request_info(data, str)

@response_handler
async def handle_response(
self,
original_request: CheckpointTestFirstRequest | CheckpointTestSecondRequest,
response: str,
ctx: WorkflowContext[SubWorkflowResponseMessage],
) -> None:
sub_request = self._pending_requests.pop(original_request.id, None)
if sub_request is None:
raise ValueError(f"No pending request for ID: {original_request.id}")
await ctx.send_message(sub_request.create_response(response))

@handler
async def collect_output(self, output: CheckpointTestFinalOutput, ctx: WorkflowContext) -> None:
self.output = output

async def on_checkpoint_save(self) -> dict[str, Any]:
return {"pending_requests": self._pending_requests}

async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self._pending_requests = state.get("pending_requests", {})


def _build_checkpoint_test_sub_workflow() -> WorkflowExecutor:
sub_workflow = (
WorkflowBuilder().register_executor(TwoStepExecutor, name="two_step").set_start_executor("two_step").build()
)
return WorkflowExecutor(sub_workflow, id="sub_workflow")


def _build_checkpoint_test_workflow(storage: InMemoryCheckpointStorage) -> tuple[Workflow, CheckpointTestCoordinator]:
coordinator = CheckpointTestCoordinator()
workflow = (
WorkflowBuilder()
.register_executor(lambda: coordinator, name="coordinator")
.register_executor(_build_checkpoint_test_sub_workflow, name="sub_executor")
.set_start_executor("coordinator")
.add_edge("coordinator", "sub_executor")
.add_edge("sub_executor", "coordinator")
.with_checkpointing(storage)
.build()
)
return workflow, coordinator


async def test_checkpoint_restore_does_not_duplicate_requests():
"""
Regression test: WorkflowExecutor should not re-send already-answered
RequestInfoEvents after checkpoint restore.

This test verifies the fix for a bug where after checkpoint restore,
WorkflowExecutor._process_workflow_result() would process ALL RequestInfoEvents
from the workflow result, including ones that were already answered.
"""
storage = InMemoryCheckpointStorage()

# Step 1: Run until first request
workflow, coordinator = _build_checkpoint_test_workflow(storage)

first_request_id: str | None = None
async for event in workflow.run_stream("test_value"):
if isinstance(event, RequestInfoEvent):
first_request_id = event.request_id
if isinstance(event, WorkflowStatusEvent) and event.state is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS:
break

assert first_request_id is not None
assert len(coordinator.received_requests) == 1
assert isinstance(coordinator.received_requests[0].source_event.data, CheckpointTestFirstRequest)

# Get checkpoint
checkpoints = await storage.list_checkpoints(workflow.id)
checkpoints.sort(key=lambda cp: cp.timestamp)
checkpoint_id = checkpoints[-1].checkpoint_id

# Step 2: Resume from checkpoint
workflow2, coordinator2 = _build_checkpoint_test_workflow(storage)

resumed_first_request_id: str | None = None
async for event in workflow2.run_stream(checkpoint_id=checkpoint_id):
if isinstance(event, RequestInfoEvent):
resumed_first_request_id = event.request_id

assert resumed_first_request_id is not None

# Respond to first request - this triggers second request
second_request_id: str | None = None
async for event in workflow2.send_responses_streaming({resumed_first_request_id: "first_answer"}):
if isinstance(event, RequestInfoEvent):
second_request_id = event.request_id
if isinstance(event, WorkflowStatusEvent) and event.state is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS:
break

assert second_request_id is not None

# After responding to first request, coordinator2 should only receive the new SecondRequest
assert len(coordinator2.received_requests) == 1
assert isinstance(coordinator2.received_requests[0].source_event.data, CheckpointTestSecondRequest)


async def test_checkpoint_restore_full_flow():
"""Test that a sub-workflow with checkpointing can complete successfully."""
storage = InMemoryCheckpointStorage()

# Step 1: Run until first request
workflow, _ = _build_checkpoint_test_workflow(storage)

result = await workflow.run("test_value")
request_events = result.get_request_info_events()
assert len(request_events) == 1

# Get checkpoint
checkpoints = await storage.list_checkpoints(workflow.id)
checkpoints.sort(key=lambda cp: cp.timestamp)
checkpoint_id = checkpoints[-1].checkpoint_id

# Step 2: Resume and respond to first request
workflow2, coordinator2 = _build_checkpoint_test_workflow(storage)

result2 = await workflow2.run(checkpoint_id=checkpoint_id)
request_events2 = result2.get_request_info_events()
assert len(request_events2) == 1
resumed_request_id = request_events2[0].request_id

# Respond to first request - triggers second request
result3 = await workflow2.send_responses({resumed_request_id: "first_answer"})
request_events3 = result3.get_request_info_events()
assert len(request_events3) == 1
second_request_id = request_events3[0].request_id

# Step 3: Respond to second request to complete workflow
await workflow2.send_responses({second_request_id: "second_answer"})

# Verify final output
assert coordinator2.output is not None
assert coordinator2.output.first_response == "first_answer"
assert coordinator2.output.second_response == "second_answer"
Loading