|
27 | 27 | from starlette.requests import Request |
28 | 28 | from starlette.routing import Mount |
29 | 29 |
|
| 30 | +import mcp.client.streamable_http as streamable_http |
30 | 31 | from mcp import MCPError, types |
31 | 32 | from mcp.client.session import ClientSession |
32 | 33 | from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client |
@@ -139,6 +140,38 @@ async def replay_events_after( # pragma: no cover |
139 | 140 | return target_stream_id |
140 | 141 |
|
141 | 142 |
|
| 143 | +class FakeStreamResponse: |
| 144 | + def __init__(self) -> None: |
| 145 | + self.close_count = 0 |
| 146 | + |
| 147 | + def raise_for_status(self) -> None: |
| 148 | + pass |
| 149 | + |
| 150 | + async def aclose(self) -> None: |
| 151 | + self.close_count += 1 |
| 152 | + |
| 153 | + |
| 154 | +class FakeEventSource: |
| 155 | + def __init__(self, events: list[ServerSentEvent]) -> None: |
| 156 | + self.response = FakeStreamResponse() |
| 157 | + self.events = events |
| 158 | + self.seen = 0 |
| 159 | + |
| 160 | + async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]: |
| 161 | + for event in self.events: |
| 162 | + self.seen += 1 |
| 163 | + yield event |
| 164 | + |
| 165 | + |
| 166 | +def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent: |
| 167 | + return ServerSentEvent( |
| 168 | + event="message", |
| 169 | + data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}), |
| 170 | + id=event_id, |
| 171 | + retry=None, |
| 172 | + ) |
| 173 | + |
| 174 | + |
142 | 175 | @dataclass |
143 | 176 | class ServerState: |
144 | 177 | lock: anyio.Event = field(default_factory=anyio.Event) |
@@ -1803,6 +1836,88 @@ async def test_handle_sse_event_skips_empty_data(): |
1803 | 1836 | await read_stream.aclose() |
1804 | 1837 |
|
1805 | 1838 |
|
| 1839 | +@pytest.mark.anyio |
| 1840 | +async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch): |
| 1841 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1842 | + response = FakeStreamResponse() |
| 1843 | + event_source = FakeEventSource( |
| 1844 | + [ |
| 1845 | + jsonrpc_response_event("request-1", "event-1"), |
| 1846 | + ServerSentEvent(event="message", data="", id="event-2", retry=None), |
| 1847 | + ] |
| 1848 | + ) |
| 1849 | + monkeypatch.setattr(streamable_http, "EventSource", lambda _response: event_source) |
| 1850 | + |
| 1851 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) |
| 1852 | + try: |
| 1853 | + async with httpx.AsyncClient() as client: |
| 1854 | + ctx = streamable_http.RequestContext( |
| 1855 | + client=client, |
| 1856 | + session_id=None, |
| 1857 | + session_message=SessionMessage( |
| 1858 | + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) |
| 1859 | + ), |
| 1860 | + metadata=None, |
| 1861 | + read_stream_writer=write_stream, |
| 1862 | + ) |
| 1863 | + |
| 1864 | + await transport._handle_sse_response(response, ctx) |
| 1865 | + |
| 1866 | + received = await read_stream.receive() |
| 1867 | + assert isinstance(received.message, types.JSONRPCResponse) |
| 1868 | + assert received.message.id == "request-1" |
| 1869 | + assert event_source.seen == 2 |
| 1870 | + assert response.close_count == 0 |
| 1871 | + finally: |
| 1872 | + await write_stream.aclose() |
| 1873 | + await read_stream.aclose() |
| 1874 | + |
| 1875 | + |
| 1876 | +@pytest.mark.anyio |
| 1877 | +async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch): |
| 1878 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1879 | + event_source = FakeEventSource( |
| 1880 | + [ |
| 1881 | + jsonrpc_response_event("request-1", "event-2"), |
| 1882 | + ServerSentEvent(event="message", data="", id="event-3", retry=None), |
| 1883 | + ] |
| 1884 | + ) |
| 1885 | + |
| 1886 | + async def sleep_noop(_delay: float) -> None: |
| 1887 | + pass |
| 1888 | + |
| 1889 | + @asynccontextmanager |
| 1890 | + async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]: |
| 1891 | + yield event_source |
| 1892 | + |
| 1893 | + monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop) |
| 1894 | + monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse) |
| 1895 | + |
| 1896 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) |
| 1897 | + try: |
| 1898 | + async with httpx.AsyncClient() as client: |
| 1899 | + ctx = streamable_http.RequestContext( |
| 1900 | + client=client, |
| 1901 | + session_id=None, |
| 1902 | + session_message=SessionMessage( |
| 1903 | + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) |
| 1904 | + ), |
| 1905 | + metadata=None, |
| 1906 | + read_stream_writer=write_stream, |
| 1907 | + ) |
| 1908 | + |
| 1909 | + await transport._handle_reconnection(ctx, last_event_id="event-1") |
| 1910 | + |
| 1911 | + received = await read_stream.receive() |
| 1912 | + assert isinstance(received.message, types.JSONRPCResponse) |
| 1913 | + assert received.message.id == "request-1" |
| 1914 | + assert event_source.seen == 2 |
| 1915 | + assert event_source.response.close_count == 0 |
| 1916 | + finally: |
| 1917 | + await write_stream.aclose() |
| 1918 | + await read_stream.aclose() |
| 1919 | + |
| 1920 | + |
1806 | 1921 | @pytest.mark.anyio |
1807 | 1922 | async def test_priming_event_not_sent_for_old_protocol_version(): |
1808 | 1923 | """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" |
|
0 commit comments