1- """Cancellation interactions against the low-level Server, driven through the public Client API.
1+ """Cancellation interactions against the low-level Server, driven through the public client API.
22
33There is no client-side cancellation API: cancelling means sending a CancelledNotification
4- carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so
5- these tests capture the id from inside the blocked handler before cancelling. The handler blocks
6- on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`.
4+ carrying the request id, which only the server-side handler can observe (via
5+ `server.request_context.request_id`), so these tests capture the id from inside the blocked
6+ handler before cancelling. The handler blocks on an Event rather than a sleep, and every wait
7+ is bounded by `anyio.fail_after`.
78"""
89
10+ from typing import Any
11+
912import anyio
1013import pytest
1114from inline_snapshot import snapshot
1215
13- from mcp import MCPError , types
14- from mcp .client import ClientSession
15- from mcp .server import Server , ServerRequestContext
16+ from mcp import McpError , types
17+ from mcp .client . session import ClientSession
18+ from mcp .server . lowlevel import Server
1619from mcp .shared .memory import MessageStream , create_client_server_memory_streams
1720from mcp .shared .message import SessionMessage
1821from mcp .types import (
1922 CallToolResult ,
23+ ClientNotification ,
24+ ClientRequest ,
2025 EmptyResult ,
2126 ErrorData ,
2227 Implementation ,
2328 InitializeResult ,
29+ JSONRPCMessage ,
2430 JSONRPCNotification ,
2531 JSONRPCRequest ,
2632 JSONRPCResponse ,
@@ -49,8 +55,12 @@ async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None:
4955 request_ids : list [types .RequestId ] = []
5056 errors : list [ErrorData ] = []
5157
52- async def call_tool (ctx : ServerRequestContext , params : types .CallToolRequestParams ) -> CallToolResult :
53- assert params .name == "block"
58+ server : Server [Any ] = Server ("blocker" )
59+
60+ @server .call_tool ()
61+ async def call_tool (name : str , arguments : dict [str , Any ]) -> CallToolResult :
62+ assert name == "block"
63+ ctx = server .request_context
5464 assert ctx .request_id is not None
5565 request_ids .append (ctx .request_id )
5666 started .set ()
@@ -61,22 +71,22 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
6171 raise
6272 raise NotImplementedError # unreachable: the wait above never completes normally
6373
64- server = Server ("blocker" , on_call_tool = call_tool )
65-
6674 async with connect (server ) as client :
6775 with anyio .fail_after (5 ):
6876 async with anyio .create_task_group () as task_group :
6977
7078 async def call_and_capture_error () -> None :
71- with pytest .raises (MCPError ) as exc_info :
79+ with pytest .raises (McpError ) as exc_info :
7280 await client .call_tool ("block" , {})
7381 errors .append (exc_info .value .error )
7482
7583 task_group .start_soon (call_and_capture_error )
7684 await started .wait ()
77- await client .session .send_notification (
78- types .CancelledNotification (
79- params = types .CancelledNotificationParams (requestId = request_ids [0 ], reason = "user aborted" )
85+ await client .send_notification (
86+ ClientNotification (
87+ types .CancelledNotification (
88+ params = types .CancelledNotificationParams (requestId = request_ids [0 ], reason = "user aborted" )
89+ )
8090 )
8191 )
8292
@@ -91,39 +101,40 @@ async def test_session_serves_requests_after_cancellation(connect: Connect) -> N
91101 started = anyio .Event ()
92102 request_ids : list [types .RequestId ] = []
93103
94- async def list_tools (
95- ctx : ServerRequestContext , params : types .PaginatedRequestParams | None
96- ) -> types .ListToolsResult :
97- return types .ListToolsResult (
98- tools = [
99- types .Tool (name = "block" , inputSchema = {"type" : "object" }),
100- types .Tool (name = "echo" , inputSchema = {"type" : "object" }),
101- ]
102- )
104+ server : Server [Any ] = Server ("blocker" )
105+
106+ @server .list_tools ()
107+ async def list_tools () -> list [types .Tool ]:
108+ return [
109+ types .Tool (name = "block" , inputSchema = {"type" : "object" }),
110+ types .Tool (name = "echo" , inputSchema = {"type" : "object" }),
111+ ]
103112
104- async def call_tool (ctx : ServerRequestContext , params : types .CallToolRequestParams ) -> CallToolResult :
105- if params .name == "echo" :
113+ @server .call_tool ()
114+ async def call_tool (name : str , arguments : dict [str , Any ]) -> CallToolResult :
115+ if name == "echo" :
106116 return CallToolResult (content = [TextContent (type = "text" , text = "still alive" )])
117+ ctx = server .request_context
107118 assert ctx .request_id is not None
108119 request_ids .append (ctx .request_id )
109120 started .set ()
110121 await anyio .Event ().wait () # blocks until cancelled
111122 raise NotImplementedError # unreachable
112123
113- server = Server ("blocker" , on_list_tools = list_tools , on_call_tool = call_tool )
114-
115124 async with connect (server ) as client :
116125 with anyio .fail_after (5 ):
117126 async with anyio .create_task_group () as task_group :
118127
119128 async def call_and_swallow_cancellation_error () -> None :
120- with pytest .raises (MCPError ):
129+ with pytest .raises (McpError ):
121130 await client .call_tool ("block" , {})
122131
123132 task_group .start_soon (call_and_swallow_cancellation_error )
124133 await started .wait ()
125- await client .session .send_notification (
126- types .CancelledNotification (params = types .CancelledNotificationParams (requestId = request_ids [0 ]))
134+ await client .send_notification (
135+ ClientNotification (
136+ types .CancelledNotification (params = types .CancelledNotificationParams (requestId = request_ids [0 ]))
137+ )
127138 )
128139
129140 result = await client .call_tool ("echo" , {})
@@ -135,20 +146,20 @@ async def call_and_swallow_cancellation_error() -> None:
135146async def test_cancellation_for_unknown_request_is_ignored (connect : Connect ) -> None :
136147 """A cancellation referencing a request id that is not in flight is ignored without error."""
137148
138- async def list_tools (
139- ctx : ServerRequestContext , params : types .PaginatedRequestParams | None
140- ) -> types .ListToolsResult :
141- return types .ListToolsResult (tools = [types .Tool (name = "echo" , inputSchema = {"type" : "object" })])
149+ server : Server [Any ] = Server ("calm" )
142150
143- async def call_tool ( ctx : ServerRequestContext , params : types . CallToolRequestParams ) -> CallToolResult :
144- assert params . name == "echo"
145- return CallToolResult ( content = [ TextContent ( type = "text " , text = "unbothered" )])
151+ @ server . list_tools ()
152+ async def list_tools () -> list [ types . Tool ]:
153+ return [ types . Tool ( name = "echo " , inputSchema = { "type" : "object" })]
146154
147- server = Server ("calm" , on_list_tools = list_tools , on_call_tool = call_tool )
155+ @server .call_tool ()
156+ async def call_tool (name : str , arguments : dict [str , Any ]) -> CallToolResult :
157+ assert name == "echo"
158+ return CallToolResult (content = [TextContent (type = "text" , text = "unbothered" )])
148159
149160 async with connect (server ) as client :
150- await client .session . send_notification (
151- types .CancelledNotification (params = types .CancelledNotificationParams (requestId = 9999 ))
161+ await client .send_notification (
162+ ClientNotification ( types .CancelledNotification (params = types .CancelledNotificationParams (requestId = 9999 ) ))
152163 )
153164 result = await client .call_tool ("echo" , {})
154165
@@ -176,21 +187,23 @@ async def scripted_server(streams: MessageStream) -> None:
176187
177188 def respond (request_id : types .RequestId , result : types .Result ) -> SessionMessage :
178189 return SessionMessage (
179- JSONRPCResponse (
180- jsonrpc = "2.0" ,
181- id = request_id ,
182- # Serialized exactly as a real server serializes results onto the wire.
183- result = result .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
190+ JSONRPCMessage (
191+ JSONRPCResponse (
192+ jsonrpc = "2.0" ,
193+ id = request_id ,
194+ # Serialized exactly as a real server serializes results onto the wire.
195+ result = result .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
196+ )
184197 )
185198 )
186199
187200 init = await server_read .receive ()
188201 assert isinstance (init , SessionMessage )
189- assert isinstance (init .message , JSONRPCRequest )
190- assert init .message .method == "initialize"
202+ assert isinstance (init .message . root , JSONRPCRequest )
203+ assert init .message .root . method == "initialize"
191204 await server_write .send (
192205 respond (
193- init .message .id ,
206+ init .message .root . id ,
194207 InitializeResult (
195208 protocolVersion = "2025-11-25" ,
196209 capabilities = ServerCapabilities (),
@@ -201,16 +214,16 @@ def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage
201214
202215 initialized = await server_read .receive ()
203216 assert isinstance (initialized , SessionMessage )
204- assert isinstance (initialized .message , JSONRPCNotification )
205- assert initialized .message .method == "notifications/initialized"
217+ assert isinstance (initialized .message . root , JSONRPCNotification )
218+ assert initialized .message .root . method == "notifications/initialized"
206219
207220 ping = await server_read .receive ()
208221 assert isinstance (ping , SessionMessage )
209- assert isinstance (ping .message , JSONRPCRequest )
210- assert ping .message .method == "ping"
222+ assert isinstance (ping .message . root , JSONRPCRequest )
223+ assert ping .message .root . method == "ping"
211224 # First answer with a fabricated id that matches nothing in flight, then the real id.
212225 await server_write .send (respond (9999 , EmptyResult ()))
213- await server_write .send (respond (ping .message .id , EmptyResult ()))
226+ await server_write .send (respond (ping .message .root . id , EmptyResult ()))
214227
215228 incoming : list [IncomingMessage ] = []
216229
@@ -225,7 +238,7 @@ async def message_handler(message: IncomingMessage) -> None:
225238 task_group .start_soon (scripted_server , server_streams )
226239 with anyio .fail_after (5 ):
227240 await session .initialize ()
228- pong = await session .send_request (PingRequest (), EmptyResult )
241+ pong = await session .send_request (ClientRequest ( PingRequest () ), EmptyResult )
229242
230243 assert pong == snapshot (EmptyResult ())
231244 assert len (incoming ) == 1
0 commit comments