Skip to content

Commit f7daf85

Browse files
committed
backport: harness H3 — connect_in_memory, build_sse_app, connect_over_sse, parse_sse_messages, initialize_body bodies
1 parent a507771 commit f7daf85

1 file changed

Lines changed: 33 additions & 25 deletions

File tree

tests/interaction/_connect.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from mcp.server.streamable_http import EventStore
3232
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
3333
from mcp.server.transport_security import TransportSecuritySettings
34+
from mcp.shared.memory import create_connected_server_and_client_session
3435
from mcp.types import (
3536
LATEST_PROTOCOL_VERSION,
3637
ClientCapabilities,
@@ -94,8 +95,13 @@ async def connect_in_memory(
9495
message_handler: MessageHandlerFnT | None = None,
9596
client_info: Implementation | None = None,
9697
) -> AsyncIterator[ClientSession]:
97-
"""Yield an initialized `ClientSession` connected to the server over the in-memory transport."""
98-
async with Client( # noqa: F821 -- body rewritten in H3
98+
"""Yield an initialized `ClientSession` connected to the server over the in-memory transport.
99+
100+
This is exactly `mcp.shared.memory.create_connected_server_and_client_session` — the
101+
canonical v1 in-memory idiom — re-exported under the suite's `Connect` shape so the
102+
transport matrix can parametrize over it.
103+
"""
104+
async with create_connected_server_and_client_session(
99105
server,
100106
read_timeout_seconds=read_timeout_seconds,
101107
sampling_callback=sampling_callback,
@@ -104,8 +110,8 @@ async def connect_in_memory(
104110
message_handler=message_handler,
105111
client_info=client_info,
106112
elicitation_callback=elicitation_callback,
107-
) as client:
108-
yield client
113+
) as session:
114+
yield session
109115

110116

111117
@asynccontextmanager
@@ -229,7 +235,7 @@ async def client_via_http(
229235

230236
def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]:
231237
"""Decode SSE events into JSON-RPC messages, skipping priming events that carry no data."""
232-
return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] # noqa: F821 -- body rewritten in H3
238+
return [JSONRPCMessage.model_validate_json(event.data) for event in events if event.data]
233239

234240

235241
async def post_jsonrpc(
@@ -268,9 +274,9 @@ def base_headers(*, session_id: str | None = None) -> dict[str, str]:
268274
def initialize_body(request_id: int = 1) -> dict[str, object]:
269275
"""A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it."""
270276
params = InitializeRequestParams(
271-
protocol_version=LATEST_PROTOCOL_VERSION,
277+
protocolVersion=LATEST_PROTOCOL_VERSION,
272278
capabilities=ClientCapabilities(),
273-
client_info=Implementation(name="raw", version="0.0.0"),
279+
clientInfo=Implementation(name="raw", version="0.0.0"),
274280
)
275281
return JSONRPCRequest(
276282
jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True)
@@ -302,17 +308,15 @@ async def initialize_via_http(http: httpx.AsyncClient) -> str:
302308
def build_sse_app(server: Server[Any] | FastMCP) -> tuple[Starlette, SseServerTransport]:
303309
"""Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/.
304310
305-
`MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which
311+
`FastMCP.sse_app()` exists but does not expose the underlying `SseServerTransport`, which
306312
the SSE-specific tests need; building the app explicitly here gives both server flavours the
307313
same routing while keeping that handle.
308314
"""
309-
sse = SseServerTransport(
310-
"/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False)
311-
)
312-
lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server # noqa: F821 -- body rewritten in H3
315+
sse = SseServerTransport("/messages/", security_settings=NO_DNS_REBINDING_PROTECTION)
316+
lowlevel = _lowlevel(server)
313317

314318
async def handle_sse(request: Request) -> Response:
315-
async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write):
319+
async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): # type: ignore[reportPrivateUsage]
316320
await lowlevel.run(read, write, lowlevel.create_initialization_options())
317321
return Response()
318322

@@ -356,15 +360,19 @@ def httpx_client_factory(
356360
auth=auth,
357361
)
358362

359-
transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory)
360-
async with Client( # noqa: F821 -- body rewritten in H3
361-
transport,
362-
read_timeout_seconds=read_timeout_seconds,
363-
sampling_callback=sampling_callback,
364-
list_roots_callback=list_roots_callback,
365-
logging_callback=logging_callback,
366-
message_handler=message_handler,
367-
client_info=client_info,
368-
elicitation_callback=elicitation_callback,
369-
) as client:
370-
yield client
363+
async with (
364+
sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) as (read, write),
365+
ClientSession(
366+
read,
367+
write,
368+
read_timeout_seconds=read_timeout_seconds,
369+
sampling_callback=sampling_callback,
370+
list_roots_callback=list_roots_callback,
371+
logging_callback=logging_callback,
372+
message_handler=message_handler,
373+
client_info=client_info,
374+
elicitation_callback=elicitation_callback,
375+
) as session,
376+
):
377+
await session.initialize()
378+
yield session

0 commit comments

Comments
 (0)