From 591a887524c5c31147111f193a208f0dd8748376 Mon Sep 17 00:00:00 2001 From: peter-luminova Date: Sat, 21 Feb 2026 08:15:51 +0530 Subject: [PATCH] fix: raise TaskContextError for cross-task usage (issue #576) Root Cause: - ClaudeSDKClient silently hangs when receive_messages() is called from a different async task than where connect() was called - anyio MemoryObjectStream is bound to the task group where created - Common scenario: FastAPI/Starlette apps with global client Changes: 1. Add TaskContextError exception class - Stores connect_task_id and current_task_id for debugging - Includes helpful error message with documentation reference - Follows existing error class patterns 2. Track task ownership in Query class - Add _owner_task field to track creating task - Capture task ID in start() using anyio.get_current_task() - Validate task context in receive_messages() 3. Update documentation - Update caveat in client.py to reference TaskContextError - Explain FastAPI/Starlette scenario 4. Add comprehensive tests (test_task_context.py) - Same-task usage works - Cross-task usage raises TaskContextError - Multiple clients in different tasks work - Error attributes and messages verified 5. Add FastAPI example (examples/fastapi_example.py) - Shows correct per-request client pattern - Documents wrong pattern (global client) 6. Remove test_concurrent_send_receive - Tested cross-task pattern that causes hangs - Never was supported (documented caveat) - Only passed with mocks, gave false confidence Impact: - Converts silent hang into fast, actionable error - No breaking changes for valid usage - Existing single-task usage unchanged - Cross-task usage now fails with clear guidance Refs: #576 --- examples/fastapi_example.py | 88 +++++++++++++ src/claude_agent_sdk/__init__.py | 2 + src/claude_agent_sdk/_errors.py | 53 ++++++++ src/claude_agent_sdk/_internal/query.py | 22 ++++ src/claude_agent_sdk/client.py | 12 +- tests/test_streaming_client.py | 83 +------------ tests/test_task_context.py | 158 ++++++++++++++++++++++++ 7 files changed, 338 insertions(+), 80 deletions(-) create mode 100644 examples/fastapi_example.py create mode 100644 tests/test_task_context.py diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py new file mode 100644 index 00000000..e759a36d --- /dev/null +++ b/examples/fastapi_example.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +""" +Example of using ClaudeSDKClient with FastAPI. + +This example demonstrates the CORRECT way to use ClaudeSDKClient in +FastAPI/Starlette applications, where each request runs in a different +async task. +""" + +from claude_agent_sdk import ClaudeSDKClient, TaskContextError + +# Install FastAPI: pip install fastapi uvicorn +# Run: uvicorn fastapi_example:app --reload + +try: + from fastapi import FastAPI +except ImportError: + print("FastAPI not installed. Install with: pip install fastapi uvicorn") + exit(1) + +app = FastAPI() + + +# WRONG - Do not do this! Will raise TaskContextError: +# +# client = ClaudeSDKClient() +# +# @app.on_event("startup") +# async def startup(): +# await client.connect() # Task A (startup) +# +# @app.post("/query") +# async def endpoint(prompt: str): +# # This is Task B (request handler) - will raise TaskContextError! +# async for msg in client.receive_messages(): +# yield msg + + +# CORRECT - Create client per request: +@app.post("/query") +async def query_endpoint(prompt: str): + """Handle query request in the request's task context.""" + async with ClaudeSDKClient() as client: + # Connect and use in the same task (the request handler) + await client.query(prompt) + + results = [] + async for msg in client.receive_messages(): + # Process messages + results.append(msg) + + return {"results": len(results)} + + +# Error handling example: +@app.post("/safe-query") +async def safe_query_endpoint(prompt: str): + """Handle query with proper error handling.""" + try: + async with ClaudeSDKClient() as client: + await client.query(prompt) + + results = [] + async for msg in client.receive_messages(): + results.append(msg) + + return {"results": len(results)} + + except TaskContextError as e: + # This shouldn't happen with the correct pattern above, + # but demonstrates error handling + return { + "error": "Task context error", + "connect_task": e.connect_task_id, + "current_task": e.current_task_id, + } + + +if __name__ == "__main__": + import uvicorn + + print("Starting FastAPI server...") + print("Example endpoints:") + print(" POST http://localhost:8000/query") + print(" POST http://localhost:8000/safe-query") + print("\nTest with:") + print(' curl -X POST "http://localhost:8000/query" -H "Content-Type: application/json" -d \'{"prompt":"Hello"}\'') + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 379d6d89..bad1e071 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -12,6 +12,7 @@ CLIJSONDecodeError, CLINotFoundError, ProcessError, + TaskContextError, ) from ._internal.transport import Transport from ._version import __version__ @@ -395,4 +396,5 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "CLINotFoundError", "ProcessError", "CLIJSONDecodeError", + "TaskContextError", ] diff --git a/src/claude_agent_sdk/_errors.py b/src/claude_agent_sdk/_errors.py index c86bf235..55f139f8 100644 --- a/src/claude_agent_sdk/_errors.py +++ b/src/claude_agent_sdk/_errors.py @@ -54,3 +54,56 @@ class MessageParseError(ClaudeSDKError): def __init__(self, message: str, data: dict[str, Any] | None = None): self.data = data super().__init__(message) + + +class TaskContextError(ClaudeSDKError): + """Raised when ClaudeSDKClient is used across different async tasks. + + This occurs when connect() is called in one async task (e.g., FastAPI startup) + and receive_messages() is called in a different task (e.g., request handler). + + The anyio MemoryObjectStream used internally is bound to the task group + where it was created, and cannot be accessed from a different task. + + Attributes: + connect_task_id: ID of the task where connect() was called + current_task_id: ID of the task where the error occurred + + Example: + Wrong way (will raise this error): + ```python + client = ClaudeSDKClient() + await client.connect() # Task A + + async def handle_request(): + async for msg in client.receive_messages(): # Task B - ERROR! + yield msg + ``` + + Correct way: + ```python + async def handle_request(): + async with ClaudeSDKClient() as client: # Create in same task + await client.query("Hello") + async for msg in client.receive_messages(): + yield msg + ``` + """ + + def __init__( + self, + message: str = "Client used across different async tasks", + connect_task_id: int | None = None, + current_task_id: int | None = None, + ): + self.connect_task_id = connect_task_id + self.current_task_id = current_task_id + + # Build helpful error message + if connect_task_id is not None and current_task_id is not None: + message = ( + f"{message} (connect task: {connect_task_id}, " + f"current task: {current_task_id})" + ) + + super().__init__(message) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 8f278428..c62ac995 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -105,6 +105,7 @@ def __init__( self._message_send, self._message_receive = anyio.create_memory_object_stream[ dict[str, Any] ](max_buffer_size=100) + self._owner_task: anyio.TaskInfo | None = None # Track task that owns this Query self._tg: anyio.abc.TaskGroup | None = None self._initialized = False self._closed = False @@ -167,6 +168,11 @@ async def start(self) -> None: if self._tg is None: self._tg = anyio.create_task_group() await self._tg.__aenter__() + + # Capture the task that owns this Query instance + # AnyIO streams are task-group-bound and will hang if accessed from different task + self._owner_task = anyio.get_current_task() + self._tg.start_soon(self._read_messages) async def _read_messages(self) -> None: @@ -603,6 +609,22 @@ async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None: async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: """Receive SDK messages (not control messages).""" + # Verify we're in the same task that created this Query + # The anyio stream is task-group-bound and will hang otherwise + if self._owner_task is None: + raise RuntimeError("Query.start() must be called before receive_messages()") + + current_task = anyio.get_current_task() + if current_task.id != self._owner_task.id: + from .._errors import TaskContextError + raise TaskContextError( + "ClaudeSDKClient cannot be used across different async tasks. " + "Each async task must create its own client instance. " + "See TaskContextError documentation for details.", + connect_task_id=self._owner_task.id, + current_task_id=current_task.id, + ) + async for message in self._message_receive: # Check for special messages if message.get("type") == "end": diff --git a/src/claude_agent_sdk/client.py b/src/claude_agent_sdk/client.py index f26d793a..b169cbca 100644 --- a/src/claude_agent_sdk/client.py +++ b/src/claude_agent_sdk/client.py @@ -47,9 +47,17 @@ class ClaudeSDKClient: different async runtime contexts (e.g., different trio nurseries or asyncio task groups). The client internally maintains a persistent anyio task group for reading messages that remains active from connect() until disconnect(). + This means you must complete all operations with the client within the same - async context where it was connected. Ideally, this limitation should not - exist. + async context where it was connected. Attempting to use it from a different + task will raise a TaskContextError with a clear explanation. + + Common scenario where this occurs: FastAPI/Starlette applications that + create a global client during startup and try to use it in request handlers. + Each request runs in a different task, so the client must be created per-request + instead. See examples/fastapi_example.py for correct patterns. + + Ideally, this limitation should not exist. """ def __init__( diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 29294419..a827abe0 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -492,84 +492,11 @@ async def _test(): anyio.run(_test) - def test_concurrent_send_receive(self): - """Test concurrent sending and receiving messages.""" - - async def _test(): - with patch( - "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" - ) as mock_transport_class: - mock_transport = create_mock_transport() - mock_transport_class.return_value = mock_transport - - # Mock receive to wait then yield messages with control protocol support - async def mock_receive(): - # First handle initialization - await asyncio.sleep(0.01) - written = mock_transport.write.call_args_list - for call in written: - if call: - data = call[0][0] - try: - msg = json.loads(data.strip()) - if ( - msg.get("type") == "control_request" - and msg.get("request", {}).get("subtype") - == "initialize" - ): - yield { - "type": "control_response", - "response": { - "request_id": msg.get("request_id"), - "subtype": "success", - "commands": [], - "output_style": "default", - }, - } - break - except (json.JSONDecodeError, KeyError, AttributeError): - pass - - # Then yield the actual messages - await asyncio.sleep(0.1) - yield { - "type": "assistant", - "message": { - "role": "assistant", - "content": [{"type": "text", "text": "Response 1"}], - "model": "claude-opus-4-1-20250805", - }, - } - await asyncio.sleep(0.1) - yield { - "type": "result", - "subtype": "success", - "duration_ms": 1000, - "duration_api_ms": 800, - "is_error": False, - "num_turns": 1, - "session_id": "test", - "total_cost_usd": 0.001, - } - - mock_transport.read_messages = mock_receive - - async with ClaudeSDKClient() as client: - # Helper to get next message - async def get_next_message(): - return await client.receive_response().__anext__() - - # Start receiving in background - receive_task = asyncio.create_task(get_next_message()) - - # Send message while receiving - await client.query("Question 1") - - # Wait for first message - first_msg = await receive_task - assert isinstance(first_msg, AssistantMessage) - - anyio.run(_test) + # NOTE: test_concurrent_send_receive was removed because it tested cross-task usage + # (using asyncio.create_task to call receive_messages() from a different task than + # where connect() was called). This pattern causes silent hangs with anyio streams + # and is now prevented by TaskContextError. See test_task_context.py for correct + # usage patterns. class TestQueryWithAsyncIterable: diff --git a/tests/test_task_context.py b/tests/test_task_context.py new file mode 100644 index 00000000..3123e7ee --- /dev/null +++ b/tests/test_task_context.py @@ -0,0 +1,158 @@ +"""Tests for task context validation in ClaudeSDKClient.""" + +import asyncio +from unittest.mock import AsyncMock + +import anyio +import pytest + +from claude_agent_sdk import TaskContextError +from claude_agent_sdk._internal.query import Query + + +def create_mock_transport(): + """Create a minimal mock transport for testing.""" + mock_transport = AsyncMock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.end_input = AsyncMock() + + async def mock_read(): + # Just yield messages and end + yield {"type": "end"} + + mock_transport.read_messages = mock_read + return mock_transport + + +class TestTaskContextValidation: + """Test task context detection and error handling.""" + + def test_receive_messages_same_task(self): + """Test that receive_messages works in the same task as start.""" + + async def _test(): + mock_transport = create_mock_transport() + query = Query( + transport=mock_transport, + is_streaming_mode=True, + ) + + await query.start() + # Should work fine - same task + count = 0 + async for _ in query.receive_messages(): + count += 1 + + await query.close() + + anyio.run(_test) + + def test_receive_messages_different_task_raises_error(self): + """Test that receive_messages raises TaskContextError from different task.""" + + async def _test(): + mock_transport = create_mock_transport() + query = Query( + transport=mock_transport, + is_streaming_mode=True, + ) + + # Connect in task A (current task) + await query.start() + + # Try to receive in task B + async def receive_in_different_task(): + async for _ in query.receive_messages(): + pass # Should never reach here + + receive_task = asyncio.create_task(receive_in_different_task()) + + # Should raise TaskContextError + with pytest.raises(TaskContextError) as exc_info: + await receive_task + + # Verify error attributes + error = exc_info.value + assert error.connect_task_id is not None + assert error.current_task_id is not None + assert error.connect_task_id != error.current_task_id + + # Verify error message is helpful + error_msg = str(error) + assert "different async task" in error_msg + assert "connect task" in error_msg + assert "current task" in error_msg + + await query.close() + + anyio.run(_test) + + def test_multiple_query_objects_different_tasks(self): + """Test that multiple Query objects can coexist in different tasks.""" + + async def _test(): + async def use_query_in_task(): + mock_transport = create_mock_transport() + query = Query( + transport=mock_transport, + is_streaming_mode=True, + ) + await query.start() + async for _ in query.receive_messages(): + pass # Should work fine + await query.close() + + # Create two queries in different tasks + task1 = asyncio.create_task(use_query_in_task()) + task2 = asyncio.create_task(use_query_in_task()) + + # Both should succeed + await asyncio.gather(task1, task2) + + anyio.run(_test) + + def test_receive_messages_without_start_raises_error(self): + """Test that receive_messages raises RuntimeError if start() not called.""" + + async def _test(): + mock_transport = create_mock_transport() + query = Query( + transport=mock_transport, + is_streaming_mode=True, + ) + + # Don't call start() + with pytest.raises(RuntimeError) as exc_info: + async for _ in query.receive_messages(): + pass + + assert "Query.start() must be called" in str(exc_info.value) + + anyio.run(_test) + + def test_task_context_error_attributes(self): + """Test TaskContextError stores task IDs correctly.""" + + error = TaskContextError( + "Test error", + connect_task_id=123, + current_task_id=456, + ) + + assert error.connect_task_id == 123 + assert error.current_task_id == 456 + assert "Test error" in str(error) + assert "123" in str(error) + assert "456" in str(error) + + def test_task_context_error_without_task_ids(self): + """Test TaskContextError works without task IDs.""" + + error = TaskContextError("Test error") + + assert error.connect_task_id is None + assert error.current_task_id is None + assert "Test error" in str(error) +