|
1 | 1 | import contextlib |
| 2 | +import socket |
| 3 | +import sys |
| 4 | +from typing import cast |
2 | 5 | from unittest import mock |
3 | 6 |
|
4 | 7 | import httpx |
5 | 8 | import pytest |
6 | 9 |
|
| 10 | +if sys.version_info >= (3, 11): |
| 11 | + from builtins import BaseExceptionGroup, ExceptionGroup |
| 12 | +else: |
| 13 | + from exceptiongroup import BaseExceptionGroup, ExceptionGroup |
| 14 | + |
7 | 15 | import mcp |
8 | 16 | from mcp import types |
9 | 17 | from mcp.client.session_group import ( |
@@ -385,3 +393,113 @@ async def test_client_session_group_establish_session_parameterized( |
385 | 393 | # 3. Assert returned values |
386 | 394 | assert returned_server_info is mock_initialize_result.server_info |
387 | 395 | assert returned_session is mock_entered_session |
| 396 | + |
| 397 | + |
| 398 | +def _free_tcp_port() -> int: |
| 399 | + """Return a TCP port number not currently bound on localhost.""" |
| 400 | + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: |
| 401 | + sock.bind(("127.0.0.1", 0)) |
| 402 | + return sock.getsockname()[1] |
| 403 | + |
| 404 | + |
| 405 | +def _is_cancel_scope_runtime_error(exc: BaseException | None) -> bool: |
| 406 | + """Walk an exception chain looking for AnyIO cancel-scope RuntimeError.""" |
| 407 | + seen: set[int] = set() |
| 408 | + |
| 409 | + def _walk(current: BaseException | None) -> bool: |
| 410 | + if current is None or id(current) in seen: |
| 411 | + return False |
| 412 | + seen.add(id(current)) |
| 413 | + |
| 414 | + if isinstance(current, RuntimeError) and "cancel scope" in str(current).lower(): |
| 415 | + return True |
| 416 | + if isinstance(current, BaseExceptionGroup): |
| 417 | + group = cast("BaseExceptionGroup[BaseException]", current) |
| 418 | + return any(_walk(child) for child in group.exceptions) |
| 419 | + return _walk(current.__cause__) or _walk(current.__context__) |
| 420 | + |
| 421 | + return _walk(exc) |
| 422 | + |
| 423 | + |
| 424 | +@pytest.mark.anyio |
| 425 | +async def test_unreachable_streamable_http_error_is_catchable() -> None: |
| 426 | + """Unreachable streamable-http servers raise catchable connection errors.""" |
| 427 | + port = _free_tcp_port() |
| 428 | + server_params = StreamableHttpParameters(url=f"http://127.0.0.1:{port}/mcp/") |
| 429 | + |
| 430 | + caught: BaseException | None = None |
| 431 | + |
| 432 | + try: |
| 433 | + async with ClientSessionGroup() as group: |
| 434 | + try: |
| 435 | + await group.connect_to_server(server_params) |
| 436 | + except BaseException as inner: # noqa: BLE001 |
| 437 | + caught = inner |
| 438 | + except BaseException as outer: # noqa: BLE001 |
| 439 | + caught = outer |
| 440 | + |
| 441 | + assert caught is not None, ( |
| 442 | + "Expected to catch a connection error against an unreachable " |
| 443 | + "streamable-http server, but no exception was raised." |
| 444 | + ) |
| 445 | + assert not _is_cancel_scope_runtime_error(caught), ( |
| 446 | + "Regression of #915: connection error against an unreachable " |
| 447 | + "streamable-http server was masked by an anyio cancel-scope " |
| 448 | + f"RuntimeError. Got: {type(caught).__name__}: {caught}" |
| 449 | + ) |
| 450 | + |
| 451 | + |
| 452 | +def test_is_cancel_scope_runtime_error_detected() -> None: |
| 453 | + exc = RuntimeError("Attempted to exit cancel scope in a different task") |
| 454 | + |
| 455 | + assert _is_cancel_scope_runtime_error(exc) |
| 456 | + |
| 457 | + |
| 458 | +def test_is_cancel_scope_runtime_error_in_group_detected() -> None: |
| 459 | + exc = ExceptionGroup( |
| 460 | + "outer", |
| 461 | + [ValueError("other"), RuntimeError("Attempted to exit cancel scope in a different task")], |
| 462 | + ) |
| 463 | + |
| 464 | + assert _is_cancel_scope_runtime_error(exc) |
| 465 | + |
| 466 | + |
| 467 | +def test_is_cancel_scope_non_runtime_error_not_detected() -> None: |
| 468 | + assert not _is_cancel_scope_runtime_error(ValueError("cancel scope was mentioned")) |
| 469 | + |
| 470 | + |
| 471 | +def test_is_cancel_scope_none_is_false() -> None: |
| 472 | + assert not _is_cancel_scope_runtime_error(None) |
| 473 | + |
| 474 | + |
| 475 | +def test_is_cancel_scope_in_cause_chain() -> None: |
| 476 | + exc = ValueError("outer") |
| 477 | + exc.__cause__ = RuntimeError("Attempted to exit cancel scope in a different task") |
| 478 | + |
| 479 | + assert _is_cancel_scope_runtime_error(exc) |
| 480 | + |
| 481 | + |
| 482 | +@pytest.mark.anyio |
| 483 | +async def test_session_group_with_external_exit_stack( |
| 484 | + mock_exit_stack: mock.MagicMock, |
| 485 | +) -> None: |
| 486 | + """External exit stacks remain caller-managed.""" |
| 487 | + group = ClientSessionGroup(exit_stack=mock_exit_stack) |
| 488 | + |
| 489 | + async with group: |
| 490 | + pass |
| 491 | + |
| 492 | + mock_exit_stack.__aenter__.assert_not_called() |
| 493 | + mock_exit_stack.aclose.assert_not_called() |
| 494 | + |
| 495 | + |
| 496 | +@pytest.mark.anyio |
| 497 | +async def test_session_group_teardown_closes_session_stacks() -> None: |
| 498 | + """__aexit__ closes every session-level exit stack sequentially.""" |
| 499 | + session = mock.MagicMock(spec=mcp.ClientSession) |
| 500 | + session_stack = mock.AsyncMock() |
| 501 | + |
| 502 | + async with ClientSessionGroup() as group: |
| 503 | + group._session_exit_stacks[session] = session_stack |
| 504 | + |
| 505 | + session_stack.aclose.assert_awaited_once() |
0 commit comments