@@ -355,6 +355,7 @@ async def test_client_session_group_establish_session_parameterized(
355355 headers = server_params_instance .headers ,
356356 timeout = server_params_instance .timeout ,
357357 sse_read_timeout = server_params_instance .sse_read_timeout ,
358+ auth = server_params_instance .auth ,
358359 )
359360 elif client_type_name == "streamablehttp" : # pragma: no branch
360361 assert isinstance (server_params_instance , StreamableHttpParameters )
@@ -385,3 +386,94 @@ async def test_client_session_group_establish_session_parameterized(
385386 # 3. Assert returned values
386387 assert returned_server_info is mock_initialize_result .server_info
387388 assert returned_session is mock_entered_session
389+
390+
391+ @pytest .mark .anyio
392+ async def test_establish_session_sse_passes_auth ():
393+ """_establish_session should pass auth to sse_client for SseServerParameters."""
394+ mock_auth = mock .Mock (spec = httpx .Auth )
395+ server_params = SseServerParameters (url = "http://test.com/sse" , auth = mock_auth )
396+
397+ with mock .patch ("mcp.client.session_group.mcp.ClientSession" ) as mock_ClientSession_class :
398+ with mock .patch ("mcp.client.session_group.sse_client" ) as mock_sse_client :
399+ # --- Mock sse_client context manager ---
400+ mock_client_cm = mock .AsyncMock ()
401+ mock_read = mock .AsyncMock ()
402+ mock_write = mock .AsyncMock ()
403+ mock_client_cm .__aenter__ .return_value = (mock_read , mock_write )
404+ mock_client_cm .__aexit__ = mock .AsyncMock (return_value = None )
405+ mock_sse_client .return_value = mock_client_cm
406+
407+ # --- Mock mcp.ClientSession ---
408+ mock_session_cm = mock .AsyncMock ()
409+ mock_ClientSession_class .return_value = mock_session_cm
410+ mock_session = mock .AsyncMock ()
411+ mock_session_cm .__aenter__ .return_value = mock_session
412+ mock_session_cm .__aexit__ = mock .AsyncMock (return_value = None )
413+
414+ # Mock session.initialize()
415+ mock_result = mock .AsyncMock ()
416+ mock_result .server_info = types .Implementation (name = "test" , version = "1" )
417+ mock_session .initialize .return_value = mock_result
418+
419+ # --- Test Execution ---
420+ group = ClientSessionGroup ()
421+ async with contextlib .AsyncExitStack () as stack :
422+ group ._exit_stack = stack
423+ await group ._establish_session (server_params , ClientSessionParameters ())
424+
425+ # --- Assert auth was passed through to sse_client ---
426+ mock_sse_client .assert_called_once_with (
427+ url = "http://test.com/sse" ,
428+ headers = None ,
429+ timeout = 5.0 ,
430+ sse_read_timeout = 300.0 ,
431+ auth = mock_auth ,
432+ )
433+
434+
435+ @pytest .mark .anyio
436+ async def test_establish_session_streamable_http_passes_auth ():
437+ """_establish_session should pass auth to create_mcp_http_client for StreamableHttpParameters."""
438+ mock_auth = mock .Mock (spec = httpx .Auth )
439+ server_params = StreamableHttpParameters (url = "http://test.com/stream" , auth = mock_auth )
440+
441+ with mock .patch ("mcp.client.session_group.mcp.ClientSession" ) as mock_ClientSession_class :
442+ with mock .patch ("mcp.client.session_group.streamable_http_client" ) as mock_streamable_client :
443+ with mock .patch ("mcp.client.session_group.create_mcp_http_client" ) as mock_create_client :
444+ # --- Mock create_mcp_http_client ---
445+ mock_httpx_client = mock .AsyncMock (spec = httpx .AsyncClient )
446+ mock_httpx_client .__aenter__ = mock .AsyncMock (return_value = mock_httpx_client )
447+ mock_httpx_client .__aexit__ = mock .AsyncMock (return_value = None )
448+ mock_create_client .return_value = mock_httpx_client
449+
450+ # --- Mock streamable_http_client context manager ---
451+ mock_client_cm = mock .AsyncMock ()
452+ mock_read = mock .AsyncMock ()
453+ mock_write = mock .AsyncMock ()
454+ mock_client_cm .__aenter__ .return_value = (mock_read , mock_write )
455+ mock_client_cm .__aexit__ = mock .AsyncMock (return_value = None )
456+ mock_streamable_client .return_value = mock_client_cm
457+
458+ # --- Mock mcp.ClientSession ---
459+ mock_session_cm = mock .AsyncMock ()
460+ mock_ClientSession_class .return_value = mock_session_cm
461+ mock_session = mock .AsyncMock ()
462+ mock_session_cm .__aenter__ .return_value = mock_session
463+ mock_session_cm .__aexit__ = mock .AsyncMock (return_value = None )
464+
465+ # Mock session.initialize()
466+ mock_result = mock .AsyncMock ()
467+ mock_result .server_info = types .Implementation (name = "test" , version = "1" )
468+ mock_session .initialize .return_value = mock_result
469+
470+ # --- Test Execution ---
471+ group = ClientSessionGroup ()
472+ async with contextlib .AsyncExitStack () as stack :
473+ group ._exit_stack = stack
474+ await group ._establish_session (server_params , ClientSessionParameters ())
475+
476+ # --- Assert auth was passed through to create_mcp_http_client ---
477+ mock_create_client .assert_called_once ()
478+ call_kwargs = mock_create_client .call_args .kwargs
479+ assert call_kwargs ["auth" ] is mock_auth
0 commit comments