From a7d9e968d21ab700ecbc00d19aa3db4c72b9fa34 Mon Sep 17 00:00:00 2001 From: Ray Date: Sun, 15 Mar 2026 23:11:33 +0800 Subject: [PATCH] fix(mcp-tool): add httpx_client_factory support to SseConnectionParams - Update SseConnectionParams model to include httpx_client_factory parameter (matching StreamableHTTPConnectionParams pattern) - Pass httpx_client_factory from SseConnectionParams to underlying sse_client() call - Add model_config to SseConnectionParams to allow arbitrary types for the factory function - Update docstrings to document the new httpx_client_factory attribute This change aligns SSE transport behavior with Streamable HTTP transport for MCP tools, enabling users to customize the httpx.AsyncClient for SSE-based MCP connections (e.g., proxy, auth, TLS config). Closes https://github.com/google/adk-python/issues/4841 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 6 +++ .../mcp_tool/test_mcp_session_manager.py | 43 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index e0cd1ebc89..23f14421bd 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -99,12 +99,17 @@ class SseConnectionParams(BaseModel): server. sse_read_timeout: Timeout in seconds for reading data from the MCP SSE server. + httpx_client_factory: Factory function to create a custom HTTPX client. If + not provided, a default factory will be used. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 + httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client @runtime_checkable @@ -398,6 +403,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): headers=merged_headers, timeout=self._connection_params.timeout, sse_read_timeout=self._connection_params.sse_read_timeout, + httpx_client_factory=self._connection_params.httpx_client_factory, ) elif isinstance(self._connection_params, StreamableHTTPConnectionParams): client = streamablehttp_client( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 327df114a8..9b1a8f6b7a 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -134,6 +134,49 @@ def test_init_with_sse_connection_params(self): assert manager._connection_params == sse_params + @patch("google.adk.tools.mcp_tool.mcp_session_manager.sse_client") + def test_init_with_sse_custom_httpx_factory(self, mock_sse_client): + """Test that sse_client is called with custom httpx_client_factory.""" + custom_httpx_factory = Mock() + + sse_params = SseConnectionParams( + url="https://example.com/mcp", + timeout=10.0, + httpx_client_factory=custom_httpx_factory, + ) + manager = MCPSessionManager(sse_params) + + manager._create_client() + + mock_sse_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=10.0, + sse_read_timeout=300.0, + httpx_client_factory=custom_httpx_factory, + ) + + @patch("google.adk.tools.mcp_tool.mcp_session_manager.sse_client") + def test_init_with_sse_default_httpx_factory(self, mock_sse_client): + """Test that sse_client is called with default httpx_client_factory.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", + timeout=10.0, + ) + manager = MCPSessionManager(sse_params) + + manager._create_client() + + mock_sse_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=10.0, + sse_read_timeout=300.0, + httpx_client_factory=SseConnectionParams.model_fields[ + "httpx_client_factory" + ].get_default(), + ) + def test_init_with_streamable_http_params(self): """Test initialization with StreamableHTTPConnectionParams.""" http_params = StreamableHTTPConnectionParams(