Skip to content

Commit 9fcc3d9

Browse files
committed
fix: make ClientSessionGroup streamable HTTP errors catchable
1 parent f475344 commit 9fcc3d9

2 files changed

Lines changed: 121 additions & 5 deletions

File tree

src/mcp/client/session_group.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from types import TracebackType
1414
from typing import Any, TypeAlias
1515

16-
import anyio
1716
import httpx
1817
from pydantic import BaseModel, Field
1918
from typing_extensions import Self
@@ -165,10 +164,9 @@ async def __aexit__(
165164
if self._owns_exit_stack:
166165
await self._exit_stack.aclose()
167166

168-
# Concurrently close session stacks.
169-
async with anyio.create_task_group() as tg:
170-
for exit_stack in self._session_exit_stacks.values():
171-
tg.start_soon(exit_stack.aclose)
167+
# Sequentially close session stacks to preserve AnyIO task contexts.
168+
for exit_stack in list(self._session_exit_stacks.values()):
169+
await exit_stack.aclose()
172170

173171
@property
174172
def sessions(self) -> list[mcp.ClientSession]:

tests/client/test_session_group.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import contextlib
2+
import socket
3+
import sys
4+
from typing import cast
25
from unittest import mock
36

47
import httpx
58
import pytest
69

10+
if sys.version_info >= (3, 11):
11+
from builtins import BaseExceptionGroup, ExceptionGroup
12+
else:
13+
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
14+
715
import mcp
816
from mcp import types
917
from mcp.client.session_group import (
@@ -385,3 +393,113 @@ async def test_client_session_group_establish_session_parameterized(
385393
# 3. Assert returned values
386394
assert returned_server_info is mock_initialize_result.server_info
387395
assert returned_session is mock_entered_session
396+
397+
398+
def _free_tcp_port() -> int:
399+
"""Return a TCP port number not currently bound on localhost."""
400+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
401+
sock.bind(("127.0.0.1", 0))
402+
return sock.getsockname()[1]
403+
404+
405+
def _is_cancel_scope_runtime_error(exc: BaseException | None) -> bool:
406+
"""Walk an exception chain looking for AnyIO cancel-scope RuntimeError."""
407+
seen: set[int] = set()
408+
409+
def _walk(current: BaseException | None) -> bool:
410+
if current is None or id(current) in seen:
411+
return False
412+
seen.add(id(current))
413+
414+
if isinstance(current, RuntimeError) and "cancel scope" in str(current).lower():
415+
return True
416+
if isinstance(current, BaseExceptionGroup):
417+
group = cast("BaseExceptionGroup[BaseException]", current)
418+
return any(_walk(child) for child in group.exceptions)
419+
return _walk(current.__cause__) or _walk(current.__context__)
420+
421+
return _walk(exc)
422+
423+
424+
@pytest.mark.anyio
425+
async def test_unreachable_streamable_http_error_is_catchable() -> None:
426+
"""Unreachable streamable-http servers raise catchable connection errors."""
427+
port = _free_tcp_port()
428+
server_params = StreamableHttpParameters(url=f"http://127.0.0.1:{port}/mcp/")
429+
430+
caught: BaseException | None = None
431+
432+
try:
433+
async with ClientSessionGroup() as group:
434+
try:
435+
await group.connect_to_server(server_params)
436+
except BaseException as inner: # noqa: BLE001
437+
caught = inner
438+
except BaseException as outer: # noqa: BLE001
439+
caught = outer
440+
441+
assert caught is not None, (
442+
"Expected to catch a connection error against an unreachable "
443+
"streamable-http server, but no exception was raised."
444+
)
445+
assert not _is_cancel_scope_runtime_error(caught), (
446+
"Regression of #915: connection error against an unreachable "
447+
"streamable-http server was masked by an anyio cancel-scope "
448+
f"RuntimeError. Got: {type(caught).__name__}: {caught}"
449+
)
450+
451+
452+
def test_is_cancel_scope_runtime_error_detected() -> None:
453+
exc = RuntimeError("Attempted to exit cancel scope in a different task")
454+
455+
assert _is_cancel_scope_runtime_error(exc)
456+
457+
458+
def test_is_cancel_scope_runtime_error_in_group_detected() -> None:
459+
exc = ExceptionGroup(
460+
"outer",
461+
[ValueError("other"), RuntimeError("Attempted to exit cancel scope in a different task")],
462+
)
463+
464+
assert _is_cancel_scope_runtime_error(exc)
465+
466+
467+
def test_is_cancel_scope_non_runtime_error_not_detected() -> None:
468+
assert not _is_cancel_scope_runtime_error(ValueError("cancel scope was mentioned"))
469+
470+
471+
def test_is_cancel_scope_none_is_false() -> None:
472+
assert not _is_cancel_scope_runtime_error(None)
473+
474+
475+
def test_is_cancel_scope_in_cause_chain() -> None:
476+
exc = ValueError("outer")
477+
exc.__cause__ = RuntimeError("Attempted to exit cancel scope in a different task")
478+
479+
assert _is_cancel_scope_runtime_error(exc)
480+
481+
482+
@pytest.mark.anyio
483+
async def test_session_group_with_external_exit_stack(
484+
mock_exit_stack: mock.MagicMock,
485+
) -> None:
486+
"""External exit stacks remain caller-managed."""
487+
group = ClientSessionGroup(exit_stack=mock_exit_stack)
488+
489+
async with group:
490+
pass
491+
492+
mock_exit_stack.__aenter__.assert_not_called()
493+
mock_exit_stack.aclose.assert_not_called()
494+
495+
496+
@pytest.mark.anyio
497+
async def test_session_group_teardown_closes_session_stacks() -> None:
498+
"""__aexit__ closes every session-level exit stack sequentially."""
499+
session = mock.MagicMock(spec=mcp.ClientSession)
500+
session_stack = mock.AsyncMock()
501+
502+
async with ClientSessionGroup() as group:
503+
group._session_exit_stacks[session] = session_stack
504+
505+
session_stack.aclose.assert_awaited_once()

0 commit comments

Comments
 (0)