Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,24 @@ async def _handle_response(
try:
# Resume the sub-workflow with all collected responses
result = await self.workflow.send_responses(responses_to_send)
# Remove handled requests from result. The result may contain the original
# RequestInfoEvents that were already handled. This is due to checkpointing
# and rehydration of the workflow that re-adds the RequestInfoEvents to the
# workflow's _runner_context thus the event queue. When the workflow is resumed,
# those events will be emitted at the very beginning of the superstep, prior to
# processing messages/responses, creating the illusion that the workflow is
# requesting the same information again.
for request_id in responses_to_send:
event_to_remove = next(
(
event
for event in result
if isinstance(event, RequestInfoEvent) and event.request_id == request_id
),
None,
)
if event_to_remove:
result.remove(event_to_remove)

# Process the workflow result using shared logic
await self._process_workflow_result(result, execution_context, ctx)
Expand Down
159 changes: 159 additions & 0 deletions python/packages/core/tests/workflow/test_sub_workflow.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.

from dataclasses import dataclass, field
from typing import Any
from uuid import uuid4

from typing_extensions import Never

from agent_framework import (
Executor,
RequestInfoEvent,
SubWorkflowRequestMessage,
SubWorkflowResponseMessage,
Workflow,
Expand All @@ -16,6 +18,7 @@
handler,
response_handler,
)
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage


# Test message types
Expand Down Expand Up @@ -461,3 +464,159 @@ async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) -

# Verify that concurrent executions were properly isolated
# (This is implicitly tested by the fact that we got correct results for all emails)


# region Checkpoint-related message types and executors for sub-workflow tests


@dataclass
class CheckpointRequest:
"""Request in a two-step checkpoint test."""

prompt: str
id: str = field(default_factory=lambda: str(uuid4()))


class TwoStepSubWorkflowExecutor(Executor):
"""Sub-workflow executor that makes two sequential requests."""

def __init__(self) -> None:
super().__init__(id="two_step_executor")
self._responses: list[str] = []

@handler
async def handle_start(self, msg: str, ctx: WorkflowContext) -> None:
await ctx.request_info(
request_data=CheckpointRequest(prompt=f"First request for: {msg}"),
response_type=str,
)

@response_handler
async def handle_response(
self,
original_request: CheckpointRequest,
response: str,
ctx: WorkflowContext[Never, bool],
) -> None:
self._responses.append(response)
if len(self._responses) == 1:
# First response received, make second request
await ctx.request_info(
request_data=CheckpointRequest(prompt="Second request"),
response_type=str,
)
else:
# Second response received, yield final output
await ctx.yield_output(True)

async def on_checkpoint_save(self) -> dict[str, Any]:
return {"responses": self._responses}

async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self._responses = state.get("responses", [])


class CheckpointTestCoordinator(Executor):
"""Coordinator for checkpoint sub-workflow tests."""

def __init__(self) -> None:
super().__init__(id="checkpoint_coordinator")
self._pending_requests: dict[str, SubWorkflowRequestMessage] = {}

@handler
async def start(self, value: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(value)

@handler
async def handle_sub_workflow_request(
self,
request: SubWorkflowRequestMessage,
ctx: WorkflowContext,
) -> None:
data = request.source_event.data
if isinstance(data, CheckpointRequest):
self._pending_requests[data.id] = request
await ctx.request_info(data, str)

@response_handler
async def handle_response(
self,
original_request: CheckpointRequest,
response: str,
ctx: WorkflowContext[SubWorkflowResponseMessage],
) -> None:
sub_request = self._pending_requests.pop(original_request.id, None)
if sub_request is None:
raise ValueError(f"No pending request for ID: {original_request.id}")
await ctx.send_message(sub_request.create_response(response))

async def on_checkpoint_save(self) -> dict[str, Any]:
return {"pending_requests": self._pending_requests}

async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self._pending_requests = state.get("pending_requests", {})


def _build_checkpoint_test_workflow(storage: InMemoryCheckpointStorage) -> Workflow:
"""Build the main workflow with checkpointing for testing."""
two_step_executor = TwoStepSubWorkflowExecutor()
sub_workflow = WorkflowBuilder().set_start_executor(two_step_executor).build()
sub_workflow_executor = WorkflowExecutor(sub_workflow, id="sub_workflow_executor")

coordinator = CheckpointTestCoordinator()
return (
WorkflowBuilder()
.set_start_executor(coordinator)
.add_edge(coordinator, sub_workflow_executor)
.add_edge(sub_workflow_executor, coordinator)
.with_checkpointing(storage)
.build()
)


async def test_sub_workflow_checkpoint_restore_no_duplicate_requests() -> None:
"""Test that resuming a sub-workflow from checkpoint does not emit duplicate requests.

This test verifies the fix for an issue where after checkpoint restore, when a response
is sent to a sub-workflow, duplicate RequestInfoEvents were emitted. The bug occurred
because checkpoint rehydration re-added RequestInfoEvents to the event queue, and when
the workflow was resumed, those events were emitted again along with any new requests.

The fix ensures that already-handled requests are filtered out from the result when
the sub-workflow is resumed with responses.
"""
storage = InMemoryCheckpointStorage()

# Step 1: Run workflow until first request
workflow1 = _build_checkpoint_test_workflow(storage)

first_request_id: str | None = None
async for event in workflow1.run_stream("test_value"):
if isinstance(event, RequestInfoEvent):
first_request_id = event.request_id

assert first_request_id is not None

# Get checkpoint
checkpoints = await storage.list_checkpoints(workflow1.id)
checkpoint_id = max(checkpoints, key=lambda cp: cp.timestamp).checkpoint_id

# Step 2: Resume workflow from checkpoint
workflow2 = _build_checkpoint_test_workflow(storage)

resumed_first_request_id: str | None = None
async for event in workflow2.run_stream(checkpoint_id=checkpoint_id):
if isinstance(event, RequestInfoEvent):
resumed_first_request_id = event.request_id

assert resumed_first_request_id is not None
assert resumed_first_request_id == first_request_id

request_events: list[RequestInfoEvent] = []
async for event in workflow2.send_responses_streaming({resumed_first_request_id: "first_answer"}):
if isinstance(event, RequestInfoEvent):
request_events.append(event)

# Key assertion: Only the second request should be received, not a duplicate of the first
assert len(request_events) == 1
assert request_events[0].data.prompt == "Second request"
Loading