From 0c1b355f5e2e6da1cd986be69ced31863f2d054f Mon Sep 17 00:00:00 2001 From: wingding12 Date: Wed, 28 Jan 2026 12:56:30 -0500 Subject: [PATCH] fix: Resolve RuntimeError on async generator cleanup Fixes issue #454 where RuntimeError was raised during async generator cleanup. Problem: Task group __aenter__() was called in one task but __aexit__() was called in a different task during cleanup, which AnyIO doesn't allow. Solution: - Added CancelScope for reader task that can be cancelled from any context - Graceful __aexit__() handling that catches cross-task RuntimeError - Added aclosing() wrapper in process_query() for proper cleanup - Made Query an async context manager for cleaner usage Changes: - query.py: Added cancel scope mechanism and async context manager support - client.py: Added aclosing() and GeneratorExit handling - test_streaming_client.py: Added 3 tests for async cleanup scenarios All 134 tests pass. --- src/claude_agent_sdk/_internal/client.py | 16 +- src/claude_agent_sdk/_internal/query.py | 97 +++++++++++- tests/test_streaming_client.py | 189 +++++++++++++++++++++++ 3 files changed, 292 insertions(+), 10 deletions(-) diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 52466272..0b6c6ab7 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -1,6 +1,8 @@ """Internal client implementation.""" +import logging from collections.abc import AsyncIterable, AsyncIterator +from contextlib import aclosing from dataclasses import replace from typing import Any @@ -15,6 +17,8 @@ from .transport import Transport from .transport.subprocess_cli import SubprocessCLITransport +logger = logging.getLogger(__name__) + class InternalClient: """Internal client implementation.""" @@ -117,8 +121,14 @@ async def process_query( # For string prompts, the prompt is already passed via CLI args # Yield parsed messages - async for data in query.receive_messages(): - yield parse_message(data) - + # Use aclosing() for proper async generator cleanup + async with aclosing(query.receive_messages()) as messages: + async for data in messages: + yield parse_message(data) + + except GeneratorExit: + # Handle early termination of the async generator gracefully + # This occurs when the caller breaks out of the async for loop + logger.debug("process_query generator closed early by caller") finally: await query.close() diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 6bf5a73c..04731d6a 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any import anyio +from anyio.abc import CancelScope from mcp.types import ( CallToolRequest, CallToolRequestParams, @@ -113,6 +114,15 @@ def __init__( float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0 ) # Convert ms to seconds + # Cancel scope for the reader task - can be cancelled from any task context + # This fixes the RuntimeError when async generator cleanup happens in a different task + self._reader_cancel_scope: CancelScope | None = None + self._reader_task_started = anyio.Event() + + # Track whether we entered the task group in this task + # Used to determine if we can safely call __aexit__() + self._tg_entered_in_current_task = False + async def initialize(self) -> dict[str, Any] | None: """Initialize control protocol if in streaming mode. @@ -158,11 +168,33 @@ async def initialize(self) -> dict[str, Any] | None: return response async def start(self) -> None: - """Start reading messages from transport.""" + """Start reading messages from transport. + + This method starts background tasks for reading messages. The task lifecycle + is managed using a CancelScope that can be safely cancelled from any async + task context, avoiding the RuntimeError that occurs when task group + __aexit__() is called from a different task than __aenter__(). + """ if self._tg is None: + # Create a task group for spawning background tasks self._tg = anyio.create_task_group() await self._tg.__aenter__() - self._tg.start_soon(self._read_messages) + self._tg_entered_in_current_task = True + + # Start the reader with its own cancel scope that can be cancelled safely + self._tg.start_soon(self._read_messages_with_cancel_scope) + + async def _read_messages_with_cancel_scope(self) -> None: + """Wrapper for _read_messages that sets up a cancellable scope. + + This wrapper creates a CancelScope that can be cancelled from any task + context, solving the issue where async generator cleanup happens in a + different task than where the task group was entered. + """ + self._reader_cancel_scope = anyio.CancelScope() + self._reader_task_started.set() + with self._reader_cancel_scope: + await self._read_messages() async def _read_messages(self) -> None: """Read messages from transport and route them.""" @@ -604,15 +636,66 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: yield message async def close(self) -> None: - """Close the query and transport.""" + """Close the query and transport. + + This method safely cleans up resources, handling the case where cleanup + happens in a different async task context than where start() was called. + This commonly occurs during async generator cleanup (e.g., when breaking + out of an `async for` loop or when asyncio.run() shuts down). + + The fix uses two mechanisms: + 1. A CancelScope for the reader task that can be cancelled from any context + 2. Suppressing the RuntimeError that occurs when task group __aexit__() + is called from a different task than __aenter__() + """ + if self._closed: + return self._closed = True - if self._tg: + + # Cancel the reader task via its cancel scope (safe from any task context) + if self._reader_cancel_scope is not None: + self._reader_cancel_scope.cancel() + + # Handle task group cleanup + if self._tg is not None: + # Always cancel the task group's scope to stop any running tasks self._tg.cancel_scope.cancel() - # Wait for task group to complete cancellation - with suppress(anyio.get_cancelled_exc_class()): - await self._tg.__aexit__(None, None, None) + + # Try to properly exit the task group, but handle the case where + # we're in a different task context than where __aenter__() was called + try: + with suppress(anyio.get_cancelled_exc_class()): + await self._tg.__aexit__(None, None, None) + except RuntimeError as e: + # Handle "Attempted to exit cancel scope in a different task" + # This happens during async generator cleanup when Python's GC + # runs the finally block in a different task context. + if "different task" in str(e): + logger.debug( + "Task group cleanup skipped due to cross-task context " + "(this is expected during async generator cleanup)" + ) + else: + raise + finally: + self._tg = None + self._tg_entered_in_current_task = False + await self.transport.close() + # Make Query an async context manager + async def __aenter__(self) -> "Query": + """Enter async context - starts reading messages.""" + await self.start() + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + ) -> bool: + """Exit async context - closes the query.""" + await self.close() + return False + # Make Query an async iterator def __aiter__(self) -> AsyncIterator[dict[str, Any]]: """Return async iterator for messages.""" diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 29294419..653bfce7 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -833,3 +833,192 @@ async def mock_receive(): assert isinstance(messages[-1], ResultMessage) anyio.run(_test) + + +class TestAsyncGeneratorCleanup: + """Tests for async generator cleanup behavior (issue #454). + + These tests verify that the RuntimeError "Attempted to exit cancel scope + in a different task" does not occur during async generator cleanup. + + The key behavior we're testing is that cleanup doesn't raise RuntimeError, + not that specific mock methods are called (which depends on mock setup). + """ + + def test_streaming_client_early_disconnect(self): + """Test ClaudeSDKClient early disconnect doesn't raise RuntimeError. + + This is the primary test case from issue #454 - breaking out of an + async for loop should not cause RuntimeError during cleanup. + """ + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = create_mock_transport() + mock_transport_class.return_value = mock_transport + + async def mock_receive(): + # Send init response + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + if call: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Yield some messages + for i in range(5): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": f"Message {i}"}], + "model": "claude-opus-4-1-20250805", + }, + } + + mock_transport.read_messages = mock_receive + + # Connect, get one message, then disconnect early + client = ClaudeSDKClient() + await client.connect() + + count = 0 + async for msg in client.receive_messages(): + count += 1 + if count >= 2: + break # Early exit - this should NOT raise RuntimeError + + # Early disconnect should not raise RuntimeError + # The key assertion is that we reach this point without exception + await client.disconnect() + + assert count == 2 + # Transport close is called by disconnect + mock_transport.close.assert_called() + + anyio.run(_test) + + def test_query_cancel_scope_can_be_cancelled(self): + """Test that Query's cancel scope can be safely cancelled from any context. + + This verifies the fix for issue #454 where the cancel scope mechanism + allows cleanup without RuntimeError. + """ + + async def _test(): + from claude_agent_sdk._internal.query import Query + from claude_agent_sdk._internal.transport import Transport + + # Create a mock transport + mock_transport = AsyncMock(spec=Transport) + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.write = AsyncMock() + + messages_to_yield = [ + {"type": "system", "subtype": "init"}, + { + "type": "assistant", + "message": { + "content": [{"type": "text", "text": "Hello"}], + "model": "test", + }, + }, + ] + message_index = 0 + + async def mock_read(): + nonlocal message_index + while message_index < len(messages_to_yield): + yield messages_to_yield[message_index] + message_index += 1 + await asyncio.sleep(0.01) + + mock_transport.read_messages = mock_read + + # Create Query + q = Query( + transport=mock_transport, + is_streaming_mode=False, + ) + + # Start the query + await q.start() + + # Give reader time to start + await asyncio.sleep(0.05) + + # Cancel scope should exist + assert q._reader_cancel_scope is not None + + # Close should work without RuntimeError + # This is the key test - close() used to raise RuntimeError + await q.close() + + # Verify closed state + assert q._closed is True + mock_transport.close.assert_called() + + anyio.run(_test) + + def test_query_as_async_context_manager(self): + """Test using Query as an async context manager for proper cleanup.""" + + async def _test(): + from claude_agent_sdk._internal.query import Query + from claude_agent_sdk._internal.transport import Transport + + mock_transport = AsyncMock(spec=Transport) + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.write = AsyncMock() + + async def mock_read(): + yield {"type": "system", "subtype": "init"} + yield { + "type": "assistant", + "message": { + "content": [{"type": "text", "text": "Hello"}], + "model": "test", + }, + } + + mock_transport.read_messages = mock_read + + # Use Query as async context manager + q = Query( + transport=mock_transport, + is_streaming_mode=False, + ) + + async with q: + # Query should be started + assert q._tg is not None + # Get one message + msg = await q.__anext__() + assert msg["type"] == "system" + + # After context exit, should be closed + assert q._closed is True + + anyio.run(_test)