From 2aecd41199e3dbeb8a22427ff03197873499f22a Mon Sep 17 00:00:00 2001 From: Alexander Alderman Webb Date: Fri, 30 Jan 2026 13:29:20 +0100 Subject: [PATCH 1/3] test(fastmcp): Use AsyncClient for SSE --- tests/conftest.py | 201 +++++++++++++++++- tests/integrations/fastmcp/test_fastmcp.py | 87 ++++++-- .../mcp/streaming_asgi_transport.py | 85 -------- tests/integrations/mcp/test_mcp.py | 98 +-------- 4 files changed, 275 insertions(+), 196 deletions(-) delete mode 100644 tests/integrations/mcp/streaming_asgi_transport.py diff --git a/tests/conftest.py b/tests/conftest.py index 2c8aa7bdf8..e581c38d5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import json import os +import asyncio +from urllib.parse import urlparse, parse_qs import socket import warnings import brotli @@ -51,24 +53,33 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional + from typing import Any, Callable, MutableMapping, Optional from collections.abc import Iterator try: - from anyio import create_memory_object_stream, create_task_group + from anyio import create_memory_object_stream, create_task_group, EndOfStream from mcp.types import ( JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, ) from mcp.shared.message import SessionMessage + from httpx import ASGITransport, Request, Response, AsyncByteStream, AsyncClient except ImportError: create_memory_object_stream = None create_task_group = None + EndOfStream = None + JSONRPCMessage = None JSONRPCRequest = None SessionMessage = None + ASGITransport = None + Request = None + Response = None + AsyncByteStream = None + AsyncClient = None + SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json" @@ -786,6 +797,192 @@ def inner(events): return inner +@pytest.fixture() +def json_rpc_sse(is_structured_content: bool = True): + class StreamingASGITransport(ASGITransport): + """ + Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing + tests involving SSE interactions to run in-process. + """ + + def __init__( + self, + app: "Callable", + keep_sse_alive: "asyncio.Event", + ) -> None: + self.keep_sse_alive = keep_sse_alive + super().__init__(app) + + async def handle_async_request(self, request: "Request") -> "Response": + scope = { + "type": "http", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "path": request.url.path, + "query_string": request.url.query, + } + + is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse" + if not is_streaming_sse: + return await super().handle_async_request(request) + + request_body = b"" + if request.content: + request_body = await request.aread() + + body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore + + async def receive() -> "dict[str, Any]": + if self.keep_sse_alive.is_set(): + return {"type": "http.disconnect"} + + await self.keep_sse_alive.wait() # Keep alive :) + return { + "type": "http.request", + "body": request_body, + "more_body": False, + } + + async def send(message: "MutableMapping[str, Any]") -> None: + if message["type"] == "http.response.body": + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body == b"" and not more_body: + return + + if body: + await body_sender.send(body) + + if not more_body: + await body_sender.aclose() + + async def run_app(): + await self.app(scope, receive, send) + + class StreamingBodyStream(AsyncByteStream): # type: ignore + def __init__(self, receiver, task): + self.receiver = receiver + self.task = task + + async def __aiter__(self): + try: + async for chunk in self.receiver: + yield chunk + except EndOfStream: # type: ignore + pass + + stream = StreamingBodyStream(body_receiver, asyncio.create_task(run_app())) + response = Response(status_code=200, headers=[], stream=stream) # type: ignore + + return response + + def parse_sse_data_package(sse_chunk): + sse_text = sse_chunk.decode("utf-8") + json_str = sse_text.split("data: ")[1] + return json.loads(json_str) + + async def inner( + app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event" + ): + context = {} + + stream_complete = asyncio.Event() + endpoint_parsed = asyncio.Event() + + # https://github.com/Kludex/starlette/issues/104#issuecomment-729087925 + async with AsyncClient( # type: ignore + transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive), + base_url="http://test", + ) as client: + + async def parse_stream(): + async with client.stream("GET", "/sse") as stream: + # Read directly from stream.stream instead of aiter_bytes() + async for chunk in stream.stream: + if b"event: endpoint" in chunk: + sse_text = chunk.decode("utf-8") + url = sse_text.split("data: ")[1] + + parsed = urlparse(url) + query_params = parse_qs(parsed.query) + context["session_id"] = query_params["session_id"][0] + endpoint_parsed.set() + continue + + if ( + is_structured_content + and b"event: message" in chunk + and b"structuredContent" in chunk + ): + context["response"] = parse_sse_data_package(chunk) + stream_complete.set() + break + elif ( + "result" in parse_sse_data_package(chunk) + and "content" in parse_sse_data_package(chunk)["result"] + ): + context["response"] = parse_sse_data_package(chunk) + stream_complete.set() + break + + task = asyncio.create_task(parse_stream()) + await endpoint_parsed.wait() + + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + }, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-11-25", + "capabilities": {}, + }, + "id": request_id, + }, + ) + + # Notification response is mandatory. + # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + "mcp-session-id": context["session_id"], + }, + json={ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {}, + }, + ) + + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + "mcp-session-id": context["session_id"], + }, + json={ + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": request_id, + }, + ) + + await stream_complete.wait() + keep_sse_alive.set() + + return task, context["session_id"], context["response"] + + return inner + + class MockServerRequestHandler(BaseHTTPRequestHandler): def do_GET(self): # noqa: N802 # Process an HTTP GET request and return a response. diff --git a/tests/integrations/fastmcp/test_fastmcp.py b/tests/integrations/fastmcp/test_fastmcp.py index f2619b5104..ead395e6c0 100644 --- a/tests/integrations/fastmcp/test_fastmcp.py +++ b/tests/integrations/fastmcp/test_fastmcp.py @@ -21,6 +21,7 @@ accurate testing of the integration's behavior in real MCP Server scenarios. """ +import anyio import asyncio import json import pytest @@ -39,9 +40,12 @@ async def __call__(self, *args, **kwargs): from sentry_sdk.consts import SPANDATA, OP from sentry_sdk.integrations.mcp import MCPIntegration +from mcp.server.lowlevel import Server +from mcp.server.sse import SseServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.routing import Mount +from starlette.routing import Mount, Route +from starlette.responses import Response from starlette.applications import Starlette # Try to import both FastMCP implementations @@ -1029,8 +1033,11 @@ def test_tool_no_ctx(x: int) -> dict: # ============================================================================= +@pytest.mark.asyncio @pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids) -def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP): +async def test_fastmcp_sse_transport( + sentry_init, capture_events, FastMCP, json_rpc_sse +): """Test that FastMCP correctly detects SSE transport""" sentry_init( integrations=[MCPIntegration()], @@ -1039,25 +1046,81 @@ def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP): events = capture_events() mcp = FastMCP("Test Server") + sse = SseServerTransport("/messages/") - # Set up mock request context with SSE transport - if request_ctx is not None: - mock_ctx = MockRequestContext( - request_id="req-sse", session_id="session-sse-123", transport="sse" - ) - request_ctx.set(mock_ctx) + sse_connection_closed = asyncio.Event() + + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + async with anyio.create_task_group() as tg: + + async def run_server(): + await mcp._mcp_server.run( + streams[0], + streams[1], + mcp._mcp_server.create_initialization_options(), + ) + + tg.start_soon(run_server) + + sse_connection_closed.set() + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) @mcp.tool() def sse_tool(value: str) -> dict: """Tool for SSE transport test""" return {"message": f"Received: {value}"} - with start_transaction(name="fastmcp tx"): - result = call_tool_through_mcp(mcp, "sse_tool", {"value": "hello"}) + keep_sse_alive = asyncio.Event() + app_task, _, result = await json_rpc_sse( + app, + method="tools/call", + params={ + "name": "sse_tool", + "arguments": {"value": "hello"}, + }, + request_id="req-sse", + keep_sse_alive=keep_sse_alive, + ) - assert result == {"message": "Received: hello"} + await sse_connection_closed.wait() + await app_task - (tx,) = events + if ( + isinstance(mcp, StandaloneFastMCP) + and FASTMCP_VERSION is not None + and FASTMCP_VERSION.startswith("2") + ): + assert result["result"]["content"][0]["text"] == json.dumps( + {"message": "Received: hello"}, separators=(",", ":") + ) + elif ( + isinstance(mcp, StandaloneFastMCP) and FASTMCP_VERSION is not None + ): # Checking for None is not precise. + assert result["result"]["content"][0]["text"] == json.dumps( + {"message": "Received: hello"} + ) + else: + assert result["result"]["content"][0]["text"] == json.dumps( + {"message": "Received: hello"}, indent=2 + ) + + transactions = [ + event + for event in events + if event["type"] == "transaction" and event["transaction"] == "/sse" + ] + assert len(transactions) == 1 + tx = transactions[0] # Find MCP spans mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER] diff --git a/tests/integrations/mcp/streaming_asgi_transport.py b/tests/integrations/mcp/streaming_asgi_transport.py deleted file mode 100644 index 03f84b0e91..0000000000 --- a/tests/integrations/mcp/streaming_asgi_transport.py +++ /dev/null @@ -1,85 +0,0 @@ -import asyncio -from httpx import ASGITransport, Request, Response, AsyncByteStream -import anyio - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Any, Callable, MutableMapping - - -class StreamingASGITransport(ASGITransport): - """ - Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing - tests involving SSE interactions to run in-process. - """ - - def __init__( - self, - app: "Callable", - keep_sse_alive: "asyncio.Event", - ) -> None: - self.keep_sse_alive = keep_sse_alive - super().__init__(app) - - async def handle_async_request(self, request: "Request") -> "Response": - scope = { - "type": "http", - "method": request.method, - "headers": [(k.lower(), v) for (k, v) in request.headers.raw], - "path": request.url.path, - "query_string": request.url.query, - } - - is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse" - if not is_streaming_sse: - return await super().handle_async_request(request) - - request_body = b"" - if request.content: - request_body = await request.aread() - - body_sender, body_receiver = anyio.create_memory_object_stream[bytes](0) - - async def receive() -> "dict[str, Any]": - if self.keep_sse_alive.is_set(): - return {"type": "http.disconnect"} - - await self.keep_sse_alive.wait() # Keep alive :) - return {"type": "http.request", "body": request_body, "more_body": False} - - async def send(message: "MutableMapping[str, Any]") -> None: - if message["type"] == "http.response.body": - body = message.get("body", b"") - more_body = message.get("more_body", False) - - if body == b"" and not more_body: - return - - if body: - await body_sender.send(body) - - if not more_body: - await body_sender.aclose() - - async def run_app(): - await self.app(scope, receive, send) - - class StreamingBodyStream(AsyncByteStream): - def __init__(self, receiver, task): - self.receiver = receiver - self.task = task - - async def __aiter__(self): - try: - async for chunk in self.receiver: - yield chunk - except anyio.EndOfStream: - pass - finally: - await self.task - - stream = StreamingBodyStream(body_receiver, asyncio.create_task(run_app())) - response = Response(status_code=200, headers=[], stream=stream) - - return response diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index ab3c2cf73d..4e83c7939c 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -16,12 +16,8 @@ """ import sentry_sdk - -from urllib.parse import urlparse, parse_qs import anyio import asyncio -import httpx -from .streaming_asgi_transport import StreamingASGITransport import pytest import json @@ -138,98 +134,6 @@ def __init__(self, messages): self.messages = messages -async def json_rpc_sse( - app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event" -): - context = {} - - stream_complete = asyncio.Event() - endpoint_parsed = asyncio.Event() - - # https://github.com/Kludex/starlette/issues/104#issuecomment-729087925 - async with httpx.AsyncClient( - transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive), - base_url="http://test", - ) as client: - - async def parse_stream(): - async with client.stream("GET", "/sse") as stream: - # Read directly from stream.stream instead of aiter_bytes() - async for chunk in stream.stream: - if b"event: endpoint" in chunk: - sse_text = chunk.decode("utf-8") - url = sse_text.split("data: ")[1] - - parsed = urlparse(url) - query_params = parse_qs(parsed.query) - context["session_id"] = query_params["session_id"][0] - endpoint_parsed.set() - continue - - if b"event: message" in chunk and b"structuredContent" in chunk: - sse_text = chunk.decode("utf-8") - - json_str = sse_text.split("data: ")[1] - context["response"] = json.loads(json_str) - break - - stream_complete.set() - - task = asyncio.create_task(parse_stream()) - await endpoint_parsed.wait() - - await client.post( - f"/messages/?session_id={context['session_id']}", - headers={ - "Content-Type": "application/json", - }, - json={ - "jsonrpc": "2.0", - "method": "initialize", - "params": { - "clientInfo": {"name": "test-client", "version": "1.0"}, - "protocolVersion": "2025-11-25", - "capabilities": {}, - }, - "id": request_id, - }, - ) - - # Notification response is mandatory. - # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle - await client.post( - f"/messages/?session_id={context['session_id']}", - headers={ - "Content-Type": "application/json", - "mcp-session-id": context["session_id"], - }, - json={ - "jsonrpc": "2.0", - "method": "notifications/initialized", - "params": {}, - }, - ) - - await client.post( - f"/messages/?session_id={context['session_id']}", - headers={ - "Content-Type": "application/json", - "mcp-session-id": context["session_id"], - }, - json={ - "jsonrpc": "2.0", - "method": method, - "params": params, - "id": request_id, - }, - ) - - await stream_complete.wait() - keep_sse_alive.set() - - return task, context["session_id"], context["response"] - - def test_integration_patches_server(sentry_init): """Test that MCPIntegration patches the Server class""" # Get original methods before integration @@ -1186,7 +1090,7 @@ async def async_tool(tool_name, arguments): @pytest.mark.asyncio -async def test_sse_transport_detection(sentry_init, capture_events): +async def test_sse_transport_detection(sentry_init, capture_events, json_rpc_sse): """Test that SSE transport is correctly detected via query parameter""" sentry_init( integrations=[MCPIntegration()], From 7484fd859a0956748f5400f86c8103bbcd8a1e39 Mon Sep 17 00:00:00 2001 From: Alexander Alderman Webb Date: Fri, 30 Jan 2026 15:04:13 +0100 Subject: [PATCH 2/3] re-organize stream parsing --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d4ec78a0eb..9e76bf2d2f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -917,16 +917,16 @@ async def parse_stream(): and b"structuredContent" in chunk ): context["response"] = parse_sse_data_package(chunk) - stream_complete.set() break elif ( "result" in parse_sse_data_package(chunk) and "content" in parse_sse_data_package(chunk)["result"] ): context["response"] = parse_sse_data_package(chunk) - stream_complete.set() break + stream_complete.set() + task = asyncio.create_task(parse_stream()) await endpoint_parsed.wait() From 7c9d6026901e833486d857573cab7ae6b775bfec Mon Sep 17 00:00:00 2001 From: Alexander Alderman Webb Date: Fri, 30 Jan 2026 15:16:12 +0100 Subject: [PATCH 3/3] stop shadowing werkzeug types --- tests/conftest.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9e76bf2d2f..f5e7c67809 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,7 +64,13 @@ JSONRPCRequest, ) from mcp.shared.message import SessionMessage - from httpx import ASGITransport, Request, Response, AsyncByteStream, AsyncClient + from httpx import ( + ASGITransport, + Request as HttpxRequest, + Response as HttpxResponse, + AsyncByteStream, + AsyncClient, + ) except ImportError: create_memory_object_stream = None create_task_group = None @@ -76,8 +82,8 @@ SessionMessage = None ASGITransport = None - Request = None - Response = None + HttpxRequest = None + HttpxResponse = None AsyncByteStream = None AsyncClient = None @@ -814,7 +820,9 @@ def __init__( self.keep_sse_alive = keep_sse_alive super().__init__(app) - async def handle_async_request(self, request: "Request") -> "Response": + async def handle_async_request( + self, request: "HttpxRequest" + ) -> "HttpxResponse": scope = { "type": "http", "method": request.method, @@ -873,7 +881,7 @@ async def __aiter__(self): pass stream = StreamingBodyStream(body_receiver) - response = Response(status_code=200, headers=[], stream=stream) # type: ignore + response = HttpxResponse(status_code=200, headers=[], stream=stream) # type: ignore asyncio.create_task(run_app()) return response