11"""Shared helpers for the interaction suite.
22
33Keep this module small: it exists only for (a) types that every test would otherwise have to
4- assemble from the SDK's internals to annotate a client callback, and (b) the recording transport
4+ assemble from the SDK's internals to annotate a client callback, and (b) the recording wrapper
55used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses
66them.
77"""
88
99from types import TracebackType
1010
1111import anyio
12+ from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1213from typing_extensions import Self
1314
14- from mcp .client ._transport import ReadStream , Transport , TransportStreams , WriteStream
1515from mcp .shared .message import SessionMessage
1616from mcp .shared .session import RequestResponder
1717from mcp .types import ClientResult , ServerNotification , ServerRequest
2424IncomingMessage = RequestResponder [ServerRequest , ClientResult ] | ServerNotification | Exception
2525"""Everything a client message handler can receive."""
2626
27+ ReadStream = MemoryObjectReceiveStream [SessionMessage | Exception ]
28+ WriteStream = MemoryObjectSendStream [SessionMessage ]
29+ """Local aliases for the v1 SDK's session-stream types (v1 has no exported `ReadStream`/
30+ `WriteStream` names); exported so wire-level / scripted-peer tests can annotate without
31+ reaching into anyio."""
32+
2733
2834class _RecordingReadStream :
2935 """Delegates to a read stream, appending every received message to a log."""
3036
31- def __init__ (self , inner : ReadStream [ SessionMessage | Exception ] , log : list [SessionMessage | Exception ]) -> None :
37+ def __init__ (self , inner : ReadStream , log : list [SessionMessage | Exception ]) -> None :
3238 self ._inner = inner
3339 self ._log = log
3440
@@ -62,7 +68,7 @@ async def __aexit__(
6268class _RecordingWriteStream :
6369 """Delegates to a write stream, appending every sent message to a log."""
6470
65- def __init__ (self , inner : WriteStream [ SessionMessage ] , log : list [SessionMessage ]) -> None :
71+ def __init__ (self , inner : WriteStream , log : list [SessionMessage ]) -> None :
6672 self ._inner = inner
6773 self ._log = log
6874
@@ -83,25 +89,22 @@ async def __aexit__(
8389 return None
8490
8591
86- class RecordingTransport :
87- """Wraps a Transport and records every message crossing the client's transport boundary .
92+ class Recording :
93+ """Wraps a (read, write) stream pair and records every message crossing it .
8894
8995 `sent` holds everything the client wrote towards the server; `received` holds everything the
9096 server delivered to the client. The recording sits at the transport seam -- the exact payloads
9197 a real transport would serialise -- and never touches the session, so wire-level assertions
9298 written against it survive changes to the receive path.
99+
100+ v1 has no `Transport` abstraction; tests insert this between
101+ `create_client_server_memory_streams()` and `ClientSession`.
93102 """
94103
95- def __init__ (self , inner : Transport ) -> None :
96- self .inner = inner
104+ def __init__ (self , read : ReadStream , write : WriteStream ) -> None :
97105 self .sent : list [SessionMessage ] = []
98106 self .received : list [SessionMessage | Exception ] = []
99-
100- async def __aenter__ (self ) -> TransportStreams :
101- read_stream , write_stream = await self .inner .__aenter__ ()
102- return _RecordingReadStream (read_stream , self .received ), _RecordingWriteStream (write_stream , self .sent )
103-
104- async def __aexit__ (
105- self , exc_type : type [BaseException ] | None , exc_val : BaseException | None , exc_tb : TracebackType | None
106- ) -> bool | None :
107- return await self .inner .__aexit__ (exc_type , exc_val , exc_tb )
107+ # Duck-typed stand-ins for the anyio stream classes; ClientSession only calls
108+ # .receive()/.send()/.aclose() so the runtime contract holds.
109+ self .read : ReadStream = _RecordingReadStream (read , self .received ) # type: ignore[assignment]
110+ self .write : WriteStream = _RecordingWriteStream (write , self .sent ) # type: ignore[assignment]
0 commit comments