Skip to content

Commit 08c6c0d

Browse files
committed
fix: rebind auth context for notifications
1 parent a040142 commit 08c6c0d

File tree

5 files changed

+126
-30
lines changed

5 files changed

+126
-30
lines changed

src/mcp/client/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mcp.client.experimental import ExperimentalClientFeatures
1212
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1313
from mcp.shared._context import RequestContext
14-
from mcp.shared.message import SessionMessage
14+
from mcp.shared.message import MessageMetadata, SessionMessage
1515
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
1616
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1717
from mcp.types._types import RequestParamsMeta
@@ -461,6 +461,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
461461
async def _handle_incoming(
462462
self,
463463
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
464+
message_metadata: MessageMetadata = None,
464465
) -> None:
465466
"""Handle incoming messages by forwarding to the message handler."""
466467
await self._message_handler(req)

src/mcp/server/lowlevel/server.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ async def main():
6666
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
6767
from mcp.server.transport_security import TransportSecuritySettings
6868
from mcp.shared.exceptions import MCPError
69-
from mcp.shared.message import ServerMessageMetadata, SessionMessage
70-
from mcp.shared.session import RequestResponder
69+
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
70+
from mcp.shared.session import NotificationWithMetadata, RequestResponder
7171

7272
logger = logging.getLogger(__name__)
7373

