From 5d5fbf2043272a4785fa528883a5dffcf70d4bc4 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 20 Feb 2026 23:06:33 -0800 Subject: [PATCH] Fix workflow runner concurrent processing --- .../agent_framework/_workflows/_runner.py | 21 +++-- .../core/tests/workflow/test_runner.py | 76 +++++++++++++++++++ 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index ce1661cc7f..295a34e246 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -158,7 +158,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: self._running = False async def _run_iteration(self) -> None: - async def _deliver_messages(source_executor_id: str, messages: list[WorkflowMessage]) -> None: + async def _deliver_messages(source_executor_id: str, source_messages: list[WorkflowMessage]) -> None: """Outer loop to concurrently deliver messages from all sources to their targets.""" async def _deliver_message_inner(edge_runner: EdgeRunner, message: WorkflowMessage) -> bool: @@ -172,13 +172,20 @@ async def _deliver_message_inner(edge_runner: EdgeRunner, message: WorkflowMessa logger.debug(f"No outgoing edges found for executor {source_executor_id}; dropping messages.") return - for message in messages: - # Deliver a message through all edge runners associated with the source executor concurrently. - tasks = [_deliver_message_inner(edge_runner, message) for edge_runner in associated_edge_runners] - await asyncio.gather(*tasks) + async def _deliver_messages_for_edge_runner(edge_runner: EdgeRunner) -> None: + # Preserve message order per edge runner (and therefore per routed target path) + # while still allowing parallelism across different edge runners. + for message in source_messages: + await _deliver_message_inner(edge_runner, message) - messages = await self._ctx.drain_messages() - tasks = [_deliver_messages(source_executor_id, messages) for source_executor_id, messages in messages.items()] + tasks = [_deliver_messages_for_edge_runner(edge_runner) for edge_runner in associated_edge_runners] + await asyncio.gather(*tasks) + + message_batches = await self._ctx.drain_messages() + tasks = [ + _deliver_messages(source_executor_id, source_messages) + for source_executor_id, source_messages in message_batches.items() + ] await asyncio.gather(*tasks) async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: CheckpointID | None) -> CheckpointID | None: diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 039c61b07d..e1041d63d9 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -150,6 +150,82 @@ async def test_runner_run_until_convergence_not_completed(): assert event.type != "status" or event.state != WorkflowRunState.IDLE +async def test_runner_run_iteration_preserves_message_order_per_edge_runner() -> None: + """Test that _run_iteration preserves message order to the same target path.""" + + class RecordingEdgeRunner: + def __init__(self) -> None: + self.received: list[int] = [] + + async def send_message(self, message: WorkflowMessage, state: State, ctx: RunnerContext) -> bool: + message_data = message.data + assert isinstance(message_data, MockMessage) + self.received.append(message_data.data) + await asyncio.sleep(0.005) + return True + + ctx = InProcRunnerContext() + state = State() + runner = Runner([], {}, state, ctx, "test_name", graph_signature_hash="test_hash") + + edge_runner = RecordingEdgeRunner() + runner._edge_runner_map = {"source": [edge_runner]} # type: ignore[assignment] + + for index in range(5): + await ctx.send_message(WorkflowMessage(data=MockMessage(data=index), source_id="source")) + + await runner._run_iteration() + + assert edge_runner.received == [0, 1, 2, 3, 4] + + +async def test_runner_run_iteration_delivers_different_edge_runners_concurrently() -> None: + """Test that different edge runners for the same source are executed concurrently.""" + + class BlockingEdgeRunner: + def __init__(self) -> None: + self.started = asyncio.Event() + self.release = asyncio.Event() + self.call_count = 0 + + async def send_message(self, message: WorkflowMessage, state: State, ctx: RunnerContext) -> bool: + self.call_count += 1 + self.started.set() + await self.release.wait() + return True + + class ProbeEdgeRunner: + def __init__(self) -> None: + self.probe_completed = asyncio.Event() + self.call_count = 0 + + async def send_message(self, message: WorkflowMessage, state: State, ctx: RunnerContext) -> bool: + self.call_count += 1 + self.probe_completed.set() + return True + + ctx = InProcRunnerContext() + state = State() + runner = Runner([], {}, state, ctx, "test_name", graph_signature_hash="test_hash") + + blocking_edge_runner = BlockingEdgeRunner() + probe_edge_runner = ProbeEdgeRunner() + runner._edge_runner_map = {"source": [blocking_edge_runner, probe_edge_runner]} # type: ignore[assignment] + + await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="source")) + + iteration_task = asyncio.create_task(runner._run_iteration()) + + await blocking_edge_runner.started.wait() + await asyncio.wait_for(probe_edge_runner.probe_completed.wait(), timeout=0.2) + + blocking_edge_runner.release.set() + await iteration_task + + assert blocking_edge_runner.call_count == 1 + assert probe_edge_runner.call_count == 1 + + async def test_runner_already_running(): """Test that running the runner while it is already running raises an error.""" executor_a = MockExecutor(id="executor_a")