Skip to content

Commit 7d9881d

Browse files
committed
backport: harness H-auth-1 — connect_with_oauth body (hand-assembled AS+RS app, ClientSession yield); auth smoke passes
1 parent e55b40e commit 7d9881d

2 files changed

Lines changed: 90 additions & 31 deletions

File tree

tests/interaction/auth/_harness.py

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""In-process harness for the auth interaction tests.
22
33
Co-hosts the SDK's authorization-server routes, protected-resource metadata route, and the
4-
bearer-gated MCP endpoint on one Starlette app via `Server.streamable_http_app(auth=...,
5-
token_verifier=..., auth_server_provider=...)`, drives that app through the streaming bridge
6-
on a single `httpx.AsyncClient` carrying `auth=OAuthClientProvider(...)`, and completes the
7-
authorize redirect headlessly by GETing the URL through the same bridge and parsing the code
8-
from the 302 `Location`. The whole authorization-code flow runs in one event loop with no
9-
sockets, no threads, and no real time.
4+
bearer-gated MCP endpoint on one Starlette app assembled from the same public pieces
5+
`FastMCP.streamable_http_app()` uses (`StreamableHTTPSessionManager`, `create_auth_routes`,
6+
`BearerAuthBackend`, `RequireAuthMiddleware`, `create_protected_resource_routes`), drives
7+
that app through the streaming bridge on a single `httpx.AsyncClient` carrying
8+
`auth=OAuthClientProvider(...)`, and completes the authorize redirect headlessly by GETing the
9+
URL through the same bridge and parsing the code from the 302 `Location`. The whole
10+
authorization-code flow runs in one event loop with no sockets, no threads, and no real time.
1011
"""
1112

1213
import json
@@ -18,14 +19,23 @@
1819

1920
import httpx
2021
from pydantic import AnyHttpUrl, AnyUrl, BaseModel
22+
from starlette.applications import Starlette
23+
from starlette.middleware import Middleware
24+
from starlette.middleware.authentication import AuthenticationMiddleware
25+
from starlette.routing import Route
2126
from starlette.types import ASGIApp, Receive, Scope, Send
2227

2328
from mcp.client.auth import OAuthClientProvider
24-
from mcp.client.client import Client
29+
from mcp.client.session import ClientSession
2530
from mcp.client.streamable_http import streamable_http_client
2631
from mcp.server import Server
32+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
33+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
2734
from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier
35+
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
2836
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions
37+
from mcp.server.fastmcp.server import StreamableHTTPASGIApp
38+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
2939
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
3040
from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION
3141
from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider
@@ -385,7 +395,7 @@ async def wrapped(scope: Scope, receive: Receive, send: Send) -> None:
385395

386396
@asynccontextmanager
387397
async def connect_with_oauth(
388-
server: Server,
398+
server: Server[Any],
389399
*,
390400
provider: InMemoryAuthorizationServerProvider,
391401
settings: AuthSettings | None = None,
@@ -397,12 +407,19 @@ async def connect_with_oauth(
397407
verify_tokens: bool = True,
398408
app_shim: Callable[[ASGIApp], ASGIApp] | None = None,
399409
on_request: Callable[[httpx.Request], None] | None = None,
400-
) -> AsyncIterator[tuple[Client, HeadlessOAuth]]:
401-
"""Connect a `Client` to a server's bearer-gated streamable-HTTP app, completing OAuth in process.
410+
) -> AsyncIterator[tuple[ClientSession, HeadlessOAuth]]:
411+
"""Connect a `ClientSession` to a server's bearer-gated streamable-HTTP app, completing OAuth in process.
402412
403-
Yields the connected `Client` and the `HeadlessOAuth` whose `authorize_url` records what the
404-
SDK put on the authorize request. `on_request` records every HTTP request the underlying
405-
`httpx.AsyncClient` issues, including those yielded from inside the auth flow.
413+
Yields the connected, initialized `ClientSession` and the `HeadlessOAuth` whose
414+
`authorize_url` records what the SDK put on the authorize request. `on_request` records
415+
every HTTP request the underlying `httpx.AsyncClient` issues, including those yielded from
416+
inside the auth flow.
417+
418+
The Starlette app is assembled from the same public pieces `FastMCP.streamable_http_app()`
419+
uses, so behaviour matches what a v1 user would get from a `FastMCP` configured with
420+
`auth_server_provider=` — except that hand-assembly lets `verify_tokens=False` mount `/mcp`
421+
ungated while still mounting the authorization-server and PRM routes (FastMCP's constructor
422+
auto-derives a token verifier from the provider, so it has no ungated combination).
406423
407424
`headless`: supply a pre-configured `HeadlessOAuth` to override the callback behaviour
408425
(state mismatch, error redirects). `verify_tokens=False` mounts the MCP endpoint without
@@ -433,12 +450,44 @@ async def connect_with_oauth(
433450
)
434451
)
435452

436-
app: ASGIApp = server.streamable_http_app(
437-
auth=settings,
438-
token_verifier=ProviderTokenVerifier(provider) if verify_tokens else None,
439-
auth_server_provider=provider,
440-
transport_security=NO_DNS_REBINDING_PROTECTION,
453+
manager = StreamableHTTPSessionManager(app=server, security_settings=NO_DNS_REBINDING_PROTECTION)
454+
asgi = StreamableHTTPASGIApp(manager)
455+
456+
routes: list[Route] = list(
457+
create_auth_routes(
458+
provider=provider,
459+
issuer_url=settings.issuer_url,
460+
service_documentation_url=settings.service_documentation_url,
461+
client_registration_options=settings.client_registration_options,
462+
revocation_options=settings.revocation_options,
463+
)
464+
)
465+
middleware: list[Middleware] = []
466+
required_scopes = settings.required_scopes or []
467+
resource_metadata_url = (
468+
build_resource_metadata_url(settings.resource_server_url) if settings.resource_server_url else None
441469
)
470+
471+
if verify_tokens:
472+
token_verifier = ProviderTokenVerifier(provider)
473+
middleware = [
474+
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)),
475+
Middleware(AuthContextMiddleware),
476+
]
477+
routes.append(Route("/mcp", endpoint=RequireAuthMiddleware(asgi, required_scopes, resource_metadata_url)))
478+
else:
479+
routes.append(Route("/mcp", endpoint=asgi))
480+
481+
if settings.resource_server_url:
482+
routes.extend(
483+
create_protected_resource_routes(
484+
resource_url=settings.resource_server_url,
485+
authorization_servers=[settings.issuer_url],
486+
scopes_supported=required_scopes,
487+
)
488+
)
489+
490+
app: ASGIApp = Starlette(routes=routes, middleware=middleware)
442491
if app_shim is not None:
443492
app = app_shim(app)
444493

@@ -452,14 +501,16 @@ async def hook(request: httpx.Request) -> None:
452501
event_hooks = {"request": [hook]}
453502

454503
async with AsyncExitStack() as stack:
455-
await stack.enter_async_context(server.session_manager.run())
504+
await stack.enter_async_context(manager.run())
456505
http_client = await stack.enter_async_context(
457506
httpx.AsyncClient(
458507
transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks
459508
)
460509
)
461510
headless.bind(http_client)
462-
client = await stack.enter_async_context(
463-
Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client))
511+
read, write, _get_session_id = await stack.enter_async_context(
512+
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)
464513
)
465-
yield client, headless
514+
session = await stack.enter_async_context(ClientSession(read, write))
515+
await session.initialize()
516+
yield session, headless

tests/interaction/auth/test_flow.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pydantic import AnyUrl
2020

2121
from mcp import types
22-
from mcp.server import Server, ServerRequestContext
22+
from mcp.server import Server
2323
from mcp.server.auth.middleware.auth_context import get_access_token
2424
from mcp.shared.auth import OAuthClientInformationFull
2525
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
@@ -39,8 +39,15 @@
3939
pytestmark = pytest.mark.anyio
4040

4141

42-
async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult:
43-
return ListToolsResult(tools=[Tool(name="whoami", inputSchema={"type": "object"})])
42+
def _guarded_server() -> Server[object]:
43+
"""Build a lowlevel server exposing a single `whoami` tool, in the v1 decorator style."""
44+
server: Server[object] = Server("guarded")
45+
46+
@server.list_tools()
47+
async def _list_tools() -> list[types.Tool]:
48+
return [Tool(name="whoami", inputSchema={"type": "object"})]
49+
50+
return server
4451

4552

4653
@requirement("flow:oauth:authorization-code-roundtrip")
@@ -67,7 +74,7 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow
6774
requests: list[httpx.Request] = []
6875
provider = InMemoryAuthorizationServerProvider()
6976
storage = InMemoryTokenStorage()
70-
server = Server("guarded", on_list_tools=list_tools)
77+
server = _guarded_server()
7178

7279
with anyio.fail_after(5):
7380
async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as (
@@ -121,14 +128,15 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow
121128
@requirement("hosting:auth:authinfo-propagates")
122129
async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None:
123130
"""A tool handler reads the request's access token through `get_access_token()`."""
131+
server = _guarded_server()
124132

