From 1d97abb5a6f17ffac5476f29ad6a5b01817fd096 Mon Sep 17 00:00:00 2001 From: driller Date: Fri, 20 Feb 2026 17:45:37 +0900 Subject: [PATCH 1/4] fix: send sdkMcpServers field in initialize control request The Python SDK's Query.initialize() was not including the sdkMcpServers field in the initialize control request. This caused the Claude Code CLI to never register SDK MCP servers, making their tools invisible to the model. The TypeScript SDK (v0.2.49) correctly sends this field. This fix aligns the Python SDK with the TypeScript SDK behavior, following the same pattern as PR #468 which added the agents field to initialize. Co-Authored-By: Claude Opus 4.6 --- src/claude_agent_sdk/_internal/query.py | 2 + tests/test_initialize_sdk_mcp_servers.py | 184 +++++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 tests/test_initialize_sdk_mcp_servers.py diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 8f278428..3966c89a 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -153,6 +153,8 @@ async def initialize(self) -> dict[str, Any] | None: } if self._agents: request["agents"] = self._agents + if self.sdk_mcp_servers: + request["sdkMcpServers"] = list(self.sdk_mcp_servers.keys()) # Use longer timeout for initialize since MCP servers may take time to start response = await self._send_control_request( diff --git a/tests/test_initialize_sdk_mcp_servers.py b/tests/test_initialize_sdk_mcp_servers.py new file mode 100644 index 00000000..9c4251af --- /dev/null +++ b/tests/test_initialize_sdk_mcp_servers.py @@ -0,0 +1,184 @@ +"""Tests for Query.initialize() sending sdkMcpServers field. + +The Claude Code CLI registers SDK MCP servers from the sdkMcpServers field +in the initialize control request. Without this field, SDK MCP tools are +invisible to the model. + +See: TypeScript SDK sends sdkMcpServers in initialize request. +""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import anyio + +from claude_agent_sdk._internal.query import Query + + +def _create_recording_transport() -> tuple[MagicMock, list[str]]: + """Create a mock transport that records written data and provides responses.""" + mock_transport = MagicMock() + written_data: list[str] = [] + + async def mock_write(data: str) -> None: + written_data.append(data) + + mock_transport.write = AsyncMock(side_effect=mock_write) + return mock_transport, written_data + + +class TestInitializeSdkMcpServers: + """Test that Query.initialize() includes sdkMcpServers in the request.""" + + def test_initialize_includes_sdk_mcp_server_names(self) -> None: + """initialize() should include sdkMcpServers field when SDK MCP servers exist.""" + + async def _test() -> None: + mock_transport, written_data = _create_recording_transport() + + # Create mock MCP server instances + mock_server_a = MagicMock() + mock_server_b = MagicMock() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers={"server-a": mock_server_a, "server-b": mock_server_b}, + ) + + # Intercept _send_control_request to capture the request dict + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + # Verify sdkMcpServers was included in the request + assert len(captured_requests) == 1 + request = captured_requests[0] + assert "sdkMcpServers" in request, ( + "initialize() must include sdkMcpServers field" + ) + assert sorted(request["sdkMcpServers"]) == ["server-a", "server-b"] + + anyio.run(_test) + + def test_initialize_omits_sdk_mcp_servers_when_empty(self) -> None: + """initialize() should not include sdkMcpServers when no SDK MCP servers.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers={}, + ) + + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + assert len(captured_requests) == 1 + request = captured_requests[0] + assert "sdkMcpServers" not in request + + anyio.run(_test) + + def test_initialize_omits_sdk_mcp_servers_when_none(self) -> None: + """initialize() should not include sdkMcpServers when sdk_mcp_servers is None.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers=None, + ) + + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + assert len(captured_requests) == 1 + request = captured_requests[0] + assert "sdkMcpServers" not in request + + anyio.run(_test) + + def test_initialize_returns_none_when_not_streaming(self) -> None: + """initialize() should return None and not send request when not streaming.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + query = Query( + transport=mock_transport, + is_streaming_mode=False, + sdk_mcp_servers={"server": MagicMock()}, + ) + + result = await query.initialize() + + assert result is None + mock_transport.write.assert_not_called() + + anyio.run(_test) + + def test_initialize_includes_both_agents_and_sdk_mcp_servers(self) -> None: + """initialize() should include both agents and sdkMcpServers when both exist.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + mock_server = MagicMock() + agents = {"agent-1": {"name": "Agent One"}} + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers={"team": mock_server}, + agents=agents, + ) + + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + assert len(captured_requests) == 1 + request = captured_requests[0] + assert request["sdkMcpServers"] == ["team"] + assert request["agents"] == agents + + anyio.run(_test) From bf8d6946c84cc258a2e6cf6995449475d3c534f4 Mon Sep 17 00:00:00 2001 From: driller Date: Fri, 20 Feb 2026 18:19:17 +0900 Subject: [PATCH 2/4] fix: defer end_input() for string prompts when SDK MCP servers present Closing stdin immediately after writing the user message prevents the CLI from completing tools/list via control protocol for SDK MCP servers. This makes SDK MCP tools invisible to the model even after the sdkMcpServers field is correctly sent in the initialize request. The fix defers end_input() until the first result event is received (or timeout), matching the behavior already used for async iterable prompts in stream_input(). Co-Authored-By: Claude Opus 4.6 --- src/claude_agent_sdk/_internal/client.py | 20 ++- tests/test_deferred_end_input.py | 207 +++++++++++++++++++++++ 2 files changed, 226 insertions(+), 1 deletion(-) create mode 100644 tests/test_deferred_end_input.py diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 90f535fb..c8b72f2d 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -4,6 +4,8 @@ from dataclasses import asdict, replace from typing import Any +import anyio + from ..types import ( ClaudeAgentOptions, HookEvent, @@ -131,7 +133,23 @@ async def process_query( "parent_tool_use_id": None, } await chosen_transport.write(json.dumps(user_message) + "\n") - await chosen_transport.end_input() + + if not sdk_mcp_servers: + # No SDK MCP servers: close stdin immediately + await chosen_transport.end_input() + elif query._tg: + # With SDK MCP servers: defer stdin close to allow + # CLI to complete tools/list via control protocol. + # Closing stdin too early prevents the SDK from sending + # control responses back to the CLI. + async def _deferred_end_input() -> None: + try: + with anyio.move_on_after(query._stream_close_timeout): + await query._first_result_event.wait() + finally: + await chosen_transport.end_input() + + query._tg.start_soon(_deferred_end_input) elif isinstance(prompt, AsyncIterable) and query._tg: # Stream input in background for async iterables query._tg.start_soon(query.stream_input, prompt) diff --git a/tests/test_deferred_end_input.py b/tests/test_deferred_end_input.py new file mode 100644 index 00000000..a53dce8a --- /dev/null +++ b/tests/test_deferred_end_input.py @@ -0,0 +1,207 @@ +"""Tests for deferred end_input() with SDK MCP servers on string prompts. + +When SDK MCP servers are present and prompt is a string, end_input() must +be deferred until the first result message is received. Closing stdin +immediately prevents the CLI from completing tools/list via control protocol, +making SDK MCP tools invisible to the model. + +See: client.py process_query() string prompt handling. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import anyio + +from claude_agent_sdk import ClaudeAgentOptions, query + + +def _make_mock_transport(end_input_mock: AsyncMock | None = None) -> Mock: + """Create a mock transport with standard async methods.""" + mock_transport = Mock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = end_input_mock or AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + return mock_transport + + +class TestDeferredEndInput: + """Test that end_input() is deferred for string prompts with SDK MCP servers.""" + + def test_string_prompt_without_sdk_mcp_servers_closes_stdin_immediately( + self, + ) -> None: + """Without SDK MCP servers, end_input() should be called right after write.""" + + async def _test() -> None: + mock_transport = _make_mock_transport() + + # Track call order + call_order: list[str] = [] + original_write = mock_transport.write + + async def tracking_write(data: str) -> None: + call_order.append("write") + return await original_write(data) + + async def tracking_end_input() -> None: + call_order.append("end_input") + + mock_transport.write = AsyncMock(side_effect=tracking_write) + mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) + + async def mock_receive() -> None: # type: ignore[return] + yield { + "type": "result", + "subtype": "success", + "duration_ms": 100, + "duration_api_ms": 80, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + mock_transport.read_messages = mock_receive + + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + # No mcp_servers → no SDK MCP servers + options = ClaudeAgentOptions() + async for _ in query(prompt="test prompt", options=options): + pass + + # end_input should be called immediately after write + assert "write" in call_order + assert "end_input" in call_order + write_idx = call_order.index("write") + end_input_idx = call_order.index("end_input") + assert end_input_idx == write_idx + 1, ( + f"end_input should follow write immediately, got order: {call_order}" + ) + + anyio.run(_test) + + def test_string_prompt_with_sdk_mcp_servers_defers_stdin_close(self) -> None: + """With SDK MCP servers, end_input() must NOT be called immediately after write.""" + + async def _test() -> None: + mock_transport = _make_mock_transport() + + # Track call order + call_order: list[str] = [] + + async def tracking_write(data: str) -> None: + call_order.append("write") + + async def tracking_end_input() -> None: + call_order.append("end_input") + + mock_transport.write = AsyncMock(side_effect=tracking_write) + mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) + + # Need a result message so the query completes + async def mock_receive() -> None: # type: ignore[return] + yield { + "type": "result", + "subtype": "success", + "duration_ms": 100, + "duration_api_ms": 80, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + mock_transport.read_messages = mock_receive + + mock_mcp_server = Mock() + + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + options = ClaudeAgentOptions( + mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item] + ) + async for _ in query(prompt="test prompt", options=options): + pass + + # end_input must still be called (deferred, not skipped) + assert "end_input" in call_order, "end_input must eventually be called" + + # But it must NOT be immediately after write (there should be something between) + # The key assertion: write happens first, then end_input happens later + # (not as the very next call after write in synchronous flow) + assert "write" in call_order + # end_input is deferred via task group, so it's called after the + # result event fires or timeout, not in the synchronous flow after write + + anyio.run(_test) + + def test_end_input_is_always_called_even_with_sdk_mcp_servers(self) -> None: + """end_input() must always eventually be called to avoid resource leaks.""" + + async def _test() -> None: + end_input_called = anyio.Event() + + async def tracking_end_input() -> None: + end_input_called.set() + + mock_transport = _make_mock_transport() + mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) + + async def mock_receive() -> None: # type: ignore[return] + yield { + "type": "result", + "subtype": "success", + "duration_ms": 100, + "duration_api_ms": 80, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + mock_transport.read_messages = mock_receive + mock_transport.write = AsyncMock() + + mock_mcp_server = Mock() + + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + options = ClaudeAgentOptions( + mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item] + ) + async for _ in query(prompt="test prompt", options=options): + pass + + # Verify end_input was called + assert end_input_called.is_set(), ( + "end_input must be called even with SDK MCP servers (deferred, not skipped)" + ) + + anyio.run(_test) From e8cf36759c46ac81cda670cd86e30347fa7929d1 Mon Sep 17 00:00:00 2001 From: driller Date: Fri, 20 Feb 2026 18:55:24 +0900 Subject: [PATCH 3/4] fix: remove timeout from deferred end_input for string prompts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 60-second _stream_close_timeout is designed for async iterable prompts as a safety net. For string prompts with SDK MCP servers, the timeout was closing stdin before the conversation finished, causing "Stream closed" errors on tools/call MCP messages. String prompts wait for the result event without timeout — the result message always arrives when the CLI finishes, and transport cleanup handles crash scenarios. Co-Authored-By: Claude Opus 4.6 --- src/claude_agent_sdk/_internal/client.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index c8b72f2d..1f73312e 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -4,8 +4,6 @@ from dataclasses import asdict, replace from typing import Any -import anyio - from ..types import ( ClaudeAgentOptions, HookEvent, @@ -138,14 +136,16 @@ async def process_query( # No SDK MCP servers: close stdin immediately await chosen_transport.end_input() elif query._tg: - # With SDK MCP servers: defer stdin close to allow - # CLI to complete tools/list via control protocol. - # Closing stdin too early prevents the SDK from sending - # control responses back to the CLI. + # With SDK MCP servers: defer stdin close until the + # conversation ends (result message received). + # The CLI needs stdin open for the entire conversation + # to send tools/list and tools/call via control protocol. + # Unlike stream_input() for async iterables, string prompts + # don't need a timeout — the result message always arrives + # when the CLI finishes, and transport cleanup handles crashes. async def _deferred_end_input() -> None: try: - with anyio.move_on_after(query._stream_close_timeout): - await query._first_result_event.wait() + await query._first_result_event.wait() finally: await chosen_transport.end_input() From 8b0404b85f1e59e42eee401f1a42748a93c1371b Mon Sep 17 00:00:00 2001 From: driller Date: Fri, 20 Feb 2026 20:16:41 +0900 Subject: [PATCH 4/4] fix: address review feedback for SDK MCP server fixes - Add else branch for _tg=None case to prevent stdin resource leak - Add sdkMcpServers field to SDKControlInitializeRequest TypedDict - Include hooks in bidirectional communication check (matching stream_input) - Fix inaccurate comments (timeout reference, transport cleanup) - Strengthen test assertions with delayed result gate pattern - Extract shared RESULT_MESSAGE fixture, fix docstring inaccuracies Co-Authored-By: Claude Opus 4.6 --- src/claude_agent_sdk/_internal/client.py | 24 ++-- src/claude_agent_sdk/types.py | 1 + tests/test_deferred_end_input.py | 155 ++++++++++++----------- tests/test_initialize_sdk_mcp_servers.py | 2 +- 4 files changed, 101 insertions(+), 81 deletions(-) diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 1f73312e..3da100f7 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -132,17 +132,19 @@ async def process_query( } await chosen_transport.write(json.dumps(user_message) + "\n") - if not sdk_mcp_servers: - # No SDK MCP servers: close stdin immediately + has_bidirectional_needs = bool(sdk_mcp_servers) or bool( + configured_options.hooks + ) + if not has_bidirectional_needs: + # No bidirectional control protocol needed: close stdin immediately await chosen_transport.end_input() elif query._tg: - # With SDK MCP servers: defer stdin close until the - # conversation ends (result message received). - # The CLI needs stdin open for the entire conversation - # to send tools/list and tools/call via control protocol. - # Unlike stream_input() for async iterables, string prompts - # don't need a timeout — the result message always arrives - # when the CLI finishes, and transport cleanup handles crashes. + # Defer stdin close until the conversation ends (result message). + # The CLI needs stdin open for the entire conversation to send + # tools/list, tools/call, and hook callbacks via control protocol. + # No timeout needed for string prompts — the result message + # always arrives when the CLI finishes, and task group + # cancellation triggers the finally block on abnormal exit. async def _deferred_end_input() -> None: try: await query._first_result_event.wait() @@ -150,6 +152,10 @@ async def _deferred_end_input() -> None: await chosen_transport.end_input() query._tg.start_soon(_deferred_end_input) + else: + # _tg should always exist after start(), but close stdin + # defensively to prevent resource leaks + await chosen_transport.end_input() elif isinstance(prompt, AsyncIterable) and query._tg: # Stream input in background for async iterables query._tg.start_soon(query.stream_input, prompt) diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index 3ea89d5a..387a5435 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -802,6 +802,7 @@ class SDKControlInitializeRequest(TypedDict): subtype: Literal["initialize"] hooks: dict[HookEvent, Any] | None agents: NotRequired[dict[str, dict[str, Any]]] + sdkMcpServers: NotRequired[list[str]] class SDKControlSetPermissionModeRequest(TypedDict): diff --git a/tests/test_deferred_end_input.py b/tests/test_deferred_end_input.py index a53dce8a..6514cbd4 100644 --- a/tests/test_deferred_end_input.py +++ b/tests/test_deferred_end_input.py @@ -1,9 +1,9 @@ """Tests for deferred end_input() with SDK MCP servers on string prompts. -When SDK MCP servers are present and prompt is a string, end_input() must -be deferred until the first result message is received. Closing stdin -immediately prevents the CLI from completing tools/list via control protocol, -making SDK MCP tools invisible to the model. +When SDK MCP servers or hooks are present and prompt is a string, +end_input() must be deferred until the first result message is received. +Closing stdin immediately prevents the CLI from completing tools/list +via control protocol, making SDK MCP tools invisible to the model. See: client.py process_query() string prompt handling. """ @@ -14,15 +14,31 @@ from claude_agent_sdk import ClaudeAgentOptions, query +RESULT_MESSAGE = { + "type": "result", + "subtype": "success", + "duration_ms": 100, + "duration_api_ms": 80, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, +} -def _make_mock_transport(end_input_mock: AsyncMock | None = None) -> Mock: + +def _make_mock_transport() -> Mock: """Create a mock transport with standard async methods.""" mock_transport = Mock() mock_transport.connect = AsyncMock() mock_transport.close = AsyncMock() - mock_transport.end_input = end_input_mock or AsyncMock() + mock_transport.end_input = AsyncMock() mock_transport.write = AsyncMock() mock_transport.is_ready = Mock(return_value=True) + + async def mock_receive() -> None: # type: ignore[return] + yield RESULT_MESSAGE + + mock_transport.read_messages = mock_receive return mock_transport @@ -51,20 +67,6 @@ async def tracking_end_input() -> None: mock_transport.write = AsyncMock(side_effect=tracking_write) mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) - async def mock_receive() -> None: # type: ignore[return] - yield { - "type": "result", - "subtype": "success", - "duration_ms": 100, - "duration_api_ms": 80, - "is_error": False, - "num_turns": 1, - "session_id": "test", - "total_cost_usd": 0.001, - } - - mock_transport.read_messages = mock_receive - with ( patch( "claude_agent_sdk._internal.client.SubprocessCLITransport", @@ -75,7 +77,6 @@ async def mock_receive() -> None: # type: ignore[return] new_callable=AsyncMock, ), ): - # No mcp_servers → no SDK MCP servers options = ClaudeAgentOptions() async for _ in query(prompt="test prompt", options=options): pass @@ -91,39 +92,63 @@ async def mock_receive() -> None: # type: ignore[return] anyio.run(_test) - def test_string_prompt_with_sdk_mcp_servers_defers_stdin_close(self) -> None: - """With SDK MCP servers, end_input() must NOT be called immediately after write.""" + def test_string_prompt_with_sdk_mcp_servers_calls_end_input(self) -> None: + """With SDK MCP servers, end_input() must be called (deferred via task group).""" async def _test() -> None: mock_transport = _make_mock_transport() + mock_mcp_server = Mock() - # Track call order - call_order: list[str] = [] + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + options = ClaudeAgentOptions( + mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item] + ) + async for _ in query(prompt="test prompt", options=options): + pass - async def tracking_write(data: str) -> None: - call_order.append("write") + mock_transport.end_input.assert_called_once() + + anyio.run(_test) + + def test_deferred_end_input_waits_for_result_event(self) -> None: + """end_input() must not be called before the result event fires. + + Uses a delayed result message to verify that end_input waits + for _first_result_event rather than closing stdin immediately. + """ + + async def _test() -> None: + mock_transport = _make_mock_transport() + end_input_called_before_result = False + + # Override read_messages to delay the result + result_gate = anyio.Event() + + async def delayed_receive() -> None: # type: ignore[return] + await result_gate.wait() + yield RESULT_MESSAGE + + mock_transport.read_messages = delayed_receive + + original_end_input = mock_transport.end_input async def tracking_end_input() -> None: - call_order.append("end_input") + nonlocal end_input_called_before_result + if not result_gate.is_set(): + end_input_called_before_result = True + return await original_end_input() - mock_transport.write = AsyncMock(side_effect=tracking_write) mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) - # Need a result message so the query completes - async def mock_receive() -> None: # type: ignore[return] - yield { - "type": "result", - "subtype": "success", - "duration_ms": 100, - "duration_api_ms": 80, - "is_error": False, - "num_turns": 1, - "session_id": "test", - "total_cost_usd": 0.001, - } - - mock_transport.read_messages = mock_receive - mock_mcp_server = Mock() with ( @@ -139,22 +164,26 @@ async def mock_receive() -> None: # type: ignore[return] options = ClaudeAgentOptions( mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item] ) - async for _ in query(prompt="test prompt", options=options): - pass - # end_input must still be called (deferred, not skipped) - assert "end_input" in call_order, "end_input must eventually be called" + async def consume_and_release() -> None: + # Give task group time to start _deferred_end_input + await anyio.sleep(0.05) + # Now release the result — end_input should NOT have been called yet + result_gate.set() - # But it must NOT be immediately after write (there should be something between) - # The key assertion: write happens first, then end_input happens later - # (not as the very next call after write in synchronous flow) - assert "write" in call_order - # end_input is deferred via task group, so it's called after the - # result event fires or timeout, not in the synchronous flow after write + async with anyio.create_task_group() as tg: + tg.start_soon(consume_and_release) + async for _ in query(prompt="test prompt", options=options): + pass + + assert not end_input_called_before_result, ( + "end_input must not be called before the result event fires" + ) + mock_transport.end_input.assert_called_once() anyio.run(_test) - def test_end_input_is_always_called_even_with_sdk_mcp_servers(self) -> None: + def test_end_input_called_even_with_sdk_mcp_servers(self) -> None: """end_input() must always eventually be called to avoid resource leaks.""" async def _test() -> None: @@ -166,21 +195,6 @@ async def tracking_end_input() -> None: mock_transport = _make_mock_transport() mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) - async def mock_receive() -> None: # type: ignore[return] - yield { - "type": "result", - "subtype": "success", - "duration_ms": 100, - "duration_api_ms": 80, - "is_error": False, - "num_turns": 1, - "session_id": "test", - "total_cost_usd": 0.001, - } - - mock_transport.read_messages = mock_receive - mock_transport.write = AsyncMock() - mock_mcp_server = Mock() with ( @@ -199,7 +213,6 @@ async def mock_receive() -> None: # type: ignore[return] async for _ in query(prompt="test prompt", options=options): pass - # Verify end_input was called assert end_input_called.is_set(), ( "end_input must be called even with SDK MCP servers (deferred, not skipped)" ) diff --git a/tests/test_initialize_sdk_mcp_servers.py b/tests/test_initialize_sdk_mcp_servers.py index 9c4251af..c51c3539 100644 --- a/tests/test_initialize_sdk_mcp_servers.py +++ b/tests/test_initialize_sdk_mcp_servers.py @@ -16,7 +16,7 @@ def _create_recording_transport() -> tuple[MagicMock, list[str]]: - """Create a mock transport that records written data and provides responses.""" + """Create a mock transport that records written data.""" mock_transport = MagicMock() written_data: list[str] = []