|
10 | 10 | from mcp.server.lowlevel import NotificationOptions |
11 | 11 | from mcp.server.models import InitializationOptions |
12 | 12 | from mcp.server.session import ServerSession |
13 | | -from mcp.shared._context import RequestContext |
14 | 13 | from mcp.shared.message import SessionMessage |
15 | | -from mcp.shared.progress import progress |
16 | 14 | from mcp.shared.session import RequestResponder |
17 | 15 |
|
18 | 16 |
|
@@ -198,117 +196,6 @@ async def handle_client_message( |
198 | 196 | assert server_progress_updates[2]["progress"] == 1.0 |
199 | 197 |
|
200 | 198 |
|
201 | | -@pytest.mark.anyio |
202 | | -async def test_progress_context_manager(): |
203 | | - """Test client using progress context manager for sending progress notifications.""" |
204 | | - # Create memory streams for client/server |
205 | | - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) |
206 | | - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) |
207 | | - |
208 | | - # Track progress updates |
209 | | - server_progress_updates: list[dict[str, Any]] = [] |
210 | | - |
211 | | - progress_token = None |
212 | | - |
213 | | - # Register progress handler |
214 | | - async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: |
215 | | - server_progress_updates.append( |
216 | | - { |
217 | | - "token": params.progress_token, |
218 | | - "progress": params.progress, |
219 | | - "total": params.total, |
220 | | - "message": params.message, |
221 | | - } |
222 | | - ) |
223 | | - |
224 | | - server = Server(name="ProgressContextTestServer", on_progress=handle_progress) |
225 | | - |
226 | | - # Run server session to receive progress updates |
227 | | - async def run_server(): |
228 | | - # Create a server session |
229 | | - async with ServerSession( |
230 | | - client_to_server_receive, |
231 | | - server_to_client_send, |
232 | | - InitializationOptions( |
233 | | - server_name="ProgressContextTestServer", |
234 | | - server_version="0.1.0", |
235 | | - capabilities=server.get_capabilities(NotificationOptions(), {}), |
236 | | - ), |
237 | | - ) as server_session: |
238 | | - async for message in server_session.incoming_messages: |
239 | | - try: |
240 | | - await server._handle_message(message, server_session, {}) |
241 | | - except Exception as e: # pragma: no cover |
242 | | - raise e |
243 | | - |
244 | | - # Client message handler |
245 | | - async def handle_client_message( |
246 | | - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, |
247 | | - ) -> None: |
248 | | - if isinstance(message, Exception): # pragma: no cover |
249 | | - raise message |
250 | | - |
251 | | - # run client session |
252 | | - async with ( |
253 | | - ClientSession( |
254 | | - server_to_client_receive, |
255 | | - client_to_server_send, |
256 | | - message_handler=handle_client_message, |
257 | | - ) as client_session, |
258 | | - anyio.create_task_group() as tg, |
259 | | - ): |
260 | | - tg.start_soon(run_server) |
261 | | - |
262 | | - await client_session.initialize() |
263 | | - |
264 | | - progress_token = "client_token_456" |
265 | | - |
266 | | - # Create request context |
267 | | - request_context = RequestContext( |
268 | | - request_id="test-request", |
269 | | - session=client_session, |
270 | | - meta={"progress_token": progress_token}, |
271 | | - ) |
272 | | - |
273 | | - # Utilize progress context manager |
274 | | - with progress(request_context, total=100) as p: |
275 | | - await p.progress(10, message="Loading configuration...") |
276 | | - await p.progress(30, message="Connecting to database...") |
277 | | - await p.progress(40, message="Fetching data...") |
278 | | - await p.progress(20, message="Processing results...") |
279 | | - |
280 | | - # Wait for all messages to be processed |
281 | | - await anyio.sleep(0.5) |
282 | | - tg.cancel_scope.cancel() |
283 | | - |
284 | | - # Verify progress updates were received by server |
285 | | - assert len(server_progress_updates) == 4 |
286 | | - |
287 | | - # first update |
288 | | - assert server_progress_updates[0]["token"] == progress_token |
289 | | - assert server_progress_updates[0]["progress"] == 10 |
290 | | - assert server_progress_updates[0]["total"] == 100 |
291 | | - assert server_progress_updates[0]["message"] == "Loading configuration..." |
292 | | - |
293 | | - # second update |
294 | | - assert server_progress_updates[1]["token"] == progress_token |
295 | | - assert server_progress_updates[1]["progress"] == 40 |
296 | | - assert server_progress_updates[1]["total"] == 100 |
297 | | - assert server_progress_updates[1]["message"] == "Connecting to database..." |
298 | | - |
299 | | - # third update |
300 | | - assert server_progress_updates[2]["token"] == progress_token |
301 | | - assert server_progress_updates[2]["progress"] == 80 |
302 | | - assert server_progress_updates[2]["total"] == 100 |
303 | | - assert server_progress_updates[2]["message"] == "Fetching data..." |
304 | | - |
305 | | - # final update |
306 | | - assert server_progress_updates[3]["token"] == progress_token |
307 | | - assert server_progress_updates[3]["progress"] == 100 |
308 | | - assert server_progress_updates[3]["total"] == 100 |
309 | | - assert server_progress_updates[3]["message"] == "Processing results..." |
310 | | - |
311 | | - |
312 | 199 | @pytest.mark.anyio |
313 | 200 | async def test_progress_callback_exception_logging(): |
314 | 201 | """Test that exceptions in progress callbacks are logged and \ |
|
0 commit comments