Skip to content

Commit a507771

Browse files
committed
backport: harness H2 — _connect.py imports/annotations + conftest marker registration
- conftest.py: add pytest_configure registering the 'requirement' marker (round-1 adv-3 S1 fix; without this every file is a PytestUnknownMarkWarning collection error under filterwarnings=['error']). - _connect.py imports: drop Client/MCPServer/jsonrpc_message_adapter (not in v1); add timedelta, ClientSession, FastMCP. Mount KEPT (adversarial-v2-gate S1: SSE build_sse_app needs it). - _connect.py annotations: Connect Protocol + all factory signatures retyped to Server[Any]|FastMCP / timedelta / ClientSession; kwarg order matches v1 ClientSession.__init__; add _lowlevel() helper. Function bodies untouched per plan-v2 H2; the 7 dead Client/MCPServer/adapter refs carry temporary noqa:F821 until H3/H4/H5 rewrite them.
1 parent 92f2bb4 commit a507771

2 files changed

Lines changed: 57 additions & 40 deletions

File tree

tests/interaction/_connect.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Transport-parametrized connection factories for the interaction suite.
22
33
The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body
4-
runs over each transport without naming any of them: the factory is a drop-in replacement for
5-
constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the
6-
server's real Starlette app through the in-process streaming bridge, so the full transport layer
7-
(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses.
4+
runs over each transport without naming any of them: the factory yields an initialized
5+
`ClientSession` connected to the given server. v1 has no high-level `Client` class —
6+
`ClientSession` *is* the client. The HTTP factories drive the server's real Starlette app through
7+
the in-process streaming bridge, so the full transport layer (session ids, SSE encoding, session
8+
management) runs with no sockets, threads, or subprocesses.
89
"""
910

1011
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
1112
from contextlib import AbstractAsyncContextManager, asynccontextmanager
13+
from datetime import timedelta
1214
from typing import Any, Protocol
1315

1416
import httpx
@@ -18,14 +20,13 @@
1820
from starlette.responses import Response
1921
from starlette.routing import Mount, Route
2022

21-
from mcp.client.client import Client
22-
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
23+
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
2324
from mcp.client.sse import sse_client
2425
from mcp.client.streamable_http import streamable_http_client
2526
from mcp.server import Server
2627
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
2728
from mcp.server.auth.settings import AuthSettings
28-
from mcp.server.mcpserver import MCPServer
29+
from mcp.server.fastmcp import FastMCP
2930
from mcp.server.sse import SseServerTransport
3031
from mcp.server.streamable_http import EventStore
3132
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
@@ -38,7 +39,6 @@
3839
JSONRPCMessage,
3940
JSONRPCRequest,
4041
JSONRPCResponse,
41-
jsonrpc_message_adapter,
4242
)
4343
from tests.interaction.transports._bridge import StreamingASGITransport
4444

@@ -52,40 +52,50 @@
5252
NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False)
5353

5454

55+
def _lowlevel(server: Server[Any] | FastMCP) -> Server[Any]:
56+
"""Return the lowlevel `Server` for either flavour.
57+
58+
Reaching `FastMCP._mcp_server` is the v1 idiom — `mcp.shared.memory` itself does exactly
59+
this (with the same `# type: ignore`).
60+
"""
61+
return server._mcp_server if isinstance(server, FastMCP) else server # type: ignore[reportPrivateUsage]
62+
63+
5564
class Connect(Protocol):
56-
"""Connect a Client to a server over the transport selected by the `connect` fixture.
65+
"""Connect a `ClientSession` to a server over the transport selected by the `connect` fixture.
5766
58-
Accepts the same keyword arguments as `Client` and yields the connected client.
67+
Accepts the same callback keyword arguments as `ClientSession` and yields the connected,
68+
initialized session.
5969
"""
6070

6171
def __call__(
6272
self,
63-
server: Server | MCPServer,
73+
server: Server[Any] | FastMCP,
6474
*,
65-
read_timeout_seconds: float | None = None,
75+
read_timeout_seconds: timedelta | None = None,
6676
sampling_callback: SamplingFnT | None = None,
77+
elicitation_callback: ElicitationFnT | None = None,
6778
list_roots_callback: ListRootsFnT | None = None,
6879
logging_callback: LoggingFnT | None = None,
6980
message_handler: MessageHandlerFnT | None = None,
7081
client_info: Implementation | None = None,
71-
elicitation_callback: ElicitationFnT | None = None,
72-
) -> AbstractAsyncContextManager[Client]: ...
82+
) -> AbstractAsyncContextManager[ClientSession]: ...
7383

