44Implements authorization code flow with PKCE and automatic token refresh.
55"""
66
7+ from __future__ import annotations as _annotations
8+
79import base64
810import hashlib
911import logging
1315from collections .abc import AsyncGenerator , Awaitable , Callable
1416from dataclasses import dataclass , field
1517from typing import Any , Protocol
16- from urllib .parse import quote , urlencode , urljoin , urlparse
18+ from urllib .parse import quote , urlencode , urljoin , urlparse , urlsplit , urlunsplit
1719
1820import anyio
1921import httpx
20- from pydantic import BaseModel , Field , ValidationError
22+ from pydantic import AnyUrl , BaseModel , Field , HttpUrl , ValidationError
2123
2224from mcp .client .auth .exceptions import OAuthFlowError , OAuthTokenError
2325from mcp .client .auth .utils import (
4547 OAuthToken ,
4648 ProtectedResourceMetadata ,
4749)
48- from mcp .shared .auth_utils import (
49- calculate_token_expiry ,
50- check_resource_allowed ,
51- resource_url_from_server_url ,
52- )
5350
5451logger = logging .getLogger (__name__ )
5552
@@ -61,7 +58,7 @@ class PKCEParameters(BaseModel):
6158 code_challenge : str = Field (..., min_length = 43 , max_length = 128 )
6259
6360 @classmethod
64- def generate (cls ) -> " PKCEParameters" :
61+ def generate (cls ) -> PKCEParameters :
6562 """Generate new PKCE parameters."""
6663 code_verifier = "" .join (secrets .choice (string .ascii_letters + string .digits + "-._~" ) for _ in range (128 ))
6764 digest = hashlib .sha256 (code_verifier .encode ()).digest ()
@@ -74,19 +71,15 @@ class TokenStorage(Protocol):
7471
7572 async def get_tokens (self ) -> OAuthToken | None :
7673 """Get stored tokens."""
77- ...
7874
7975 async def set_tokens (self , tokens : OAuthToken ) -> None :
8076 """Store tokens."""
81- ...
8277
8378 async def get_client_info (self ) -> OAuthClientInformationFull | None :
8479 """Get stored client information."""
85- ...
8680
8781 async def set_client_info (self , client_info : OAuthClientInformationFull ) -> None :
8882 """Store client information."""
89- ...
9083
9184
9285@dataclass
@@ -124,7 +117,7 @@ def get_authorization_base_url(self, server_url: str) -> str:
124117
125118 def update_token_expiry (self , token : OAuthToken ) -> None :
126119 """Update token expiry time using shared util function."""
127- self .token_expiry_time = calculate_token_expiry (token .expires_in )
120+ self .token_expiry_time = _calculate_token_expiry (token .expires_in )
128121
129122 def is_token_valid (self ) -> bool :
130123 """Check if current token is valid."""
@@ -148,12 +141,12 @@ def get_resource_url(self) -> str:
148141
149142 Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
150143 """
151- resource = resource_url_from_server_url (self .server_url )
144+ resource = _resource_url_from_server_url (self .server_url )
152145
153146 # If PRM provides a resource that's a valid parent, use it
154147 if self .protected_resource_metadata and self .protected_resource_metadata .resource :
155148 prm_resource = str (self .protected_resource_metadata .resource )
156- if check_resource_allowed (requested_resource = resource , configured_resource = prm_resource ):
149+ if _check_resource_allowed (requested_resource = resource , configured_resource = prm_resource ):
157150 resource = prm_resource
158151
159152 return resource
@@ -614,3 +607,82 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
614607 # Retry with new tokens
615608 self ._add_auth_header (request )
616609 yield request
610+
611+
612+ def _resource_url_from_server_url (url : str | HttpUrl | AnyUrl ) -> str :
613+ """Convert server URL to canonical resource URL per RFC 8707.
614+
615+ RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component".
616+ Returns absolute URI with lowercase scheme/host for canonical form.
617+
618+ Args:
619+ url: Server URL to convert
620+
621+ Returns:
622+ Canonical resource URL string
623+ """
624+ # Convert to string if needed
625+ url_str = str (url )
626+
627+ # Parse the URL and remove fragment, create canonical form
628+ parsed = urlsplit (url_str )
629+ canonical = urlunsplit (parsed ._replace (scheme = parsed .scheme .lower (), netloc = parsed .netloc .lower (), fragment = "" ))
630+
631+ return canonical
632+
633+
634+ def _check_resource_allowed (requested_resource : str , configured_resource : str ) -> bool :
635+ """Check if a requested resource URL matches a configured resource URL.
636+
637+ A requested resource matches if it has the same scheme, domain, port,
638+ and its path starts with the configured resource's path. This allows
639+ hierarchical matching where a token for a parent resource can be used
640+ for child resources.
641+
642+ Args:
643+ requested_resource: The resource URL being requested
644+ configured_resource: The resource URL that has been configured
645+
646+ Returns:
647+ True if the requested resource matches the configured resource
648+ """
649+ # Parse both URLs
650+ requested = urlparse (requested_resource )
651+ configured = urlparse (configured_resource )
652+
653+ # Compare scheme, host, and port (origin)
654+ if requested .scheme .lower () != configured .scheme .lower () or requested .netloc .lower () != configured .netloc .lower ():
655+ return False
656+
657+ # Handle cases like requested=/foo and configured=/foo/
658+ requested_path = requested .path
659+ configured_path = configured .path
660+
661+ # If requested path is shorter, it cannot be a child
662+ if len (requested_path ) < len (configured_path ):
663+ return False
664+
665+ # Check if the requested path starts with the configured path
666+ # Ensure both paths end with / for proper comparison
667+ # This ensures that paths like "/api123" don't incorrectly match "/api"
668+ if not requested_path .endswith ("/" ):
669+ requested_path += "/"
670+ if not configured_path .endswith ("/" ):
671+ configured_path += "/"
672+
673+ return requested_path .startswith (configured_path )
674+
675+
676+ def _calculate_token_expiry (expires_in : int | str | None ) -> float | None :
677+ """Calculate token expiry timestamp from expires_in seconds.
678+
679+ Args:
680+ expires_in: Seconds until token expiration (may be string from some servers)
681+
682+ Returns:
683+ Unix timestamp when token expires, or None if no expiry specified
684+ """
685+ if expires_in is None :
686+ return None # pragma: no cover
687+ # Defensive: handle servers that return expires_in as string
688+ return time .time () + int (expires_in )
0 commit comments