11from __future__ import annotations
22
3+ import contextvars
34import logging
45from collections .abc import Callable
56from contextlib import AsyncExitStack
@@ -79,11 +80,13 @@ def __init__(
7980 session : BaseSession [SendRequestT , SendNotificationT , SendResultT , ReceiveRequestT , ReceiveNotificationT ],
8081 on_complete : Callable [[RequestResponder [ReceiveRequestT , SendResultT ]], Any ],
8182 message_metadata : MessageMetadata = None ,
83+ context : contextvars .Context | None = None ,
8284 ) -> None :
8385 self .request_id = request_id
8486 self .request_meta = request_meta
8587 self .request = request
8688 self .message_metadata = message_metadata
89+ self .context = context
8790 self ._session = session
8891 self ._completed = False
8992 self ._cancel_scope = anyio .CancelScope ()
@@ -333,10 +336,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
333336 async def _receive_loop (self ) -> None :
334337 async with self ._read_stream , self ._write_stream :
335338 try :
336- async for message in self ._read_stream :
337- if isinstance (message , Exception ): # pragma: no cover
338- await self ._handle_incoming (message )
339- elif isinstance (message .message , JSONRPCRequest ):
339+
340+ async def handle_message (message : SessionMessage ) -> None :
341+ if isinstance (message .message , JSONRPCRequest ):
340342 try :
341343 validated_request = self ._receive_request_adapter .validate_python (
342344 message .message .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
@@ -349,6 +351,7 @@ async def _receive_loop(self) -> None:
349351 session = self ,
350352 on_complete = lambda r : self ._in_flight .pop (r .request_id , None ),
351353 message_metadata = message .metadata ,
354+ context = message .context ,
352355 )
353356 self ._in_flight [responder .request_id ] = responder
354357 await self ._received_request (responder )
@@ -397,7 +400,7 @@ async def _receive_loop(self) -> None:
397400 logging .exception ("Progress callback raised an exception" )
398401 await self ._received_notification (notification )
399402 await self ._handle_incoming (notification )
400- except Exception : # pragma: no cover
403+ except Exception : # pragma: lax no cover
401404 # For other validation errors, log and continue
402405 logging .warning (
403406 f"Failed to validate notification:. Message was: { message .message } " ,
@@ -406,6 +409,13 @@ async def _receive_loop(self) -> None:
406409 else : # Response or error
407410 await self ._handle_response (message )
408411
412+ async for message in self ._read_stream :
413+ if isinstance (message , Exception ): # pragma: no cover
414+ await self ._handle_incoming (message )
415+ else :
416+ async with anyio .create_task_group () as tg :
417+ message .context .run (tg .start_soon , handle_message , message )
418+
409419 except anyio .ClosedResourceError :
410420 # This is expected when the client disconnects abruptly.
411421 # Without this handler, the exception would propagate up and
0 commit comments