diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a85..28a00e925 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -40,6 +40,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # JWT "sub" claim (user/resource owner ID) + claims: dict[str, Any] | None = None # Additional token claims RegistrationErrorCode = Literal[ diff --git a/tests/server/auth/test_provider.py b/tests/server/auth/test_provider.py index aaaeb413a..dea553a0b 100644 --- a/tests/server/auth/test_provider.py +++ b/tests/server/auth/test_provider.py @@ -1,6 +1,111 @@ """Tests for mcp.server.auth.provider module.""" -from mcp.server.auth.provider import construct_redirect_uri +from mcp.server.auth.provider import AccessToken, construct_redirect_uri + +# --- AccessToken tests --- + + +def test_access_token_basic_fields(): + """Test AccessToken with only required fields.""" + token = AccessToken( + token="tok_123", + client_id="client_1", + scopes=["read"], + ) + assert token.token == "tok_123" + assert token.client_id == "client_1" + assert token.scopes == ["read"] + assert token.expires_at is None + assert token.resource is None + assert token.subject is None + assert token.claims is None + + +def test_access_token_with_subject(): + """Test AccessToken with subject field for JWT sub claim.""" + token = AccessToken( + token="tok_123", + client_id="client_1", + scopes=["read"], + subject="user_42", + ) + assert token.subject == "user_42" + + +def test_access_token_with_claims(): + """Test AccessToken with custom claims dict.""" + custom_claims = { + "sub": "user_42", + "iss": "https://auth.example.com", + "org_id": "org_7", + "roles": ["admin", "editor"], + } + token = AccessToken( + token="tok_123", + client_id="client_1", + scopes=["read"], + claims=custom_claims, + ) + assert token.claims is not None + assert token.claims == custom_claims + assert token.claims["org_id"] == "org_7" + assert token.claims["roles"] == ["admin", "editor"] + + +def test_access_token_with_subject_and_claims(): + """Test AccessToken with both subject and claims for convenience.""" + token = AccessToken( + token="tok_123", + client_id="client_1", + scopes=["read", "write"], + subject="user_42", + claims={"sub": "user_42", "iss": "https://auth.example.com"}, + ) + assert token.subject == "user_42" + assert token.claims is not None + assert token.claims["sub"] == token.subject + + +def test_access_token_backward_compatible(): + """Test that existing code creating AccessToken without new fields still works.""" + token = AccessToken( + token="tok_123", + client_id="client_1", + scopes=["read"], + expires_at=1700000000, + resource="https://api.example.com", + ) + assert token.expires_at == 1700000000 + assert token.resource == "https://api.example.com" + # New fields default to None + assert token.subject is None + assert token.claims is None + + +def test_access_token_serialization_roundtrip(): + """Test that AccessToken with new fields survives JSON serialization.""" + token = AccessToken( + token="tok_123", + client_id="client_1", + scopes=["read"], + subject="user_42", + claims={"org_id": "org_7", "custom": True}, + ) + data = token.model_dump() + restored = AccessToken.model_validate(data) + assert restored.subject == "user_42" + assert restored.claims == {"org_id": "org_7", "custom": True} + + +def test_access_token_empty_claims(): + """Test AccessToken with empty claims dict.""" + token = AccessToken( + token="tok_123", + client_id="client_1", + scopes=["read"], + claims={}, + ) + assert token.claims == {} def test_construct_redirect_uri_no_existing_params():