Skip to content

Commit bce6dd2

Browse files
fix: suppress GeneratorExit during client cleanup
GeneratorExit can leak from sse_client and streamablehttp_client during cleanup, causing RuntimeError in downstream code. This handles both direct GeneratorExit and BaseExceptionGroup wrapping (cpython#95571). Fixes #1214 Signed-off-by: Adrian Cole <adrian@tetrate.io>
1 parent 0da9a07 commit bce6dd2

File tree

6 files changed

+183
-20
lines changed

6 files changed

+183
-20
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"pyjwt[crypto]>=2.10.1",
4040
"typing-extensions>=4.9.0",
4141
"typing-inspection>=0.4.1",
42+
"exceptiongroup>=1.0.0; python_version < '3.11'",
4243
]
4344

4445
[project.optional-dependencies]

src/mcp/client/sse.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import logging
2+
import sys
23
from collections.abc import Callable
34
from contextlib import asynccontextmanager
45
from typing import Any
56
from urllib.parse import parse_qs, urljoin, urlparse
67

78
import anyio
9+
10+
if sys.version_info >= (3, 11):
11+
from builtins import BaseExceptionGroup # pragma: no cover
12+
else:
13+
from exceptiongroup import BaseExceptionGroup # pragma: no cover
814
import httpx
915
from anyio.abc import TaskStatus
1016
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -157,8 +163,19 @@ async def post_writer(endpoint_url: str):
157163

158164
try:
159165
yield read_stream, write_stream
166+
# Suppress GeneratorExit to prevent "generator didn't stop after athrow()"
167+
# when client code exits the context manager during cancellation.
168+
# See https://github.com/python/cpython/issues/95571
169+
except GeneratorExit:
170+
pass
171+
# anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions
172+
except BaseExceptionGroup as eg:
173+
_, rest = eg.split(GeneratorExit)
174+
if rest:
175+
raise rest from None
160176
finally:
161177
tg.cancel_scope.cancel()
162178
finally:
163179
await read_stream_writer.aclose()
180+
await read_stream.aclose()
164181
await write_stream.aclose()

src/mcp/client/streamable_http.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import contextlib
1010
import logging
11+
import sys
1112
from collections.abc import AsyncGenerator, Awaitable, Callable
1213
from contextlib import asynccontextmanager
1314
from dataclasses import dataclass
@@ -16,6 +17,11 @@
1617
from warnings import warn
1718

1819
import anyio
20+
21+
if sys.version_info >= (3, 11):
22+
from builtins import BaseExceptionGroup # pragma: no cover
23+
else:
24+
from exceptiongroup import BaseExceptionGroup # pragma: no cover
1925
import httpx
2026
from anyio.abc import TaskGroup
2127
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -672,12 +678,23 @@ def start_get_stream() -> None:
672678
write_stream,
673679
transport.get_session_id,
674680
)
681+
# Suppress GeneratorExit to prevent "generator didn't stop after athrow()"
682+
# when client code exits the context manager during cancellation.
683+
# See https://github.com/python/cpython/issues/95571
684+
except GeneratorExit:
685+
pass
686+
# anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions
687+
except BaseExceptionGroup as eg:
688+
_, rest = eg.split(GeneratorExit)
689+
if rest:
690+
raise rest from None
675691
finally:
676692
if transport.session_id and terminate_on_close:
677693
await transport.terminate_session(client)
678694
tg.cancel_scope.cancel()
679695
finally:
680696
await read_stream_writer.aclose()
697+
await read_stream.aclose()
681698
await write_stream.aclose()
682699

683700

tests/client/conftest.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,86 @@
1-
from collections.abc import Callable, Generator
1+
import multiprocessing
2+
import socket
3+
from collections.abc import AsyncGenerator, Callable, Generator
24
from contextlib import asynccontextmanager
35
from typing import Any
46
from unittest.mock import patch
57

