Skip to content

Commit 54828a1

Browse files
committed
fix: drain terminal streamable HTTP responses
1 parent 616476f commit 54828a1

2 files changed

Lines changed: 137 additions & 7 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,19 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
240240
event_source.response.raise_for_status()
241241
logger.debug("Resumption GET SSE connection established")
242242

243+
response_complete = False
243244
async for sse in event_source.aiter_sse(): # pragma: no branch
245+
if response_complete:
246+
continue
247+
244248
is_complete = await self._handle_sse_event(
245249
sse,
246250
ctx.read_stream_writer,
247251
original_request_id,
248252
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249253
)
250254
if is_complete:
251-
await event_source.response.aclose()
252-
break
255+
response_complete = True
253256

254257
async def _handle_post_request(self, ctx: RequestContext) -> None:
255258
"""Handle a POST request with response processing."""
@@ -342,7 +345,11 @@ async def _handle_sse_response(
342345

343346
try:
344347
event_source = EventSource(response)
348+
response_complete = False
345349
async for sse in event_source.aiter_sse(): # pragma: no branch
350+
if response_complete:
351+
continue
352+
346353
# Track last event ID for potential reconnection
347354
if sse.id:
348355
last_event_id = sse.id
@@ -359,13 +366,15 @@ async def _handle_sse_response(
359366
is_initialization=is_initialization,
360367
)
361368
# If the SSE event indicates completion, like returning response/error
362-
# break the loop
369+
# keep draining the response to EOF so the HTTP connection can be reused.
363370
if is_complete:
364-
await response.aclose()
365-
return # Normal completion, no reconnect needed
371+
response_complete = True
366372
except Exception:
367373
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
368374

375+
if response_complete:
376+
return # Normal completion, no reconnect needed
377+
369378
# Stream ended without response - reconnect if we received an event with ID
370379
if last_event_id is not None: # pragma: no branch
371380
logger.info("SSE stream disconnected, reconnecting...")
@@ -405,7 +414,11 @@ async def _handle_reconnection(
405414
reconnect_last_event_id: str = last_event_id
406415
reconnect_retry_ms = retry_interval_ms
407416

417+
response_complete = False
408418
async for sse in event_source.aiter_sse():
419+
if response_complete:
420+
continue
421+
409422
if sse.id: # pragma: no branch
410423
reconnect_last_event_id = sse.id
411424
if sse.retry is not None:
@@ -418,8 +431,10 @@ async def _handle_reconnection(
418431
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
419432
)
420433
if is_complete:
421-
await event_source.response.aclose()
422-
return
434+
response_complete = True
435+
436+
if response_complete:
437+
return
423438

424439
# Stream ended again without response - reconnect again (reset attempt counter)
425440
logger.info("SSE stream disconnected, reconnecting...")

tests/shared/test_streamable_http.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from starlette.requests import Request
2828
from starlette.routing import Mount
2929

30+
import mcp.client.streamable_http as streamable_http
3031
from mcp import MCPError, types
3132
from mcp.client.session import ClientSession
3233
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
@@ -139,6 +140,38 @@ async def replay_events_after( # pragma: no cover
139140
return target_stream_id
140141

141142

143+
class FakeStreamResponse:
144+
def __init__(self) -> None:
145+
self.close_count = 0
146+
147+
def raise_for_status(self) -> None:
148+
pass
149+
150+
async def aclose(self) -> None:
151+
self.close_count += 1
152+
153+
154+
class FakeEventSource:
155+
def __init__(self, events: list[ServerSentEvent]) -> None:
156+
self.response = FakeStreamResponse()
157+
self.events = events
158+
self.seen = 0
159+
160+
async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]:
161+
for event in self.events:
162+
self.seen += 1
163+
yield event
164+
165+
166+
def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent:
167+
return ServerSentEvent(
168+
event="message",
169+
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
170+
id=event_id,
171+
retry=None,
172+
)
173+
174+
142175
@dataclass
143176
class ServerState:
144177
lock: anyio.Event = field(default_factory=anyio.Event)
@@ -1803,6 +1836,88 @@ async def test_handle_sse_event_skips_empty_data():
18031836
await read_stream.aclose()
18041837

18051838

1839+
@pytest.mark.anyio
1840+
async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch):
1841+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1842+
response = FakeStreamResponse()
1843+
event_source = FakeEventSource(
1844+
[
1845+
jsonrpc_response_event("request-1", "event-1"),
1846+
ServerSentEvent(event="message", data="", id="event-2", retry=None),
1847+
]
1848+
)
1849+
monkeypatch.setattr(streamable_http, "EventSource", lambda _response: event_source)
1850+
1851+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
1852+
try:
1853+
async with httpx.AsyncClient() as client:
1854+
ctx = streamable_http.RequestContext(
1855+
client=client,
1856+
session_id=None,
1857+
session_message=SessionMessage(
1858+
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
1859+
),
1860+
metadata=None,
1861+
read_stream_writer=write_stream,
1862+
)
1863+
1864+
await transport._handle_sse_response(response, ctx)
1865+
1866+
received = await read_stream.receive()
1867+
assert isinstance(received.message, types.JSONRPCResponse)
1868+
assert received.message.id == "request-1"
1869+
assert event_source.seen == 2
1870+
assert response.close_count == 0
1871+
finally:
1872+
await write_stream.aclose()
1873+
await read_stream.aclose()
1874+
1875+
1876+
@pytest.mark.anyio
1877+
async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch):
1878+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1879+
event_source = FakeEventSource(
1880+
[
1881+
jsonrpc_response_event("request-1", "event-2"),
1882+
ServerSentEvent(event="message", data="", id="event-3", retry=None),
1883+
]
1884+
)
1885+
1886+
async def sleep_noop(_delay: float) -> None:
1887+
pass
1888+
1889+
@asynccontextmanager
1890+
async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]:
1891+
yield event_source
1892+
1893+
monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop)
1894+
monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse)
1895+
1896+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
1897+
try:
1898+
async with httpx.AsyncClient() as client:
1899+
ctx = streamable_http.RequestContext(
1900+
client=client,
1901+
session_id=None,
1902+
session_message=SessionMessage(
1903+
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
1904+
),
1905+
metadata=None,
1906+
read_stream_writer=write_stream,
1907+
)
1908+
1909+
await transport._handle_reconnection(ctx, last_event_id="event-1")
1910+
1911+
received = await read_stream.receive()
1912+
assert isinstance(received.message, types.JSONRPCResponse)
1913+
assert received.message.id == "request-1"
1914+
assert event_source.seen == 2
1915+
assert event_source.response.close_count == 0
1916+
finally:
1917+
await write_stream.aclose()
1918+
await read_stream.aclose()
1919+
1920+
18061921
@pytest.mark.anyio
18071922
async def test_priming_event_not_sent_for_old_protocol_version():
18081923
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""

0 commit comments

Comments
 (0)