@@ -355,6 +355,7 @@ async def test_client_session_group_establish_session_parameterized(
355355 headers = server_params_instance .headers ,
356356 timeout = server_params_instance .timeout ,
357357 sse_read_timeout = server_params_instance .sse_read_timeout ,
358+ auth = server_params_instance .auth ,
358359 )
359360 elif client_type_name == "streamablehttp" : # pragma: no branch
360361 assert isinstance (server_params_instance , StreamableHttpParameters )
@@ -385,3 +386,96 @@ async def test_client_session_group_establish_session_parameterized(
385386 # 3. Assert returned values
386387 assert returned_server_info is mock_initialize_result .server_info
387388 assert returned_session is mock_entered_session
389+
390+
391+ @pytest .mark .anyio
392+ async def test_establish_session_sse_passes_auth ():
393+ """_establish_session should pass auth to sse_client for SseServerParameters."""
394+ mock_auth = mock .Mock (spec = httpx .Auth )
395+ server_params = SseServerParameters (url = "http://test.com/sse" , auth = mock_auth )
396+
397+ with mock .patch ("mcp.client.session_group.mcp.ClientSession" ) as mock_ClientSession_class :
398+ with mock .patch ("mcp.client.session_group.sse_client" ) as mock_sse_client :
399+ # --- Mock sse_client context manager ---
400+ mock_client_cm = mock .AsyncMock ()
401+ mock_read = mock .AsyncMock ()
402+ mock_write = mock .AsyncMock ()
403+ mock_client_cm .__aenter__ .return_value = (mock_read , mock_write )
404+ mock_client_cm .__aexit__ = mock .AsyncMock (return_value = None )
405+ mock_sse_client .return_value = mock_client_cm
406+
407+ # --- Mock mcp.ClientSession ---
408+ mock_session_cm = mock .AsyncMock ()
409+ mock_ClientSession_class .return_value = mock_session_cm
410+ mock_session = mock .AsyncMock ()
411+ mock_session_cm .__aenter__ .return_value = mock_session
412+ mock_session_cm .__aexit__ = mock .AsyncMock (return_value = None )
413+
414+ # Mock session.initialize()
415+ mock_result = mock .AsyncMock ()
416+ mock_result .server_info = types .Implementation (name = "test" , version = "1" )
417+ mock_session .initialize .return_value = mock_result
418+
419+ # --- Test Execution ---
420+ group = ClientSessionGroup ()
421+ async with contextlib .AsyncExitStack () as stack :
422+ group ._exit_stack = stack
423+ await group ._establish_session (server_params , ClientSessionParameters ())
424+
425+ # --- Assert auth was passed through to sse_client ---
426+ mock_sse_client .assert_called_once_with (
427+ url = "http://test.com/sse" ,
428+ headers = None ,
429+ timeout = 5.0 ,
430+ sse_read_timeout = 300.0 ,
431+ auth = mock_auth ,
432+ )
433+
434+
435+ @pytest .mark .anyio
436+ async def test_establish_session_streamable_http_passes_auth ():
437+ """_establish_session should pass auth to create_mcp_http_client for StreamableHttpParameters."""
438+ mock_auth = mock .Mock (spec = httpx .Auth )
439+ server_params = StreamableHttpParameters (url = "http://test.com/stream" , auth = mock_auth )
440+
441+ with mock .patch ("mcp.client.session_group.mcp.ClientSession" ) as mock_ClientSession_class :
442+ with mock .patch ("mcp.client.session_group.streamable_http_client" ) as mock_streamable_client :
443+ with mock .patch (
444+ "mcp.client.session_group.create_mcp_http_client"
445+ ) as mock_create_client : # pragma: no branch
446+ # --- Mock create_mcp_http_client ---
447+ mock_httpx_client = mock .AsyncMock (spec = httpx .AsyncClient )
448+ mock_httpx_client .__aenter__ = mock .AsyncMock (return_value = mock_httpx_client )
449+ mock_httpx_client .__aexit__ = mock .AsyncMock (return_value = None )
450+ mock_create_client .return_value = mock_httpx_client
451+
452+ # --- Mock streamable_http_client context manager ---
453+ mock_client_cm = mock .AsyncMock ()
454+ mock_read = mock .AsyncMock ()
455+ mock_write = mock .AsyncMock ()
456+ mock_client_cm .__aenter__ .return_value = (mock_read , mock_write )
457+ mock_client_cm .__aexit__ = mock .AsyncMock (return_value = None )
458+ mock_streamable_client .return_value = mock_client_cm
459+
460+ # --- Mock mcp.ClientSession ---
461+ mock_session_cm = mock .AsyncMock ()
462+ mock_ClientSession_class .return_value = mock_session_cm
463+ mock_session = mock .AsyncMock ()
464+ mock_session_cm .__aenter__ .return_value = mock_session
465+ mock_session_cm .__aexit__ = mock .AsyncMock (return_value = None )
466+
467+ # Mock session.initialize()
468+ mock_result = mock .AsyncMock ()
469+ mock_result .server_info = types .Implementation (name = "test" , version = "1" )
470+ mock_session .initialize .return_value = mock_result
471+
472+ # --- Test Execution ---
473+ group = ClientSessionGroup ()
474+ async with contextlib .AsyncExitStack () as stack :
475+ group ._exit_stack = stack
476+ await group ._establish_session (server_params , ClientSessionParameters ())
477+
478+ # --- Assert auth was passed through to create_mcp_http_client ---
479+ mock_create_client .assert_called_once ()
480+ call_kwargs = mock_create_client .call_args .kwargs
481+ assert call_kwargs ["auth" ] is mock_auth
0 commit comments