Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/claude_agent_sdk/_internal/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Internal client implementation."""

import logging
from collections.abc import AsyncIterable, AsyncIterator
from contextlib import aclosing
from dataclasses import replace
from typing import Any

Expand All @@ -15,6 +17,8 @@
from .transport import Transport
from .transport.subprocess_cli import SubprocessCLITransport

logger = logging.getLogger(__name__)


class InternalClient:
"""Internal client implementation."""
Expand Down Expand Up @@ -117,8 +121,14 @@ async def process_query(
# For string prompts, the prompt is already passed via CLI args

# Yield parsed messages
async for data in query.receive_messages():
yield parse_message(data)

# Use aclosing() for proper async generator cleanup
async with aclosing(query.receive_messages()) as messages:
async for data in messages:
yield parse_message(data)

except GeneratorExit:
# Handle early termination of the async generator gracefully
# This occurs when the caller breaks out of the async for loop
logger.debug("process_query generator closed early by caller")
finally:
await query.close()
97 changes: 90 additions & 7 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any

import anyio
from anyio.abc import CancelScope
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
Expand Down Expand Up @@ -113,6 +114,15 @@ def __init__(
float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0
) # Convert ms to seconds

# Cancel scope for the reader task - can be cancelled from any task context
# This fixes the RuntimeError when async generator cleanup happens in a different task
self._reader_cancel_scope: CancelScope | None = None
self._reader_task_started = anyio.Event()

# Track whether we entered the task group in this task
# Used to determine if we can safely call __aexit__()
self._tg_entered_in_current_task = False

async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.

Expand Down Expand Up @@ -158,11 +168,33 @@ async def initialize(self) -> dict[str, Any] | None:
return response

async def start(self) -> None:
"""Start reading messages from transport."""
"""Start reading messages from transport.

This method starts background tasks for reading messages. The task lifecycle
is managed using a CancelScope that can be safely cancelled from any async
task context, avoiding the RuntimeError that occurs when task group
__aexit__() is called from a different task than __aenter__().
"""
if self._tg is None:
# Create a task group for spawning background tasks
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
self._tg_entered_in_current_task = True

# Start the reader with its own cancel scope that can be cancelled safely
self._tg.start_soon(self._read_messages_with_cancel_scope)

async def _read_messages_with_cancel_scope(self) -> None:
"""Wrapper for _read_messages that sets up a cancellable scope.

This wrapper creates a CancelScope that can be cancelled from any task
context, solving the issue where async generator cleanup happens in a
different task than where the task group was entered.
"""
self._reader_cancel_scope = anyio.CancelScope()
self._reader_task_started.set()
with self._reader_cancel_scope:
await self._read_messages()

async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
Expand Down Expand Up @@ -604,15 +636,66 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
yield message

async def close(self) -> None:
"""Close the query and transport."""
"""Close the query and transport.

This method safely cleans up resources, handling the case where cleanup
happens in a different async task context than where start() was called.
This commonly occurs during async generator cleanup (e.g., when breaking
out of an `async for` loop or when asyncio.run() shuts down).

