Skip to content

Commit 2054cb6

Browse files
committed
feat: add subject and claims fields to AccessToken
Add optional `subject` and `claims` fields to `AccessToken` to support JWT claim access without requiring token re-decoding. - `subject: str | None` stores the JWT "sub" claim (user/resource owner ID) - `claims: dict[str, Any] | None` stores additional decoded token claims Both fields are optional with None defaults, maintaining full backward compatibility. Resolves #1038
1 parent 0fe16dd commit 2054cb6

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

src/mcp/server/auth/provider.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Generic, Literal, Protocol, TypeVar
2+
from typing import Any, Generic, Literal, Protocol, TypeVar
33
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
44

55
from pydantic import AnyUrl, BaseModel
@@ -40,6 +40,8 @@ class AccessToken(BaseModel):
4040
scopes: list[str]
4141
expires_at: int | None = None
4242
resource: str | None = None # RFC 8707 resource indicator
43+
subject: str | None = None # JWT "sub" claim (user/resource owner ID)
44+
claims: dict[str, Any] | None = None # Additional token claims
4345

4446

4547
RegistrationErrorCode = Literal[

tests/server/auth/test_provider.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,110 @@
11
"""Tests for mcp.server.auth.provider module."""
22

3-
from mcp.server.auth.provider import construct_redirect_uri
3+
from mcp.server.auth.provider import AccessToken, construct_redirect_uri
4+
5+
# --- AccessToken tests ---
6+
7+
8+
def test_access_token_basic_fields():
9+
"""Test AccessToken with only required fields."""
10+
token = AccessToken(
11+
token="tok_123",
12+
client_id="client_1",
13+
scopes=["read"],
14+
)
15+
assert token.token == "tok_123"
16+
assert token.client_id == "client_1"
17+
assert token.scopes == ["read"]
18+
assert token.expires_at is None
19+
assert token.resource is None
20+
assert token.subject is None
21+
assert token.claims is None
22+
23+
24+
def test_access_token_with_subject():
25+
"""Test AccessToken with subject field for JWT sub claim."""
26+
token = AccessToken(
27+
token="tok_123",
28+
client_id="client_1",
29+
scopes=["read"],
30+
subject="user_42",
31+
)
32+
assert token.subject == "user_42"
33+
34+
35+
def test_access_token_with_claims():
36+
"""Test AccessToken with custom claims dict."""
37+
custom_claims = {
38+
"sub": "user_42",
39+
"iss": "https://auth.example.com",
40+
"org_id": "org_7",
41+
"roles": ["admin", "editor"],
42+
}
43+
token = AccessToken(
44+
token="tok_123",
45+
client_id="client_1",
46+
scopes=["read"],
47+
claims=custom_claims,
48+
)
49+
assert token.claims == custom_claims
50+
assert token.claims["org_id"] == "org_7"
51+
assert token.claims["roles"] == ["admin", "editor"]
52+
53+
54+
def test_access_token_with_subject_and_claims():
55+
"""Test AccessToken with both subject and claims for convenience."""
56+
token = AccessToken(
57+
token="tok_123",
58+
client_id="client_1",
59+
scopes=["read", "write"],
60+
subject="user_42",
61+
claims={"sub": "user_42", "iss": "https://auth.example.com"},
62+
)
63+
assert token.subject == "user_42"
64+
assert token.claims is not None
65+
assert token.claims["sub"] == token.subject
66+
67+
68+
def test_access_token_backward_compatible():
69+
"""Test that existing code creating AccessToken without new fields still works."""
70+
token = AccessToken(
71+
token="tok_123",
72+
client_id="client_1",
73+
scopes=["read"],
74+
expires_at=1700000000,
75+
resource="https://api.example.com",
76+
)
77+
assert token.expires_at == 1700000000
78+
assert token.resource == "https://api.example.com"
79+
# New fields default to None
80+
assert token.subject is None
81+
assert token.claims is None
82+
83+
84+
def test_access_token_serialization_roundtrip():
85+
"""Test that AccessToken with new fields survives JSON serialization."""
86+
token = AccessToken(
87+
token="tok_123",
88+
client_id="client_1",
89+
scopes=["read"],
90+
subject="user_42",
91+
claims={"org_id": "org_7", "custom": True},
92+
)
93+
data = token.model_dump()
94+
restored = AccessToken.model_validate(data)
95+
assert restored.subject == "user_42"
96+
assert restored.claims == {"org_id": "org_7", "custom": True}
97+
98+
99+
def test_access_token_empty_claims():
100+
"""Test AccessToken with empty claims dict."""
101+
token = AccessToken(
102+
token="tok_123",
103+
client_id="client_1",
104+
scopes=["read"],
105+
claims={},
106+
)
107+
assert token.claims == {}
4108

5109

6110
def test_construct_redirect_uri_no_existing_params():

0 commit comments

Comments
 (0)