diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index e0cd1ebc89..23f14421bd 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -99,12 +99,17 @@ class SseConnectionParams(BaseModel): server. sse_read_timeout: Timeout in seconds for reading data from the MCP SSE server. + httpx_client_factory: Factory function to create a custom HTTPX client. If + not provided, a default factory will be used. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 + httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client @runtime_checkable @@ -398,6 +403,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): headers=merged_headers, timeout=self._connection_params.timeout, sse_read_timeout=self._connection_params.sse_read_timeout, + httpx_client_factory=self._connection_params.httpx_client_factory, ) elif isinstance(self._connection_params, StreamableHTTPConnectionParams): client = streamablehttp_client( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 327df114a8..9b1a8f6b7a 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -134,6 +134,49 @@ def test_init_with_sse_connection_params(self): assert manager._connection_params == sse_params + @patch("google.adk.tools.mcp_tool.mcp_session_manager.sse_client") + def test_init_with_sse_custom_httpx_factory(self, mock_sse_client): + """Test that sse_client is called with custom httpx_client_factory.""" + custom_httpx_factory = Mock() + + sse_params = SseConnectionParams( + url="https://example.com/mcp", + timeout=10.0, + httpx_client_factory=custom_httpx_factory, + ) + manager = MCPSessionManager(sse_params) + + manager._create_client() + + mock_sse_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=10.0, + sse_read_timeout=300.0, + httpx_client_factory=custom_httpx_factory, + ) + + @patch("google.adk.tools.mcp_tool.mcp_session_manager.sse_client") + def test_init_with_sse_default_httpx_factory(self, mock_sse_client): + """Test that sse_client is called with default httpx_client_factory.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", + timeout=10.0, + ) + manager = MCPSessionManager(sse_params) + + manager._create_client() + + mock_sse_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=10.0, + sse_read_timeout=300.0, + httpx_client_factory=SseConnectionParams.model_fields[ + "httpx_client_factory" + ].get_default(), + ) + def test_init_with_streamable_http_params(self): """Test initialization with StreamableHTTPConnectionParams.""" http_params = StreamableHTTPConnectionParams(