The fix uses two mechanisms:
1. A CancelScope for the reader task that can be cancelled from any context
2. Suppressing the RuntimeError that occurs when task group __aexit__()
is called from a different task than __aenter__()
"""
if self._closed:
return
self._closed = True
if self._tg:

# Cancel the reader task via its cancel scope (safe from any task context)
if self._reader_cancel_scope is not None:
self._reader_cancel_scope.cancel()

# Handle task group cleanup
if self._tg is not None:
# Always cancel the task group's scope to stop any running tasks
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)

# Try to properly exit the task group, but handle the case where
# we're in a different task context than where __aenter__() was called
try:
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
except RuntimeError as e:
# Handle "Attempted to exit cancel scope in a different task"
# This happens during async generator cleanup when Python's GC
# runs the finally block in a different task context.
if "different task" in str(e):
logger.debug(
"Task group cleanup skipped due to cross-task context "
"(this is expected during async generator cleanup)"
)
else:
raise
finally:
self._tg = None
self._tg_entered_in_current_task = False

await self.transport.close()

# Make Query an async context manager
async def __aenter__(self) -> "Query":
"""Enter async context - starts reading messages."""
await self.start()
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any
) -> bool:
"""Exit async context - closes the query."""
await self.close()
return False

# Make Query an async iterator
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
"""Return async iterator for messages."""
Expand Down
189 changes: 189 additions & 0 deletions tests/test_streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,192 @@ async def mock_receive():
assert isinstance(messages[-1], ResultMessage)

anyio.run(_test)


class TestAsyncGeneratorCleanup:
"""Tests for async generator cleanup behavior (issue #454).

These tests verify that the RuntimeError "Attempted to exit cancel scope
in a different task" does not occur during async generator cleanup.

The key behavior we're testing is that cleanup doesn't raise RuntimeError,
not that specific mock methods are called (which depends on mock setup).
"""

def test_streaming_client_early_disconnect(self):
"""Test ClaudeSDKClient early disconnect doesn't raise RuntimeError.

This is the primary test case from issue #454 - breaking out of an
async for loop should not cause RuntimeError during cleanup.
"""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport

async def mock_receive():
# Send init response
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
if call:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass

# Yield some messages
for i in range(5):
yield {
"type": "assistant",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": f"Message {i}"}],
"model": "claude-opus-4-1-20250805",
},
}

mock_transport.read_messages = mock_receive

# Connect, get one message, then disconnect early
client = ClaudeSDKClient()
await client.connect()

count = 0
async for msg in client.receive_messages():
count += 1
if count >= 2:
break # Early exit - this should NOT raise RuntimeError

# Early disconnect should not raise RuntimeError
# The key assertion is that we reach this point without exception
await client.disconnect()

assert count == 2
# Transport close is called by disconnect
mock_transport.close.assert_called()

anyio.run(_test)

def test_query_cancel_scope_can_be_cancelled(self):
"""Test that Query's cancel scope can be safely cancelled from any context.

This verifies the fix for issue #454 where the cancel scope mechanism
allows cleanup without RuntimeError.
"""

async def _test():
from claude_agent_sdk._internal.query import Query
from claude_agent_sdk._internal.transport import Transport

# Create a mock transport
mock_transport = AsyncMock(spec=Transport)
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.write = AsyncMock()

messages_to_yield = [
{"type": "system", "subtype": "init"},
{
"type": "assistant",
"message": {
"content": [{"type": "text", "text": "Hello"}],
"model": "test",
},
},
]
message_index = 0

async def mock_read():
nonlocal message_index
while message_index < len(messages_to_yield):
yield messages_to_yield[message_index]
message_index += 1
await asyncio.sleep(0.01)

mock_transport.read_messages = mock_read

# Create Query
q = Query(
transport=mock_transport,
is_streaming_mode=False,
)

# Start the query
await q.start()

# Give reader time to start
await asyncio.sleep(0.05)

# Cancel scope should exist
assert q._reader_cancel_scope is not None

# Close should work without RuntimeError
# This is the key test - close() used to raise RuntimeError
await q.close()

# Verify closed state
assert q._closed is True
mock_transport.close.assert_called()

anyio.run(_test)

def test_query_as_async_context_manager(self):
"""Test using Query as an async context manager for proper cleanup."""

async def _test():
from claude_agent_sdk._internal.query import Query
from claude_agent_sdk._internal.transport import Transport

mock_transport = AsyncMock(spec=Transport)
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.write = AsyncMock()

async def mock_read():
yield {"type": "system", "subtype": "init"}
yield {
"type": "assistant",
"message": {
"content": [{"type": "text", "text": "Hello"}],
"model": "test",
},
}

mock_transport.read_messages = mock_read

# Use Query as async context manager
q = Query(
transport=mock_transport,
is_streaming_mode=False,
)

async with q:
# Query should be started
assert q._tg is not None
# Get one message
msg = await q.__anext__()
assert msg["type"] == "system"

# After context exit, should be closed
assert q._closed is True

anyio.run(_test)