Skip to content

Commit 90e8720

Browse files
feat: add message middleware support for ClientSession and ServerSession
Add a middleware pattern that allows transforming JSON-RPC messages before sending and after receiving. This provides a clean way to extend protocol messages (e.g., adding custom capabilities to initialize requests) without needing to subclass or override session methods. Middleware functions receive a JSONRPCMessage and return a (possibly transformed) JSONRPCMessage. Both sync and async middleware are supported. Usage example: def add_capabilities(message: JSONRPCMessage) -> JSONRPCMessage: if isinstance(message.root, JSONRPCRequest): # Transform the message... pass return message session = ClientSession( read_stream, write_stream, send_middleware=[add_capabilities], ) Changes: - Add MessageMiddleware type alias in mcp.shared.session - Add send_middleware and receive_middleware parameters to BaseSession - Apply middleware in send_request, send_notification, _send_response - Apply middleware in _receive_loop after receiving messages - Export MessageMiddleware, JSONRPCMessage, JSONRPCNotification from mcp - Add tests for sync and async middleware
1 parent ef96a31 commit 90e8720

File tree

5 files changed

+208
-8
lines changed

5 files changed

+208
-8
lines changed

src/mcp/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .server.session import ServerSession
55
from .server.stdio import stdio_server
66
from .shared.exceptions import McpError, UrlElicitationRequiredError
7+
from .shared.session import MessageMiddleware
78
from .types import (
89
CallToolRequest,
910
ClientCapabilities,
@@ -23,6 +24,8 @@
2324
InitializeRequest,
2425
InitializeResult,
2526
JSONRPCError,
27+
JSONRPCMessage,
28+
JSONRPCNotification,
2629
JSONRPCRequest,
2730
JSONRPCResponse,
2831
ListPromptsRequest,
@@ -87,8 +90,11 @@
8790
"InitializeResult",
8891
"InitializedNotification",
8992
"JSONRPCError",
93+
"JSONRPCMessage",
94+
"JSONRPCNotification",
9095
"JSONRPCRequest",
9196
"JSONRPCResponse",
97+
"MessageMiddleware",
9298
"ListPromptsRequest",
9399
"ListPromptsResult",
94100
"ListResourcesRequest",

src/mcp/client/session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1313
from mcp.shared.context import RequestContext
1414
from mcp.shared.message import SessionMessage
15-
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
15+
from mcp.shared.session import BaseSession, MessageMiddleware, ProgressFnT, RequestResponder
1616
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1717

1818
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -123,13 +123,17 @@ def __init__(
123123
*,
124124
sampling_capabilities: types.SamplingCapability | None = None,
125125
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
126+
send_middleware: list["MessageMiddleware"] | None = None,
127+
receive_middleware: list["MessageMiddleware"] | None = None,
126128
) -> None:
127129
super().__init__(
128130
read_stream,
129131
write_stream,
130132
types.ServerRequest,
131133
types.ServerNotification,
132134
read_timeout_seconds=read_timeout_seconds,
135+
send_middleware=send_middleware,
136+
receive_middleware=receive_middleware,
133137
)
134138
self._client_info = client_info or DEFAULT_CLIENT_INFO
135139
self._sampling_callback = sampling_callback or _default_sampling_callback

src/mcp/server/session.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
5454
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5555
from mcp.shared.session import (
5656
BaseSession,
57+
MessageMiddleware,
5758
RequestResponder,
5859
)
5960
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -91,8 +92,18 @@ def __init__(
9192
write_stream: MemoryObjectSendStream[SessionMessage],
9293
init_options: InitializationOptions,
9394
stateless: bool = False,
95+
*,
96+
send_middleware: list["MessageMiddleware"] | None = None,
97+
receive_middleware: list["MessageMiddleware"] | None = None,
9498
) -> None:
95-
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
99+
super().__init__(
100+
read_stream,
101+
write_stream,
102+
types.ClientRequest,
103+
types.ClientNotification,
104+
send_middleware=send_middleware,
105+
receive_middleware=receive_middleware,
106+
)
96107
self._initialization_state = (
97108
InitializationState.Initialized if stateless else InitializationState.NotInitialized
98109
)

src/mcp/shared/session.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from collections.abc import Callable
2+
from collections.abc import Awaitable, Callable
33
from contextlib import AsyncExitStack
44
from datetime import timedelta
55
from types import TracebackType
@@ -43,6 +43,10 @@
4343

4444
RequestId = str | int
4545

46+
# Middleware type for transforming messages before sending or after receiving.
47+
# Can be sync (returns JSONRPCMessage) or async (returns Awaitable[JSONRPCMessage]).
48+
MessageMiddleware = Callable[[JSONRPCMessage], JSONRPCMessage | Awaitable[JSONRPCMessage]]
49+
4650

4751
class ProgressFnT(Protocol):
4852
"""Protocol for progress notification callbacks."""
@@ -190,6 +194,9 @@ def __init__(
190194
receive_notification_type: type[ReceiveNotificationT],
191195
# If none, reading will never time out
192196
read_timeout_seconds: timedelta | None = None,
197+
*,
198+
send_middleware: list[MessageMiddleware] | None = None,
199+
receive_middleware: list[MessageMiddleware] | None = None,
193200
) -> None:
194201
self._read_stream = read_stream
195202
self._write_stream = write_stream
@@ -202,6 +209,22 @@ def __init__(
202209
self._progress_callbacks = {}
203210
self._response_routers = []
204211
self._exit_stack = AsyncExitStack()
212+
self._send_middleware = send_middleware or []
213+
self._receive_middleware = receive_middleware or []
214+
215+
async def _apply_middleware(
216+
self, message: JSONRPCMessage, middleware_list: list[MessageMiddleware]
217+
) -> JSONRPCMessage:
218+
"""Apply a list of middleware functions to a message."""
219+
import inspect
220+
221+
for middleware in middleware_list:
222+
result = middleware(message)
223+
if inspect.isawaitable(result):
224+
message = await result
225+
else:
226+
message = result # type: ignore[assignment]
227+
return message
205228

206229
def add_response_router(self, router: ResponseRouter) -> None:
207230
"""
@@ -278,7 +301,9 @@ async def send_request(
278301
**request_data,
279302
)
280303

281-
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
304+
message = JSONRPCMessage(jsonrpc_request)
305+
message = await self._apply_middleware(message, self._send_middleware)
306+
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))
282307

283308
# request read timeout takes precedence over session read timeout
284309
timeout = None
@@ -328,24 +353,30 @@ async def send_notification(
328353
jsonrpc="2.0",
329354
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
330355
)
356+
message = JSONRPCMessage(jsonrpc_notification)
357+
message = await self._apply_middleware(message, self._send_middleware)
331358
session_message = SessionMessage( # pragma: no cover
332-
message=JSONRPCMessage(jsonrpc_notification),
359+
message=message,
333360
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
334361
)
335362
await self._write_stream.send(session_message)
336363

337364
async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
338365
if isinstance(response, ErrorData):
339366
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
340-
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
367+
message = JSONRPCMessage(jsonrpc_error)
368+
message = await self._apply_middleware(message, self._send_middleware)
369+
session_message = SessionMessage(message=message)
341370
await self._write_stream.send(session_message)
342371
else:
343372
jsonrpc_response = JSONRPCResponse(
344373
jsonrpc="2.0",
345374
id=request_id,
346375
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
347376
)
348-
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
377+
message = JSONRPCMessage(jsonrpc_response)
378+
message = await self._apply_middleware(message, self._send_middleware)
379+
session_message = SessionMessage(message=message)
349380
await self._write_stream.send(session_message)
350381

351382
async def _receive_loop(self) -> None:
@@ -357,7 +388,14 @@ async def _receive_loop(self) -> None:
357388
async for message in self._read_stream:
358389
if isinstance(message, Exception): # pragma: no cover
359390
await self._handle_incoming(message)
360-
elif isinstance(message.message.root, JSONRPCRequest):
391+
continue
392+
393+
# Apply receive middleware to transform the message
394+
if self._receive_middleware:
395+
transformed_msg = await self._apply_middleware(message.message, self._receive_middleware)
396+
message = SessionMessage(message=transformed_msg, metadata=message.metadata) # noqa: PLW2901
397+
398+
if isinstance(message.message.root, JSONRPCRequest):
361399
try:
362400
validated_request = self._receive_request_type.model_validate(
363401
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)

tests/client/test_session.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,144 @@ async def mock_server():
768768
await session.initialize()
769769

770770
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)
771+
772+
773+
@pytest.mark.anyio
774+
async def test_client_session_send_middleware():
775+
"""Test that send middleware can transform outgoing messages."""
776+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
777+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
778+
779+
received_request = None
780+
middleware_called = False
781+
782+
def add_custom_field(message: JSONRPCMessage) -> JSONRPCMessage:
783+
"""Middleware that adds a custom field to initialize request params."""
784+
nonlocal middleware_called
785+
middleware_called = True
786+
787+
if isinstance(message.root, JSONRPCRequest):
788+
# Add custom extension to the capabilities
789+
data = message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
790+
if data.get("method") == "initialize" and "params" in data:
791+
if "capabilities" not in data["params"]:
792+
data["params"]["capabilities"] = {}
793+
# Add a custom extension field
794+
data["params"]["capabilities"]["customExtension"] = {"enabled": True}
795+
return JSONRPCMessage(JSONRPCRequest(**data))
796+
return message
797+
798+
async def mock_server():
799+
nonlocal received_request
800+
801+
session_message = await client_to_server_receive.receive()
802+
jsonrpc_request = session_message.message
803+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
804+
received_request = jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
805+
806+
result = ServerResult(
807+
InitializeResult(
808+
protocolVersion=LATEST_PROTOCOL_VERSION,
809+
capabilities=ServerCapabilities(),
810+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
811+
)
812+
)
813+
814+
async with server_to_client_send:
815+
await server_to_client_send.send(
816+
SessionMessage(
817+
JSONRPCMessage(
818+
JSONRPCResponse(
819+
jsonrpc="2.0",
820+
id=jsonrpc_request.root.id,
821+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
822+
)
823+
)
824+
)
825+
)
826+
# Receive initialized notification
827+
await client_to_server_receive.receive()
828+
829+
async with (
830+
ClientSession(
831+
server_to_client_receive,
832+
client_to_server_send,
833+
send_middleware=[add_custom_field],
834+
) as session,
835+
anyio.create_task_group() as tg,
836+
client_to_server_send,
837+
client_to_server_receive,
838+
server_to_client_send,
839+
server_to_client_receive,
840+
):
841+
tg.start_soon(mock_server)
842+
await session.initialize()
843+
844+
# Verify middleware was called and transformed the request
845+
assert middleware_called
846+
assert received_request is not None
847+
assert "params" in received_request
848+
assert "capabilities" in received_request["params"]
849+
assert "customExtension" in received_request["params"]["capabilities"]
850+
assert received_request["params"]["capabilities"]["customExtension"] == {"enabled": True}
851+
852+
853+
@pytest.mark.anyio
854+
async def test_client_session_async_middleware():
855+
"""Test that async middleware works correctly."""
856+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
857+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
858+
859+
middleware_called = False
860+
861+
async def async_middleware(message: JSONRPCMessage) -> JSONRPCMessage:
862+
"""Async middleware that just passes through."""
863+
nonlocal middleware_called
864+
middleware_called = True
865+
# Simulate some async work
866+
await anyio.sleep(0)
867+
return message
868+
869+
async def mock_server():
870+
session_message = await client_to_server_receive.receive()
871+
jsonrpc_request = session_message.message
872+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
873+
874+
result = ServerResult(
875+
InitializeResult(
876+
protocolVersion=LATEST_PROTOCOL_VERSION,
877+
capabilities=ServerCapabilities(),
878+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
879+
)
880+
)
881+
882+
async with server_to_client_send:
883+
await server_to_client_send.send(
884+
SessionMessage(
885+
JSONRPCMessage(
886+
JSONRPCResponse(
887+
jsonrpc="2.0",
888+
id=jsonrpc_request.root.id,
889+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
890+
)
891+
)
892+
)
893+
)
894+
await client_to_server_receive.receive()
895+
896+
async with (
897+
ClientSession(
898+
server_to_client_receive,
899+
client_to_server_send,
900+
send_middleware=[async_middleware],
901+
) as session,
902+
anyio.create_task_group() as tg,
903+
client_to_server_send,
904+
client_to_server_receive,
905+
server_to_client_send,
906+
server_to_client_receive,
907+
):
908+
tg.start_soon(mock_server)
909+
await session.initialize()
910+
911+
assert middleware_called

0 commit comments

Comments
 (0)