From d1faf98d3f67fb14cb55a1a7a66fb8871fad0761 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 5 May 2026 20:10:36 +0000 Subject: [PATCH] fix(server): propagate auth ContextVars through A2A dispatch boundary A2ABearerAuthMiddleware validated bearer tokens but only wrote the principal to scope["user"]/scope["auth"], never setting the current_principal/current_tenant/current_principal_metadata ContextVars. auth_context_factory and adopter code reading current_principal.get() directly got None on every A2A call even after successful auth, causing AdCPAuthenticationError on policies that require a bound principal. Mirrors the existing BearerTokenAuthMiddleware try/finally pattern: set all three vars on auth success; set to None on OPTIONS/discovery pass-throughs; reset unconditionally in finally. Fixes #590 https://claude.ai/code/session_01ETystgT4HH4ZHaAziszjBF --- src/adcp/server/auth.py | 77 ++++++++++++++++++++++------------- tests/test_serve_auth_both.py | 73 +++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 29 deletions(-) diff --git a/src/adcp/server/auth.py b/src/adcp/server/auth.py index 4aa0c013..edf7359c 100644 --- a/src/adcp/server/auth.py +++ b/src/adcp/server/auth.py @@ -637,41 +637,60 @@ def __init__(self, app: Any, config: BearerTokenAuth) -> None: async def __call__(self, scope: Any, receive: Any, send: Any) -> None: # Lifespan + websocket pass through unchanged. Auth applies to - # HTTP requests only. + # HTTP requests only. No contextvar changes — these scopes are + # not dispatched to skill handlers that read auth vars. if scope.get("type") != "http": await self._app(scope, receive, send) return - # CORS preflight is part of the public surface — browser-origin - # clients send ``OPTIONS`` before any auth'd POST. Returning 401 - # here breaks the preflight and the buyer never gets a chance to - # retry with a token. Pass through; let the inner app's CORS - # handler (or operator-supplied ``asgi_middleware``) respond. - if scope.get("method") == "OPTIONS": - await self._app(scope, receive, send) - return + principal_token = None + tenant_token = None + metadata_token = None + try: + # CORS preflight and A2A discovery are part of the public surface. + # Set contextvars to None to prevent stale values from an enclosing + # task context from leaking into downstream code on these paths — + # mirrors BearerTokenAuthMiddleware's discovery branch. + if scope.get("method") == "OPTIONS" or scope.get("path", "") in _A2A_DISCOVERY_PATHS: + principal_token = current_principal.set(None) + tenant_token = current_tenant.set(None) + metadata_token = current_principal_metadata.set(None) + await self._app(scope, receive, send) + return - path = scope.get("path", "") - if path in _A2A_DISCOVERY_PATHS: + principal = self._authenticate_scope(scope) + if principal is None: + await self._send_unauthenticated(send) + return + + # Stash both the duck-typed user (for DefaultServerCallContextBuilder) + # and the raw Principal (for downstream code reading scope['auth']). + # Mutating the scope dict before delegating propagates state to + # nested apps without copying. + scope["user"] = _A2AAuthenticatedUser( + display_name=principal.caller_identity, + tenant_id=principal.tenant_id, + principal_metadata=dict(principal.metadata) if principal.metadata else None, + ) + scope["auth"] = principal + # Populate the same module-level ContextVars that BearerTokenAuthMiddleware + # sets on the MCP path. auth_context_factory and adopter code that reads + # current_principal.get() directly see the authenticated identity on A2A + # exactly as they do on MCP. Reset unconditionally in finally so a later + # task sharing this context can't read a stale principal. + principal_token = current_principal.set(principal.caller_identity) + tenant_token = current_tenant.set(principal.tenant_id) + metadata_token = current_principal_metadata.set( + dict(principal.metadata) if principal.metadata else None + ) await self._app(scope, receive, send) - return - - principal = self._authenticate_scope(scope) - if principal is None: - await self._send_unauthenticated(send) - return - - # Stash both the duck-typed user (for DefaultServerCallContextBuilder) - # and the raw Principal (for downstream code reading scope['auth']). - # Mutating the scope dict before delegating propagates state to - # nested apps without copying. - scope["user"] = _A2AAuthenticatedUser( - display_name=principal.caller_identity, - tenant_id=principal.tenant_id, - principal_metadata=dict(principal.metadata) if principal.metadata else None, - ) - scope["auth"] = principal - await self._app(scope, receive, send) + finally: + if principal_token is not None: + current_principal.reset(principal_token) + if tenant_token is not None: + current_tenant.reset(tenant_token) + if metadata_token is not None: + current_principal_metadata.reset(metadata_token) def _authenticate_scope(self, scope: Any) -> Principal | None: """Read + validate the bearer header off raw ASGI scope. diff --git a/tests/test_serve_auth_both.py b/tests/test_serve_auth_both.py index ceb91d54..d390d0c2 100644 --- a/tests/test_serve_auth_both.py +++ b/tests/test_serve_auth_both.py @@ -94,6 +94,28 @@ async def inner(scope: Any, _receive: Any, _send: Any) -> None: assert "auth" in passed_scope assert passed_scope["auth"].caller_identity == "p-acme" + @pytest.mark.asyncio + async def test_valid_token_sets_current_principal_contextvar(self): + """On auth success, current_principal must be populated inside the + inner app and reset to None after __call__ returns (#590 regression).""" + from adcp.server.auth import current_principal, current_tenant + + captured: dict[str, str | None] = {} + + async def inner(scope: Any, _receive: Any, _send: Any) -> None: + captured["principal"] = current_principal.get() + captured["tenant"] = current_tenant.get() + + mw = A2ABearerAuthMiddleware(inner, _auth()) + scope = self._scope(headers=[(b"authorization", b"Bearer good-token")]) + await mw(scope, lambda: None, lambda _: None) + + assert captured["principal"] == "p-acme" + assert captured["tenant"] == "acme" + # Verify reset-in-finally: contextvar must be cleared after __call__ returns. + assert current_principal.get() is None + assert current_tenant.get() is None + @pytest.mark.asyncio async def test_missing_header_returns_401(self): sent: list[dict] = [] @@ -302,6 +324,57 @@ async def test_a2a_jsonrpc_authenticated_passes_through() -> None: assert response.status_code == 200 +@pytest.mark.asyncio +async def test_a2a_auth_populates_current_principal_contextvar() -> None: + """A2ABearerAuthMiddleware must set current_principal contextvar so + auth_context_factory and adopter code reading it directly see the + authenticated identity on A2A — same as MCP (regression for #590). + + Verifies both that the var is populated inside the handler AND that it is + reset to None after the request completes (try/finally contract).""" + from adcp.server.a2a_server import create_a2a_server + from adcp.server.auth import current_principal, current_tenant + + observed: dict[str, str | None] = {} + + class _ContextCaptureHandler(ADCPHandler): + async def get_adcp_capabilities(self, params: Any, context: Any = None) -> dict[str, Any]: + return {"adcp": {"major_versions": [3]}, "supported_protocols": ["media_buy"]} + + async def get_products(self, params: Any, context: Any = None) -> dict[str, Any]: + observed["principal"] = current_principal.get() + observed["tenant"] = current_tenant.get() + return {"products": []} + + inner = create_a2a_server(_ContextCaptureHandler(), name="ctx-test", validation=None) + app = A2ABearerAuthMiddleware(inner, _auth()) + body = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "m1", + "role": "user", + "parts": [{"kind": "data", "data": {"skill": "get_products", "parameters": {}}}], + } + }, + } + async with LifespanManager(inner): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/", json=body, headers={"Authorization": "Bearer good-token"} + ) + assert response.status_code == 200 + assert observed.get("principal") == "p-acme", "current_principal not set on A2A path" + assert observed.get("tenant") == "acme", "current_tenant not set on A2A path" + # Verify reset-in-finally: contextvar must be None after the request. + assert current_principal.get() is None + assert current_tenant.get() is None + + # =========================================================================== # transport="both": the regression case from #558 # ===========================================================================