Skip to content

Commit bd43714

Browse files
committed
Propagate contextvars through anyio streams
This fix covers in memory, streamable http, and sse transports.
1 parent 11caa72 commit bd43714

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

src/mcp/client/sse.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
132132
async def post_writer(endpoint_url: str):
133133
try:
134134
async with write_stream_reader:
135-
async for session_message in write_stream_reader:
135+
136+
async def handle_message(session_message: SessionMessage) -> None:
136137
logger.debug(f"Sending client message: {session_message}")
137138
response = await client.post(
138139
endpoint_url,
@@ -144,6 +145,13 @@ async def post_writer(endpoint_url: str):
144145
)
145146
response.raise_for_status()
146147
logger.debug(f"Client message sent successfully: {response.status_code}")
148+
149+
async for session_message in write_stream_reader:
150+
async with anyio.create_task_group() as tg_local:
151+
session_message.context.run(
152+
tg_local.start_soon, handle_message, session_message
153+
)
154+
147155
except Exception: # pragma: lax no cover
148156
logger.exception("Error in post_writer")
149157
finally:

src/mcp/client/streamable_http.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,8 @@ async def post_writer(
441441
"""Handle writing requests to the server."""
442442
try:
443443
async with write_stream_reader:
444-
async for session_message in write_stream_reader:
444+
445+
async def handle_message(session_message: SessionMessage) -> None:
445446
message = session_message.message
446447
metadata = (
447448
session_message.metadata
@@ -478,8 +479,12 @@ async def handle_request_async():
478479
else:
479480
await handle_request_async()
480481

481-
except Exception: # pragma: lax no cover
482-
logger.exception("Error in post_writer")
482+
async for session_message in write_stream_reader:
483+
async with anyio.create_task_group() as tg_local:
484+
session_message.context.run(tg_local.start_soon, handle_message, session_message)
485+
486+
except Exception:
487+
logger.exception("Error in post_writer") # pragma: no cover
483488
finally:
484489
await read_stream_writer.aclose()
485490
await write_stream.aclose()

src/mcp/server/lowlevel/server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,13 @@ async def run(
393393
async for message in session.incoming_messages:
394394
logger.debug("Received message: %s", message)
395395

396-
tg.start_soon(
396+
if isinstance(message, RequestResponder) and message.context is not None:
397+
context = message.context
398+
else:
399+
context = contextvars.copy_context()
400+
401+
context.run(
402+
tg.start_soon,
397403
self._handle_message,
398404
message,
399405
session,

src/mcp/shared/message.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
to support transport-specific features like resumability.
55
"""
66

7+
import contextvars
78
from collections.abc import Awaitable, Callable
8-
from dataclasses import dataclass
9+
from dataclasses import dataclass, field
910
from typing import Any
1011

1112
from mcp.types import JSONRPCMessage, RequestId
@@ -49,4 +50,5 @@ class SessionMessage:
4950
"""A message with specific metadata for transport-specific features."""
5051

5152
message: JSONRPCMessage
53+
context: contextvars.Context = field(default_factory=contextvars.copy_context)
5254
metadata: MessageMetadata = None

src/mcp/shared/session.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextvars
34
import logging
45
from collections.abc import Callable
56
from contextlib import AsyncExitStack
@@ -79,11 +80,13 @@ def __init__(
7980
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
8081
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
8182
message_metadata: MessageMetadata = None,
83+
context: contextvars.Context | None = None,
8284
) -> None:
8385
self.request_id = request_id
8486
self.request_meta = request_meta
8587
self.request = request
8688
self.message_metadata = message_metadata
89+
self.context = context
8790
self._session = session
8891
self._completed = False
8992
self._cancel_scope = anyio.CancelScope()
@@ -333,10 +336,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
333336
async def _receive_loop(self) -> None:
334337
async with self._read_stream, self._write_stream:
335338
try:
336-
async for message in self._read_stream:
337-
if isinstance(message, Exception): # pragma: no cover
338-
await self._handle_incoming(message)
339-
elif isinstance(message.message, JSONRPCRequest):
339+
340+
async def handle_message(message: SessionMessage) -> None:
341+
if isinstance(message.message, JSONRPCRequest):
340342
try:
341343
validated_request = self._receive_request_adapter.validate_python(
342344
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
@@ -349,6 +351,7 @@ async def _receive_loop(self) -> None:
349351
session=self,
350352
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
351353
message_metadata=message.metadata,
354+
context=message.context,
352355
)
353356
self._in_flight[responder.request_id] = responder
354357
await self._received_request(responder)
@@ -397,7 +400,7 @@ async def _receive_loop(self) -> None:
397400
logging.exception("Progress callback raised an exception")
398401
await self._received_notification(notification)
399402
await self._handle_incoming(notification)
400-
except Exception: # pragma: no cover
403+
except Exception: # pragma: lax no cover
401404
# For other validation errors, log and continue
402405
logging.warning(
403406
f"Failed to validate notification:. Message was: {message.message}",
@@ -406,6 +409,13 @@ async def _receive_loop(self) -> None:
406409
else: # Response or error
407410
await self._handle_response(message)
408411

412+
async for message in self._read_stream:
413+
if isinstance(message, Exception): # pragma: no cover
414+
await self._handle_incoming(message)
415+
else:
416+
async with anyio.create_task_group() as tg:
417+
message.context.run(tg.start_soon, handle_message, message)
418+
409419
except anyio.ClosedResourceError:
410420
# This is expected when the client disconnects abruptly.
411421
# Without this handler, the exception would propagate up and

0 commit comments

Comments
 (0)