diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 90f535fb..3da100f7 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -131,7 +131,31 @@ async def process_query( "parent_tool_use_id": None, } await chosen_transport.write(json.dumps(user_message) + "\n") - await chosen_transport.end_input() + + has_bidirectional_needs = bool(sdk_mcp_servers) or bool( + configured_options.hooks + ) + if not has_bidirectional_needs: + # No bidirectional control protocol needed: close stdin immediately + await chosen_transport.end_input() + elif query._tg: + # Defer stdin close until the conversation ends (result message). + # The CLI needs stdin open for the entire conversation to send + # tools/list, tools/call, and hook callbacks via control protocol. + # No timeout needed for string prompts — the result message + # always arrives when the CLI finishes, and task group + # cancellation triggers the finally block on abnormal exit. + async def _deferred_end_input() -> None: + try: + await query._first_result_event.wait() + finally: + await chosen_transport.end_input() + + query._tg.start_soon(_deferred_end_input) + else: + # _tg should always exist after start(), but close stdin + # defensively to prevent resource leaks + await chosen_transport.end_input() elif isinstance(prompt, AsyncIterable) and query._tg: # Stream input in background for async iterables query._tg.start_soon(query.stream_input, prompt) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 8f278428..3966c89a 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -153,6 +153,8 @@ async def initialize(self) -> dict[str, Any] | None: } if self._agents: request["agents"] = self._agents + if self.sdk_mcp_servers: + request["sdkMcpServers"] = list(self.sdk_mcp_servers.keys()) # Use longer timeout for initialize since MCP servers may take time to start response = await self._send_control_request( diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index 3ea89d5a..387a5435 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -802,6 +802,7 @@ class SDKControlInitializeRequest(TypedDict): subtype: Literal["initialize"] hooks: dict[HookEvent, Any] | None agents: NotRequired[dict[str, dict[str, Any]]] + sdkMcpServers: NotRequired[list[str]] class SDKControlSetPermissionModeRequest(TypedDict): diff --git a/tests/test_deferred_end_input.py b/tests/test_deferred_end_input.py new file mode 100644 index 00000000..6514cbd4 --- /dev/null +++ b/tests/test_deferred_end_input.py @@ -0,0 +1,220 @@ +"""Tests for deferred end_input() with SDK MCP servers on string prompts. + +When SDK MCP servers or hooks are present and prompt is a string, +end_input() must be deferred until the first result message is received. +Closing stdin immediately prevents the CLI from completing tools/list +via control protocol, making SDK MCP tools invisible to the model. + +See: client.py process_query() string prompt handling. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import anyio + +from claude_agent_sdk import ClaudeAgentOptions, query + +RESULT_MESSAGE = { + "type": "result", + "subtype": "success", + "duration_ms": 100, + "duration_api_ms": 80, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, +} + + +def _make_mock_transport() -> Mock: + """Create a mock transport with standard async methods.""" + mock_transport = Mock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + + async def mock_receive() -> None: # type: ignore[return] + yield RESULT_MESSAGE + + mock_transport.read_messages = mock_receive + return mock_transport + + +class TestDeferredEndInput: + """Test that end_input() is deferred for string prompts with SDK MCP servers.""" + + def test_string_prompt_without_sdk_mcp_servers_closes_stdin_immediately( + self, + ) -> None: + """Without SDK MCP servers, end_input() should be called right after write.""" + + async def _test() -> None: + mock_transport = _make_mock_transport() + + # Track call order + call_order: list[str] = [] + original_write = mock_transport.write + + async def tracking_write(data: str) -> None: + call_order.append("write") + return await original_write(data) + + async def tracking_end_input() -> None: + call_order.append("end_input") + + mock_transport.write = AsyncMock(side_effect=tracking_write) + mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) + + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + options = ClaudeAgentOptions() + async for _ in query(prompt="test prompt", options=options): + pass + + # end_input should be called immediately after write + assert "write" in call_order + assert "end_input" in call_order + write_idx = call_order.index("write") + end_input_idx = call_order.index("end_input") + assert end_input_idx == write_idx + 1, ( + f"end_input should follow write immediately, got order: {call_order}" + ) + + anyio.run(_test) + + def test_string_prompt_with_sdk_mcp_servers_calls_end_input(self) -> None: + """With SDK MCP servers, end_input() must be called (deferred via task group).""" + + async def _test() -> None: + mock_transport = _make_mock_transport() + mock_mcp_server = Mock() + + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + options = ClaudeAgentOptions( + mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item] + ) + async for _ in query(prompt="test prompt", options=options): + pass + + mock_transport.end_input.assert_called_once() + + anyio.run(_test) + + def test_deferred_end_input_waits_for_result_event(self) -> None: + """end_input() must not be called before the result event fires. + + Uses a delayed result message to verify that end_input waits + for _first_result_event rather than closing stdin immediately. + """ + + async def _test() -> None: + mock_transport = _make_mock_transport() + end_input_called_before_result = False + + # Override read_messages to delay the result + result_gate = anyio.Event() + + async def delayed_receive() -> None: # type: ignore[return] + await result_gate.wait() + yield RESULT_MESSAGE + + mock_transport.read_messages = delayed_receive + + original_end_input = mock_transport.end_input + + async def tracking_end_input() -> None: + nonlocal end_input_called_before_result + if not result_gate.is_set(): + end_input_called_before_result = True + return await original_end_input() + + mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) + + mock_mcp_server = Mock() + + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + options = ClaudeAgentOptions( + mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item] + ) + + async def consume_and_release() -> None: + # Give task group time to start _deferred_end_input + await anyio.sleep(0.05) + # Now release the result — end_input should NOT have been called yet + result_gate.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(consume_and_release) + async for _ in query(prompt="test prompt", options=options): + pass + + assert not end_input_called_before_result, ( + "end_input must not be called before the result event fires" + ) + mock_transport.end_input.assert_called_once() + + anyio.run(_test) + + def test_end_input_called_even_with_sdk_mcp_servers(self) -> None: + """end_input() must always eventually be called to avoid resource leaks.""" + + async def _test() -> None: + end_input_called = anyio.Event() + + async def tracking_end_input() -> None: + end_input_called.set() + + mock_transport = _make_mock_transport() + mock_transport.end_input = AsyncMock(side_effect=tracking_end_input) + + mock_mcp_server = Mock() + + with ( + patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport", + return_value=mock_transport, + ), + patch( + "claude_agent_sdk._internal.query.Query.initialize", + new_callable=AsyncMock, + ), + ): + options = ClaudeAgentOptions( + mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item] + ) + async for _ in query(prompt="test prompt", options=options): + pass + + assert end_input_called.is_set(), ( + "end_input must be called even with SDK MCP servers (deferred, not skipped)" + ) + + anyio.run(_test) diff --git a/tests/test_initialize_sdk_mcp_servers.py b/tests/test_initialize_sdk_mcp_servers.py new file mode 100644 index 00000000..c51c3539 --- /dev/null +++ b/tests/test_initialize_sdk_mcp_servers.py @@ -0,0 +1,184 @@ +"""Tests for Query.initialize() sending sdkMcpServers field. + +The Claude Code CLI registers SDK MCP servers from the sdkMcpServers field +in the initialize control request. Without this field, SDK MCP tools are +invisible to the model. + +See: TypeScript SDK sends sdkMcpServers in initialize request. +""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import anyio + +from claude_agent_sdk._internal.query import Query + + +def _create_recording_transport() -> tuple[MagicMock, list[str]]: + """Create a mock transport that records written data.""" + mock_transport = MagicMock() + written_data: list[str] = [] + + async def mock_write(data: str) -> None: + written_data.append(data) + + mock_transport.write = AsyncMock(side_effect=mock_write) + return mock_transport, written_data + + +class TestInitializeSdkMcpServers: + """Test that Query.initialize() includes sdkMcpServers in the request.""" + + def test_initialize_includes_sdk_mcp_server_names(self) -> None: + """initialize() should include sdkMcpServers field when SDK MCP servers exist.""" + + async def _test() -> None: + mock_transport, written_data = _create_recording_transport() + + # Create mock MCP server instances + mock_server_a = MagicMock() + mock_server_b = MagicMock() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers={"server-a": mock_server_a, "server-b": mock_server_b}, + ) + + # Intercept _send_control_request to capture the request dict + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + # Verify sdkMcpServers was included in the request + assert len(captured_requests) == 1 + request = captured_requests[0] + assert "sdkMcpServers" in request, ( + "initialize() must include sdkMcpServers field" + ) + assert sorted(request["sdkMcpServers"]) == ["server-a", "server-b"] + + anyio.run(_test) + + def test_initialize_omits_sdk_mcp_servers_when_empty(self) -> None: + """initialize() should not include sdkMcpServers when no SDK MCP servers.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers={}, + ) + + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + assert len(captured_requests) == 1 + request = captured_requests[0] + assert "sdkMcpServers" not in request + + anyio.run(_test) + + def test_initialize_omits_sdk_mcp_servers_when_none(self) -> None: + """initialize() should not include sdkMcpServers when sdk_mcp_servers is None.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers=None, + ) + + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + assert len(captured_requests) == 1 + request = captured_requests[0] + assert "sdkMcpServers" not in request + + anyio.run(_test) + + def test_initialize_returns_none_when_not_streaming(self) -> None: + """initialize() should return None and not send request when not streaming.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + query = Query( + transport=mock_transport, + is_streaming_mode=False, + sdk_mcp_servers={"server": MagicMock()}, + ) + + result = await query.initialize() + + assert result is None + mock_transport.write.assert_not_called() + + anyio.run(_test) + + def test_initialize_includes_both_agents_and_sdk_mcp_servers(self) -> None: + """initialize() should include both agents and sdkMcpServers when both exist.""" + + async def _test() -> None: + mock_transport, _ = _create_recording_transport() + + mock_server = MagicMock() + agents = {"agent-1": {"name": "Agent One"}} + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + sdk_mcp_servers={"team": mock_server}, + agents=agents, + ) + + captured_requests: list[dict[str, Any]] = [] + + async def capturing_send( + request: dict[str, Any], timeout: float = 60.0 + ) -> dict[str, Any]: + captured_requests.append(request.copy()) + return {"supportedCommands": []} + + query._send_control_request = capturing_send # type: ignore[assignment] + + await query.initialize() + + assert len(captured_requests) == 1 + request = captured_requests[0] + assert request["sdkMcpServers"] == ["team"] + assert request["agents"] == agents + + anyio.run(_test)