7484

7585
@asynccontextmanager
7686
async def connect_in_memory(
77-
server: Server | MCPServer,
87+
server: Server[Any] | FastMCP,
7888
*,
79-
read_timeout_seconds: float | None = None,
89+
read_timeout_seconds: timedelta | None = None,
8090
sampling_callback: SamplingFnT | None = None,
91+
elicitation_callback: ElicitationFnT | None = None,
8192
list_roots_callback: ListRootsFnT | None = None,
8293
logging_callback: LoggingFnT | None = None,
8394
message_handler: MessageHandlerFnT | None = None,
8495
client_info: Implementation | None = None,
85-
elicitation_callback: ElicitationFnT | None = None,
86-
) -> AsyncIterator[Client]:
87-
"""Yield a Client connected to the server over the in-memory transport."""
88-
async with Client(
96+
) -> AsyncIterator[ClientSession]:
97+
"""Yield an initialized `ClientSession` connected to the server over the in-memory transport."""
98+
async with Client( # noqa: F821 -- body rewritten in H3
8999
server,
90100
read_timeout_seconds=read_timeout_seconds,
91101
sampling_callback=sampling_callback,
@@ -100,21 +110,21 @@ async def connect_in_memory(
100110

101111
@asynccontextmanager
102112
async def connect_over_streamable_http(
103-
server: Server | MCPServer,
113+
server: Server[Any] | FastMCP,
104114
*,
105115
stateless_http: bool = False,
106116
json_response: bool = False,
107117
event_store: EventStore | None = None,
108118
retry_interval: int | None = None,
109-
read_timeout_seconds: float | None = None,
119+
read_timeout_seconds: timedelta | None = None,
110120
sampling_callback: SamplingFnT | None = None,
121+
elicitation_callback: ElicitationFnT | None = None,
111122
list_roots_callback: ListRootsFnT | None = None,
112123
logging_callback: LoggingFnT | None = None,
113124
message_handler: MessageHandlerFnT | None = None,
114125
client_info: Implementation | None = None,
115-
elicitation_callback: ElicitationFnT | None = None,
116-
) -> AsyncIterator[Client]:
117-
"""Yield a Client connected to the server's streamable HTTP app, entirely in process.
126+
) -> AsyncIterator[ClientSession]:
127+
"""Yield an initialized `ClientSession` over the server's streamable HTTP app, entirely in process.
118128
119129
With the defaults this is the matrix leg (stateful sessions, SSE responses); the
120130
transport-specific tests pass `stateless_http` or `json_response` to select the other
@@ -131,7 +141,7 @@ async def connect_over_streamable_http(
131141
async with (
132142
server.session_manager.run(),
133143
httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client,
134-
Client(
144+
Client( # noqa: F821 -- body rewritten in H4
135145
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client),
136146
read_timeout_seconds=read_timeout_seconds,
137147
sampling_callback=sampling_callback,
@@ -147,7 +157,7 @@ async def connect_over_streamable_http(
147157

148158
@asynccontextmanager
149159
async def mounted_app(
150-
server: Server | MCPServer,
160+
server: Server[Any] | FastMCP,
151161
*,
152162
stateless_http: bool = False,
153163
json_response: bool = False,
@@ -172,7 +182,7 @@ async def mounted_app(
172182
DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the
173183
localhost auto-enable behaviour) to test the protection itself.
174184
"""
175-
lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server
185+
lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server # noqa: F821 -- body rewritten in H5
176186
app = lowlevel.streamable_http_app(
177187
stateless_http=stateless_http,
178188
json_response=json_response,
@@ -200,15 +210,15 @@ async def client_via_http(
200210
logging_callback: LoggingFnT | None = None,
201211
message_handler: MessageHandlerFnT | None = None,
202212
elicitation_callback: ElicitationFnT | None = None,
203-
) -> AsyncIterator[Client]:
204-
"""Connect a `Client` over an already-mounted streamable HTTP app.
213+
) -> AsyncIterator[ClientSession]:
214+
"""Connect a `ClientSession` over an already-mounted streamable HTTP app.
205215
206216
Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a
207217
client-driven assertion can sit alongside raw-httpx assertions in the same test. The
208218
underlying `httpx.AsyncClient` is left open when the `Client` exits.
209219
"""
210220
transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)
211-
async with Client(
221+
async with Client( # noqa: F821 -- body rewritten in H4
212222
transport,
213223
logging_callback=logging_callback,
214224
message_handler=message_handler,
@@ -219,7 +229,7 @@ async def client_via_http(
219229

220230
def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]:
221231
"""Decode SSE events into JSON-RPC messages, skipping priming events that carry no data."""
222-
return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data]
232+
return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] # noqa: F821 -- body rewritten in H3
223233

224234

225235
async def post_jsonrpc(
@@ -289,7 +299,7 @@ async def initialize_via_http(http: httpx.AsyncClient) -> str:
289299
return session_id
290300

291301

292-
def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]:
302+
def build_sse_app(server: Server[Any] | FastMCP) -> tuple[Starlette, SseServerTransport]:
293303
"""Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/.
294304
295305
`MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which
@@ -299,7 +309,7 @@ def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTrans
299309
sse = SseServerTransport(
300310
"/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False)
301311
)
302-
lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server
312+
lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server # noqa: F821 -- body rewritten in H3
303313

304314
async def handle_sse(request: Request) -> Response:
305315
async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write):
@@ -317,17 +327,17 @@ async def handle_sse(request: Request) -> Response:
317327

318328
@asynccontextmanager
319329
async def connect_over_sse(
320-
server: Server | MCPServer,
330+
server: Server[Any] | FastMCP,
321331
*,
322-
read_timeout_seconds: float | None = None,
332+
read_timeout_seconds: timedelta | None = None,
323333
sampling_callback: SamplingFnT | None = None,
334+
elicitation_callback: ElicitationFnT | None = None,
324335
list_roots_callback: ListRootsFnT | None = None,
325336
logging_callback: LoggingFnT | None = None,
326337
message_handler: MessageHandlerFnT | None = None,
327338
client_info: Implementation | None = None,
328-
elicitation_callback: ElicitationFnT | None = None,
329-
) -> AsyncIterator[Client]:
330-
"""Yield a Client connected to the server's legacy SSE transport, entirely in process."""
339+
) -> AsyncIterator[ClientSession]:
340+
"""Yield an initialized `ClientSession` over the server's legacy SSE transport, entirely in process."""
331341
app, _ = build_sse_app(server)
332342

333343
def httpx_client_factory(
@@ -347,7 +357,7 @@ def httpx_client_factory(
347357
)
348358

349359
transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory)
350-
async with Client(
360+
async with Client( # noqa: F821 -- body rewritten in H3
351361
transport,
352362
read_timeout_seconds=read_timeout_seconds,
353363
sampling_callback=sampling_callback,

tests/interaction/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http
66

7+
8+
def pytest_configure(config: pytest.Config) -> None:
9+
config.addinivalue_line(
10+
"markers", "requirement(id): tag a test as covering an entry in tests/interaction/_requirements.py"
11+
)
12+
13+
714
_FACTORIES: dict[str, Connect] = {
815
"in-memory": connect_in_memory,
916
"streamable-http": connect_over_streamable_http,

0 commit comments

Comments
 (0)