125-
async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
126-
assert params.name == "whoami"
133+
@server.call_tool()
134+
async def _call_tool(name: str, arguments: dict[str, object]) -> CallToolResult:
135+
assert name == "whoami"
127136
token = get_access_token()
128137
assert token is not None
129138
return CallToolResult(content=[TextContent(type="text", text=" ".join(token.scopes))])
130139

131-
server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool)
132140
provider = InMemoryAuthorizationServerProvider()
133141

134142
with anyio.fail_after(5):
@@ -148,7 +156,7 @@ async def test_a_preregistered_client_skips_registration() -> None:
148156
requests: list[httpx.Request] = []
149157
provider = InMemoryAuthorizationServerProvider()
150158
storage = InMemoryTokenStorage()
151-
server = Server("guarded", on_list_tools=list_tools)
159+
server = _guarded_server()
152160

153161
client_info = OAuthClientInformationFull(
154162
client_id="preregistered",
@@ -183,7 +191,7 @@ async def test_the_dcr_request_carries_the_client_metadata() -> None:
183191
requests: list[httpx.Request] = []
184192
provider = InMemoryAuthorizationServerProvider()
185193
storage = InMemoryTokenStorage()
186-
server = Server("guarded", on_list_tools=list_tools)
194+
server = _guarded_server()
187195

188196
client_metadata = oauth_client_metadata()
189197
client_metadata.software_id = "interaction-test-suite"

0 commit comments

Comments
 (0)