@@ -424,7 +424,9 @@ async def run(
424424

425425
async def _handle_message(
426426
self,
427-
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
427+
message: RequestResponder[types.ClientRequest, types.ServerResult]
428+
| NotificationWithMetadata[types.ClientNotification]
429+
| Exception,
428430
session: ServerSession,
429431
lifespan_context: LifespanResultT,
430432
raise_exceptions: bool = False,
@@ -436,6 +438,13 @@ async def _handle_message(
436438
await self._handle_request(
437439
message, responder.request, session, lifespan_context, raise_exceptions
438440
)
441+
case NotificationWithMetadata() as notification:
442+
await self._handle_notification(
443+
notification.notification,
444+
session,
445+
lifespan_context,
446+
notification.message_metadata,
447+
)
439448
case Exception():
440449
logger.error(f"Received exception from stream: {message}")
441450
if raise_exceptions:
@@ -532,24 +541,31 @@ async def _handle_notification(
532541
notify: types.ClientNotification,
533542
session: ServerSession,
534543
lifespan_context: LifespanResultT,
544+
message_metadata: MessageMetadata = None,
535545
) -> None:
536546
if handler := self._notification_handlers.get(notify.method):
537547
logger.debug("Dispatching notification of type %s", type(notify).__name__)
538548

539549
try:
540-
client_capabilities = session.client_params.capabilities if session.client_params else None
541-
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
542-
ctx = ServerRequestContext(
543-
session=session,
544-
lifespan_context=lifespan_context,
545-
experimental=Experimental(
546-
task_metadata=None,
547-
_client_capabilities=client_capabilities,
548-
_session=session,
549-
_task_support=task_support,
550-
),
551-
)
552-
await handler(ctx, notify.params)
550+
request_data = None
551+
if isinstance(message_metadata, ServerMessageMetadata):
552+
request_data = message_metadata.request_context
553+
554+
with _bind_request_auth_context(request_data):
555+
client_capabilities = session.client_params.capabilities if session.client_params else None
556+
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
557+
ctx = ServerRequestContext(
558+
session=session,
559+
lifespan_context=lifespan_context,
560+
experimental=Experimental(
561+
task_metadata=None,
562+
_client_capabilities=client_capabilities,
563+
_session=session,
564+
_task_support=task_support,
565+
),
566+
request=request_data,
567+
)
568+
await handler(ctx, notify.params)
553569
except Exception: # pragma: no cover
554570
logger.exception("Uncaught exception in notification handler")
555571

src/mcp/server/session.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
4343
from mcp.shared.exceptions import StatelessModeNotSupported
4444
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
4545
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY
46-
from mcp.shared.message import ServerMessageMetadata, SessionMessage
46+
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
4747
from mcp.shared.session import (
4848
BaseSession,
49+
NotificationWithMetadata,
4950
RequestResponder,
5051
)
5152
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -60,7 +61,9 @@ class InitializationState(Enum):
6061
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
6162

6263
ServerRequestResponder = (
63-
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
64+
RequestResponder[types.ClientRequest, types.ServerResult]
65+
| NotificationWithMetadata[types.ClientNotification]
66+
| Exception
6467
)
6568

6669

@@ -683,7 +686,15 @@ async def send_message(self, message: SessionMessage) -> None:
683686
"""
684687
await self._write_stream.send(message)
685688

686-
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
689+
async def _handle_incoming(
690+
self,
691+
req: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
692+
message_metadata: MessageMetadata = None,
693+
) -> None:
694+
if isinstance(req, types.ClientNotification):
695+
await self._incoming_message_stream_writer.send(NotificationWithMetadata(req, message_metadata))
696+
return
697+
687698
await self._incoming_message_stream_writer.send(req)
688699

689700
@property

src/mcp/shared/session.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from collections.abc import Callable
55
from contextlib import AsyncExitStack
6+
from dataclasses import dataclass
67
from types import TracebackType
78
from typing import Any, Generic, Protocol, TypeVar
89

@@ -53,6 +54,14 @@ async def __call__(
5354
) -> None: ... # pragma: no branch
5455

5556

57+
@dataclass
58+
class NotificationWithMetadata(Generic[ReceiveNotificationT]):
59+
"""A validated notification paired with its transport metadata."""
60+
61+
notification: ReceiveNotificationT
62+
message_metadata: MessageMetadata = None
63+
64+
5665
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
5766
"""Handles responding to MCP requests and manages request lifecycle.
5867
@@ -396,7 +405,7 @@ async def _receive_loop(self) -> None:
396405
except Exception:
397406
logging.exception("Progress callback raised an exception")
398407
await self._received_notification(notification)
399-
await self._handle_incoming(notification)
408+
await self._handle_incoming(notification, message.metadata)
400409
except Exception:
401410
# For other validation errors, log and continue
402411
logging.warning( # pragma: no cover
@@ -515,6 +524,8 @@ async def send_progress_notification(
515524
"""Sends a progress notification for a request that is currently being processed."""
516525

517526
async def _handle_incoming(
518-
self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception
527+
self,
528+
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
529+
message_metadata: MessageMetadata = None,
519530
) -> None:
520531
"""A generic handler for incoming messages. Overridden by subclasses."""

tests/server/auth/middleware/test_auth_context_streamable_http.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""Regression tests for auth context in StreamableHTTP servers."""
22

3+
from __future__ import annotations
4+
35
import multiprocessing
6+
import queue
47
import socket
58
import time
69
from collections.abc import Generator
10+
from multiprocessing.queues import Queue
711

12+
import anyio
813
import httpx
914
import pytest
1015
import uvicorn
@@ -58,11 +63,19 @@ def auth_flow(self, request: httpx.Request):
5863
yield request
5964

6065

61-
def _create_stateful_auth_app() -> Starlette:
66+
def _create_stateful_auth_app(progress_tokens: Queue[str] | None = None) -> Starlette:
67+
async def _handle_progress(ctx: ServerRequestContext, params: object) -> None:
68+
if progress_tokens is None:
69+
return
70+
71+
access = get_access_token()
72+
progress_tokens.put(access.token if access else "<none>")
73+
6274
server = Server(
6375
"auth-test-server",
6476
on_call_tool=_handle_whoami,
6577
on_list_tools=_handle_list_tools,
78+
on_progress=_handle_progress,
6679
)
6780
session_manager = StreamableHTTPSessionManager(app=server, stateless=False)
6881
return Starlette(
@@ -75,9 +88,12 @@ def _create_stateful_auth_app() -> Starlette:
7588
)
7689

7790

78-
def run_stateful_auth_server(port: int) -> None: # pragma: no cover
91+
def run_stateful_auth_server(
92+
port: int,
93+
progress_tokens: Queue[str] | None = None,
94+
) -> None: # pragma: no cover
7995
config = uvicorn.Config(
80-
app=_create_stateful_auth_app(),
96+
app=_create_stateful_auth_app(progress_tokens),
8197
host="127.0.0.1",
8298
port=port,
8399
log_level="error",
@@ -94,34 +110,45 @@ def stateful_auth_server_port() -> int:
94110

95111

96112
@pytest.fixture
97-
def stateful_auth_server(stateful_auth_server_port: int) -> Generator[str, None, None]:
113+
def stateful_auth_server(
114+
stateful_auth_server_port: int,
115+
) -> Generator[tuple[str, Queue[str]], None, None]:
116+
progress_tokens: Queue[str] = multiprocessing.Queue()
98117
proc = multiprocessing.Process(
99118
target=run_stateful_auth_server,
100-
kwargs={"port": stateful_auth_server_port},
119+
kwargs={
120+
"port": stateful_auth_server_port,
121+
"progress_tokens": progress_tokens,
122+
},
101123
daemon=True,
102124
)
103125
proc.start()
104126
wait_for_server(stateful_auth_server_port)
105127

106128
try:
107-
yield f"http://127.0.0.1:{stateful_auth_server_port}/mcp"
129+
yield f"http://127.0.0.1:{stateful_auth_server_port}/mcp", progress_tokens
108130
finally:
109131
proc.terminate()
110132
proc.join(timeout=2)
111133
if proc.is_alive(): # pragma: no cover
112134
proc.kill()
113135
proc.join(timeout=1)
136+
progress_tokens.close()
137+
progress_tokens.join_thread()
114138

115139

116140
@pytest.mark.anyio
117-
async def test_get_access_token_reflects_current_request_in_stateful_session(stateful_auth_server: str) -> None:
141+
async def test_get_access_token_reflects_current_request_in_stateful_session(
142+
stateful_auth_server: tuple[str, Queue[str]],
143+
) -> None:
144+
server_url, _ = stateful_auth_server
118145
auth = _MutableBearerAuth("token-A")
119146
async with httpx.AsyncClient(
120147
auth=auth,
121148
timeout=httpx.Timeout(30, read=30),
122149
follow_redirects=True,
123150
) as http_client:
124-
async with streamable_http_client(stateful_auth_server, http_client=http_client) as (
151+
async with streamable_http_client(server_url, http_client=http_client) as (
125152
read_stream,
126153
write_stream,
127154
):
@@ -139,3 +166,33 @@ async def test_get_access_token_reflects_current_request_in_stateful_session(sta
139166
assert len(second_response.content) == 1
140167
assert isinstance(second_response.content[0], TextContent)
141168
assert second_response.content[0].text == "token-B"
169+
170+
171+
@pytest.mark.anyio
172+
async def test_get_access_token_reflects_current_notification_in_stateful_session(
173+
stateful_auth_server: tuple[str, Queue[str]],
174+
) -> None:
175+
server_url, progress_tokens = stateful_auth_server
176+
auth = _MutableBearerAuth("token-A")
177+
async with httpx.AsyncClient(
178+
auth=auth,
179+
timeout=httpx.Timeout(30, read=30),
180+
follow_redirects=True,
181+
) as http_client:
182+
async with streamable_http_client(server_url, http_client=http_client) as (
183+
read_stream,
184+
write_stream,
185+
):
186+
async with ClientSession(read_stream, write_stream) as session:
187+
await session.initialize()
188+
189+
auth.token = "token-B"
190+
await session.send_progress_notification(progress_token="progress-1", progress=1)
191+
192+
with anyio.fail_after(5):
193+
while True:
194+
try:
195+
assert progress_tokens.get_nowait() == "token-B"
196+
break
197+
except queue.Empty:
198+
await anyio.sleep(0.01)

0 commit comments

Comments
 (0)