Skip to content

Commit a4ffc4b

Browse files
committed
Fix stateful StreamableHTTP auth context rebinding
1 parent fb2276b commit a4ffc4b

File tree

2 files changed

+151
-26
lines changed

2 files changed

+151
-26
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ async def main():
3838

3939
import logging
4040
import warnings
41-
from collections.abc import AsyncIterator, Awaitable, Callable
42-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
41+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
42+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
4343
from importlib.metadata import version as importlib_version
4444
from typing import Any, Generic
4545

@@ -52,8 +52,8 @@ async def main():
5252
from typing_extensions import TypeVar
5353

5454
from mcp import types
55-
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
56-
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
55+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, auth_context_var
56+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware
5757
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
5858
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
5959
from mcp.server.auth.settings import AuthSettings
@@ -74,6 +74,23 @@ async def main():
7474
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7575

7676

77+
@contextmanager
78+
def _bind_request_auth_context(request_context: Any) -> Iterator[None]:
79+
"""Rebind auth context from the current transport request while handling a message."""
80+
authenticated_user = None
81+
scope = getattr(request_context, "scope", None)
82+
if isinstance(scope, dict):
83+
scope_user = scope.get("user")
84+
if isinstance(scope_user, AuthenticatedUser):
85+
authenticated_user = scope_user
86+
87+
token = auth_context_var.set(authenticated_user)
88+
try:
89+
yield
90+
finally:
91+
auth_context_var.reset(token)
92+
93+
7794
class NotificationOptions:
7895
def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False):
7996
self.prompts_changed = prompts_changed
@@ -452,28 +469,32 @@ async def _handle_request(
452469
close_sse_stream_cb = message.message_metadata.close_sse_stream
453470
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream
454471

455-
client_capabilities = session.client_params.capabilities if session.client_params else None
456-
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
457-
# Get task metadata from request params if present
458-
task_metadata = None
459-
if hasattr(req, "params") and req.params is not None:
460-
task_metadata = getattr(req.params, "task", None)
461-
ctx = ServerRequestContext(
462-
request_id=message.request_id,
463-
meta=message.request_meta,
464-
session=session,
465-
lifespan_context=lifespan_context,
466-
experimental=Experimental(
467-
task_metadata=task_metadata,
468-
_client_capabilities=client_capabilities,
469-
_session=session,
470-
_task_support=task_support,
471-
),
472-
request=request_data,
473-
close_sse_stream=close_sse_stream_cb,
474-
close_standalone_sse_stream=close_standalone_sse_stream_cb,
475-
)
476-
response = await handler(ctx, req.params)
472+
# Stateful HTTP sessions process later requests on tasks that were
473+
# created during session setup, so ContextVar snapshots can lag
474+
# behind the current request unless we rebind them here.
475+
with _bind_request_auth_context(request_data):
476+
client_capabilities = session.client_params.capabilities if session.client_params else None
477+
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
478+
# Get task metadata from request params if present
479+
task_metadata = None
480+
if hasattr(req, "params") and req.params is not None:
481+
task_metadata = getattr(req.params, "task", None)
482+
ctx = ServerRequestContext(
483+
request_id=message.request_id,
484+
meta=message.request_meta,
485+
session=session,
486+
lifespan_context=lifespan_context,
487+
experimental=Experimental(
488+
task_metadata=task_metadata,
489+
_client_capabilities=client_capabilities,
490+
_session=session,
491+
_task_support=task_support,
492+
),
493+
request=request_data,
494+
close_sse_stream=close_sse_stream_cb,
495+
close_standalone_sse_stream=close_standalone_sse_stream_cb,
496+
)
497+
response = await handler(ctx, req.params)
477498
except MCPError as err:
478499
response = err.error
479500
except anyio.get_cancelled_exc_class():
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Regression tests for auth context in StreamableHTTP servers."""
2+
3+
import time
4+
from collections.abc import Generator
5+
6+
import httpx
7+
import pytest
8+
from starlette.applications import Starlette
9+
from starlette.middleware import Middleware
10+
from starlette.middleware.authentication import AuthenticationMiddleware
11+
from starlette.routing import Mount
12+
13+
from mcp.client.session import ClientSession
14+
from mcp.client.streamable_http import streamable_http_client
15+
from mcp.server import Server, ServerRequestContext
16+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
17+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
18+
from mcp.server.auth.provider import AccessToken
19+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
20+
from mcp.types import (
21+
CallToolRequestParams,
22+
CallToolResult,
23+
ListToolsResult,
24+
PaginatedRequestParams,
25+
TextContent,
26+
Tool,
27+
)
28+
from tests.test_helpers import run_uvicorn_in_thread
29+
30+
31+
class _EchoTokenVerifier:
32+
async def verify_token(self, token: str) -> AccessToken | None:
33+
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)
34+
35+
36+
async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
37+
access = get_access_token()
38+
text = access.token if access else "<none>"
39+
return CallToolResult(content=[TextContent(type="text", text=text)])
40+
41+
42+
async def _handle_list_tools(
43+
ctx: ServerRequestContext,
44+
params: PaginatedRequestParams | None,
45+
) -> ListToolsResult:
46+
return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})])
47+
48+
49+
class _MutableBearerAuth(httpx.Auth):
50+
def __init__(self, token: str) -> None:
51+
self.token = token
52+
53+
def auth_flow(self, request: httpx.Request):
54+
request.headers["Authorization"] = f"Bearer {self.token}"
55+
yield request
56+
57+
58+
@pytest.fixture
59+
def stateful_auth_server() -> Generator[str, None, None]:
60+
server = Server(
61+
"auth-test-server",
62+
on_call_tool=_handle_whoami,
63+
on_list_tools=_handle_list_tools,
64+
)
65+
session_manager = StreamableHTTPSessionManager(app=server, stateless=False)
66+
app = Starlette(
67+
routes=[Mount("/mcp", app=session_manager.handle_request)],
68+
middleware=[
69+
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())),
70+
Middleware(AuthContextMiddleware),
71+
],
72+
lifespan=lambda app: session_manager.run(),
73+
)
74+
75+
with run_uvicorn_in_thread(app, host="127.0.0.1", log_level="error") as base_url:
76+
yield f"{base_url}/mcp"
77+
78+
79+
@pytest.mark.anyio
80+
async def test_get_access_token_reflects_current_request_in_stateful_session(stateful_auth_server: str) -> None:
81+
auth = _MutableBearerAuth("token-A")
82+
async with httpx.AsyncClient(
83+
auth=auth,
84+
timeout=httpx.Timeout(30, read=30),
85+
follow_redirects=True,
86+
) as http_client:
87+
async with streamable_http_client(stateful_auth_server, http_client=http_client) as (
88+
read_stream,
89+
write_stream,
90+
):
91+
async with ClientSession(read_stream, write_stream) as session:
92+
await session.initialize()
93+
94+
first_response = await session.call_tool("whoami", {})
95+
assert len(first_response.content) == 1
96+
assert isinstance(first_response.content[0], TextContent)
97+
assert first_response.content[0].text == "token-A"
98+
99+
auth.token = "token-B"
100+
101+
second_response = await session.call_tool("whoami", {})
102+
assert len(second_response.content) == 1
103+
assert isinstance(second_response.content[0], TextContent)
104+
assert second_response.content[0].text == "token-B"

0 commit comments

Comments
 (0)