Skip to content

Commit 528abfa

Browse files
authored
tests: remove lax-no-cover pragmas by moving assertions before cancellation (#2206)
1 parent b3149d2 commit 528abfa

File tree

2 files changed

+36
-46
lines changed

2 files changed

+36
-46
lines changed

tests/shared/test_sse.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,15 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
203203

204204
@pytest.mark.anyio
205205
async def test_sse_client_on_session_created(server: None, server_url: str) -> None:
206-
captured_session_id: str | None = None
207-
208-
def on_session_created(session_id: str) -> None:
209-
nonlocal captured_session_id
210-
captured_session_id = session_id
206+
captured: list[str] = []
211207

212-
async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams:
208+
async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams:
213209
async with ClientSession(*streams) as session:
214210
result = await session.initialize()
215211
assert isinstance(result, InitializeResult)
216-
217-
assert captured_session_id is not None # pragma: lax no cover
218-
assert len(captured_session_id) > 0 # pragma: lax no cover
212+
# Callback fires when the endpoint event arrives, before sse_client yields.
213+
assert len(captured) == 1
214+
assert len(captured[0]) > 0
219215

220216

221217
@pytest.mark.parametrize(
@@ -248,8 +244,9 @@ def mock_extract(url: str) -> None:
248244
async with ClientSession(*streams) as session:
249245
result = await session.initialize()
250246
assert isinstance(result, InitializeResult)
251-
252-
callback_mock.assert_not_called() # pragma: lax no cover
247+
# Callback would have fired by now (endpoint event arrives before
248+
# sse_client yields); if it hasn't, it won't.
249+
callback_mock.assert_not_called()
253250

254251

255252
@pytest.fixture

tests/shared/test_streamable_http.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,22 +1132,19 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba
11321132
read_stream,
11331133
write_stream,
11341134
):
1135-
async with ClientSession(read_stream, write_stream) as session:
1135+
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
11361136
# Initialize the session
11371137
result = await session.initialize()
11381138
assert isinstance(result, InitializeResult)
11391139
assert len(captured_ids) > 0
11401140
captured_session_id = captured_ids[0]
11411141
assert captured_session_id is not None
1142+
headers = {MCP_SESSION_ID_HEADER: captured_session_id}
11421143

11431144
# Make a request to confirm session is working
11441145
tools = await session.list_tools()
11451146
assert len(tools.tools) == 10
11461147

1147-
headers: dict[str, str] = {} # pragma: lax no cover
1148-
if captured_session_id: # pragma: lax no cover
1149-
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1150-
11511148
async with create_mcp_http_client(headers=headers) as httpx_client2:
11521149
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as (
11531150
read_stream,
@@ -1196,22 +1193,19 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11961193
read_stream,
11971194
write_stream,
11981195
):
1199-
async with ClientSession(read_stream, write_stream) as session:
1196+
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
12001197
# Initialize the session
12011198
result = await session.initialize()
12021199
assert isinstance(result, InitializeResult)
12031200
assert len(captured_ids) > 0
12041201
captured_session_id = captured_ids[0]
12051202
assert captured_session_id is not None
1203+
headers = {MCP_SESSION_ID_HEADER: captured_session_id}
12061204

12071205
# Make a request to confirm session is working
12081206
tools = await session.list_tools()
12091207
assert len(tools.tools) == 10
12101208

1211-
headers: dict[str, str] = {} # pragma: lax no cover
1212-
if captured_session_id: # pragma: lax no cover
1213-
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1214-
12151209
async with create_mcp_http_client(headers=headers) as httpx_client2:
12161210
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as (
12171211
read_stream,
@@ -1231,7 +1225,6 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent
12311225
# Variables to track the state
12321226
captured_resumption_token: str | None = None
12331227
captured_notifications: list[types.ServerNotification] = []
1234-
captured_protocol_version: str | int | None = None
12351228
first_notification_received = False
12361229

12371230
async def message_handler( # pragma: no branch
@@ -1258,15 +1251,20 @@ async def on_resumption_token_update(token: str) -> None:
12581251
read_stream,
12591252
write_stream,
12601253
):
1261-
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
1254+
async with ClientSession( # pragma: no branch
1255+
read_stream, write_stream, message_handler=message_handler
1256+
) as session:
12621257
# Initialize the session
12631258
result = await session.initialize()
12641259
assert isinstance(result, InitializeResult)
12651260
assert len(captured_ids) > 0
12661261
captured_session_id = captured_ids[0]
12671262
assert captured_session_id is not None
1268-
# Capture the negotiated protocol version
1269-
captured_protocol_version = result.protocol_version
1263+
# Build phase-2 headers now while both values are in scope
1264+
headers: dict[str, Any] = {
1265+
MCP_SESSION_ID_HEADER: captured_session_id,
1266+
MCP_PROTOCOL_VERSION_HEADER: result.protocol_version,
1267+
}
12701268

12711269
# Start the tool that will wait on lock in a task
12721270
async with anyio.create_task_group() as tg: # pragma: no branch
@@ -1291,25 +1289,19 @@ async def run_tool():
12911289
while not first_notification_received or not captured_resumption_token:
12921290
await anyio.sleep(0.1)
12931291

1292+
# The while loop only exits after first_notification_received=True,
1293+
# which is set by message_handler immediately after appending to
1294+
# captured_notifications. The server tool is blocked on its lock,
1295+
# so nothing else can arrive before we cancel.
1296+
assert len(captured_notifications) == 1
1297+
assert isinstance(captured_notifications[0], types.LoggingMessageNotification)
1298+
assert captured_notifications[0].params.data == "First notification before lock"
1299+
# Reset for phase 2 before cancelling
1300+
captured_notifications.clear()
1301+
12941302
# Kill the client session while tool is waiting on lock
12951303
tg.cancel_scope.cancel()
12961304

1297-
# Verify we received exactly one notification (inside ClientSession
1298-
# so coverage tracks these on Python 3.11, see PR #1897 for details)
1299-
assert len(captured_notifications) == 1 # pragma: lax no cover
1300-
assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: lax no cover
1301-
assert captured_notifications[0].params.data == "First notification before lock" # pragma: lax no cover
1302-
1303-
# Clear notifications and set up headers for phase 2 (between connections,
1304-
# not tracked by coverage on Python 3.11 due to cancel scope + sys.settrace bug)
1305-
captured_notifications = [] # pragma: lax no cover
1306-
assert captured_session_id is not None # pragma: lax no cover
1307-
assert captured_protocol_version is not None # pragma: lax no cover
1308-
headers: dict[str, Any] = { # pragma: lax no cover
1309-
MCP_SESSION_ID_HEADER: captured_session_id,
1310-
MCP_PROTOCOL_VERSION_HEADER: captured_protocol_version,
1311-
}
1312-
13131305
async with create_mcp_http_client(headers=headers) as httpx_client2:
13141306
async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as (
13151307
read_stream,
@@ -2092,11 +2084,12 @@ async def on_resumption_token(token: str) -> None:
20922084
assert isinstance(result.content[0], TextContent)
20932085
assert "Completed 3 checkpoints" in result.content[0].text
20942086

2095-
# 4 priming + 3 notifications + 1 response = 8 tokens
2096-
assert len(resumption_tokens) == 8, ( # pragma: lax no cover
2097-
f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
2098-
f"got {len(resumption_tokens)}: {resumption_tokens}"
2099-
)
2087+
# 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are
2088+
# captured before send_request returns, so this is safe to check here.
2089+
assert len(resumption_tokens) == 8, (
2090+
f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
2091+
f"got {len(resumption_tokens)}: {resumption_tokens}"
2092+
)
21002093

21012094

21022095
@pytest.mark.anyio

0 commit comments

Comments
 (0)