diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 227f0f7fe7..cdd3cd690c 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import contextlib import logging from collections import defaultdict from collections.abc import AsyncGenerator, Sequence @@ -106,14 +107,21 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: # Run iteration concurrently with live event streaming: we poll # for new events while the iteration coroutine progresses. iteration_task = asyncio.create_task(self._run_iteration()) - while not iteration_task.done(): - try: - # Wait briefly for any new event; timeout allows progress checks - event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) - yield event - except asyncio.TimeoutError: - # Periodically continue to let iteration advance - continue + try: + while not iteration_task.done(): + try: + # Wait briefly for any new event; timeout allows progress checks + event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) + yield event + except asyncio.TimeoutError: + # Periodically continue to let iteration advance + continue + except asyncio.CancelledError: + # Propagate cancellation to the iteration task to avoid orphaned work + iteration_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await iteration_task + raise # Propagate errors from iteration, but first surface any pending events try: diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index f6a031e5a3..fc21ba049d 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -191,3 +191,73 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets # The runner should complete without errors when handling AgentExecutorResponse without targets # No specific events are expected since there are no executors to process the message assert isinstance(events, list) # Just verify the runner completed without errors + + +class SlowExecutor(Executor): + """An executor that takes time to process, used for cancellation testing.""" + + def __init__(self, id: str, work_duration: float = 0.5): + super().__init__(id=id) + self.started_count = 0 + self.completed_count = 0 + self.work_duration = work_duration + + @handler + async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None: + self.started_count += 1 + await asyncio.sleep(self.work_duration) + self.completed_count += 1 + if message.data < 2: + await ctx.send_message(MockMessage(data=message.data + 1)) + else: + await ctx.yield_output(message.data) + + +async def test_runner_cancellation_stops_active_executor(): + """Test that cancelling a workflow properly cancels the active executor.""" + executor_a = SlowExecutor(id="executor_a", work_duration=0.3) + executor_b = SlowExecutor(id="executor_b", work_duration=1.0) + + edges = [ + SingleEdgeGroup(executor_a.id, executor_b.id), + SingleEdgeGroup(executor_b.id, executor_a.id), + ] + + executors: dict[str, Executor] = { + executor_a.id: executor_a, + executor_b.id: executor_b, + } + shared_state = SharedState() + ctx = InProcRunnerContext() + + runner = Runner(edges, executors, shared_state, ctx) + + await executor_a.execute( + MockMessage(data=0), + ["START"], + shared_state, + ctx, + ) + + async def run_workflow(): + async for _ in runner.run_until_convergence(): + pass + + task = asyncio.create_task(run_workflow()) + + # Wait for executor_a to complete (0.3s) and executor_b to start but not finish + await asyncio.sleep(0.5) + + # Cancel while executor_b is mid-execution (it takes 1.0s) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + # Give time for any leaked tasks to complete (if cancellation didn't work) + await asyncio.sleep(1.5) + + # executor_a should have completed once, executor_b should have started but not completed + assert executor_a.completed_count == 1 + assert executor_b.started_count == 1 + assert executor_b.completed_count == 0 # Should NOT have completed due to cancellation