@@ -1132,22 +1132,19 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba
11321132 read_stream ,
11331133 write_stream ,
11341134 ):
1135- async with ClientSession (read_stream , write_stream ) as session :
1135+ async with ClientSession (read_stream , write_stream ) as session : # pragma: no branch
11361136 # Initialize the session
11371137 result = await session .initialize ()
11381138 assert isinstance (result , InitializeResult )
11391139 assert len (captured_ids ) > 0
11401140 captured_session_id = captured_ids [0 ]
11411141 assert captured_session_id is not None
1142+ headers = {MCP_SESSION_ID_HEADER : captured_session_id }
11421143
11431144 # Make a request to confirm session is working
11441145 tools = await session .list_tools ()
11451146 assert len (tools .tools ) == 10
11461147
1147- headers : dict [str , str ] = {} # pragma: lax no cover
1148- if captured_session_id : # pragma: lax no cover
1149- headers [MCP_SESSION_ID_HEADER ] = captured_session_id
1150-
11511148 async with create_mcp_http_client (headers = headers ) as httpx_client2 :
11521149 async with streamable_http_client (f"{ basic_server_url } /mcp" , http_client = httpx_client2 ) as (
11531150 read_stream ,
@@ -1196,22 +1193,19 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11961193 read_stream ,
11971194 write_stream ,
11981195 ):
1199- async with ClientSession (read_stream , write_stream ) as session :
1196+ async with ClientSession (read_stream , write_stream ) as session : # pragma: no branch
12001197 # Initialize the session
12011198 result = await session .initialize ()
12021199 assert isinstance (result , InitializeResult )
12031200 assert len (captured_ids ) > 0
12041201 captured_session_id = captured_ids [0 ]
12051202 assert captured_session_id is not None
1203+ headers = {MCP_SESSION_ID_HEADER : captured_session_id }
12061204
12071205 # Make a request to confirm session is working
12081206 tools = await session .list_tools ()
12091207 assert len (tools .tools ) == 10
12101208
1211- headers : dict [str , str ] = {} # pragma: lax no cover
1212- if captured_session_id : # pragma: lax no cover
1213- headers [MCP_SESSION_ID_HEADER ] = captured_session_id
1214-
12151209 async with create_mcp_http_client (headers = headers ) as httpx_client2 :
12161210 async with streamable_http_client (f"{ basic_server_url } /mcp" , http_client = httpx_client2 ) as (
12171211 read_stream ,
@@ -1231,7 +1225,6 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent
12311225 # Variables to track the state
12321226 captured_resumption_token : str | None = None
12331227 captured_notifications : list [types .ServerNotification ] = []
1234- captured_protocol_version : str | int | None = None
12351228 first_notification_received = False
12361229
12371230 async def message_handler ( # pragma: no branch
@@ -1258,15 +1251,20 @@ async def on_resumption_token_update(token: str) -> None:
12581251 read_stream ,
12591252 write_stream ,
12601253 ):
1261- async with ClientSession (read_stream , write_stream , message_handler = message_handler ) as session :
1254+ async with ClientSession ( # pragma: no branch
1255+ read_stream , write_stream , message_handler = message_handler
1256+ ) as session :
12621257 # Initialize the session
12631258 result = await session .initialize ()
12641259 assert isinstance (result , InitializeResult )
12651260 assert len (captured_ids ) > 0
12661261 captured_session_id = captured_ids [0 ]
12671262 assert captured_session_id is not None
1268- # Capture the negotiated protocol version
1269- captured_protocol_version = result .protocol_version
1263+ # Build phase-2 headers now while both values are in scope
1264+ headers : dict [str , Any ] = {
1265+ MCP_SESSION_ID_HEADER : captured_session_id ,
1266+ MCP_PROTOCOL_VERSION_HEADER : result .protocol_version ,
1267+ }
12701268
12711269 # Start the tool that will wait on lock in a task
12721270 async with anyio .create_task_group () as tg : # pragma: no branch
@@ -1291,25 +1289,19 @@ async def run_tool():
12911289 while not first_notification_received or not captured_resumption_token :
12921290 await anyio .sleep (0.1 )
12931291
1292+ # The while loop only exits after first_notification_received=True,
1293+ # which is set by message_handler immediately after appending to
1294+ # captured_notifications. The server tool is blocked on its lock,
1295+ # so nothing else can arrive before we cancel.
1296+ assert len (captured_notifications ) == 1
1297+ assert isinstance (captured_notifications [0 ], types .LoggingMessageNotification )
1298+ assert captured_notifications [0 ].params .data == "First notification before lock"
1299+ # Reset for phase 2 before cancelling
1300+ captured_notifications .clear ()
1301+
12941302 # Kill the client session while tool is waiting on lock
12951303 tg .cancel_scope .cancel ()
12961304
1297- # Verify we received exactly one notification (inside ClientSession
1298- # so coverage tracks these on Python 3.11, see PR #1897 for details)
1299- assert len (captured_notifications ) == 1 # pragma: lax no cover
1300- assert isinstance (captured_notifications [0 ], types .LoggingMessageNotification ) # pragma: lax no cover
1301- assert captured_notifications [0 ].params .data == "First notification before lock" # pragma: lax no cover
1302-
1303- # Clear notifications and set up headers for phase 2 (between connections,
1304- # not tracked by coverage on Python 3.11 due to cancel scope + sys.settrace bug)
1305- captured_notifications = [] # pragma: lax no cover
1306- assert captured_session_id is not None # pragma: lax no cover
1307- assert captured_protocol_version is not None # pragma: lax no cover
1308- headers : dict [str , Any ] = { # pragma: lax no cover
1309- MCP_SESSION_ID_HEADER : captured_session_id ,
1310- MCP_PROTOCOL_VERSION_HEADER : captured_protocol_version ,
1311- }
1312-
13131305 async with create_mcp_http_client (headers = headers ) as httpx_client2 :
13141306 async with streamable_http_client (f"{ server_url } /mcp" , http_client = httpx_client2 ) as (
13151307 read_stream ,
@@ -2092,11 +2084,12 @@ async def on_resumption_token(token: str) -> None:
20922084 assert isinstance (result .content [0 ], TextContent )
20932085 assert "Completed 3 checkpoints" in result .content [0 ].text
20942086
2095- # 4 priming + 3 notifications + 1 response = 8 tokens
2096- assert len (resumption_tokens ) == 8 , ( # pragma: lax no cover
2097- f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
2098- f"got { len (resumption_tokens )} : { resumption_tokens } "
2099- )
2087+ # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are
2088+ # captured before send_request returns, so this is safe to check here.
2089+ assert len (resumption_tokens ) == 8 , (
2090+ f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
2091+ f"got { len (resumption_tokens )} : { resumption_tokens } "
2092+ )
21002093
21012094
21022095@pytest .mark .anyio
0 commit comments