68
import pytest
9+
import uvicorn
710
from anyio.streams.memory import MemoryObjectSendStream
11+
from starlette.applications import Starlette
12+
from starlette.requests import Request
13+
from starlette.responses import Response
14+
from starlette.routing import Mount, Route
815

916
import mcp.shared.memory
17+
from mcp.server import Server
18+
from mcp.server.sse import SseServerTransport
19+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
1020
from mcp.shared.message import SessionMessage
1121
from mcp.types import JSONRPCNotification, JSONRPCRequest
22+
from tests.test_helpers import wait_for_server
23+
24+
25+
def run_server(port: int) -> None: # pragma: no cover
26+
"""Run server with SSE and Streamable HTTP endpoints."""
27+
server = Server(name="cleanup_test_server")
28+
session_manager = StreamableHTTPSessionManager(app=server, json_response=False)
29+
sse_transport = SseServerTransport("/messages/")
30+
31+
async def handle_sse(request: Request) -> Response:
32+
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
33+
if streams:
34+
await server.run(streams[0], streams[1], server.create_initialization_options())
35+
return Response()
36+
37+
@asynccontextmanager
38+
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
39+
async with session_manager.run():
40+
yield
41+
42+
app = Starlette(
43+
routes=[
44+
Route("/sse", endpoint=handle_sse),
45+
Mount("/messages/", app=sse_transport.handle_post_message),
46+
Mount("/mcp", app=session_manager.handle_request),
47+
],
48+
lifespan=lifespan,
49+
)
50+
uvicorn.Server(uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error")).run()
51+
52+
53+
@pytest.fixture
54+
def server_port() -> int:
55+
with socket.socket() as s:
56+
s.bind(("127.0.0.1", 0))
57+
return s.getsockname()[1]
58+
59+
60+
@pytest.fixture
61+
def test_server(server_port: int) -> Generator[str, None, None]:
62+
"""Start server with SSE and Streamable HTTP endpoints."""
63+
proc = multiprocessing.Process(target=run_server, kwargs={"port": server_port}, daemon=True)
64+
proc.start()
65+
wait_for_server(server_port)
66+
try:
67+
yield f"http://127.0.0.1:{server_port}"
68+
finally:
69+
proc.terminate()
70+
proc.join(timeout=2)
71+
if proc.is_alive(): # pragma: no cover
72+
proc.kill()
73+
proc.join(timeout=1)
74+
75+
76+
@pytest.fixture
77+
def sse_server_url(test_server: str) -> str:
78+
return f"{test_server}/sse"
79+
80+
81+
@pytest.fixture
82+
def streamable_server_url(test_server: str) -> str:
83+
return f"{test_server}/mcp"
1284

1385

1486
class SpyMemoryObjectSendStream:
Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,117 @@
1+
import sys
2+
from collections.abc import Callable
13
from typing import Any
4+
5+
if sys.version_info >= (3, 11):
6+
from builtins import BaseExceptionGroup # pragma: no cover
7+
else:
8+
from exceptiongroup import BaseExceptionGroup # pragma: no cover
9+
210
from unittest.mock import patch
311

412
import anyio
513
import pytest
14+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
615

16+
from mcp.client.sse import sse_client
17+
from mcp.client.streamable_http import streamable_http_client
718
from mcp.shared.message import SessionMessage
819
from mcp.shared.session import BaseSession, RequestId, SendResultT
920
from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest
1021

22+
ClientTransport = tuple[
23+
str,
24+
Callable[..., Any],
25+
Callable[[Any], tuple[MemoryObjectReceiveStream[Any], MemoryObjectSendStream[Any]]],
26+
]
27+
1128

1229
@pytest.mark.anyio
1330
async def test_send_request_stream_cleanup():
14-
"""
15-
Test that send_request properly cleans up streams when an exception occurs.
31+
"""Test that send_request properly cleans up streams when an exception occurs."""
1632

17-
This test mocks out most of the session functionality to focus on stream cleanup.
18-
"""
19-
20-
# Create a mock session with the minimal required functionality
2133
class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]):
2234
async def _send_response(
2335
self, request_id: RequestId, response: SendResultT | ErrorData
2436
) -> None: # pragma: no cover
2537
pass
2638

