Skip to content

Commit f73a9c5

Browse files
committed
feat: add auth parameter to ClientSessionGroup server parameters
Github-Issue: #1723
1 parent 1e0b5c0 commit f73a9c5

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

src/mcp/client/session_group.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import anyio
1717
import httpx
18-
from pydantic import BaseModel, Field
18+
from pydantic import BaseModel, ConfigDict, Field
1919
from typing_extensions import Self
2020

2121
import mcp
@@ -32,6 +32,8 @@
3232
class SseServerParameters(BaseModel):
3333
"""Parameters for initializing a sse_client."""
3434

35+
model_config = ConfigDict(arbitrary_types_allowed=True)
36+
3537
# The endpoint URL.
3638
url: str
3739

@@ -44,10 +46,15 @@ class SseServerParameters(BaseModel):
4446
# Timeout for SSE read operations (in seconds).
4547
sse_read_timeout: float = 300.0
4648

49+
# Optional HTTPX authentication handler.
50+
auth: httpx.Auth | None = None
51+
4752

4853
class StreamableHttpParameters(BaseModel):
4954
"""Parameters for initializing a streamable_http_client."""
5055

56+
model_config = ConfigDict(arbitrary_types_allowed=True)
57+
5158
# The endpoint URL.
5259
url: str
5360

@@ -63,6 +70,9 @@ class StreamableHttpParameters(BaseModel):
6370
# Close the client session when the transport closes.
6471
terminate_on_close: bool = True
6572

73+
# Optional HTTPX authentication handler.
74+
auth: httpx.Auth | None = None
75+
6676

6777
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
6878

@@ -279,6 +289,7 @@ async def _establish_session(
279289
headers=server_params.headers,
280290
timeout=server_params.timeout,
281291
sse_read_timeout=server_params.sse_read_timeout,
292+
auth=server_params.auth,
282293
)
283294
read, write = await session_stack.enter_async_context(client)
284295
else:
@@ -288,6 +299,7 @@ async def _establish_session(
288299
server_params.timeout,
289300
read=server_params.sse_read_timeout,
290301
),
302+
auth=server_params.auth,
291303
)
292304
await session_stack.enter_async_context(httpx_client)
293305

tests/client/test_session_group.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ async def test_client_session_group_establish_session_parameterized(
355355
headers=server_params_instance.headers,
356356
timeout=server_params_instance.timeout,
357357
sse_read_timeout=server_params_instance.sse_read_timeout,
358+
auth=server_params_instance.auth,
358359
)
359360
elif client_type_name == "streamablehttp": # pragma: no branch
360361
assert isinstance(server_params_instance, StreamableHttpParameters)
@@ -385,3 +386,94 @@ async def test_client_session_group_establish_session_parameterized(
385386
# 3. Assert returned values
386387
assert returned_server_info is mock_initialize_result.server_info
387388
assert returned_session is mock_entered_session
389+
390+
391+
@pytest.mark.anyio
392+
async def test_establish_session_sse_passes_auth():
393+
"""_establish_session should pass auth to sse_client for SseServerParameters."""
394+
mock_auth = mock.Mock(spec=httpx.Auth)
395+
server_params = SseServerParameters(url="http://test.com/sse", auth=mock_auth)
396+
397+
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
398+
with mock.patch("mcp.client.session_group.sse_client") as mock_sse_client:
399+
# --- Mock sse_client context manager ---
400+
mock_client_cm = mock.AsyncMock()
401+
mock_read = mock.AsyncMock()
402+
mock_write = mock.AsyncMock()
403+
mock_client_cm.__aenter__.return_value = (mock_read, mock_write)
404+
mock_client_cm.__aexit__ = mock.AsyncMock(return_value=None)
405+
mock_sse_client.return_value = mock_client_cm
406+
407+
# --- Mock mcp.ClientSession ---
408+
mock_session_cm = mock.AsyncMock()
409+
mock_ClientSession_class.return_value = mock_session_cm
410+
mock_session = mock.AsyncMock()
411+
mock_session_cm.__aenter__.return_value = mock_session
412+
mock_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
413+
414+
# Mock session.initialize()
415+
mock_result = mock.AsyncMock()
416+
mock_result.server_info = types.Implementation(name="test", version="1")
417+
mock_session.initialize.return_value = mock_result
418+
419+
# --- Test Execution ---
420+
group = ClientSessionGroup()
421+
async with contextlib.AsyncExitStack() as stack:
422+
group._exit_stack = stack
423+
await group._establish_session(server_params, ClientSessionParameters())
424+
425+
# --- Assert auth was passed through to sse_client ---
426+
mock_sse_client.assert_called_once_with(
427+
url="http://test.com/sse",
428+
headers=None,
429+
timeout=5.0,
430+
sse_read_timeout=300.0,
431+
auth=mock_auth,
432+
)
433+
434+
435+
@pytest.mark.anyio
436+
async def test_establish_session_streamable_http_passes_auth():
437+
"""_establish_session should pass auth to create_mcp_http_client for StreamableHttpParameters."""
438+
mock_auth = mock.Mock(spec=httpx.Auth)
439+
server_params = StreamableHttpParameters(url="http://test.com/stream", auth=mock_auth)
440+
441+
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
442+
with mock.patch("mcp.client.session_group.streamable_http_client") as mock_streamable_client:
443+
with mock.patch("mcp.client.session_group.create_mcp_http_client") as mock_create_client:
444+
# --- Mock create_mcp_http_client ---
445+
mock_httpx_client = mock.AsyncMock(spec=httpx.AsyncClient)
446+
mock_httpx_client.__aenter__ = mock.AsyncMock(return_value=mock_httpx_client)
447+
mock_httpx_client.__aexit__ = mock.AsyncMock(return_value=None)
448+
mock_create_client.return_value = mock_httpx_client
449+
450+
# --- Mock streamable_http_client context manager ---
451+
mock_client_cm = mock.AsyncMock()
452+
mock_read = mock.AsyncMock()
453+
mock_write = mock.AsyncMock()
454+
mock_client_cm.__aenter__.return_value = (mock_read, mock_write)
455+
mock_client_cm.__aexit__ = mock.AsyncMock(return_value=None)
456+
mock_streamable_client.return_value = mock_client_cm
457+
458+
# --- Mock mcp.ClientSession ---
459+
mock_session_cm = mock.AsyncMock()
460+
mock_ClientSession_class.return_value = mock_session_cm
461+
mock_session = mock.AsyncMock()
462+
mock_session_cm.__aenter__.return_value = mock_session
463+
mock_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
464+
465+
# Mock session.initialize()
466+
mock_result = mock.AsyncMock()
467+
mock_result.server_info = types.Implementation(name="test", version="1")
468+
mock_session.initialize.return_value = mock_result
469+
470+
# --- Test Execution ---
471+
group = ClientSessionGroup()
472+
async with contextlib.AsyncExitStack() as stack:
473+
group._exit_stack = stack
474+
await group._establish_session(server_params, ClientSessionParameters())
475+
476+
# --- Assert auth was passed through to create_mcp_http_client ---
477+
mock_create_client.assert_called_once()
478+
call_kwargs = mock_create_client.call_args.kwargs
479+
assert call_kwargs["auth"] is mock_auth

0 commit comments

Comments
 (0)