From ac24ef71e6976d25eefcad3d80c15bd477ba65a3 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 2 Feb 2026 23:35:32 +0530 Subject: [PATCH 1/6] feat: add Multiple Custom Domains (MCD) support and fix JWT verification --- .gitignore | 4 +- examples/MCD.md | 139 +++ .../auth_server/server_client.py | 599 ++++++++-- .../auth_types/__init__.py | 23 + src/auth0_server_python/error/__init__.py | 27 + .../tests/test_server_client.py | 1028 ++++++++++++++++- src/auth0_server_python/utils/helpers.py | 68 ++ 7 files changed, 1778 insertions(+), 110 deletions(-) create mode 100644 examples/MCD.md diff --git a/.gitignore b/.gitignore index fe90143..3d5c66a 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ test.py test-script.py .coverage coverage.xml - +examples/mcd-poc +IMPLEMENTATION_NOTES.md +examples/MCD_DEVELOPER_GUIDE.md \ No newline at end of file diff --git a/examples/MCD.md b/examples/MCD.md new file mode 100644 index 0000000..85d188d --- /dev/null +++ b/examples/MCD.md @@ -0,0 +1,139 @@ +# Multiple Custom Domains (MCD) Guide + +This guide explains how to implement Multiple Custom Domain (MCD) support using the Auth0 Python SDKs. + +## What is MCD? + +Multiple Custom Domains (MCD) allows your application to serve different organizations or tenants from different hostnames, each mapping to a different Auth0 tenant/domain. + +**Example:** +- `https://acme.yourapp.com` → Auth0 tenant: `acme.auth0.com` +- `https://globex.yourapp.com` → Auth0 tenant: `globex.auth0.com` + +Each tenant gets its own branded login experience while using a single application codebase. + +## Configuration Methods + +### Method 1: Static Domain (Single Tenant) + +For applications with a single Auth0 domain: + +```python +from auth0_server_python import ServerClient + +client = ServerClient( + domain="your-tenant.auth0.com", # Static string + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +### Method 2: Dynamic Domain Resolver (MCD) + +For MCD support, provide a domain resolver function that receives a `DomainResolverContext`: + +```python +from auth0_server_python import ServerClient +from auth0_server_python.auth_types import DomainResolverContext + +# Map your app hostnames to Auth0 domains +DOMAIN_MAP = { + "acme.yourapp.com": "acme.auth0.com", + "globex.yourapp.com": "globex.auth0.com", +} +DEFAULT_DOMAIN = "default.auth0.com" + +async def domain_resolver(context: DomainResolverContext) -> str: + """ + Resolve Auth0 domain based on request hostname. + + Args: + context: Contains request_url and request_headers + + Returns: + Auth0 domain string (e.g., "acme.auth0.com") + """ + # Extract hostname from request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + host = context.request_headers.get('host', DEFAULT_DOMAIN) + host_without_port = host.split(':')[0] + + # Look up Auth0 domain + return DOMAIN_MAP.get(host_without_port, DEFAULT_DOMAIN) + +client = ServerClient( + domain=domain_resolver, # Callable function + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +## DomainResolverContext + +The `DomainResolverContext` object provides request information to your resolver: + +| Property | Type | Description | +|----------|------|-------------| +| `request_url` | `Optional[str]` | Full request URL (e.g., "https://acme.yourapp.com/auth/login") | +| `request_headers` | `Optional[dict[str, str]]` | Request headers dictionary | + +**Common headers:** +- `host`: Request hostname (e.g., "acme.yourapp.com") +- `x-forwarded-host`: Original host when behind proxy/load balancer + +**Example usage:** + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + # Check if we have request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + # Use x-forwarded-host if behind proxy, otherwise use host + host = (context.request_headers.get('x-forwarded-host') or + context.request_headers.get('host', '')) + + # Remove port number if present + hostname = host.split(':')[0].lower() + + # Look up in mapping + return DOMAIN_MAP.get(hostname, DEFAULT_DOMAIN) +``` + +## Error Handling + +### DomainResolverError + +The domain resolver should return a valid Auth0 domain string. Invalid returns will raise `DomainResolverError`: + +```python +from auth0_server_python.error import DomainResolverError + +async def domain_resolver(context: DomainResolverContext) -> str: + try: + domain = lookup_domain_from_db(context) + + if not domain: + # Return default instead of None + return DEFAULT_DOMAIN + + return domain # Must be a non-empty string + + except Exception as e: + # Log error and return default + logger.error(f"Domain resolution failed: {e}") + return DEFAULT_DOMAIN +``` + +**Invalid return values that raise `DomainResolverError`:** +- `None` +- Empty string `""` +- Non-string types (int, list, dict, etc.) + +**Exceptions raised by your resolver:** +- Automatically wrapped in `DomainResolverError` +- Original exception accessible via `.original_error` \ No newline at end of file diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index c968120..bee5541 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -6,7 +6,7 @@ import asyncio import json import time -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Callable, Generic, Optional, TypeVar, Union from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx @@ -32,19 +32,25 @@ AccessTokenForConnectionErrorCode, ApiError, BackchannelLogoutError, + ConfigurationError, + DomainResolverError, MissingRequiredArgumentError, MissingTransactionError, PollingApiError, StartLinkUserError, ) from auth0_server_python.utils import PKCE, URL, State +from auth0_server_python.utils.helpers import ( + build_domain_resolver_context, + validate_resolved_domain_value, +) from authlib.integrations.base_client.errors import OAuthError from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError # Generic type for store options TStoreOptions = TypeVar('TStoreOptions') -INTERNAL_AUTHORIZE_PARAMS = ["client_id", "redirect_uri", "response_type", +INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", "code_challenge", "code_challenge_method", "state", "nonce", "scope"] @@ -55,11 +61,15 @@ class ServerClient(Generic[TStoreOptions]): """ DEFAULT_AUDIENCE_STATE_KEY = "default" + # ========================================== + # Initialization + # ========================================== + def __init__( self, - domain: str, - client_id: str, - client_secret: str, + domain: Union[str, Callable[[Optional[dict[str, Any]]], str]] = None, + client_id: str = None, + client_secret: str = None, redirect_uri: Optional[str] = None, secret: str = None, transaction_store=None, @@ -67,13 +77,13 @@ def __init__( transaction_identifier: str = "_a0_tx", state_identifier: str = "_a0_session", authorization_params: Optional[dict[str, Any]] = None, - pushed_authorization_requests: bool = False + pushed_authorization_requests: bool = False, ): """ Initialize the Auth0 server client. Args: - domain: Auth0 domain (e.g., 'your-tenant.auth0.com') + domain: Auth0 domain - either a static string (e.g., 'tenant.auth0.com') or a callable that resolves domain dynamically. client_id: Auth0 client ID client_secret: Auth0 client secret redirect_uri: Default redirect URI for authentication @@ -83,12 +93,34 @@ def __init__( transaction_identifier: Identifier for transaction data state_identifier: Identifier for state data authorization_params: Default parameters for authorization requests + pushed_authorization_requests: Whether to use Pushed Authorization Requests """ if not secret: raise MissingRequiredArgumentError("secret") - # Store configuration - self._domain = domain + if domain is None: + raise ConfigurationError( + "Domain is required" + ) + + # Validate domain type + if not isinstance(domain, str) and not callable(domain): + raise ConfigurationError( + f"Domain must be either a string or a callable function. " + f"Got {type(domain).__name__} instead." + ) + + # Determine if domain is static string or dynamic callable + if callable(domain): + self._domain = None + self._domain_resolver = domain + else: + # Validate static domain string + domain_str = str(domain) + if not domain_str or domain_str.strip() == "": + raise ConfigurationError("Domain cannot be empty.") + self._domain = domain_str + self._domain_resolver = None self._client_id = client_id self._client_secret = client_secret self._redirect_uri = redirect_uri @@ -109,12 +141,15 @@ def __init__( self._my_account_client = MyAccountClient(domain=domain) - async def _fetch_oidc_metadata(self, domain: str) -> dict: - metadata_url = f"https://{domain}/.well-known/openid-configuration" - async with httpx.AsyncClient() as client: - response = await client.get(metadata_url) - response.raise_for_status() - return response.json() + # Cache for OIDC metadata and JWKS (Requirement 3: MCD Support) + self._metadata_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} + self._jwks_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} + self._cache_ttl = 3600 # 1 hour TTL + self._cache_max_size = 100 # Max 100 domains to prevent memory bloat + + # ========================================== + # Interactive Login Flow + # ========================================== async def start_interactive_login( self, @@ -126,12 +161,38 @@ async def start_interactive_login( Args: options: Configuration options for the login process + store_options: Store options containing request/response Returns: Authorization URL to redirect the user to """ options = options or StartInteractiveLoginOptions() + # Resolve domain (static or dynamic) + if self._domain_resolver: + # Build context and call developer's resolver + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + origin_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + else: + origin_domain = self._domain + + # Fetch OIDC metadata from resolved domain + try: + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + except Exception as e: + raise ApiError("metadata_error", + "Failed to fetch OIDC metadata", e) + # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: @@ -160,17 +221,20 @@ async def start_interactive_login( state = PKCE.generate_random_string(32) auth_params["state"] = state - #merge any requested scope with defaults + # Merge any requested scope with defaults requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope - # Build the transaction data to store + # Build the transaction data to store with origin domain and issuer transaction_data = TransactionData( code_verifier=code_verifier, app_state=options.app_state, audience=audience, + origin_domain=origin_domain, + origin_issuer=origin_issuer, + redirect_uri=auth_params.get("redirect_uri"), ) # Store the transaction data @@ -179,11 +243,9 @@ async def start_interactive_login( transaction_data, options=store_options ) - try: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) - except Exception as e: - raise ApiError("metadata_error", - "Failed to fetch OIDC metadata", e) + + # Set metadata for OAuth client + self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: par_endpoint = self._oauth.metadata.get( @@ -274,34 +336,101 @@ async def complete_interactive_login( if not code: raise MissingRequiredArgumentError("code") - if not self._oauth.metadata or "token_endpoint" not in self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Get origin domain and issuer from transaction + origin_domain = transaction_data.origin_domain + origin_issuer = transaction_data.origin_issuer + + # Fetch metadata from the origin domain + metadata = await self._get_oidc_metadata_cached(origin_domain) + self._oauth.metadata = metadata # Exchange the code for tokens + # Use redirect_uri from transaction if available, otherwise fall back to default + token_redirect_uri = transaction_data.redirect_uri or self._redirect_uri try: token_endpoint = self._oauth.metadata["token_endpoint"] token_response = await self._oauth.fetch_token( token_endpoint, code=code, code_verifier=transaction_data.code_verifier, - redirect_uri=self._redirect_uri, + redirect_uri=token_redirect_uri, ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) raise ApiError( "token_error", f"Token exchange failed: {str(e)}", e) + print(f"Token Response : {token_response}") - # Use the userinfo field from the token_response for user claims + # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") user_claims = None + id_token = token_response.get("id_token") + if user_info: user_claims = UserClaims.parse_obj(user_info) - else: - id_token = token_response.get("id_token") - if id_token: - claims = jwt.decode(id_token, options={ - "verify_signature": False}) + elif id_token: + # Fetch JWKS for signature verification (Requirement 3) + jwks = await self._get_jwks_cached(origin_domain, metadata) + + # Decode and verify ID token with signature verification enabled + try: + # Get the signing key from JWKS + unverified_header = jwt.get_unverified_header(id_token) + kid = unverified_header.get("kid") + + # Find the key with matching kid + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise ApiError( + "jwks_key_not_found", + f"No matching key found in JWKS for kid: {kid}" + ) + + claims = jwt.decode( + id_token, + signing_key.key, + algorithms=["RS256"], + audience=self._client_id, + issuer=origin_issuer, + options={"verify_signature": True} + ) user_claims = UserClaims.parse_obj(claims) + except jwt.InvalidSignatureError as e: + raise ApiError( + "invalid_signature", + f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", + e + ) + except jwt.InvalidAudienceError as e: + raise ApiError( + "invalid_audience", + f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", + e + ) + except jwt.InvalidIssuerError as e: + raise ApiError( + "invalid_issuer", + f"ID token issuer mismatch. Expected: {origin_issuer}. Ensure your Auth0 domain is configured correctly: {str(e)}", + e + ) + except jwt.ExpiredSignatureError as e: + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) + except jwt.InvalidTokenError as e: + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) + # Build a token set using the token response data token_set = TokenSet( @@ -323,6 +452,7 @@ async def complete_interactive_login( # might be None if not provided refresh_token=token_response.get("refresh_token"), token_sets=[token_set], + domain=origin_domain, internal={ "sid": sid, "created_at": int(time.time()) @@ -346,6 +476,10 @@ async def complete_interactive_login( return result + # ========================================== + # User Account Linking + # ========================================== + async def start_link_user( self, options, @@ -493,6 +627,10 @@ async def complete_unlink_user( "app_state": result.get("app_state") } + # ========================================== + # Backchannel Authentication (CIBA) + # ========================================== + async def login_backchannel( self, options: dict[str, Any], @@ -539,6 +677,10 @@ async def login_backchannel( } return result + # ========================================== + # Session & Token Management + # ========================================== + async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: """ Retrieves the user from the store, or None if no user found. @@ -552,6 +694,25 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Validate session domain matches current request domain + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + return None + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() return state_data.get("user") @@ -570,6 +731,25 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Validate session domain matches current request domain + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + return None + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_data = {k: v for k, v in state_data.items() @@ -599,6 +779,28 @@ async def get_access_token( """ state_data = await self._state_store.get(self._state_identifier, store_options) + # Validate session domain matches current request domain + if state_data and self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + raise AccessTokenError( + AccessTokenErrorCode.MISSING_REFRESH_TOKEN, + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + auth_params = self._default_authorization_params or {} # Get audience passed in on options or use defaults @@ -630,7 +832,12 @@ async def get_access_token( # Get new token with refresh token try: - get_refresh_token_options = {"refresh_token": state_data_dict["refresh_token"]} + # Use session's domain for token refresh + session_domain = state_data_dict.get("domain") or self._domain + get_refresh_token_options = { + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain + } if audience: get_refresh_token_options["audience"] = audience @@ -656,50 +863,7 @@ async def get_access_token( f"Failed to get token with refresh token: {str(e)}" ) - def _merge_scope_with_defaults( - self, - request_scope: Optional[str], - audience: Optional[str] - ) -> Optional[str]: - audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY - default_scopes = "" - if self._default_authorization_params and "scope" in self._default_authorization_params: - auth_param_scope = self._default_authorization_params.get("scope") - # For backwards compatibility, allow scope to be a single string - # or dictionary by audience for MRRT - if isinstance(auth_param_scope, dict) and audience in auth_param_scope: - default_scopes = auth_param_scope[audience] - elif isinstance(auth_param_scope, str): - default_scopes = auth_param_scope - - default_scopes_list = default_scopes.split() - request_scopes_list = (request_scope or "").split() - - merged_scopes = list(dict.fromkeys(default_scopes_list + request_scopes_list)) - return " ".join(merged_scopes) if merged_scopes else None - - - def _find_matching_token_set( - self, - token_sets: list[dict[str, Any]], - audience: Optional[str], - scope: Optional[str] - ) -> Optional[dict[str, Any]]: - audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY - requested_scopes = set(scope.split()) if scope else set() - matches: list[tuple[int, dict]] = [] - for token_set in token_sets: - token_set_audience = token_set.get("audience") - token_set_scopes = set(token_set.get("scope", "").split()) - if token_set_audience == audience and token_set_scopes == requested_scopes: - # short-circuit if exact match - return token_set - if token_set_audience == audience and token_set_scopes.issuperset(requested_scopes): - # consider stored tokens with more scopes than requested by number of scopes - matches.append((len(token_set_scopes), token_set)) - - # Return the token set with the smallest superset of scopes that matches the requested audience and scopes - return min(matches, key=lambda t: t[0])[1] if matches else None + async def get_access_token_for_connection( self, @@ -751,10 +915,13 @@ async def get_access_token_for_connection( "A refresh token was not found but is required to be able to retrieve an access token for a connection." ) # Get new token for connection + # Use session's domain for token exchange + session_domain = state_data_dict.get("domain") or self._domain token_endpoint_response = await self.get_token_for_connection({ "connection": options.get("connection"), "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"] + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain }) # Update state data with new token @@ -766,6 +933,10 @@ async def get_access_token_for_connection( return token_endpoint_response["access_token"] + # ========================================== + # Logout + # ========================================== + async def logout( self, options: Optional[LogoutOptions] = None, @@ -776,9 +947,25 @@ async def logout( # Delete the session from the state store await self._state_store.delete(self._state_identifier, store_options) + # Resolve domain dynamically for MCD support + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + else: + domain = self._domain + # Use the URL helper to create the logout URL. logout_url = URL.create_logout_url( - self._domain, self._client_id, options.return_to) + domain, self._client_id, options.return_to) return logout_url @@ -798,9 +985,41 @@ async def handle_backchannel_logout( raise BackchannelLogoutError("Missing logout token") try: - # Decode the token without verification - claims = jwt.decode(logout_token, options={ - "verify_signature": False}) + # Fetch JWKS for signature verification (Requirement 3) + jwks = await self._get_jwks_cached(self._domain) + + # Decode and verify logout token with signature verification enabled + try: + # Get the signing key from JWKS + unverified_header = jwt.get_unverified_header(logout_token) + kid = unverified_header.get("kid") + + # Find the key with matching kid + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise BackchannelLogoutError( + f"No matching key found in JWKS for kid: {kid}" + ) + + claims = jwt.decode( + logout_token, + signing_key.key, + algorithms=["RS256"], + options={"verify_signature": True} + ) + except jwt.InvalidSignatureError as e: + raise BackchannelLogoutError( + f"Logout token signature verification failed: {str(e)}" + ) + except jwt.InvalidTokenError as e: + raise BackchannelLogoutError( + f"Logout token verification failed: {str(e)}" + ) # Validate the token is a logout token events = claims.get("events", {}) @@ -816,11 +1035,195 @@ async def handle_backchannel_logout( await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) - except (jwt.JoseError, ValidationError) as e: + except (jwt.PyJWTError, ValidationError) as e: raise BackchannelLogoutError( f"Error processing logout token: {str(e)}") - # Authlib Helpers + # ========================================== + # Internal Helpers + # ========================================== + + # ------------------------------------------ + # OIDC Discovery & Metadata + # ------------------------------------------ + + def _normalize_domain(self, domain: str) -> str: + """ + Normalize domain for comparison and URL construction. + Handles cases with/without https:// scheme. + """ + if domain.startswith('https://'): + return domain + elif domain.startswith('http://'): + return domain.replace('http://', 'https://') + else: + return f'https://{domain}' + + async def _fetch_oidc_metadata(self, domain: str) -> dict: + """Fetch OIDC metadata from domain.""" + normalized_domain = self._normalize_domain(domain) + metadata_url = f"{normalized_domain}/.well-known/openid-configuration" + async with httpx.AsyncClient() as client: + response = await client.get(metadata_url) + response.raise_for_status() + return response.json() + + async def _get_oidc_metadata_cached(self, domain: str) -> dict: + """ + Get OIDC metadata with caching. + + Args: + domain: Auth0 domain + + Returns: + OIDC metadata document + """ + now = time.time() + + # Check cache + if domain in self._metadata_cache: + cached = self._metadata_cache[domain] + if cached["expires_at"] > now: + return cached["data"] + + # Cache miss/expired - fetch fresh + metadata = await self._fetch_oidc_metadata(domain) + + # Enforce cache size limit (FIFO eviction) + if len(self._metadata_cache) >= self._cache_max_size: + oldest_key = next(iter(self._metadata_cache)) + del self._metadata_cache[oldest_key] + + # Store in cache + self._metadata_cache[domain] = { + "data": metadata, + "expires_at": now + self._cache_ttl + } + + return metadata + + async def _fetch_jwks(self, jwks_uri: str) -> dict: + """ + Fetch JWKS (JSON Web Key Set) from jwks_uri. + + Args: + jwks_uri: The JWKS endpoint URL + + Returns: + JWKS document containing public keys + + Raises: + ApiError: If JWKS fetch fails + """ + try: + async with httpx.AsyncClient() as client: + response = await client.get(jwks_uri) + response.raise_for_status() + return response.json() + except Exception as e: + raise ApiError("jwks_fetch_error", f"Failed to fetch JWKS from {jwks_uri}", e) + + async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: + """ + Get JWKS with caching usingOIDC discovery. + + Args: + domain: Auth0 domain + metadata: Optional OIDC metadata (if already fetched) + + Returns: + JWKS document + + Raises: + ApiError: If JWKS fetch fails or jwks_uri missing from metadata + """ + now = time.time() + + # Check cache + if domain in self._jwks_cache: + cached = self._jwks_cache[domain] + if cached["expires_at"] > now: + return cached["data"] + + # Get jwks_uri from OIDC metadata + if not metadata: + metadata = await self._get_oidc_metadata_cached(domain) + + jwks_uri = metadata.get('jwks_uri') + if not jwks_uri: + raise ApiError( + "missing_jwks_uri", + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." + ) + + # Fetch JWKS + jwks = await self._fetch_jwks(jwks_uri) + + # Enforce cache size limit (FIFO eviction) + if len(self._jwks_cache) >= self._cache_max_size: + oldest_key = next(iter(self._jwks_cache)) + del self._jwks_cache[oldest_key] + + # Store in cache + self._jwks_cache[domain] = { + "data": jwks, + "expires_at": now + self._cache_ttl + } + + return jwks + + # ------------------------------------------ + # Token & Scope Management - MRRT + # ------------------------------------------ + + def _merge_scope_with_defaults( + self, + request_scope: Optional[str], + audience: Optional[str] + ) -> Optional[str]: + audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY + default_scopes = "" + if self._default_authorization_params and "scope" in self._default_authorization_params: + auth_param_scope = self._default_authorization_params.get("scope") + # For backwards compatibility, allow scope to be a single string + # or dictionary by audience for MRRT + if isinstance(auth_param_scope, dict) and audience in auth_param_scope: + default_scopes = auth_param_scope[audience] + elif isinstance(auth_param_scope, str): + default_scopes = auth_param_scope + + default_scopes_list = default_scopes.split() + request_scopes_list = (request_scope or "").split() + + merged_scopes = list(dict.fromkeys(default_scopes_list + request_scopes_list)) + return " ".join(merged_scopes) if merged_scopes else None + + + def _find_matching_token_set( + self, + token_sets: list[dict[str, Any]], + audience: Optional[str], + scope: Optional[str] + ) -> Optional[dict[str, Any]]: + audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY + requested_scopes = set(scope.split()) if scope else set() + matches: list[tuple[int, dict]] = [] + for token_set in token_sets: + token_set_audience = token_set.get("audience") + token_set_scopes = set(token_set.get("scope", "").split()) + if token_set_audience == audience and token_set_scopes == requested_scopes: + # short-circuit if exact match + return token_set + if token_set_audience == audience and token_set_scopes.issuperset(requested_scopes): + # consider stored tokens with more scopes than requested by number of scopes + matches.append((len(token_set_scopes), token_set)) + + # Return the token set with the smallest superset of scopes that matches the requested audience and scopes + return min(matches, key=lambda t: t[0])[1] if matches else None + + # ------------------------------------------ + # URL Builders + # ------------------------------------------ async def _build_link_user_url( self, @@ -837,7 +1240,7 @@ async def _build_link_user_url( # Get metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get authorization endpoint auth_endpoint = self._oauth_metadata.get("authorization_endpoint", @@ -880,7 +1283,7 @@ async def _build_unlink_user_url( # Get metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get authorization endpoint auth_endpoint = self._oauth_metadata.get("authorization_endpoint", @@ -1025,7 +1428,7 @@ async def initiate_backchannel_authentication( try: # Fetch OpenID Connect metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get the issuer from metadata issuer = self._oauth_metadata.get( @@ -1120,7 +1523,7 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, try: # Ensure we have the OIDC metadata if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(self._domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -1178,6 +1581,10 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, e ) + # ========================================== + # Token Exchange Operations + # ========================================== + async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -1197,9 +1604,12 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise MissingRequiredArgumentError("refresh_token") try: - # Ensure we have the OIDC metadata + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Ensure we have the OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -1280,9 +1690,12 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" try: - # Ensure we have OIDC metadata + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Ensure we have OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -1340,6 +1753,10 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A e ) + # ========================================== + # Account Connection + # ========================================== + async def start_connect_account( self, options: ConnectAccountOptions, diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 677a7da..7c27d36 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -66,6 +66,7 @@ class SessionData(BaseModel): refresh_token: Optional[str] = None token_sets: list[TokenSet] = Field(default_factory=list) connection_token_sets: list[ConnectionTokenSet] = Field(default_factory=list) + domain: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -89,6 +90,8 @@ class TransactionData(BaseModel): app_state: Optional[Any] = None auth_session: Optional[str] = None redirect_uri: Optional[str] = None + origin_domain: Optional[str] = None + origin_issuer: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -252,3 +255,23 @@ class CompleteConnectAccountResponse(BaseModel): created_at: str expires_at: Optional[str] = None app_state: Optional[Any] = None + + +class DomainResolverContext(BaseModel): + """ + Context passed to domain resolver function for MCD support. + + Contains request information needed to determine the correct Auth0 domain + based on the incoming request's hostname or headers. + + Attributes: + request_url: The full request URL (e.g., "https://a.my-app.com/auth/login") + request_headers: Dictionary of request headers (e.g., {"host": "a.my-app.com", "x-forwarded-host": "..."}) + + Example: + async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get('host', '').split(':')[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) + """ + request_url: Optional[str] = None + request_headers: Optional[dict[str, str]] = None diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index ef181ce..93fcba2 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -101,6 +101,18 @@ def __init__(self, argument: str): self.argument = argument +class ConfigurationError(Auth0Error): + """ + Error raised when SDK configuration is invalid. + This includes invalid combinations of parameters or incorrect configuration values. + """ + code = "configuration_error" + + def __init__(self, message: str): + super().__init__(message) + self.name = "ConfigurationError" + + class BackchannelLogoutError(Auth0Error): """ Error raised during backchannel logout processing. @@ -113,6 +125,21 @@ def __init__(self, message: str): self.name = "BackchannelLogoutError" +class DomainResolverError(Auth0Error): + """ + Error raised when domain resolver function fails or returns invalid value. + + This error indicates an issue with the custom domain resolver function + provided for MCD (Multiple Custom Domains) support. + """ + code = "domain_resolver_error" + + def __init__(self, message: str, original_error: Exception = None): + super().__init__(message) + self.name = "DomainResolverError" + self.original_error = original_error + + class AccessTokenForConnectionError(Auth0Error): """Error when retrieving access tokens for a specific connection fails.""" diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 9f4f2cd..97fcc77 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1,8 +1,9 @@ import json import time -from unittest.mock import ANY, AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock, patch from urllib.parse import parse_qs, urlparse +import jwt import pytest from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_server.server_client import ServerClient @@ -12,13 +13,17 @@ ConnectAccountRequest, ConnectAccountResponse, ConnectParams, + DomainResolverContext, LogoutOptions, + StateData, TransactionData, ) from auth0_server_python.error import ( AccessTokenForConnectionError, ApiError, BackchannelLogoutError, + ConfigurationError, + DomainResolverError, MissingRequiredArgumentError, MissingTransactionError, PollingApiError, @@ -42,7 +47,7 @@ async def test_init_no_secret_raises(): @pytest.mark.asyncio -async def test_start_interactive_login_no_redirect_uri(): +async def test_start_interactive_login_no_redirect_uri(mocker): client = ServerClient( domain="auth0.local", client_id="", @@ -51,6 +56,14 @@ async def test_start_interactive_login_no_redirect_uri(): transaction_store=AsyncMock(), secret="some-secret" ) + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "authorization_endpoint": "https://auth0.local/authorize"} + ) + with pytest.raises(MissingRequiredArgumentError) as exc: await client.start_interactive_login() # Check the error message @@ -74,7 +87,7 @@ async def test_start_interactive_login_builds_auth_url(mocker): # Mock out HTTP calls or the internal methods that create the auth URL mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) mock_oauth = mocker.patch.object( @@ -115,8 +128,13 @@ async def test_complete_interactive_login_no_transaction(): @pytest.mark.asyncio async def test_complete_interactive_login_returns_app_state(mocker): mock_tx_store = AsyncMock() - # The stored transaction includes an appState - mock_tx_store.get.return_value = TransactionData(code_verifier="123", app_state={"foo": "bar"}) + # The stored transaction includes an appState with origin_domain and origin_issuer + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + app_state={"foo": "bar"}, + origin_domain="auth0.local", + origin_issuer="https://auth0.local/" + ) mock_state_store = AsyncMock() @@ -129,6 +147,13 @@ async def test_complete_interactive_login_returns_app_state(mocker): secret="some-secret", ) + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "token_endpoint": "https://auth0.local/token"} + ) + # Patch token exchange mocker.patch.object(client._oauth, "metadata", {"token_endpoint": "https://auth0.local/token"}) @@ -204,7 +229,7 @@ async def test_complete_link_user_returns_app_state(mocker): ) # Patch token exchange - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={"token_endpoint": "https://auth0.local/token"}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={"token_endpoint": "https://auth0.local/token"}) async_fetch_token = AsyncMock() async_fetch_token.return_value = { "access_token": "token123", @@ -400,7 +425,8 @@ async def test_get_access_token_refresh_expired(mocker): assert token == "new_token" mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ - "refresh_token": "refresh_xyz" + "refresh_token": "refresh_xyz", + "domain": "auth0.local" }) @pytest.mark.asyncio @@ -441,6 +467,7 @@ async def test_get_access_token_refresh_merging_default_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "default", "scope": "openid profile email foo:bar" }) @@ -482,6 +509,7 @@ async def test_get_access_token_refresh_with_auth_params_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "scope": "openid profile email" }) @@ -522,6 +550,7 @@ async def test_get_access_token_refresh_with_auth_params_audience(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "my_audience" }) @@ -568,6 +597,7 @@ async def test_get_access_token_mrrt(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -621,6 +651,7 @@ async def test_get_access_token_mrrt_with_auth_params_scope(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -848,6 +879,18 @@ async def test_handle_backchannel_logout_ok(mocker): secret="some-secret" ) + # Mock JWKS fetch to prevent network call + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) mocker.patch("jwt.decode", return_value={ "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, "sub": "user_sub", @@ -874,7 +917,7 @@ async def test_build_link_user_url_success(mocker): # Patch _fetch_oidc_metadata to return an authorization_endpoint mock_fetch = mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -932,7 +975,7 @@ async def test_build_link_user_url_fallback_authorize(mocker): # Patch _fetch_oidc_metadata to NOT have an authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} # empty dict, triggers fallback ) @@ -969,7 +1012,7 @@ async def test_build_unlink_user_url_success(mocker): # Patch out metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1002,7 +1045,7 @@ async def test_build_unlink_user_url_fallback_authorize(mocker): ) # No 'authorization_endpoint' - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={}) result_url = await client._build_unlink_user_url( connection="", @@ -1033,7 +1076,7 @@ async def test_build_unlink_user_url_with_metadata(mocker): # Patch the metadata fetch to include a valid authorization endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1086,7 +1129,7 @@ async def test_build_unlink_user_url_no_authorization_endpoint(mocker): # Patch _fetch_oidc_metadata to return no authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} ) result_url = await client._build_unlink_user_url( @@ -1117,7 +1160,7 @@ async def test_backchannel_auth_with_audience_and_binding_message(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1166,7 +1209,7 @@ async def test_backchannel_auth_rar(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1217,7 +1260,7 @@ async def test_backchannel_auth_token_exchange_failed(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1267,7 +1310,7 @@ async def test_initiate_backchannel_authentication_success(mocker): # Mock OIDC metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -1315,7 +1358,7 @@ async def test_initiate_backchannel_authentication_error_response(mocker): ) mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -1932,3 +1975,952 @@ async def test_complete_connect_account_no_transactions(mocker): # Assert assert "transaction" in str(exc.value) mock_my_account_client.complete_connect_account.assert_not_awaited() + + +# ============================================================================= +# Requirement 1: Multiple Issuer Configuration Methods Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_as_static_string(): + """Test Method 1: Static domain string configuration.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain == "tenant.auth0.com" + assert client._domain_resolver is None + + +@pytest.mark.asyncio +async def test_domain_as_callable_function(): + """Test Method 2: Domain resolver function configuration.""" + async def domain_resolver(store_options): + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain is None + assert client._domain_resolver == domain_resolver + + +@pytest.mark.asyncio +async def test_missing_domain_raises_configuration_error(): + """Test that missing domain parameter raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain is required"): + ServerClient( + domain=None, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_invalid_domain_type_list(): + """Test that list domain raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="must be either a string or a callable"): + ServerClient( + domain=["tenant.auth0.com"], + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_empty_domain_string(): + """Test that empty domain string raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain cannot be empty"): + ServerClient( + domain="", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +# ============================================================================= +# Requirement 2: Domain Resolver Context Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_resolver_receives_context(mocker): + """Test that domain resolver receives DomainResolverContext with request data.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock request with headers + mock_request = MagicMock() + mock_request.url = "https://a.my-app.com/auth/login" + mock_request.headers = {"host": "a.my-app.com", "x-forwarded-host": "a.my-app.com"} + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options={"request": mock_request}) + except: + pass # We only care about context being passed + + assert received_context is not None + assert isinstance(received_context, DomainResolverContext) + assert received_context.request_url == "https://a.my-app.com/auth/login" + assert received_context.request_headers.get("host") == "a.my-app.com" + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_none(): + """Test that domain resolver returning None raises DomainResolverError.""" + async def bad_resolver(context): + return None + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="returned None"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_empty_string(): + """Test that domain resolver returning empty string raises DomainResolverError.""" + async def bad_resolver(context): + return "" + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="empty string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_exception(): + """Test that domain resolver exceptions are wrapped in DomainResolverError.""" + async def bad_resolver(context): + raise ValueError("Something went wrong") + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="raised an exception"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_with_no_request(mocker): + """Test that domain resolver works with empty context when no request.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options=None) + except: + pass + + assert received_context is not None + assert received_context.request_url is None + assert received_context.request_headers is None + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_non_string_type(): + """Test that domain resolver returning non-string raises DomainResolverError.""" + async def bad_resolver(context): + return 12345 + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="must return a string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +# ============================================================================= +# Requirement 3: OIDC Metadata and JWKS Fetching Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_fetch_jwks_success(): + """Test successful JWKS fetch from URI.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_jwks = { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "kid": "test-key-id", + "n": "test-modulus", + "e": "AQAB" + } + ] + } + + # Mock httpx client + mock_response = MagicMock() + mock_response.json.return_value = mock_jwks + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = mock_response + + with patch('httpx.AsyncClient', return_value=mock_client): + jwks = await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + assert jwks == mock_jwks + assert "keys" in jwks + mock_client.get.assert_awaited_once_with("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_fetch_jwks_failure(): + """Test JWKS fetch failure raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock httpx client to raise exception + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.side_effect = Exception("Network error") + + with patch('httpx.AsyncClient', return_value=mock_client): + with pytest.raises(ApiError, match="Failed to fetch JWKS"): + await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_oidc_metadata_caching(): + """Test OIDC metadata is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize", + "token_endpoint": "https://tenant.auth0.com/oauth/token", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + # Mock _fetch_oidc_metadata to track calls + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call - should fetch + result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result1 == mock_metadata + assert fetch_count == 1 + + # Second call - should use cache + result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result2 == mock_metadata + assert fetch_count == 1 # Should NOT increment + + # Verify cache contains data + assert "tenant.auth0.com" in client._metadata_cache + assert client._metadata_cache["tenant.auth0.com"]["data"] == mock_metadata + + +@pytest.mark.asyncio +async def test_oidc_metadata_cache_expiration(): + """Test OIDC metadata cache expires after TTL.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set short TTL for testing + client._cache_ttl = 1 # 1 second + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 1 + + # Wait for cache to expire + time.sleep(1.1) + + # Second call after expiration - should fetch again + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 2 + + +@pytest.mark.asyncio +async def test_jwks_caching(): + """Test JWKS is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + mock_jwks = { + "keys": [{"kty": "RSA", "kid": "key1"}] + } + + # Mock the fetch methods + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata) + + fetch_count = 0 + async def mock_fetch_jwks(uri): + nonlocal fetch_count + fetch_count += 1 + return mock_jwks + + client._fetch_jwks = mock_fetch_jwks + + # First call - should fetch + result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result1 == mock_jwks + assert fetch_count == 1 + + # Second call - should use cache + result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result2 == mock_jwks + assert fetch_count == 1 # Should NOT increment + + +@pytest.mark.asyncio +async def test_jwks_cache_size_limit(): + """Test JWKS cache enforces max size limit with FIFO eviction.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set small cache size for testing + client._cache_max_size = 3 + + mock_jwks = {"keys": [{"kty": "RSA"}]} + + # Mock methods + async def mock_fetch_metadata(domain): + return {"jwks_uri": f"https://{domain}/.well-known/jwks.json"} + + async def mock_fetch_jwks(uri): + return mock_jwks + + client._fetch_oidc_metadata = mock_fetch_metadata + client._fetch_jwks = mock_fetch_jwks + + # Fill cache to limit + await client._get_jwks_cached("domain1.auth0.com") + await client._get_jwks_cached("domain2.auth0.com") + await client._get_jwks_cached("domain3.auth0.com") + + assert len(client._jwks_cache) == 3 + assert "domain1.auth0.com" in client._jwks_cache + + # Add one more - should evict oldest (domain1) + await client._get_jwks_cached("domain4.auth0.com") + + assert len(client._jwks_cache) == 3 + assert "domain1.auth0.com" not in client._jwks_cache # Evicted + assert "domain4.auth0.com" in client._jwks_cache + + +@pytest.mark.asyncio +async def test_jwks_missing_uri_raises_error(): + """Test that missing jwks_uri in metadata raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Metadata WITHOUT jwks_uri + mock_metadata_no_jwks_uri = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize" + # No jwks_uri + } + + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata_no_jwks_uri) + + # Should raise ApiError when jwks_uri is missing + with pytest.raises(ApiError) as exc_info: + await client._get_jwks_cached("tenant.auth0.com") + + assert exc_info.value.code == "missing_jwks_uri" + assert "non-RFC-compliant" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_metadata_cache_size_limit(): + """Test OIDC metadata cache enforces max size limit.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + client._cache_max_size = 2 + + async def mock_fetch(domain): + return {"issuer": f"https://{domain}/"} + + client._fetch_oidc_metadata = mock_fetch + + # Fill cache + await client._get_oidc_metadata_cached("domain1.auth0.com") + await client._get_oidc_metadata_cached("domain2.auth0.com") + + assert len(client._metadata_cache) == 2 + + # Add third - should evict first + await client._get_oidc_metadata_cached("domain3.auth0.com") + + assert len(client._metadata_cache) == 2 + assert "domain1.auth0.com" not in client._metadata_cache + assert "domain3.auth0.com" in client._metadata_cache + + +# ============================================================================= +# Requirement 4: Issuer Validation Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_complete_login_issuer_validation_success(mocker): + """Test complete login with valid issuer in ID token.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode with valid issuer + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", # Matches origin_issuer + "aud": "test_client" + }) + + # Should succeed without raising error + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert result is not None + assert "state_data" in result + + +@pytest.mark.asyncio +async def test_complete_login_issuer_mismatch_raises_error(mocker): + """Test that issuer mismatch in ID token raises ApiError.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode to raise InvalidIssuerError + mocker.patch("jwt.decode", side_effect=jwt.InvalidIssuerError("Invalid issuer")) + + # Should raise ApiError with invalid_issuer code + with pytest.raises(ApiError) as exc_info: + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert exc_info.value.code == "invalid_issuer" + assert "issuer mismatch" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_normalize_domain_handles_different_schemes(): + """Test that _normalize_domain handles various URL schemes correctly.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Test domain without scheme + assert client._normalize_domain("auth0.com") == "https://auth0.com" + + # Test domain with https scheme (should remain unchanged) + assert client._normalize_domain("https://auth0.com") == "https://auth0.com" + + # Test domain with http scheme (should convert to https) + assert client._normalize_domain("http://auth0.com") == "https://auth0.com" + + # Test domain with trailing slash + assert client._normalize_domain("https://auth0.com/") == "https://auth0.com/" + + +# ============================================================================= +# Requirements 5-8: Domain-specific Session Management Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_session_stores_origin_domain(mocker): + """Test that session stores origin domain from login (Requirement 5).""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant1.auth0.com", + origin_issuer="https://tenant1.auth0.com/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain="tenant1.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "token_endpoint": "https://tenant1.auth0.com/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": "https://tenant1.auth0.com/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify session has domain field set + assert captured_state.domain == "tenant1.auth0.com" + + +@pytest.mark.asyncio +async def test_cross_domain_session_rejected(): + """Test that session from domain1 cannot be used with domain2 (Requirement 5).""" + # Create session with domain1 + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns domain2 (different from session) + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # get_user should return None (session rejected) + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_logout_uses_current_domain(mocker): + """Test that logout uses current resolved domain (Requirement 7).""" + current_domain = "tenant2.auth0.com" + + async def domain_resolver(context): + return current_domain + + mock_state_store = AsyncMock() + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + logout_url = await client.logout(store_options={"request": {}}) + + # Verify logout URL uses current domain + assert current_domain in logout_url + assert logout_url.startswith(f"https://{current_domain}") + + +@pytest.mark.asyncio +async def test_logout_clears_session_for_current_domain(): + """Test that logout clears session (Requirement 7).""" + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + await client.logout() + + # Verify session was deleted + mock_state_store.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_domain_migration_old_sessions_remain_valid(): + """Test that old sessions remain valid with old domain requests (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns old domain + async def domain_resolver(context): + return old_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should successfully retrieve user + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + +@pytest.mark.asyncio +async def test_domain_migration_new_sessions_use_new_domain(mocker): + """Test that new logins create sessions with new domain (Requirement 8).""" + new_domain = "new-tenant.auth0.com" + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain=new_domain, + origin_issuer=f"https://{new_domain}/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain=new_domain, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": f"https://{new_domain}/", + "token_endpoint": f"https://{new_domain}/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": f"https://{new_domain}/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify new session has new domain + assert captured_state.domain == new_domain + + +@pytest.mark.asyncio +async def test_domain_migration_sessions_isolated(): + """Test that old domain sessions cannot be used with new domain (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + new_domain = "new-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns NEW domain (migration happened) + async def domain_resolver(context): + return new_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should reject old session + user = await client.get_user(store_options={"request": {}}) + assert user is None \ No newline at end of file diff --git a/src/auth0_server_python/utils/helpers.py b/src/auth0_server_python/utils/helpers.py index c57ab18..1a49835 100644 --- a/src/auth0_server_python/utils/helpers.py +++ b/src/auth0_server_python/utils/helpers.py @@ -5,6 +5,8 @@ import time from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse +from auth0_server_python.auth_types import DomainResolverContext +from auth0_server_python.error import DomainResolverError class PKCE: @@ -224,3 +226,69 @@ def create_logout_url(domain: str, client_id: str, return_to: Optional[str] = No if return_to: params["returnTo"] = return_to return URL.build_url(base_url, params) + + +# ============================================================================= +# Domain Resolver Utilities +# ============================================================================= + +def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'DomainResolverContext': + """ + Build DomainResolverContext from store_options. + + Extracts request information in a framework-agnostic way using duck typing. + + Args: + store_options: Dictionary containing 'request' and 'response' objects + + Returns: + DomainResolverContext with extracted request data + """ + + if not store_options: + return DomainResolverContext() + + request = store_options.get('request') + if not request: + return DomainResolverContext() + + # Framework-agnostic extraction using duck typing + request_url = str(request.url) if hasattr(request, 'url') else None + request_headers = dict(request.headers) if hasattr(request, 'headers') else None + + return DomainResolverContext( + request_url=request_url, + request_headers=request_headers + ) + + +def validate_resolved_domain_value(domain_value: Any) -> str: + """ + Validate the value returned by domain resolver. + + Args: + domain_value: The value returned by the domain resolver + + Returns: + The validated domain string + + Raises: + DomainResolverError: If the returned value is invalid + """ + + if domain_value is None: + raise DomainResolverError( + "Domain resolver returned None. Must return a valid domain string." + ) + + if not isinstance(domain_value, str): + raise DomainResolverError( + f"Domain resolver must return a string. Got {type(domain_value).__name__} instead." + ) + + if not domain_value.strip(): + raise DomainResolverError( + "Domain resolver returned an empty string. Must return a valid domain." + ) + + return domain_value From 116484868b468235dbe219f996e76c625a0538d1 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 2 Feb 2026 23:45:32 +0530 Subject: [PATCH 2/6] Bump poetry version from latest to 2.2.1 in test workflow --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d6e025f..9c87ea2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: - version: latest + version: 2.2.1 virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true From 26466821e66c76b9a09f3648b532e848807f0b54 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 2 Feb 2026 23:51:56 +0530 Subject: [PATCH 3/6] Fix linting errors --- .../auth_server/server_client.py | 86 +++---- .../auth_types/__init__.py | 6 +- src/auth0_server_python/error/__init__.py | 2 +- .../tests/test_server_client.py | 222 +++++++++--------- src/auth0_server_python/utils/helpers.py | 29 +-- 5 files changed, 173 insertions(+), 172 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index bee5541..56cd636 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -184,7 +184,7 @@ async def start_interactive_login( ) else: origin_domain = self._domain - + # Fetch OIDC metadata from resolved domain try: metadata = await self._get_oidc_metadata_cached(origin_domain) @@ -243,7 +243,7 @@ async def start_interactive_login( transaction_data, options=store_options ) - + # Set metadata for OAuth client self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint @@ -339,7 +339,7 @@ async def complete_interactive_login( # Get origin domain and issuer from transaction origin_domain = transaction_data.origin_domain origin_issuer = transaction_data.origin_issuer - + # Fetch metadata from the origin domain metadata = await self._get_oidc_metadata_cached(origin_domain) self._oauth.metadata = metadata @@ -365,32 +365,32 @@ async def complete_interactive_login( user_info = token_response.get("userinfo") user_claims = None id_token = token_response.get("id_token") - + if user_info: user_claims = UserClaims.parse_obj(user_info) elif id_token: # Fetch JWKS for signature verification (Requirement 3) jwks = await self._get_jwks_cached(origin_domain, metadata) - + # Decode and verify ID token with signature verification enabled try: # Get the signing key from JWKS unverified_header = jwt.get_unverified_header(id_token) kid = unverified_header.get("kid") - + # Find the key with matching kid signing_key = None for key in jwks.get("keys", []): if key.get("kid") == kid: signing_key = jwt.PyJWK.from_dict(key) break - + if not signing_key: raise ApiError( "jwks_key_not_found", f"No matching key found in JWKS for kid: {kid}" ) - + claims = jwt.decode( id_token, signing_key.key, @@ -430,7 +430,7 @@ async def complete_interactive_login( f"ID token verification failed: {str(e)}", e ) - + # Build a token set using the token response data token_set = TokenSet( @@ -708,11 +708,11 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti original_error=e ) session_domain = getattr(state_data, 'domain', None) - + if session_domain and session_domain != current_domain: # Session created with different domain - reject for security return None - + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() return state_data.get("user") @@ -745,11 +745,11 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O original_error=e ) session_domain = getattr(state_data, 'domain', None) - + if session_domain and session_domain != current_domain: # Session created with different domain - reject for security return None - + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_data = {k: v for k, v in state_data.items() @@ -793,7 +793,7 @@ async def get_access_token( original_error=e ) session_domain = getattr(state_data, 'domain', None) - + if session_domain and session_domain != current_domain: # Session created with different domain - reject for security raise AccessTokenError( @@ -832,7 +832,7 @@ async def get_access_token( # Get new token with refresh token try: - # Use session's domain for token refresh + # Use session's domain for token refresh session_domain = state_data_dict.get("domain") or self._domain get_refresh_token_options = { "refresh_token": state_data_dict["refresh_token"], @@ -863,7 +863,7 @@ async def get_access_token( f"Failed to get token with refresh token: {str(e)}" ) - + async def get_access_token_for_connection( self, @@ -987,25 +987,25 @@ async def handle_backchannel_logout( try: # Fetch JWKS for signature verification (Requirement 3) jwks = await self._get_jwks_cached(self._domain) - + # Decode and verify logout token with signature verification enabled try: # Get the signing key from JWKS unverified_header = jwt.get_unverified_header(logout_token) kid = unverified_header.get("kid") - + # Find the key with matching kid signing_key = None for key in jwks.get("keys", []): if key.get("kid") == kid: signing_key = jwt.PyJWK.from_dict(key) break - + if not signing_key: raise BackchannelLogoutError( f"No matching key found in JWKS for kid: {kid}" ) - + claims = jwt.decode( logout_token, signing_key.key, @@ -1071,47 +1071,47 @@ async def _fetch_oidc_metadata(self, domain: str) -> dict: async def _get_oidc_metadata_cached(self, domain: str) -> dict: """ Get OIDC metadata with caching. - + Args: domain: Auth0 domain - + Returns: OIDC metadata document """ now = time.time() - + # Check cache if domain in self._metadata_cache: cached = self._metadata_cache[domain] if cached["expires_at"] > now: return cached["data"] - + # Cache miss/expired - fetch fresh metadata = await self._fetch_oidc_metadata(domain) - + # Enforce cache size limit (FIFO eviction) if len(self._metadata_cache) >= self._cache_max_size: oldest_key = next(iter(self._metadata_cache)) del self._metadata_cache[oldest_key] - + # Store in cache self._metadata_cache[domain] = { "data": metadata, "expires_at": now + self._cache_ttl } - + return metadata async def _fetch_jwks(self, jwks_uri: str) -> dict: """ Fetch JWKS (JSON Web Key Set) from jwks_uri. - + Args: jwks_uri: The JWKS endpoint URL - + Returns: JWKS document containing public keys - + Raises: ApiError: If JWKS fetch fails """ @@ -1126,54 +1126,54 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict: async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: """ Get JWKS with caching usingOIDC discovery. - + Args: domain: Auth0 domain metadata: Optional OIDC metadata (if already fetched) - + Returns: JWKS document - + Raises: ApiError: If JWKS fetch fails or jwks_uri missing from metadata """ now = time.time() - + # Check cache if domain in self._jwks_cache: cached = self._jwks_cache[domain] if cached["expires_at"] > now: return cached["data"] - + # Get jwks_uri from OIDC metadata if not metadata: metadata = await self._get_oidc_metadata_cached(domain) - + jwks_uri = metadata.get('jwks_uri') if not jwks_uri: raise ApiError( "missing_jwks_uri", f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." ) - + # Fetch JWKS jwks = await self._fetch_jwks(jwks_uri) - + # Enforce cache size limit (FIFO eviction) if len(self._jwks_cache) >= self._cache_max_size: oldest_key = next(iter(self._jwks_cache)) del self._jwks_cache[oldest_key] - + # Store in cache self._jwks_cache[domain] = { "data": jwks, "expires_at": now + self._cache_ttl } - + return jwks # ------------------------------------------ - # Token & Scope Management - MRRT + # Token & Scope Management - MRRT # ------------------------------------------ def _merge_scope_with_defaults( @@ -1606,7 +1606,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain - + # Ensure we have the OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: self._oauth.metadata = await self._get_oidc_metadata_cached(domain) @@ -1692,7 +1692,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain - + # Ensure we have OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: self._oauth.metadata = await self._get_oidc_metadata_cached(domain) diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 7c27d36..6686a3f 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -260,14 +260,14 @@ class CompleteConnectAccountResponse(BaseModel): class DomainResolverContext(BaseModel): """ Context passed to domain resolver function for MCD support. - + Contains request information needed to determine the correct Auth0 domain based on the incoming request's hostname or headers. - + Attributes: request_url: The full request URL (e.g., "https://a.my-app.com/auth/login") request_headers: Dictionary of request headers (e.g., {"host": "a.my-app.com", "x-forwarded-host": "..."}) - + Example: async def domain_resolver(context: DomainResolverContext) -> str: host = context.request_headers.get('host', '').split(':')[0] diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index 93fcba2..6b863e1 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -128,7 +128,7 @@ def __init__(self, message: str): class DomainResolverError(Auth0Error): """ Error raised when domain resolver function fails or returns invalid value. - + This error indicates an issue with the custom domain resolver function provided for MCD (Multiple Custom Domains) support. """ diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 97fcc77..7a9665c 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -56,14 +56,14 @@ async def test_start_interactive_login_no_redirect_uri(mocker): transaction_store=AsyncMock(), secret="some-secret" ) - + # Mock OIDC metadata fetch mocker.patch.object( client, "_get_oidc_metadata_cached", return_value={"issuer": "https://auth0.local/", "authorization_endpoint": "https://auth0.local/authorize"} ) - + with pytest.raises(MissingRequiredArgumentError) as exc: await client.start_interactive_login() # Check the error message @@ -130,7 +130,7 @@ async def test_complete_interactive_login_returns_app_state(mocker): mock_tx_store = AsyncMock() # The stored transaction includes an appState with origin_domain and origin_issuer mock_tx_store.get.return_value = TransactionData( - code_verifier="123", + code_verifier="123", app_state={"foo": "bar"}, origin_domain="auth0.local", origin_issuer="https://auth0.local/" @@ -1990,7 +1990,7 @@ async def test_domain_as_static_string(): client_secret="test_client_secret", secret="test_secret_key_32_chars_long!!" ) - + assert client._domain == "tenant.auth0.com" assert client._domain_resolver is None @@ -2000,14 +2000,14 @@ async def test_domain_as_callable_function(): """Test Method 2: Domain resolver function configuration.""" async def domain_resolver(store_options): return "tenant.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client_id", client_secret="test_client_secret", secret="test_secret_key_32_chars_long!!" ) - + assert client._domain is None assert client._domain_resolver == domain_resolver @@ -2056,12 +2056,12 @@ async def test_empty_domain_string(): async def test_domain_resolver_receives_context(mocker): """Test that domain resolver receives DomainResolverContext with request data.""" received_context = None - + async def domain_resolver(context): nonlocal received_context received_context = context return "tenant.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2070,24 +2070,24 @@ async def domain_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Mock request with headers mock_request = MagicMock() mock_request.url = "https://a.my-app.com/auth/login" mock_request.headers = {"host": "a.my-app.com", "x-forwarded-host": "a.my-app.com"} - + # Mock OIDC metadata fetch mocker.patch.object( client, "_get_oidc_metadata_cached", return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} ) - + try: await client.start_interactive_login(store_options={"request": mock_request}) - except: + except Exception: # noqa: S110 pass # We only care about context being passed - + assert received_context is not None assert isinstance(received_context, DomainResolverContext) assert received_context.request_url == "https://a.my-app.com/auth/login" @@ -2099,7 +2099,7 @@ async def test_domain_resolver_error_on_none(): """Test that domain resolver returning None raises DomainResolverError.""" async def bad_resolver(context): return None - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2108,7 +2108,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="returned None"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2118,7 +2118,7 @@ async def test_domain_resolver_error_on_empty_string(): """Test that domain resolver returning empty string raises DomainResolverError.""" async def bad_resolver(context): return "" - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2127,7 +2127,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="empty string"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2137,7 +2137,7 @@ async def test_domain_resolver_error_on_exception(): """Test that domain resolver exceptions are wrapped in DomainResolverError.""" async def bad_resolver(context): raise ValueError("Something went wrong") - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2146,7 +2146,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="raised an exception"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2155,12 +2155,12 @@ async def bad_resolver(context): async def test_domain_resolver_with_no_request(mocker): """Test that domain resolver works with empty context when no request.""" received_context = None - + async def domain_resolver(context): nonlocal received_context received_context = context return "tenant.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2169,18 +2169,18 @@ async def domain_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mocker.patch.object( client, "_get_oidc_metadata_cached", return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} ) - + try: await client.start_interactive_login(store_options=None) - except: - pass - + except Exception: # noqa: S110 + pass # Intentionally ignore - testing context only + assert received_context is not None assert received_context.request_url is None assert received_context.request_headers is None @@ -2191,7 +2191,7 @@ async def test_domain_resolver_error_on_non_string_type(): """Test that domain resolver returning non-string raises DomainResolverError.""" async def bad_resolver(context): return 12345 - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2200,7 +2200,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="must return a string"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2221,7 +2221,7 @@ async def test_fetch_jwks_success(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mock_jwks = { "keys": [ { @@ -2233,20 +2233,20 @@ async def test_fetch_jwks_success(): } ] } - + # Mock httpx client mock_response = MagicMock() mock_response.json.return_value = mock_jwks mock_response.raise_for_status = MagicMock() - + mock_client = AsyncMock() mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None mock_client.get.return_value = mock_response - + with patch('httpx.AsyncClient', return_value=mock_client): jwks = await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") - + assert jwks == mock_jwks assert "keys" in jwks mock_client.get.assert_awaited_once_with("https://tenant.auth0.com/.well-known/jwks.json") @@ -2263,13 +2263,13 @@ async def test_fetch_jwks_failure(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Mock httpx client to raise exception mock_client = AsyncMock() mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None mock_client.get.side_effect = Exception("Network error") - + with patch('httpx.AsyncClient', return_value=mock_client): with pytest.raises(ApiError, match="Failed to fetch JWKS"): await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") @@ -2286,33 +2286,33 @@ async def test_oidc_metadata_caching(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mock_metadata = { "issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize", "token_endpoint": "https://tenant.auth0.com/oauth/token", "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" } - + # Mock _fetch_oidc_metadata to track calls fetch_count = 0 async def mock_fetch(domain): nonlocal fetch_count fetch_count += 1 return mock_metadata - + client._fetch_oidc_metadata = mock_fetch - + # First call - should fetch result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result1 == mock_metadata assert fetch_count == 1 - + # Second call - should use cache result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result2 == mock_metadata assert fetch_count == 1 # Should NOT increment - + # Verify cache contains data assert "tenant.auth0.com" in client._metadata_cache assert client._metadata_cache["tenant.auth0.com"]["data"] == mock_metadata @@ -2329,30 +2329,30 @@ async def test_oidc_metadata_cache_expiration(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Set short TTL for testing client._cache_ttl = 1 # 1 second - + mock_metadata = { "issuer": "https://tenant.auth0.com/", "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" } - + fetch_count = 0 async def mock_fetch(domain): nonlocal fetch_count fetch_count += 1 return mock_metadata - + client._fetch_oidc_metadata = mock_fetch - + # First call await client._get_oidc_metadata_cached("tenant.auth0.com") assert fetch_count == 1 - + # Wait for cache to expire time.sleep(1.1) - + # Second call after expiration - should fetch again await client._get_oidc_metadata_cached("tenant.auth0.com") assert fetch_count == 2 @@ -2369,32 +2369,32 @@ async def test_jwks_caching(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mock_metadata = { "issuer": "https://tenant.auth0.com/", "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" } - + mock_jwks = { "keys": [{"kty": "RSA", "kid": "key1"}] } - + # Mock the fetch methods client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata) - + fetch_count = 0 async def mock_fetch_jwks(uri): nonlocal fetch_count fetch_count += 1 return mock_jwks - + client._fetch_jwks = mock_fetch_jwks - + # First call - should fetch result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result1 == mock_jwks assert fetch_count == 1 - + # Second call - should use cache result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result2 == mock_jwks @@ -2412,33 +2412,33 @@ async def test_jwks_cache_size_limit(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Set small cache size for testing client._cache_max_size = 3 - + mock_jwks = {"keys": [{"kty": "RSA"}]} - + # Mock methods async def mock_fetch_metadata(domain): return {"jwks_uri": f"https://{domain}/.well-known/jwks.json"} - + async def mock_fetch_jwks(uri): return mock_jwks - + client._fetch_oidc_metadata = mock_fetch_metadata client._fetch_jwks = mock_fetch_jwks - + # Fill cache to limit await client._get_jwks_cached("domain1.auth0.com") await client._get_jwks_cached("domain2.auth0.com") await client._get_jwks_cached("domain3.auth0.com") - + assert len(client._jwks_cache) == 3 assert "domain1.auth0.com" in client._jwks_cache - + # Add one more - should evict oldest (domain1) await client._get_jwks_cached("domain4.auth0.com") - + assert len(client._jwks_cache) == 3 assert "domain1.auth0.com" not in client._jwks_cache # Evicted assert "domain4.auth0.com" in client._jwks_cache @@ -2455,20 +2455,20 @@ async def test_jwks_missing_uri_raises_error(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Metadata WITHOUT jwks_uri mock_metadata_no_jwks_uri = { "issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize" # No jwks_uri } - + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata_no_jwks_uri) - + # Should raise ApiError when jwks_uri is missing with pytest.raises(ApiError) as exc_info: await client._get_jwks_cached("tenant.auth0.com") - + assert exc_info.value.code == "missing_jwks_uri" assert "non-RFC-compliant" in str(exc_info.value) @@ -2484,23 +2484,23 @@ async def test_metadata_cache_size_limit(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + client._cache_max_size = 2 - + async def mock_fetch(domain): return {"issuer": f"https://{domain}/"} - + client._fetch_oidc_metadata = mock_fetch - + # Fill cache await client._get_oidc_metadata_cached("domain1.auth0.com") await client._get_oidc_metadata_cached("domain2.auth0.com") - + assert len(client._metadata_cache) == 2 - + # Add third - should evict first await client._get_oidc_metadata_cached("domain3.auth0.com") - + assert len(client._metadata_cache) == 2 assert "domain1.auth0.com" not in client._metadata_cache assert "domain3.auth0.com" in client._metadata_cache @@ -2557,12 +2557,12 @@ async def test_complete_login_issuer_validation_success(mocker): # Mock jwt.get_unverified_header mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) - + # Mock PyJWK.from_dict mock_signing_key = mocker.MagicMock() mock_signing_key.key = "mock_pem_key" mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) - + # Mock jwt.decode with valid issuer mocker.patch("jwt.decode", return_value={ "sub": "user123", @@ -2572,7 +2572,7 @@ async def test_complete_login_issuer_validation_success(mocker): # Should succeed without raising error result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + assert result is not None assert "state_data" in result @@ -2623,19 +2623,19 @@ async def test_complete_login_issuer_mismatch_raises_error(mocker): # Mock jwt.get_unverified_header mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) - + # Mock PyJWK.from_dict mock_signing_key = mocker.MagicMock() mock_signing_key.key = "mock_pem_key" mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) - + # Mock jwt.decode to raise InvalidIssuerError mocker.patch("jwt.decode", side_effect=jwt.InvalidIssuerError("Invalid issuer")) # Should raise ApiError with invalid_issuer code with pytest.raises(ApiError) as exc_info: await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + assert exc_info.value.code == "invalid_issuer" assert "issuer mismatch" in str(exc_info.value).lower() @@ -2654,13 +2654,13 @@ async def test_normalize_domain_handles_different_schemes(): # Test domain without scheme assert client._normalize_domain("auth0.com") == "https://auth0.com" - + # Test domain with https scheme (should remain unchanged) assert client._normalize_domain("https://auth0.com") == "https://auth0.com" - + # Test domain with http scheme (should convert to https) assert client._normalize_domain("http://auth0.com") == "https://auth0.com" - + # Test domain with trailing slash assert client._normalize_domain("https://auth0.com/") == "https://auth0.com/" @@ -2684,7 +2684,7 @@ async def test_session_stores_origin_domain(mocker): async def capture_state(identifier, state_data, options=None): nonlocal captured_state captured_state = state_data - + mock_state_store = AsyncMock() mock_state_store.set = AsyncMock(side_effect=capture_state) @@ -2702,14 +2702,14 @@ async def capture_state(identifier, state_data, options=None): "token_endpoint": "https://tenant1.auth0.com/token" }) mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) - + async_fetch_token = AsyncMock(return_value={ "access_token": "token123", "id_token": "id_token_jwt", "scope": "openid" }) mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) - + # Mock JWT verification mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) mock_signing_key = mocker.MagicMock() @@ -2718,7 +2718,7 @@ async def capture_state(identifier, state_data, options=None): mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": "https://tenant1.auth0.com/"}) await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + # Verify session has domain field set assert captured_state.domain == "tenant1.auth0.com" @@ -2733,14 +2733,14 @@ async def test_cross_domain_session_rejected(): token_sets=[], internal={"sid": "123", "created_at": int(time.time())} ) - + mock_state_store = AsyncMock() mock_state_store.get.return_value = session_data - + # Domain resolver returns domain2 (different from session) async def domain_resolver(context): return "tenant2.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2759,12 +2759,12 @@ async def domain_resolver(context): async def test_logout_uses_current_domain(mocker): """Test that logout uses current resolved domain (Requirement 7).""" current_domain = "tenant2.auth0.com" - + async def domain_resolver(context): return current_domain - + mock_state_store = AsyncMock() - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2775,7 +2775,7 @@ async def domain_resolver(context): ) logout_url = await client.logout(store_options={"request": {}}) - + # Verify logout URL uses current domain assert current_domain in logout_url assert logout_url.startswith(f"https://{current_domain}") @@ -2785,7 +2785,7 @@ async def domain_resolver(context): async def test_logout_clears_session_for_current_domain(): """Test that logout clears session (Requirement 7).""" mock_state_store = AsyncMock() - + client = ServerClient( domain="tenant.auth0.com", client_id="test_client", @@ -2796,7 +2796,7 @@ async def test_logout_clears_session_for_current_domain(): ) await client.logout() - + # Verify session was deleted mock_state_store.delete.assert_called_once() @@ -2805,7 +2805,7 @@ async def test_logout_clears_session_for_current_domain(): async def test_domain_migration_old_sessions_remain_valid(): """Test that old sessions remain valid with old domain requests (Requirement 8).""" old_domain = "old-tenant.auth0.com" - + # Session from old domain session_data = StateData( user={"sub": "user123"}, @@ -2813,14 +2813,14 @@ async def test_domain_migration_old_sessions_remain_valid(): token_sets=[], internal={"sid": "123", "created_at": int(time.time())} ) - + mock_state_store = AsyncMock() mock_state_store.get.return_value = session_data - + # Domain resolver returns old domain async def domain_resolver(context): return old_domain - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2840,7 +2840,7 @@ async def domain_resolver(context): async def test_domain_migration_new_sessions_use_new_domain(mocker): """Test that new logins create sessions with new domain (Requirement 8).""" new_domain = "new-tenant.auth0.com" - + mock_tx_store = AsyncMock() mock_tx_store.get.return_value = TransactionData( code_verifier="123", @@ -2852,7 +2852,7 @@ async def test_domain_migration_new_sessions_use_new_domain(mocker): async def capture_state(identifier, state_data, options=None): nonlocal captured_state captured_state = state_data - + mock_state_store = AsyncMock() mock_state_store.set = AsyncMock(side_effect=capture_state) @@ -2870,14 +2870,14 @@ async def capture_state(identifier, state_data, options=None): "token_endpoint": f"https://{new_domain}/token" }) mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) - + async_fetch_token = AsyncMock(return_value={ "access_token": "token123", "id_token": "id_token_jwt", "scope": "openid" }) mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) - + # Mock JWT verification mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) mock_signing_key = mocker.MagicMock() @@ -2886,7 +2886,7 @@ async def capture_state(identifier, state_data, options=None): mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": f"https://{new_domain}/"}) await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + # Verify new session has new domain assert captured_state.domain == new_domain @@ -2896,7 +2896,7 @@ async def test_domain_migration_sessions_isolated(): """Test that old domain sessions cannot be used with new domain (Requirement 8).""" old_domain = "old-tenant.auth0.com" new_domain = "new-tenant.auth0.com" - + # Session from old domain session_data = StateData( user={"sub": "user123"}, @@ -2904,14 +2904,14 @@ async def test_domain_migration_sessions_isolated(): token_sets=[], internal={"sid": "123", "created_at": int(time.time())} ) - + mock_state_store = AsyncMock() mock_state_store.get.return_value = session_data - + # Domain resolver returns NEW domain (migration happened) async def domain_resolver(context): return new_domain - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2923,4 +2923,4 @@ async def domain_resolver(context): # Should reject old session user = await client.get_user(store_options={"request": {}}) - assert user is None \ No newline at end of file + assert user is None diff --git a/src/auth0_server_python/utils/helpers.py b/src/auth0_server_python/utils/helpers.py index 1a49835..05cb0f8 100644 --- a/src/auth0_server_python/utils/helpers.py +++ b/src/auth0_server_python/utils/helpers.py @@ -5,6 +5,7 @@ import time from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse + from auth0_server_python.auth_types import DomainResolverContext from auth0_server_python.error import DomainResolverError @@ -235,27 +236,27 @@ def create_logout_url(domain: str, client_id: str, return_to: Optional[str] = No def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'DomainResolverContext': """ Build DomainResolverContext from store_options. - + Extracts request information in a framework-agnostic way using duck typing. - + Args: store_options: Dictionary containing 'request' and 'response' objects - + Returns: DomainResolverContext with extracted request data """ - + if not store_options: return DomainResolverContext() - + request = store_options.get('request') if not request: return DomainResolverContext() - + # Framework-agnostic extraction using duck typing request_url = str(request.url) if hasattr(request, 'url') else None request_headers = dict(request.headers) if hasattr(request, 'headers') else None - + return DomainResolverContext( request_url=request_url, request_headers=request_headers @@ -265,30 +266,30 @@ def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'D def validate_resolved_domain_value(domain_value: Any) -> str: """ Validate the value returned by domain resolver. - + Args: domain_value: The value returned by the domain resolver - + Returns: The validated domain string - + Raises: DomainResolverError: If the returned value is invalid """ - + if domain_value is None: raise DomainResolverError( "Domain resolver returned None. Must return a valid domain string." ) - + if not isinstance(domain_value, str): raise DomainResolverError( f"Domain resolver must return a string. Got {type(domain_value).__name__} instead." ) - + if not domain_value.strip(): raise DomainResolverError( "Domain resolver returned an empty string. Must return a valid domain." ) - + return domain_value From 307000b8ae54220190f9b8df3f24ab15e961f1c7 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Tue, 3 Feb 2026 00:03:18 +0530 Subject: [PATCH 4/6] test: improve cache verification in OIDC metadata and JWKS tests --- src/auth0_server_python/tests/test_server_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 7a9665c..62532fb 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -2307,11 +2307,12 @@ async def mock_fetch(domain): result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result1 == mock_metadata assert fetch_count == 1 + first_fetch_count = fetch_count # Second call - should use cache result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result2 == mock_metadata - assert fetch_count == 1 # Should NOT increment + assert fetch_count == first_fetch_count # Should NOT increment # Verify cache contains data assert "tenant.auth0.com" in client._metadata_cache @@ -2394,11 +2395,12 @@ async def mock_fetch_jwks(uri): result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result1 == mock_jwks assert fetch_count == 1 + first_fetch_count = fetch_count # Second call - should use cache result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result2 == mock_jwks - assert fetch_count == 1 # Should NOT increment + assert fetch_count == first_fetch_count # Should NOT increment @pytest.mark.asyncio From 3f68c3c6cfdb1176c295ce6fe89f4e841f885494 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Tue, 3 Feb 2026 23:24:48 +0530 Subject: [PATCH 5/6] refactor: rename cache size variable and reorganize test comments --- examples/{MCD.md => MultipleCustomDomains.md} | 0 .../auth_server/server_client.py | 6 +++--- .../tests/test_server_client.py | 17 ++++++++--------- 3 files changed, 11 insertions(+), 12 deletions(-) rename examples/{MCD.md => MultipleCustomDomains.md} (100%) diff --git a/examples/MCD.md b/examples/MultipleCustomDomains.md similarity index 100% rename from examples/MCD.md rename to examples/MultipleCustomDomains.md diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 56cd636..f3940ab 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -145,7 +145,7 @@ def __init__( self._metadata_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} self._jwks_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} self._cache_ttl = 3600 # 1 hour TTL - self._cache_max_size = 100 # Max 100 domains to prevent memory bloat + self._cache_max_entries = 100 # Max 100 domains to prevent memory bloat # ========================================== # Interactive Login Flow @@ -1090,7 +1090,7 @@ async def _get_oidc_metadata_cached(self, domain: str) -> dict: metadata = await self._fetch_oidc_metadata(domain) # Enforce cache size limit (FIFO eviction) - if len(self._metadata_cache) >= self._cache_max_size: + if len(self._metadata_cache) >= self._cache_max_entries: oldest_key = next(iter(self._metadata_cache)) del self._metadata_cache[oldest_key] @@ -1160,7 +1160,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: jwks = await self._fetch_jwks(jwks_uri) # Enforce cache size limit (FIFO eviction) - if len(self._jwks_cache) >= self._cache_max_size: + if len(self._jwks_cache) >= self._cache_max_entries: oldest_key = next(iter(self._jwks_cache)) del self._jwks_cache[oldest_key] diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 62532fb..6f14b52 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1978,7 +1978,7 @@ async def test_complete_connect_account_no_transactions(mocker): # ============================================================================= -# Requirement 1: Multiple Issuer Configuration Methods Tests +# MCD Tests : Multiple Issuer Configuration Methods Tests # ============================================================================= @pytest.mark.asyncio @@ -2049,7 +2049,7 @@ async def test_empty_domain_string(): # ============================================================================= -# Requirement 2: Domain Resolver Context Tests +# MCD Tests : Domain Resolver Context Tests # ============================================================================= @pytest.mark.asyncio @@ -2179,8 +2179,7 @@ async def domain_resolver(context): try: await client.start_interactive_login(store_options=None) except Exception: # noqa: S110 - pass # Intentionally ignore - testing context only - + pass # We only care about context being passed assert received_context is not None assert received_context.request_url is None assert received_context.request_headers is None @@ -2206,7 +2205,7 @@ async def bad_resolver(context): # ============================================================================= -# Requirement 3: OIDC Metadata and JWKS Fetching Tests +# OIDC Metadata and JWKS Fetching Tests # ============================================================================= @@ -2416,7 +2415,7 @@ async def test_jwks_cache_size_limit(): ) # Set small cache size for testing - client._cache_max_size = 3 + client._cache_max_entries = 3 mock_jwks = {"keys": [{"kty": "RSA"}]} @@ -2487,7 +2486,7 @@ async def test_metadata_cache_size_limit(): state_store=AsyncMock() ) - client._cache_max_size = 2 + client._cache_max_entries = 2 async def mock_fetch(domain): return {"issuer": f"https://{domain}/"} @@ -2509,7 +2508,7 @@ async def mock_fetch(domain): # ============================================================================= -# Requirement 4: Issuer Validation Tests +# Issuer Validation Tests # ============================================================================= @@ -2668,7 +2667,7 @@ async def test_normalize_domain_handles_different_schemes(): # ============================================================================= -# Requirements 5-8: Domain-specific Session Management Tests +# MCD Tests : Domain-specific Session Management Tests # ============================================================================= From 2e0ae17b711fb7dec48f773a51ad69639148cf95 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Tue, 3 Feb 2026 23:27:58 +0530 Subject: [PATCH 6/6] chore: add cryptography package to Snyk license ignore list --- .snyk | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.snyk b/.snyk index 4eaa56f..7d0fc1c 100644 --- a/.snyk +++ b/.snyk @@ -21,4 +21,8 @@ ignore: - '*': reason: "Accepting the Unknown license for now" expires: "2030-12-31T23:59:59Z" + "snyk:lic:pip:cryptography:Unknown": + - '*': + reason: "Accepting the Unknown license for now" + expires: "2030-12-31T23:59:59Z" patch: {}