27-
# Create streams
2839
write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
2940
read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
3041

31-
# Create the session
3242
session = TestSession(
3343
read_stream_receive,
3444
write_stream_send,
35-
object, # Request type doesn't matter for this test
36-
object, # Notification type doesn't matter for this test
45+
object,
46+
object,
3747
)
3848

39-
# Create a test request
4049
request = ClientRequest(PingRequest())
4150

42-
# Patch the _write_stream.send method to raise an exception
4351
async def mock_send(*args: Any, **kwargs: Any):
4452
raise RuntimeError("Simulated network error")
4553

46-
# Record the response streams before the test
4754
initial_stream_count = len(session._response_streams)
4855

49-
# Run the test with the patched method
5056
with patch.object(session._write_stream, "send", mock_send):
5157
with pytest.raises(RuntimeError):
5258
await session.send_request(request, EmptyResult)
5359

54-
# Verify that no response streams were leaked
55-
assert len(session._response_streams) == initial_stream_count, (
56-
f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}"
57-
)
60+
assert len(session._response_streams) == initial_stream_count
5861

59-
# Clean up
6062
await write_stream_send.aclose()
6163
await write_stream_receive.aclose()
6264
await read_stream_send.aclose()
6365
await read_stream_receive.aclose()
66+
67+
68+
@pytest.fixture(params=["sse", "streamable"])
69+
def client_transport(
70+
request: pytest.FixtureRequest, sse_server_url: str, streamable_server_url: str
71+
) -> ClientTransport:
72+
if request.param == "sse":
73+
return (sse_server_url, sse_client, lambda x: (x[0], x[1]))
74+
else:
75+
return (streamable_server_url, streamable_http_client, lambda x: (x[0], x[1]))
76+
77+
78+
@pytest.mark.anyio
79+
async def test_generator_exit_on_gc_cleanup(client_transport: ClientTransport) -> None:
80+
"""Suppress GeneratorExit from aclose() during GC cleanup (python/cpython#95571)."""
81+
url, client_func, unpack = client_transport
82+
cm = client_func(url)
83+
result = await cm.__aenter__()
84+
read_stream, write_stream = unpack(result)
85+
await cm.gen.aclose()
86+
await read_stream.aclose()
87+
await write_stream.aclose()
88+
89+
90+
@pytest.mark.anyio
91+
async def test_generator_exit_in_exception_group(client_transport: ClientTransport) -> None:
92+
"""Extract GeneratorExit from BaseExceptionGroup (python/cpython#135736)."""
93+
url, client_func, unpack = client_transport
94+
async with client_func(url) as result:
95+
unpack(result)
96+
raise BaseExceptionGroup("unhandled errors in a TaskGroup", [GeneratorExit()])
97+
98+
99+
@pytest.mark.anyio
100+
async def test_generator_exit_mixed_group(client_transport: ClientTransport) -> None:
101+
"""Extract GeneratorExit from BaseExceptionGroup, re-raise other exceptions (python/cpython#135736)."""
102+
url, client_func, unpack = client_transport
103+
with pytest.raises(BaseExceptionGroup) as exc_info:
104+
async with client_func(url) as result:
105+
unpack(result)
106+
raise BaseExceptionGroup("errors", [GeneratorExit(), ValueError("real error")])
107+
108+
def has_generator_exit(eg: BaseExceptionGroup[Any]) -> bool:
109+
for e in eg.exceptions:
110+
if isinstance(e, GeneratorExit):
111+
return True # pragma: no cover
112+
if isinstance(e, BaseExceptionGroup):
113+
if has_generator_exit(eg=e): # type: ignore[arg-type]
114+
return True # pragma: no cover
115+
return False
116+
117+
assert not has_generator_exit(exc_info.value)

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)