diff --git a/packages/google-auth/google/auth/_credentials_async.py b/packages/google-auth/google/auth/_credentials_async.py index 760758d851b0..937f6e8fb6df 100644 --- a/packages/google-auth/google/auth/_credentials_async.py +++ b/packages/google-auth/google/auth/_credentials_async.py @@ -18,6 +18,7 @@ import abc import inspect +from google.auth import _regional_access_boundary_utils from google.auth import credentials @@ -64,8 +65,28 @@ async def before_request(self, request, method, url, headers): await self.refresh(request) else: self.refresh(request) + + if inspect.iscoroutinefunction(self._after_refresh): + await self._after_refresh(request, method, url, headers) + else: + self._after_refresh(request, method, url, headers) + self.apply(headers) + def _after_refresh(self, request, method, url, headers): + """Hook for subclasses to perform actions after refresh but before + applying credentials to headers. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping[str, str]): The request's headers. + """ + pass + class CredentialsWithQuotaProject(credentials.CredentialsWithQuotaProject): """Abstract base for credentials supporting ``with_quota_project`` factory""" @@ -169,3 +190,74 @@ def with_scopes_if_required(credentials, scopes): class Signing(credentials.Signing, metaclass=abc.ABCMeta): """Interface for credentials that can cryptographically sign messages.""" + + +class CredentialsWithRegionalAccessBoundary( + Credentials, credentials.CredentialsWithRegionalAccessBoundary +): + """Async base for credentials supporting regional access boundary configuration.""" + + def __init__(self): + super().__init__() + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + def __setstate__(self, state): + super().__setstate__(state) + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + async def _after_refresh(self, request, method, url, headers): + """Triggers the Regional Access Boundary lookup asynchronously if necessary.""" + await self._maybe_start_regional_access_boundary_refresh_async(request, url) + + async def _maybe_start_regional_access_boundary_refresh_async(self, request, url): + """Starts a background refresh or performs a blocking refresh asynchronously. + + Args: + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + url (str): The URL of the request. + """ + # Do not perform a lookup if the request is for a regional endpoint. + if self._is_regional_endpoint(url): + return + + # A refresh is only needed if the feature is enabled. + if not self._is_regional_access_boundary_lookup_required(): + return + + # Trigger background or blocking refresh if needed. + await self._rab_manager.maybe_start_refresh_async(self, request) + + async def _lookup_regional_access_boundary(self, request, fail_fast=False): + """Calls the Regional Access Boundary lookup API asynchronously. + + Args: + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + fail_fast (bool): Whether the lookup should fail fast (short timeout, no retries). + + Returns: + Optional[Dict[str, str]]: The Regional Access Boundary information + returned by the lookup API, or None if the lookup failed. + """ + url_builder = self._build_regional_access_boundary_lookup_url + if inspect.iscoroutinefunction(url_builder): + url = await url_builder(request=request) + else: + url = url_builder(request=request) + + if not url: + return None + + headers = {} + self._apply(headers) + + from google.oauth2 import _client_async + + return await _client_async._lookup_regional_access_boundary( + request, url, headers=headers, fail_fast=fail_fast + ) diff --git a/packages/google-auth/google/auth/_helpers.py b/packages/google-auth/google/auth/_helpers.py index 08146221503e..86c48c1e525c 100644 --- a/packages/google-auth/google/auth/_helpers.py +++ b/packages/google-auth/google/auth/_helpers.py @@ -28,6 +28,8 @@ from google.auth import exceptions +DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" + # _BASE_LOGGER_NAME is the base logger for all google-based loggers. _BASE_LOGGER_NAME = "google" diff --git a/packages/google-auth/google/auth/_jwt_async.py b/packages/google-auth/google/auth/_jwt_async.py index 3a1abc5b85c9..ce3bfe4eba57 100644 --- a/packages/google-auth/google/auth/_jwt_async.py +++ b/packages/google-auth/google/auth/_jwt_async.py @@ -44,6 +44,8 @@ """ from google.auth import _credentials_async +from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.auth import jwt @@ -91,7 +93,9 @@ def decode(token, certs=None, verify=True, audience=None): class Credentials( - jwt.Credentials, _credentials_async.Signing, _credentials_async.Credentials + jwt.Credentials, + _credentials_async.Signing, + _credentials_async.CredentialsWithRegionalAccessBoundary, ): """Credentials that use a JWT as the bearer token. @@ -142,6 +146,14 @@ class Credentials( new_credentials = credentials.with_claims(audience=new_audience) """ + def __setstate__(self, state): + """Restores the credential state and ensures the async refresh manager is attached.""" + super().__setstate__(state) + + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + class OnDemandCredentials( jwt.OnDemandCredentials, _credentials_async.Signing, _credentials_async.Credentials @@ -162,3 +174,7 @@ class OnDemandCredentials( .. _grpc: http://www.grpc.io/ """ + + @_helpers.copy_docstring(jwt.OnDemandCredentials) + async def before_request(self, request, method, url, headers): + super(OnDemandCredentials, self).before_request(request, method, url, headers) diff --git a/packages/google-auth/google/auth/_regional_access_boundary_utils.py b/packages/google-auth/google/auth/_regional_access_boundary_utils.py index 81011911df3d..1d6f24d6d9be 100644 --- a/packages/google-auth/google/auth/_regional_access_boundary_utils.py +++ b/packages/google-auth/google/auth/_regional_access_boundary_utils.py @@ -14,9 +14,11 @@ """Utilities for Regional Access Boundary management.""" +import asyncio import copy import datetime import functools +import inspect import logging import os import threading @@ -170,12 +172,11 @@ def apply_headers(self, headers): else: headers.pop(_REGIONAL_ACCESS_BOUNDARY_HEADER, None) - def maybe_start_refresh(self, credentials, request): - """Starts a background thread to refresh the Regional Access Boundary if needed. + def _should_refresh(self): + """Checks if the Regional Access Boundary data needs a refresh and is not in cooldown. - Args: - credentials (google.auth.credentials.Credentials): The credentials to refresh. - request (google.auth.transport.Request): The object used to make HTTP requests. + Returns: + bool: True if a refresh is required, False otherwise. """ rab_data = self._data @@ -186,10 +187,22 @@ def maybe_start_refresh(self, credentials, request): and _helpers.utcnow() < (rab_data.expiry - REGIONAL_ACCESS_BOUNDARY_REFRESH_THRESHOLD) ): - return + return False # Don't start a new refresh if the cooldown is still in effect. if rab_data.cooldown_expiry and _helpers.utcnow() < rab_data.cooldown_expiry: + return False + + return True + + def maybe_start_refresh(self, credentials, request): + """Starts a background thread to refresh the Regional Access Boundary if needed. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to refresh. + request (google.auth.transport.Request): The object used to make HTTP requests. + """ + if not self._should_refresh(): return # If all checks pass, start the background refresh. @@ -198,6 +211,22 @@ def maybe_start_refresh(self, credentials, request): else: self.refresh_manager.start_refresh(credentials, request, self) + async def maybe_start_refresh_async(self, credentials, request): + """Starts a background refresh or performs a blocking refresh asynchronously. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to refresh. + request (google.auth.aio.transport.Request): The object used to make HTTP requests. + """ + if not self._should_refresh(): + return + + # If all checks pass, start the refresh. + if self._use_blocking_regional_access_boundary_lookup: + await self.start_blocking_refresh_async(credentials, request) + else: + self.refresh_manager.start_refresh(credentials, request, self) + def start_blocking_refresh(self, credentials, request): """Initiates a blocking lookup of the Regional Access Boundary. @@ -209,6 +238,15 @@ def start_blocking_refresh(self, credentials, request): credentials (google.auth.credentials.Credentials): The credentials to refresh. request (google.auth.transport.Request): The object used to make HTTP requests. """ + # Async credentials do not support blocking lookups. + if inspect.iscoroutinefunction(credentials._lookup_regional_access_boundary): + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Blocking Regional Access Boundary lookup is not supported for async credentials." + ) + self.process_regional_access_boundary_info(None) + return + try: # The fail_fast parameter is set to True to ensure we don't block the calling # thread for too long. This will do two things: 1) set a timeout to 3s @@ -227,6 +265,37 @@ def start_blocking_refresh(self, credentials, request): self.process_regional_access_boundary_info(regional_access_boundary_info) + async def start_blocking_refresh_async(self, credentials, request): + """Initiates a blocking lookup of the Regional Access Boundary asynchronously. + + If the lookup raises an exception, it is caught and logged as a warning, + and the lookup is treated as a failure (entering cooldown). Exceptions + are not propagated to the caller. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to refresh. + request (google.auth.aio.transport.Request): The object used to make HTTP requests. + """ + try: + # The fail_fast parameter is set to True to ensure we don't block the calling + # thread for too long. This will do two things: 1) set a timeout to 3s + # instead of the default 120s and 2) ensure we do not retry at all + regional_access_boundary_info = ( + await credentials._lookup_regional_access_boundary( + request, fail_fast=True + ) + ) + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Regional Access Boundary lookup raised an exception: %s", + e, + exc_info=True, + ) + regional_access_boundary_info = None + + self.process_regional_access_boundary_info(regional_access_boundary_info) + def process_regional_access_boundary_info(self, regional_access_boundary_info): """Processes the regional access boundary info and updates the state. @@ -384,3 +453,120 @@ def start_refresh(self, credentials, request, rab_manager): credentials, copied_request, rab_manager ) self._worker.start() + + +class _AsyncRegionalAccessBoundaryRefreshManager(object): + """Manages a task for background refreshing of the Regional Access Boundary in async flows.""" + + def __init__(self): + self._lock = threading.Lock() + self._worker_task = None + + def __getstate__(self): + """Pickle helper that excludes the un-picklable _lock and _worker_task attributes from serialization.""" + state = self.__dict__.copy() + state["_lock"] = None + state["_worker_task"] = None + return state + + def __setstate__(self, state): + """Pickle helper that restores state and re-initializes the _lock and _worker_task attributes.""" + self.__dict__.update(state) + self._lock = threading.Lock() + self._worker_task = None + + def start_refresh(self, credentials, request, rab_manager): + """ + Starts a background task to refresh the Regional Access Boundary if one is not already running. + + Args: + credentials (CredentialsWithRegionalAccessBoundary): The credentials + to refresh. + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + rab_manager (_RegionalAccessBoundaryManager): The manager container to update. + """ + with self._lock: + if self._worker_task and not self._worker_task.done(): + # A refresh is already in progress. + return + + async def _worker(): + try: + # credentials._lookup_regional_access_boundary should be async in the async creds class + regional_access_boundary_info = ( + await credentials._lookup_regional_access_boundary(request) + ) + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Asynchronous Regional Access Boundary lookup raised an exception: %s", + e, + exc_info=True, + ) + regional_access_boundary_info = None + + rab_manager.process_regional_access_boundary_info( + regional_access_boundary_info + ) + + coro = _worker() + try: + self._worker_task = asyncio.create_task(coro) + except Exception: + coro.close() + raise + + +def _get_domain() -> str: + """Dynamically determines the domain for IAM credentials based on active mTLS configuration. + + Returns: + str: The dynamic domain string. + """ + from google.auth.transport import _mtls_helper + + if ( + hasattr(_mtls_helper, "check_use_client_cert") + and _mtls_helper.check_use_client_cert() + ): + return f"iamcredentials.mtls.{_helpers.DEFAULT_UNIVERSE_DOMAIN}" + else: + return f"iamcredentials.{_helpers.DEFAULT_UNIVERSE_DOMAIN}" + + +def get_service_account_rab_endpoint(service_account_email: str) -> str: + """Builds the Regional Access Boundary lookup URL for service accounts. + + Args: + service_account_email: The service account email. + + Returns: + str: The complete lookup URL. + """ + return f"https://{_get_domain()}/v1/projects/-/serviceAccounts/{service_account_email}/allowedLocations" + + +def get_workforce_pool_rab_endpoint(pool_id: str) -> str: + """Builds the Regional Access Boundary lookup URL for workforce pools. + + Args: + pool_id: The workforce pool ID. + + Returns: + str: The complete lookup URL. + """ + return f"https://{_get_domain()}/v1/locations/global/workforcePools/{pool_id}/allowedLocations" + + +def get_workload_identity_pool_rab_endpoint(project_number: str, pool_id: str) -> str: + """Builds the Regional Access Boundary lookup URL for workload identity pools. + + Args: + project_number: The Google Cloud project number. + pool_id: The workload identity pool ID. + + Returns: + str: The complete lookup URL. + """ + return f"https://{_get_domain()}/v1/projects/{project_number}/locations/global/workloadIdentityPools/{pool_id}/allowedLocations" diff --git a/packages/google-auth/google/auth/compute_engine/_metadata.py b/packages/google-auth/google/auth/compute_engine/_metadata.py index aae724ab18ee..f8e1769334d2 100644 --- a/packages/google-auth/google/auth/compute_engine/_metadata.py +++ b/packages/google-auth/google/auth/compute_engine/_metadata.py @@ -22,6 +22,7 @@ import json import logging import os +import re from urllib.parse import urljoin import requests @@ -37,6 +38,8 @@ _LOGGER = logging.getLogger(__name__) +_SERVICE_ACCOUNT_EMAIL_PATTERN = re.compile(r"^[^@]+@[^@]+\.[^@]+$") + _GCE_DEFAULT_MDS_IP = "169.254.169.254" _GCE_DEFAULT_HOST = "metadata.google.internal" _GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP] @@ -502,3 +505,20 @@ def get_service_account_token(request, service_account="default", scopes=None): seconds=token_json["expires_in"] ) return token_json["access_token"], token_expiry + + +def _is_service_account_email(email): + """Checks if the provided string is a service account email. + + This is a check that ensures the candidate string is non-empty + and matches a standard email format. + + Args: + email (str): The candidate string to check. + + Returns: + bool: True if the string is non-empty and matches email format, False otherwise. + """ + if not email: + return False + return bool(_SERVICE_ACCOUNT_EMAIL_PATTERN.match(email)) diff --git a/packages/google-auth/google/auth/compute_engine/credentials.py b/packages/google-auth/google/auth/compute_engine/credentials.py index b91e06cf5407..ffe62e0ba9af 100644 --- a/packages/google-auth/google/auth/compute_engine/credentials.py +++ b/packages/google-auth/google/auth/compute_engine/credentials.py @@ -25,6 +25,7 @@ from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.auth import credentials from google.auth import exceptions from google.auth import iam @@ -99,6 +100,7 @@ def __init__( self._universe_domain_cached = True self._trust_boundary = trust_boundary + self._rab_disabled = False def _retrieve_info(self, request): """Retrieve information about the service account. @@ -151,6 +153,26 @@ def _perform_refresh_token(self, request): new_exc = exceptions.RefreshError(caught_exc) raise new_exc from caught_exc + def _is_regional_access_boundary_lookup_required(self): + """Checks if a Regional Access Boundary lookup is required. + + Returns: + bool: True if a Regional Access Boundary lookup is required, False otherwise. + """ + if not super()._is_regional_access_boundary_lookup_required(): + return False + + if getattr(self, "_rab_disabled", False): + return False + + # If the field is 'default', the actual value hasn't been fetched from the metadata + # server yet. Allow it to proceed so the actual value can be retrieved and checked + # during the URL construction. + if self.service_account_email == "default": + return True + + return _metadata._is_service_account_email(self.service_account_email) + def _build_regional_access_boundary_lookup_url( self, request: "Optional[google.auth.transport.Request]" = None # noqa: F821 ): @@ -196,8 +218,16 @@ def _build_regional_access_boundary_lookup_url( ) return None - return iam._SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( - service_account_email=self.service_account_email + if not _metadata._is_service_account_email(self.service_account_email): + _LOGGER.info( + "Service account email '%s' is not a valid email. Skipping Regional Access Boundary lookup.", + self.service_account_email, + ) + self._rab_disabled = True + return None + + return _regional_access_boundary_utils.get_service_account_rab_endpoint( + self.service_account_email ) @property diff --git a/packages/google-auth/google/auth/credentials.py b/packages/google-auth/google/auth/credentials.py index 4a686cb01907..ac619309b6d1 100644 --- a/packages/google-auth/google/auth/credentials.py +++ b/packages/google-auth/google/auth/credentials.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: # pragma: NO COVER import google.auth.transport -DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" +DEFAULT_UNIVERSE_DOMAIN = _helpers.DEFAULT_UNIVERSE_DOMAIN # These constants are deprecated and no longer used. # They are kept solely for backward compatibility with older implementations. @@ -239,9 +239,25 @@ def before_request(self, request, method, url, headers): else: self._blocking_refresh(request) + self._after_refresh(request, method, url, headers) + metrics.add_metric_header(headers, self._metric_header_for_usage()) self.apply(headers) + def _after_refresh(self, request, method, url, headers): + """Hook for subclasses to perform actions after refresh but before + applying credentials to headers. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + pass + def with_non_blocking_refresh(self): self._use_non_blocking_refresh = True @@ -309,6 +325,22 @@ def __init__(self): _regional_access_boundary_utils._RegionalAccessBoundaryManager() ) + def __setstate__(self, state): + """Pickle helper that restores state, safely reconstructing RAB fields if missing.""" + self.__dict__.update(state) + if "_rab_manager" not in self.__dict__: + from google.auth import _regional_access_boundary_utils + + self._rab_manager = ( + _regional_access_boundary_utils._RegionalAccessBoundaryManager() + ) + if "_use_non_blocking_refresh" not in self.__dict__: + self._use_non_blocking_refresh = False + if "_refresh_worker" not in self.__dict__: + from google.auth._refresh_worker import RefreshThreadManager + + self._refresh_worker = RefreshThreadManager() + @property def regional_access_boundary(self): """Optional[str]: The encoded Regional Access Boundary locations.""" @@ -364,12 +396,11 @@ def with_trust_boundary(self, trust_boundary): ) def _copy_regional_access_boundary_manager(self, target): - """Copies the regional access boundary manager to another instance.""" - # Create a new manager for the clone to isolate background refresh locks and threads, - # but share the immutable data reference to avoid unnecessary initial lookups. - new_manager = _regional_access_boundary_utils._RegionalAccessBoundaryManager() - new_manager._data = self._rab_manager._data - target._rab_manager = new_manager + """Copies the regional access boundary manager state to another instance.""" + target._rab_manager._data = self._rab_manager._data + target._rab_manager._use_blocking_regional_access_boundary_lookup = ( + self._rab_manager._use_blocking_regional_access_boundary_lookup + ) def _set_regional_access_boundary(self, seed): """Applies the regional_access_boundary provided via the seed on these @@ -403,18 +434,14 @@ def _set_blocking_regional_access_boundary_lookup(self): self._rab_manager.enable_blocking_lookup() return self - def _maybe_start_regional_access_boundary_refresh(self, request, url): - """ - Starts a background thread to refresh the Regional Access Boundary if needed. - - This method checks if a refresh is necessary and if one is not already - in progress or in a cooldown period. If so, it starts a background - thread to perform the lookup. + def _is_regional_endpoint(self, url): + """Checks if the request URL is for a regional endpoint. Args: - request (google.auth.transport.Request): The object used to make - HTTP requests. url (str): The URL of the request. + + Returns: + bool: True if the URL is a regional endpoint, False otherwise. """ try: # Do not perform a lookup if the request is for a regional endpoint. @@ -423,16 +450,35 @@ def _maybe_start_regional_access_boundary_refresh(self, request, url): hostname.endswith(".rep.googleapis.com") or hostname.endswith(".rep.sandbox.googleapis.com") ): - return - except (ValueError, TypeError): + return True + except (ValueError, TypeError, AttributeError): # If the URL is malformed, proceed with the default lookup behavior. pass + return False + + def _maybe_start_regional_access_boundary_refresh(self, request, url): + """ + Starts a background thread to refresh the Regional Access Boundary if needed. + + This method checks if a refresh is necessary and if one is not already + in progress or in a cooldown period. If so, it starts a background + thread to perform the lookup. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + url (str): The URL of the request. + """ + # Do not perform a lookup if the request is for a regional endpoint. + if self._is_regional_endpoint(url): + return + # A refresh is only needed if the feature is enabled. if not self._is_regional_access_boundary_lookup_required(): return - # Start the background refresh if needed. + # Trigger background or blocking refresh if needed self._rab_manager.maybe_start_refresh(self, request) def _is_regional_access_boundary_lookup_required(self): @@ -444,11 +490,11 @@ def _is_regional_access_boundary_lookup_required(self): Returns: bool: True if a Regional Access Boundary lookup is required, False otherwise. """ - # 1. Check if the feature is enabled. + # Check if the feature is enabled. if not _regional_access_boundary_utils.is_regional_access_boundary_enabled(): return False - # 2. Skip for non-default universe domains. + # Skip for non-default universe domains. if self.universe_domain != DEFAULT_UNIVERSE_DOMAIN: return False @@ -459,20 +505,10 @@ def apply(self, headers, token=None): super().apply(headers, token) self._rab_manager.apply_headers(headers) - def before_request(self, request, method, url, headers): - """Refreshes the access token and triggers the Regional Access Boundary - lookup if necessary. - """ - if self._use_non_blocking_refresh: - self._non_blocking_refresh(request) - else: - self._blocking_refresh(request) - + def _after_refresh(self, request, method, url, headers): + """Triggers the Regional Access Boundary lookup if necessary.""" self._maybe_start_regional_access_boundary_refresh(request, url) - metrics.add_metric_header(headers, self._metric_header_for_usage()) - self.apply(headers) - def refresh(self, request): """Refreshes the access token. @@ -500,12 +536,11 @@ def _lookup_regional_access_boundary( url = self._build_regional_access_boundary_lookup_url(request=request) if not url: - _LOGGER.error("Failed to build Regional Access Boundary lookup URL.") + _LOGGER.warning("Failed to build Regional Access Boundary lookup URL.") return None headers: Dict[str, str] = {} self._apply(headers) - self._rab_manager.apply_headers(headers) return _client._lookup_regional_access_boundary( request, url, headers=headers, fail_fast=fail_fast ) diff --git a/packages/google-auth/google/auth/external_account.py b/packages/google-auth/google/auth/external_account.py index b490f368ea45..eee6d1194031 100644 --- a/packages/google-auth/google/auth/external_account.py +++ b/packages/google-auth/google/auth/external_account.py @@ -40,9 +40,9 @@ from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.auth import credentials from google.auth import exceptions -from google.auth import iam from google.auth import impersonated_credentials from google.auth import metrics from google.oauth2 import sts @@ -526,9 +526,10 @@ def _build_regional_access_boundary_lookup_url( ) if workload_match: project_number, pool_id = workload_match.groups() - url = iam._WORKLOAD_IDENTITY_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( - project_number=project_number, - pool_id=pool_id, + url = ( + _regional_access_boundary_utils.get_workload_identity_pool_rab_endpoint( + project_number, pool_id + ) ) else: # If that fails, try to parse as a workforce pool. @@ -538,10 +539,8 @@ def _build_regional_access_boundary_lookup_url( ) if workforce_match: pool_id = workforce_match.groups()[0] - url = ( - iam._WORKFORCE_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( - pool_id=pool_id - ) + url = _regional_access_boundary_utils.get_workforce_pool_rab_endpoint( + pool_id ) if url: @@ -620,7 +619,7 @@ def _initialize_impersonated_credentials(self): scopes = self._scopes if self._scopes is not None else self._default_scopes # Initialize and return impersonated credentials. - return impersonated_credentials.Credentials( + impersonated_creds = impersonated_credentials.Credentials( source_credentials=source_credentials, target_principal=target_principal, target_scopes=scopes, @@ -631,6 +630,9 @@ def _initialize_impersonated_credentials(self): ), trust_boundary=self._trust_boundary, ) + if self._rab_manager._use_blocking_regional_access_boundary_lookup: + impersonated_creds._set_blocking_regional_access_boundary_lookup() + return impersonated_creds def _create_default_metrics_options(self): metrics_options = {} diff --git a/packages/google-auth/google/auth/external_account_authorized_user.py b/packages/google-auth/google/auth/external_account_authorized_user.py index d292589b6010..35144f15d69e 100644 --- a/packages/google-auth/google/auth/external_account_authorized_user.py +++ b/packages/google-auth/google/auth/external_account_authorized_user.py @@ -42,9 +42,9 @@ from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.auth import credentials from google.auth import exceptions -from google.auth import iam from google.oauth2 import sts from google.oauth2 import utils @@ -337,9 +337,7 @@ def _build_regional_access_boundary_lookup_url( pool_id = match.groups()[0] - return iam._WORKFORCE_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( - pool_id=pool_id - ) + return _regional_access_boundary_utils.get_workforce_pool_rab_endpoint(pool_id) def revoke(self, request): """Revokes the refresh token. diff --git a/packages/google-auth/google/auth/iam.py b/packages/google-auth/google/auth/iam.py index 00b6e06a2c4f..2ecb1b0014b8 100644 --- a/packages/google-auth/google/auth/iam.py +++ b/packages/google-auth/google/auth/iam.py @@ -49,23 +49,17 @@ else: _IAM_DOMAIN = f"iamcredentials.{credentials.DEFAULT_UNIVERSE_DOMAIN}" -# 3. Create the common base URL template +# Create the common base URL template # We use double brackets {{}} so .format() can be called later for the email. _IAM_BASE_URL = f"https://{_IAM_DOMAIN}/v1/projects/-/serviceAccounts/{{}}" -# 4. Define the endpoints as templates +# Define the endpoints as static templates _IAM_ENDPOINT = _IAM_BASE_URL + ":generateAccessToken" _IAM_SIGN_ENDPOINT = _IAM_BASE_URL + ":signBlob" _IAM_SIGNJWT_ENDPOINT = _IAM_BASE_URL + ":signJwt" _IAM_IDTOKEN_ENDPOINT = _IAM_BASE_URL + ":generateIdToken" -# Regional Access Boundary (RAB) Lookup Endpoints -_SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT = f"https://{_IAM_DOMAIN}/v1/projects/-/serviceAccounts/{{service_account_email}}/allowedLocations" -_WORKFORCE_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT = f"https://{_IAM_DOMAIN}/v1/locations/global/workforcePools/{{pool_id}}/allowedLocations" -_WORKLOAD_IDENTITY_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT = f"https://{_IAM_DOMAIN}/v1/projects/{{project_number}}/locations/global/workloadIdentityPools/{{pool_id}}/allowedLocations" - - class Signer(crypt.Signer): """Signs messages using the IAM `signBlob API`_. diff --git a/packages/google-auth/google/auth/impersonated_credentials.py b/packages/google-auth/google/auth/impersonated_credentials.py index 45db79daa42e..2f14d809319e 100644 --- a/packages/google-auth/google/auth/impersonated_credentials.py +++ b/packages/google-auth/google/auth/impersonated_credentials.py @@ -36,6 +36,7 @@ from google.auth import _exponential_backoff from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.auth import credentials from google.auth import exceptions from google.auth import iam @@ -368,8 +369,8 @@ def _build_regional_access_boundary_lookup_url( "Service account email is required to build the Regional Access Boundary lookup URL for impersonated credentials." ) return None - return iam._SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( - service_account_email=self.service_account_email + return _regional_access_boundary_utils.get_service_account_rab_endpoint( + self.service_account_email ) def sign_bytes(self, message): diff --git a/packages/google-auth/google/auth/jwt.py b/packages/google-auth/google/auth/jwt.py index b6fe60736fa1..1241aee70121 100644 --- a/packages/google-auth/google/auth/jwt.py +++ b/packages/google-auth/google/auth/jwt.py @@ -52,6 +52,7 @@ from google.auth import _cache from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.auth import _service_account_info from google.auth import crypt from google.auth import exceptions @@ -317,7 +318,9 @@ def decode(token, certs=None, verify=True, audience=None, clock_skew_in_seconds= class Credentials( - google.auth.credentials.Signing, google.auth.credentials.CredentialsWithQuotaProject + google.auth.credentials.Signing, + google.auth.credentials.CredentialsWithQuotaProject, + google.auth.credentials.CredentialsWithRegionalAccessBoundary, ): """Credentials that use a JWT as the bearer token. @@ -490,7 +493,15 @@ def from_signing_credentials(cls, credentials, audience, **kwargs): """ kwargs.setdefault("issuer", credentials.signer_email) kwargs.setdefault("subject", credentials.signer_email) - return cls(credentials.signer, audience=audience, **kwargs) + jwt_creds = cls(credentials.signer, audience=audience, **kwargs) + + if isinstance( + credentials, + google.auth.credentials.CredentialsWithRegionalAccessBoundary, + ): + credentials._copy_regional_access_boundary_manager(jwt_creds) + + return jwt_creds def with_claims( self, issuer=None, subject=None, audience=None, additional_claims=None @@ -514,7 +525,7 @@ def with_claims( new_additional_claims = copy.deepcopy(self._additional_claims) new_additional_claims.update(additional_claims or {}) - return self.__class__( + cred = self.__class__( self._signer, issuer=issuer if issuer is not None else self._issuer, subject=subject if subject is not None else self._subject, @@ -522,10 +533,12 @@ def with_claims( additional_claims=new_additional_claims, quota_project_id=self._quota_project_id, ) + self._copy_regional_access_boundary_manager(cred) + return cred @_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): - return self.__class__( + cred = self.__class__( self._signer, issuer=self._issuer, subject=self._subject, @@ -533,6 +546,8 @@ def with_quota_project(self, quota_project_id): additional_claims=self._additional_claims, quota_project_id=quota_project_id, ) + self._copy_regional_access_boundary_manager(cred) + return cred def _make_jwt(self): """Make a signed JWT. @@ -559,7 +574,7 @@ def _make_jwt(self): return jwt, expiry - def refresh(self, request): + def _perform_refresh_token(self, request): """Refreshes the access token. Args: @@ -569,6 +584,15 @@ def refresh(self, request): # (pylint doesn't correctly recognize overridden methods.) self.token, self.expiry = self._make_jwt() + def _build_regional_access_boundary_lookup_url(self, request=None): + """Builds the lookup URL using the service account's email address.""" + if not self.signer_email: + return None + + return _regional_access_boundary_utils.get_service_account_rab_endpoint( + self.signer_email + ) + @_helpers.copy_docstring(google.auth.credentials.Signing) def sign_bytes(self, message): return self._signer.sign(message) diff --git a/packages/google-auth/google/oauth2/_client.py b/packages/google-auth/google/oauth2/_client.py index 1c7ba46b72e1..88083d022986 100644 --- a/packages/google-auth/google/oauth2/_client.py +++ b/packages/google-auth/google/oauth2/_client.py @@ -549,7 +549,7 @@ def _lookup_regional_access_boundary(request, url, headers=None, fail_fast=False # Error was already logged by _lookup_regional_access_boundary_request return None - if "encodedLocations" not in response_data: + if not isinstance(response_data, dict) or "encodedLocations" not in response_data: _LOGGER.error( "Regional Access Boundary response malformed: missing 'encodedLocations' key in %s", response_data, diff --git a/packages/google-auth/google/oauth2/_client_async.py b/packages/google-auth/google/oauth2/_client_async.py index a6201fbdcb94..ce94284ea7c9 100644 --- a/packages/google-auth/google/oauth2/_client_async.py +++ b/packages/google-auth/google/oauth2/_client_async.py @@ -23,6 +23,7 @@ .. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2 """ +import asyncio import http.client as http_client import json import urllib @@ -288,3 +289,166 @@ async def refresh_grant( request, token_uri, body, can_retry=can_retry ) return client._handle_refresh_grant_response(response_data, refresh_token) + + +async def _lookup_regional_access_boundary(request, url, headers=None, fail_fast=False): + """Implements the global lookup of a credential Regional Access Boundary. + For the lookup, we send a request to the global lookup endpoint and then + parse the response. Service account credentials, workload identity + pools and workforce pools implementation may have Regional Access Boundaries configured. + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. The returned response must support `await response.read()` + (standard async transport) or `await response.content()` (legacy/custom transport). + url (str): The Regional Access Boundary lookup url. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + Returns: + Optional[Mapping[str,list|str]]: A dictionary containing + "locations" as a list of allowed locations as strings and + "encodedLocations" as a hex string. + e.g: + { + "locations": [ + "us-central1", "us-east1", "europe-west1", "asia-east1" + ], + "encodedLocations": "0xA30" + } + """ + response_data = await _lookup_regional_access_boundary_request( + request, url, headers=headers, fail_fast=fail_fast + ) + if response_data is None: + # Error was already logged by _lookup_regional_access_boundary_request + return None + + if not isinstance(response_data, dict) or "encodedLocations" not in response_data: + client._LOGGER.error( + "Regional Access Boundary response malformed: missing 'encodedLocations' key in %s", + response_data, + ) + return None + return response_data + + +async def _lookup_regional_access_boundary_request( + request, url, can_retry=True, headers=None, fail_fast=False +): + """Makes a request to the Regional Access Boundary lookup endpoint. + + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. The returned response must support `await response.read()` + (standard async transport) or `await response.content()` (legacy/custom transport). + url (str): The Regional Access Boundary lookup url. + can_retry (bool): Enable or disable request retry behavior. Defaults to true. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + + Returns: + Optional[Mapping[str, str]]: The JSON-decoded response data on success, or None on failure. + """ + ( + response_status_ok, + response_data, + retryable_error, + ) = await _lookup_regional_access_boundary_request_no_throw( + request, url, can_retry=can_retry, headers=headers, fail_fast=fail_fast + ) + if not response_status_ok: + client._LOGGER.warning( + "Regional Access Boundary HTTP request failed after retries: response_data=%s, retryable_error=%s", + response_data, + retryable_error, + ) + return None + return response_data + + +async def _lookup_regional_access_boundary_request_no_throw( + request, url, can_retry=True, headers=None, fail_fast=False +): + """Makes a request to the Regional Access Boundary lookup endpoint. This + function doesn't throw on response errors. + + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. The returned response must support `await response.read()` + (standard async transport) or `await response.content()` (legacy/custom transport). + url (str): The Regional Access Boundary lookup url. + can_retry (bool): Enable or disable request retry behavior. Defaults to true. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + + Returns: + Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating + if the request is successful, a mapping for the JSON-decoded response + data and in the case of an error a boolean indicating if the error + is retryable. + """ + + response_data = {} + retryable_error = False + + timeout = ( + client._BLOCKING_REGIONAL_ACCESS_BOUNDARY_LOOKUP_TIMEOUT if fail_fast else None + ) + total_attempts = 1 if fail_fast else 6 + retries = _exponential_backoff.AsyncExponentialBackoff( + total_attempts=total_attempts + ) + + async for _ in retries: + try: + if timeout: + response = await asyncio.wait_for( + request(method="GET", url=url, headers=headers, timeout=timeout), + timeout=timeout, + ) + else: + response = await request(method="GET", url=url, headers=headers) + + # Supports both modern google.auth.aio (exposing read()) and legacy transports (exposing content()) + if hasattr(response, "read"): + response_bytes = await response.read() + else: + response_bytes = await response.content() + except (asyncio.TimeoutError, exceptions.TransportError): + retryable_error = True + if not can_retry: + return False, {}, retryable_error + continue + except Exception: + # Catch raw transport/socket exceptions raised during body streaming. + return False, {}, False + + try: + response_body = ( + response_bytes.decode("utf-8") + if hasattr(response_bytes, "decode") + else response_bytes + ) + response_data = json.loads(response_body) + except (UnicodeDecodeError, ValueError): + # Keep types safe and allow status-code checks below to determine retryability + response_data = {} + + status_code = ( + response.status_code + if hasattr(response, "status_code") + else response.status + ) + + if status_code == http_client.OK: + return True, response_data, None + + retryable_error = client._can_retry( + status_code=status_code, response_data=response_data + ) + if status_code == http_client.BAD_GATEWAY: + retryable_error = True + + if not can_retry or not retryable_error: + return False, response_data, retryable_error + + return False, response_data, retryable_error diff --git a/packages/google-auth/google/oauth2/_service_account_async.py b/packages/google-auth/google/oauth2/_service_account_async.py index fa6cfb7b7d7a..69b80a2531d2 100644 --- a/packages/google-auth/google/oauth2/_service_account_async.py +++ b/packages/google-auth/google/oauth2/_service_account_async.py @@ -24,12 +24,15 @@ from google.auth import _credentials_async as credentials_async from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.oauth2 import _client_async from google.oauth2 import service_account class Credentials( - service_account.Credentials, credentials_async.Scoped, credentials_async.Credentials + service_account.Credentials, + credentials_async.Scoped, + credentials_async.CredentialsWithRegionalAccessBoundary, ): """Service account credentials @@ -66,6 +69,14 @@ class Credentials( credentials = credentials.with_quota_project('myproject-123') """ + def __setstate__(self, state): + """Restores the credential state and ensures the async refresh manager is attached.""" + super().__setstate__(state) + + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + @_helpers.copy_docstring(credentials_async.Credentials) async def refresh(self, request): assertion = self._make_authorization_grant_assertion() @@ -75,13 +86,6 @@ async def refresh(self, request): self.token = access_token self.expiry = expiry - @_helpers.copy_docstring(credentials_async.Credentials) - async def before_request(self, request, method, url, headers): - # Explicit override to bypass synchronous CredentialsWithRegionalAccessBoundary. - await credentials_async.Credentials.before_request( - self, request, method, url, headers - ) - class IDTokenCredentials( service_account.IDTokenCredentials, @@ -137,11 +141,3 @@ async def refresh(self, request): ) self.token = access_token self.expiry = expiry - - @_helpers.copy_docstring(credentials_async.Credentials) - async def before_request(self, request, method, url, headers): - # Explicit override to bypass synchronous CredentialsWithRegionalAccessBoundary - # and disable Regional Access Boundary refresh for async credentials. - await credentials_async.Credentials.before_request( - self, request, method, url, headers - ) diff --git a/packages/google-auth/google/oauth2/service_account.py b/packages/google-auth/google/oauth2/service_account.py index 5c19b8fe01ae..7f719ade2cdb 100644 --- a/packages/google-auth/google/oauth2/service_account.py +++ b/packages/google-auth/google/oauth2/service_account.py @@ -77,6 +77,7 @@ from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.auth import _service_account_info from google.auth import credentials from google.auth import exceptions @@ -520,8 +521,8 @@ def _build_regional_access_boundary_lookup_url( "Service account email is required to build the Regional Access Boundary lookup URL for service account credentials." ) return None - return iam._SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( - service_account_email=self._service_account_email, + return _regional_access_boundary_utils.get_service_account_rab_endpoint( + self._service_account_email ) @_helpers.copy_docstring(credentials.Signing) diff --git a/packages/google-auth/tests/compute_engine/test__metadata.py b/packages/google-auth/tests/compute_engine/test__metadata.py index e2cbf425a1ec..b27e7f7f4fb5 100644 --- a/packages/google-auth/tests/compute_engine/test__metadata.py +++ b/packages/google-auth/tests/compute_engine/test__metadata.py @@ -985,3 +985,28 @@ def test__prepare_request_for_mds_mtls_http_request(mock_mds_mtls_adapter): _metadata._prepare_request_for_mds(request, use_mtls=True) assert mock_mds_mtls_adapter.call_count == 0 + + +def test__is_service_account_email(): + # Valid email formats + assert ( + _metadata._is_service_account_email("my-sa@my-project.iam.gserviceaccount.com") + is True + ) + assert _metadata._is_service_account_email("test@example.com") is True + + # Empty inputs and standard string placeholders + assert _metadata._is_service_account_email("default") is False + assert _metadata._is_service_account_email("") is False + assert _metadata._is_service_account_email(None) is False + + # Workload identity principal URI formats + assert ( + _metadata._is_service_account_email( + "principal://iam.googleapis.com/projects/1234567890/locations/global/workloadIdentityPools/my-project.svc.id.goog/subject/ns/my-namespace/sa/my-kubernetes-sa" + ) + is False + ) + + # Workforce or workload pool identifier paths + assert _metadata._is_service_account_email("my-gcp-project.svc.id.goog") is False diff --git a/packages/google-auth/tests/compute_engine/test_credentials.py b/packages/google-auth/tests/compute_engine/test_credentials.py index 5a60ffd44145..7fb2b8b504fc 100644 --- a/packages/google-auth/tests/compute_engine/test_credentials.py +++ b/packages/google-auth/tests/compute_engine/test_credentials.py @@ -306,8 +306,9 @@ def test_build_regional_access_boundary_lookup_url_default_email( url = creds._build_regional_access_boundary_lookup_url(request=mock_request) mock_get_service_account_info.assert_called_once_with(mock_request, "default") - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" - assert url == expected_url + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) def test_build_regional_access_boundary_lookup_url_http_client_request( @@ -323,7 +324,33 @@ def test_build_regional_access_boundary_lookup_url_http_client_request( url = creds._build_regional_access_boundary_lookup_url(request=req) - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) + + @mock.patch( + "google.auth.compute_engine._metadata.get_service_account_info", autospec=True + ) + @mock.patch( + "google.auth.compute_engine._metadata.get_universe_domain", autospec=True + ) + def test_build_regional_access_boundary_lookup_url_explicit_email_standard( + self, mock_get_universe_domain, mock_get_service_account_info, monkeypatch + ): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return False + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + + # Test with an explicit service account email, no resolution needed + creds = self.credentials + creds._service_account_email = FAKE_SERVICE_ACCOUNT_EMAIL + mock_get_universe_domain.return_value = "googleapis.com" + + url = creds._build_regional_access_boundary_lookup_url() + + mock_get_service_account_info.assert_not_called() + expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" assert url == expected_url @mock.patch( @@ -332,9 +359,14 @@ def test_build_regional_access_boundary_lookup_url_http_client_request( @mock.patch( "google.auth.compute_engine._metadata.get_universe_domain", autospec=True ) - def test_build_regional_access_boundary_lookup_url_explicit_email( - self, mock_get_universe_domain, mock_get_service_account_info + def test_build_regional_access_boundary_lookup_url_explicit_email_mtls( + self, mock_get_universe_domain, mock_get_service_account_info, monkeypatch ): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return True + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + # Test with an explicit service account email, no resolution needed creds = self.credentials creds._service_account_email = FAKE_SERVICE_ACCOUNT_EMAIL @@ -343,9 +375,8 @@ def test_build_regional_access_boundary_lookup_url_explicit_email( url = creds._build_regional_access_boundary_lookup_url() mock_get_service_account_info.assert_not_called() - assert url == ( - "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" - ) + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" + assert url == expected_url @mock.patch( "google.auth.compute_engine._metadata.get_universe_domain", autospec=True @@ -379,6 +410,68 @@ def test_build_regional_access_boundary_lookup_url_no_email( url = creds._build_regional_access_boundary_lookup_url() assert url is None + @mock.patch( + "google.auth._regional_access_boundary_utils.is_regional_access_boundary_enabled", + return_value=True, + ) + def test_is_regional_access_boundary_lookup_required(self, mock_enabled): + creds = self.credentials + creds._universe_domain_cached = True + + # Valid email formats should pass. + creds._service_account_email = "my-sa@my-project.iam.gserviceaccount.com" + assert creds._is_regional_access_boundary_lookup_required() is True + + # GCE default email placeholder should pass to allow dynamic resolution. + creds._service_account_email = "default" + assert creds._is_regional_access_boundary_lookup_required() is True + + # Lookup for non-email based identities should be skipped. + creds._service_account_email = "my-gcp-project.svc.id.goog" + assert creds._is_regional_access_boundary_lookup_required() is False + + creds._service_account_email = "principal://iam.googleapis.com/projects/1234567890/locations/global/workloadIdentityPools/my-project.svc.id.goog/subject/ns/my-namespace/sa/my-kubernetes-sa" + assert creds._is_regional_access_boundary_lookup_required() is False + + def test_build_regional_access_boundary_lookup_url_with_invalid_email(self): + creds = self.credentials + creds._universe_domain_cached = True + + # Set a non-email identity. + creds._service_account_email = "my-gcp-project.svc.id.goog" + url = creds._build_regional_access_boundary_lookup_url() + assert url is None + + @mock.patch( + "google.auth._regional_access_boundary_utils.is_regional_access_boundary_enabled", + return_value=True, + ) + @mock.patch( + "google.auth.compute_engine._metadata.get_service_account_info", autospec=True + ) + def test_regional_access_boundary_disabled_state_transitions( + self, mock_get_service_account_info, mock_enabled + ): + mock_get_service_account_info.return_value = { + "email": "spiffe://trust-domain/ns/ns/sa/sa", + "scopes": ["one", "two"], + } + creds = self.credentials + creds._universe_domain_cached = True + creds._service_account_email = "default" + + # Initially, GCE 'default' placeholder passes the pre-check + assert not creds._rab_disabled + assert creds._is_regional_access_boundary_lookup_required() is True + + # Resolving a non-email identity should disable RAB lookup + url = creds._build_regional_access_boundary_lookup_url() + assert url is None + assert creds._rab_disabled is True + + # Subsequent check calls should return False early + assert creds._is_regional_access_boundary_lookup_required() is False + @mock.patch("google.auth.compute_engine._metadata.get") @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") @mock.patch("google.auth._agent_identity_utils.parse_certificate") diff --git a/packages/google-auth/tests/oauth2/test__client.py b/packages/google-auth/tests/oauth2/test__client.py index 173ddbd27948..b20a8042d5f5 100644 --- a/packages/google-auth/tests/oauth2/test__client.py +++ b/packages/google-auth/tests/oauth2/test__client.py @@ -185,7 +185,8 @@ def test__token_endpoint_request_error(): _client._token_endpoint_request(request, "http://example.com", {}) -def test__token_endpoint_request_internal_failure_error(): +@mock.patch("time.sleep", return_value=None) +def test__token_endpoint_request_internal_failure_error(mock_sleep): request = make_request( {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST ) @@ -207,9 +208,11 @@ def test__token_endpoint_request_internal_failure_error(): ) # request with 2 retries assert request.call_count == 3 + assert mock_sleep.call_count == 4 -def test__token_endpoint_request_internal_failure_and_retry_failure_error(): +@mock.patch("time.sleep", return_value=None) +def test__token_endpoint_request_internal_failure_and_retry_failure_error(mock_sleep): retryable_error = mock.create_autospec(transport.Response, instance=True) retryable_error.status = http_client.BAD_REQUEST retryable_error.data = json.dumps({"error_description": "internal_failure"}).encode( @@ -233,9 +236,11 @@ def test__token_endpoint_request_internal_failure_and_retry_failure_error(): # request should be called three times. Two retryable errors and one # unretryable error to break the retry loop. assert request.call_count == 3 + assert mock_sleep.call_count == 2 -def test__token_endpoint_request_internal_failure_and_retry_succeeds(): +@mock.patch("time.sleep", return_value=None) +def test__token_endpoint_request_internal_failure_and_retry_succeeds(mock_sleep): retryable_error = mock.create_autospec(transport.Response, instance=True) retryable_error.status = http_client.BAD_REQUEST retryable_error.data = json.dumps({"error_description": "internal_failure"}).encode( @@ -255,6 +260,7 @@ def test__token_endpoint_request_internal_failure_and_retry_succeeds(): ) assert request.call_count == 2 + assert mock_sleep.call_count == 1 def test__token_endpoint_request_string_error(): @@ -611,7 +617,8 @@ def test_refresh_grant_retry_with_retry( @pytest.mark.parametrize("can_retry", [True, False]) -def test__token_endpoint_request_no_throw_with_retry(can_retry): +@mock.patch("time.sleep", return_value=None) +def test__token_endpoint_request_no_throw_with_retry(mock_sleep, can_retry): response_data = {"error": "help", "error_description": "I'm alive"} body = "dummy body" @@ -628,8 +635,10 @@ def test__token_endpoint_request_no_throw_with_retry(can_retry): if can_retry: assert mock_request.call_count == 3 + assert mock_sleep.call_count == 2 else: assert mock_request.call_count == 1 + mock_sleep.assert_not_called() def test_lookup_regional_access_boundary(): @@ -706,7 +715,10 @@ def test_lookup_regional_access_boundary_non_retryable_error(status_code): ) -def test_lookup_regional_access_boundary_internal_failure_and_retry_failure_error(): +@mock.patch("time.sleep", return_value=None) +def test_lookup_regional_access_boundary_internal_failure_and_retry_failure_error( + mock_sleep, +): retryable_error = mock.create_autospec(transport.Response, instance=True) retryable_error.status = http_client.BAD_REQUEST retryable_error.data = json.dumps({"error_description": "internal_failure"}).encode( @@ -731,11 +743,15 @@ def test_lookup_regional_access_boundary_internal_failure_and_retry_failure_erro # request should be called three times. Two retryable errors and one # unretryable error to break the retry loop. assert request.call_count == 3 + assert mock_sleep.call_count == 2 for call in request.call_args_list: assert call[1]["headers"] == headers -def test_lookup_regional_access_boundary_internal_failure_and_retry_succeeds(): +@mock.patch("time.sleep", return_value=None) +def test_lookup_regional_access_boundary_internal_failure_and_retry_succeeds( + mock_sleep, +): retryable_error = mock.create_autospec(transport.Response, instance=True) retryable_error.status = http_client.BAD_REQUEST retryable_error.data = json.dumps({"error_description": "internal_failure"}).encode( @@ -760,6 +776,7 @@ def test_lookup_regional_access_boundary_internal_failure_and_retry_succeeds(): ) assert request.call_count == 2 + assert mock_sleep.call_count == 1 for call in request.call_args_list: assert call[1]["headers"] == headers diff --git a/packages/google-auth/tests/oauth2/test_service_account.py b/packages/google-auth/tests/oauth2/test_service_account.py index f0d8f0759e50..958eace2dd22 100644 --- a/packages/google-auth/tests/oauth2/test_service_account.py +++ b/packages/google-auth/tests/oauth2/test_service_account.py @@ -228,15 +228,65 @@ def test_with_quota_project(self): new_credentials.apply(hdrs, token="tok") assert "x-goog-user-project" in hdrs - def test_build_regional_access_boundary_lookup_url(self): + def test_copy_regional_access_boundary_manager_state_and_config_with_scopes(self): credentials = self.make_credentials() - expected_url = ( - "https://iamcredentials.googleapis.com/v1/projects/-/" - "serviceAccounts/{}/allowedLocations".format( - credentials.service_account_email - ) + credentials._rab_manager._data = mock.sentinel.rab_data + credentials._rab_manager._use_blocking_regional_access_boundary_lookup = True + + new_credentials = credentials.with_scopes(["scope-foo"]) + + # Verify references to boundary data are shared + assert new_credentials._rab_manager._data == mock.sentinel.rab_data + # Verify blocking config flag is preserved + assert ( + new_credentials._rab_manager._use_blocking_regional_access_boundary_lookup + is True + ) + # Verify target manager object is not replaced + assert new_credentials._rab_manager is not credentials._rab_manager + + def test_copy_regional_access_boundary_manager_state_and_config_with_quota_project( + self, + ): + credentials = self.make_credentials() + credentials._rab_manager._data = mock.sentinel.rab_data + credentials._rab_manager._use_blocking_regional_access_boundary_lookup = True + + new_credentials = credentials.with_quota_project("new-project-foo") + + # Verify references to boundary data are shared + assert new_credentials._rab_manager._data == mock.sentinel.rab_data + # Verify blocking config flag is preserved + assert ( + new_credentials._rab_manager._use_blocking_regional_access_boundary_lookup + is True + ) + # Verify target manager object is not replaced + assert new_credentials._rab_manager is not credentials._rab_manager + + def test_build_regional_access_boundary_lookup_url_standard(self, monkeypatch): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + + credentials = self.make_credentials() + url = credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email + ) + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_mtls(self, monkeypatch): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + credentials = self.make_credentials() + url = credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email ) - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + assert url == expected_url def test_with_token_uri(self): credentials = self.make_credentials() diff --git a/packages/google-auth/tests/test__regional_access_boundary_utils.py b/packages/google-auth/tests/test__regional_access_boundary_utils.py index ab6ec75fd9b8..c612b60b8ed2 100644 --- a/packages/google-auth/tests/test__regional_access_boundary_utils.py +++ b/packages/google-auth/tests/test__regional_access_boundary_utils.py @@ -1,4 +1,4 @@ -# Copyright 2026 Google Inc. +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import pytest # type: ignore +from google.auth import _credentials_async from google.auth import _helpers from google.auth import _regional_access_boundary_utils from google.auth import credentials @@ -301,6 +302,24 @@ def test_serialization(self): assert unpickled.refresh_manager._lock is not None assert unpickled.refresh_manager._worker is None + def test_unpickle_old_credentials_without_rab(self): + creds = CredentialsImpl() + old_state = creds.__dict__.copy() + if "_rab_manager" in old_state: + del old_state["_rab_manager"] + if "_use_non_blocking_refresh" in old_state: + del old_state["_use_non_blocking_refresh"] + if "_refresh_worker" in old_state: + del old_state["_refresh_worker"] + + new_instance = CredentialsImpl.__new__(CredentialsImpl) + new_instance.__setstate__(old_state) + + assert hasattr(new_instance, "_rab_manager") + assert new_instance._rab_manager is not None + assert new_instance._use_non_blocking_refresh is False + assert new_instance._refresh_worker is not None + @mock.patch( "google.auth._regional_access_boundary_utils._RegionalAccessBoundaryRefreshManager.start_refresh" ) @@ -379,6 +398,21 @@ def test_start_blocking_refresh_failure(self): assert creds._rab_manager._data.encoded_locations is None assert creds._rab_manager._data.cooldown_expiry is not None + def test_start_blocking_refresh_with_async_credentials(self): + creds = CredentialsImpl() + request = mock.Mock() + + with mock.patch.object( + creds, + "_lookup_regional_access_boundary", + new_callable=mock.AsyncMock, + ) as mock_lookup: + creds._rab_manager.start_blocking_refresh(creds, request) + + mock_lookup.assert_not_called() + assert creds._rab_manager._data.encoded_locations is None + assert creds._rab_manager._data.cooldown_expiry is not None + @mock.patch("copy.deepcopy") def test_start_refresh_deepcopy_failure(self, mock_deepcopy): mock_deepcopy.side_effect = Exception("deepcopy error") @@ -552,3 +586,178 @@ def test_regional_access_boundary_refresh_manager_start_refresh_safety_lock(self mock_thread_class.assert_not_called() assert manager._worker == mock_worker + + +class AsyncCredentialsImpl(_credentials_async.CredentialsWithRegionalAccessBoundary): + def __init__(self, universe_domain=None): + super().__init__() + if universe_domain: + self._universe_domain = universe_domain + + async def _perform_refresh_token(self, request): + self.token = "refreshed-token" + self.expiry = ( + _helpers.utcnow() + + _helpers.REFRESH_THRESHOLD + + datetime.timedelta(seconds=5) + ) + + def with_quota_project(self, quota_project_id): + raise NotImplementedError() + + def _build_regional_access_boundary_lookup_url(self, request=None): + # Using self.token here to make the URL dynamic for testing purposes + return "http://mock.url/lookup_for_{}".format(self.token) + + def _make_copy(self): + new_credentials = self.__class__() + self._copy_regional_access_boundary_manager(new_credentials) + return new_credentials + + +class TestAsyncCredentialsWithRegionalAccessBoundary(object): + @pytest.mark.asyncio + async def test_maybe_start_refresh_async_blocking(self): + creds = AsyncCredentialsImpl() + creds._rab_manager._use_blocking_regional_access_boundary_lookup = True + request = mock.Mock() + + with mock.patch.dict( + os.environ, + {environment_vars.GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED: "true"}, + ): + with mock.patch.object( + creds._rab_manager, + "start_blocking_refresh_async", + new_callable=mock.AsyncMock, + ) as mock_start_blocking: + await creds._maybe_start_regional_access_boundary_refresh_async( + request, "http://example.com" + ) + mock_start_blocking.assert_called_once_with(creds, request) + + @pytest.mark.asyncio + async def test_start_blocking_refresh_async_success(self): + creds = AsyncCredentialsImpl() + request = mock.Mock() + + with mock.patch.object( + creds, + "_lookup_regional_access_boundary", + new_callable=mock.AsyncMock, + return_value={"encodedLocations": "0xABC"}, + ) as mock_lookup: + await creds._rab_manager.start_blocking_refresh_async(creds, request) + + mock_lookup.assert_called_once_with(request, fail_fast=True) + assert creds._rab_manager._data.encoded_locations == "0xABC" + + @pytest.mark.asyncio + async def test_start_blocking_refresh_async_failure(self): + creds = AsyncCredentialsImpl() + request = mock.Mock() + + with mock.patch.object( + creds, + "_lookup_regional_access_boundary", + new_callable=mock.AsyncMock, + side_effect=Exception("error"), + ) as mock_lookup: + await creds._rab_manager.start_blocking_refresh_async(creds, request) + + mock_lookup.assert_called_once_with(request, fail_fast=True) + assert creds._rab_manager._data.encoded_locations is None + assert creds._rab_manager._data.cooldown_expiry is not None + + @pytest.mark.asyncio + async def test_async_refresh_manager_session_closed_ignored(self): + credentials = mock.AsyncMock() + # Simulate a closed session RuntimeError when invoking the boundary lookup + credentials._lookup_regional_access_boundary.side_effect = RuntimeError( + "Session is closed" + ) + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + # Trigger refresh, which starts a background task that should swallow the error + manager.start_refresh(credentials, request, rab_manager) + + # Wait for the background worker task to terminate + await manager._worker_task + + # Verify that the lookup was still triggered but failed open cleanly + credentials._lookup_regional_access_boundary.assert_called_once_with(request) + rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) + + +def test_get_service_account_rab_endpoint(monkeypatch): + from google.auth.transport import _mtls_helper + + # Test Standard TLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + url = _regional_access_boundary_utils.get_service_account_rab_endpoint( + "test@example.com" + ) + assert ( + url + == "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test@example.com/allowedLocations" + ) + + # Test mTLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + url = _regional_access_boundary_utils.get_service_account_rab_endpoint( + "test@example.com" + ) + assert ( + url + == "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/test@example.com/allowedLocations" + ) + + +def test_get_workforce_pool_rab_endpoint(monkeypatch): + from google.auth.transport import _mtls_helper + + # Test Standard TLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + url = _regional_access_boundary_utils.get_workforce_pool_rab_endpoint("POOL_ID") + assert ( + url + == "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + ) + + # Test mTLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + url = _regional_access_boundary_utils.get_workforce_pool_rab_endpoint("POOL_ID") + assert ( + url + == "https://iamcredentials.mtls.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + ) + + +def test_get_workload_identity_pool_rab_endpoint(monkeypatch): + from google.auth.transport import _mtls_helper + + # Test Standard TLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + url = _regional_access_boundary_utils.get_workload_identity_pool_rab_endpoint( + "PROJECT_NUM", "POOL_ID" + ) + assert ( + url + == "https://iamcredentials.googleapis.com/v1/projects/PROJECT_NUM/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" + ) + + # Test mTLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + url = _regional_access_boundary_utils.get_workload_identity_pool_rab_endpoint( + "PROJECT_NUM", "POOL_ID" + ) + assert ( + url + == "https://iamcredentials.mtls.googleapis.com/v1/projects/PROJECT_NUM/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" + ) diff --git a/packages/google-auth/tests/test_credentials.py b/packages/google-auth/tests/test_credentials.py index e1528a3ce365..5c7e39d59e84 100644 --- a/packages/google-auth/tests/test_credentials.py +++ b/packages/google-auth/tests/test_credentials.py @@ -154,6 +154,21 @@ def test_before_request_with_regional_access_boundary(): assert headers["x-allowed-locations"] == DUMMY_BOUNDARY +def test_copy_regional_access_boundary_manager_state_and_config(): + creds = CredentialsImpl() + creds._rab_manager._data = mock.sentinel.rab_data + creds._rab_manager._use_blocking_regional_access_boundary_lookup = True + + new_creds = creds._make_copy() + + # Verify references to immutable boundary data are shared + assert new_creds._rab_manager._data == mock.sentinel.rab_data + # Verify blocking config flag is preserved + assert new_creds._rab_manager._use_blocking_regional_access_boundary_lookup is True + # Verify target manager object is isolated (kept from constructor, not replaced) + assert new_creds._rab_manager is not creds._rab_manager + + def test_before_request_metrics(): credentials = CredentialsImplWithMetrics() request = "token" @@ -424,3 +439,17 @@ def test_before_request_triggers_rab_refresh(): lookup.assert_called_once() args, kwargs = lookup.call_args assert args[1] == "http://mock.url/lookup_for_refreshed-token" + + +def test_maybe_start_regional_access_boundary_refresh_invalid_url(): + credentials_instance = CredentialsImpl() + request = mock.Mock() + + # Verifies that passing invalid/non-string URLs synchronously fails safe without crashing. + credentials_instance._maybe_start_regional_access_boundary_refresh( + request, url=None + ) + credentials_instance._maybe_start_regional_access_boundary_refresh(request, url=123) + credentials_instance._maybe_start_regional_access_boundary_refresh( + request, url=object() + ) diff --git a/packages/google-auth/tests/test_external_account.py b/packages/google-auth/tests/test_external_account.py index dc296f7a52ae..870b07d47b6e 100644 --- a/packages/google-auth/tests/test_external_account.py +++ b/packages/google-auth/tests/test_external_account.py @@ -403,29 +403,22 @@ def test_with_scopes_full_options_propagated(self): service_account_impersonation_options={"token_lifetime_seconds": 2800}, ) - with mock.patch.object( - external_account.Credentials, "__init__", return_value=None - ) as mock_init: - credentials.with_scopes(["email"], ["default2"]) - - # Confirm with_scopes initialized the credential with the expected - # parameters and scopes. - mock_init.assert_called_once_with( - audience=self.AUDIENCE, - subject_token_type=self.SUBJECT_TOKEN_TYPE, - token_url=self.TOKEN_URL, - token_info_url=self.TOKEN_INFO_URL, - credential_source=self.CREDENTIAL_SOURCE, - service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, - service_account_impersonation_options={"token_lifetime_seconds": 2800}, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - quota_project_id=self.QUOTA_PROJECT_ID, - scopes=["email"], - default_scopes=["default2"], - universe_domain=DEFAULT_UNIVERSE_DOMAIN, - trust_boundary=None, + cloned = credentials.with_scopes(["email"], ["default2"]) + + assert cloned.scopes == ["email"] + assert cloned.default_scopes == ["default2"] + assert cloned.quota_project_id == self.QUOTA_PROJECT_ID + assert cloned._client_id == CLIENT_ID + assert cloned._client_secret == CLIENT_SECRET + assert cloned._token_info_url == self.TOKEN_INFO_URL + assert ( + cloned._service_account_impersonation_url + == self.SERVICE_ACCOUNT_IMPERSONATION_URL ) + assert cloned._service_account_impersonation_options == { + "token_lifetime_seconds": 2800 + } + assert cloned.universe_domain == DEFAULT_UNIVERSE_DOMAIN def test_with_token_uri(self): credentials = self.make_credentials() @@ -492,33 +485,21 @@ def test_with_quota_project_full_options_propagated(self): service_account_impersonation_options={"token_lifetime_seconds": 2800}, ) - with mock.patch.object( - external_account.Credentials, "__init__", return_value=None - ) as mock_init: - new_cred = credentials.with_quota_project("project-foo") - - # Confirm with_quota_project initialized the credential with the - # expected parameters. - mock_init.assert_called_once_with( - audience=self.AUDIENCE, - subject_token_type=self.SUBJECT_TOKEN_TYPE, - token_url=self.TOKEN_URL, - token_info_url=self.TOKEN_INFO_URL, - credential_source=self.CREDENTIAL_SOURCE, - service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, - service_account_impersonation_options={"token_lifetime_seconds": 2800}, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - quota_project_id=self.QUOTA_PROJECT_ID, - scopes=self.SCOPES, - default_scopes=["default1"], - universe_domain=DEFAULT_UNIVERSE_DOMAIN, - trust_boundary=None, - ) + new_cred = credentials.with_quota_project("project-foo") - # Confirm with_quota_project sets the correct quota project after - # initialization. - assert new_cred.quota_project_id == "project-foo" + assert new_cred.quota_project_id == "project-foo" + assert new_cred.scopes == self.SCOPES + assert new_cred.default_scopes == ["default1"] + assert new_cred._client_id == CLIENT_ID + assert new_cred._client_secret == CLIENT_SECRET + assert new_cred._token_info_url == self.TOKEN_INFO_URL + assert ( + new_cred._service_account_impersonation_url + == self.SERVICE_ACCOUNT_IMPERSONATION_URL + ) + assert new_cred._service_account_impersonation_options == { + "token_lifetime_seconds": 2800 + } def test_info(self): credentials = self.make_credentials(universe_domain="dummy_universe.com") @@ -544,6 +525,23 @@ def test_with_universe_domain(self): new_credentials = credentials.with_universe_domain("dummy_universe.com") assert new_credentials.universe_domain == "dummy_universe.com" + def test_copy_regional_access_boundary_manager_state_and_config(self): + credentials = self.make_credentials() + credentials._rab_manager._data = mock.sentinel.rab_data + credentials._rab_manager._use_blocking_regional_access_boundary_lookup = True + + new_credentials = credentials.with_universe_domain("dummy_universe.com") + + # Verify references to boundary data are shared + assert new_credentials._rab_manager._data == mock.sentinel.rab_data + # Verify blocking config flag is preserved + assert ( + new_credentials._rab_manager._use_blocking_regional_access_boundary_lookup + is True + ) + # Verify target manager object is not replaced + assert new_credentials._rab_manager is not credentials._rab_manager + def test_info_workforce_pool(self): credentials = self.make_workforce_pool_credentials( workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT @@ -979,6 +977,57 @@ def test_refresh_impersonation_without_client_auth_success( assert not credentials.expired assert credentials.token == impersonation_response["accessToken"] + @mock.patch( + "google.auth.metrics.token_request_access_token_impersonate", + return_value=IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, + ) + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + def test_refresh_impersonation_propagates_rab_config( + self, mock_metrics_header_value, mock_auth_lib_value + ): + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + token_response = self.SUCCESS_RESPONSE.copy() + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + request = self.make_mock_request( + status=http_client.OK, + data=token_response, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + credentials._set_blocking_regional_access_boundary_lookup() + assert ( + credentials._rab_manager._use_blocking_regional_access_boundary_lookup + is True + ) + + credentials.refresh(request) + + assert credentials._impersonated_credentials is not None + assert ( + credentials._impersonated_credentials._rab_manager._use_blocking_regional_access_boundary_lookup + is True + ) + assert ( + credentials._rab_manager._use_blocking_regional_access_boundary_lookup + is True + ) + assert ( + credentials._rab_manager + is credentials._impersonated_credentials._rab_manager + ) + @mock.patch( "google.auth.metrics.token_request_access_token_impersonate", return_value=IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -1727,15 +1776,51 @@ def test_before_request_expired(self, utcnow): "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) } - def test_build_regional_access_boundary_lookup_url_workload(self): + def test_build_regional_access_boundary_lookup_url_workload_standard( + self, monkeypatch + ): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + credentials = self.make_credentials() + url = credentials._build_regional_access_boundary_lookup_url() expected_url = "https://iamcredentials.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_workload_mtls(self, monkeypatch): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + credentials = self.make_credentials() + url = credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_workforce_standard( + self, monkeypatch + ): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) - def test_build_regional_access_boundary_lookup_url_workforce(self): credentials = self.make_workforce_pool_credentials() + url = credentials._build_regional_access_boundary_lookup_url() expected_url = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_workforce_mtls( + self, monkeypatch + ): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + credentials = self.make_workforce_pool_credentials() + url = credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + assert url == expected_url @pytest.mark.parametrize( "audience", diff --git a/packages/google-auth/tests/test_external_account_authorized_user.py b/packages/google-auth/tests/test_external_account_authorized_user.py index 648966d924bf..69a085e65df5 100644 --- a/packages/google-auth/tests/test_external_account_authorized_user.py +++ b/packages/google-auth/tests/test_external_account_authorized_user.py @@ -601,10 +601,25 @@ def test_from_file_full_options(self, tmpdir): assert creds._revoke_url == REVOKE_URL assert creds._quota_project_id == QUOTA_PROJECT_ID - def test_build_regional_access_boundary_lookup_url(self): + def test_build_regional_access_boundary_lookup_url_standard(self, monkeypatch): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + credentials = self.make_credentials() + url = credentials._build_regional_access_boundary_lookup_url() expected_url = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_mtls(self, monkeypatch): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + credentials = self.make_credentials() + url = credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + assert url == expected_url @pytest.mark.parametrize( "audience", diff --git a/packages/google-auth/tests/test_impersonated_credentials.py b/packages/google-auth/tests/test_impersonated_credentials.py index 500209f663d7..c286e3010f38 100644 --- a/packages/google-auth/tests/test_impersonated_credentials.py +++ b/packages/google-auth/tests/test_impersonated_credentials.py @@ -717,13 +717,31 @@ def test_build_regional_access_boundary_lookup_url_no_email(self): assert credentials._build_regional_access_boundary_lookup_url() is None - def test_build_regional_access_boundary_lookup_url_success(self): + def test_build_regional_access_boundary_lookup_url_success_standard( + self, monkeypatch + ): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + credentials = self.make_credentials() - # Ensure service_account_email is properly set by default mock + url = credentials._build_regional_access_boundary_lookup_url() expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( credentials.service_account_email ) - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_success_mtls(self, monkeypatch): + from google.auth.transport import _mtls_helper + + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + credentials = self.make_credentials() + url = credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email + ) + assert url == expected_url def test_with_scopes_provide_default_scopes(self): credentials = self.make_credentials() diff --git a/packages/google-auth/tests/test_jwt.py b/packages/google-auth/tests/test_jwt.py index 4c5988469494..27b951b8b7bc 100644 --- a/packages/google-auth/tests/test_jwt.py +++ b/packages/google-auth/tests/test_jwt.py @@ -553,6 +553,57 @@ def test_before_request_refreshes(self): self.credentials.before_request(None, "GET", "http://example.com?a=1#3", {}) assert self.credentials.valid + def test_build_regional_access_boundary_lookup_url_standard(self, monkeypatch): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return False to simulate standard TLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + + url = self.credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_mtls(self, monkeypatch): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return True to simulate mTLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + url = self.credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + assert url == expected_url + + def test_cloning_retains_rab_manager_data(self): + self.credentials._rab_manager._data = mock.sentinel.rab_data + + cloned_claims = self.credentials.with_claims(audience="new-audience") + cloned_quota = self.credentials.with_quota_project("new-quota") + + # Verify references to immutable boundary data are shared + assert cloned_claims._rab_manager._data == mock.sentinel.rab_data + assert cloned_quota._rab_manager._data == mock.sentinel.rab_data + + # Verify manager objects and lock properties are isolated to prevent race conditions + assert cloned_claims._rab_manager is not self.credentials._rab_manager + assert cloned_quota._rab_manager is not self.credentials._rab_manager + + def test_from_signing_credentials_copies_rab_state(self): + from google.oauth2 import service_account + + sa_creds = service_account.Credentials.from_service_account_info( + SERVICE_ACCOUNT_INFO + ) + sa_creds._rab_manager._data = mock.sentinel.rab_data + + jwt_creds = jwt.Credentials.from_signing_credentials(sa_creds, audience="aud") + + assert jwt_creds._rab_manager._data == mock.sentinel.rab_data + assert jwt_creds._rab_manager is not sa_creds._rab_manager + class TestOnDemandCredentials(object): SERVICE_ACCOUNT_EMAIL = "service-account@example.com" diff --git a/packages/google-auth/tests_async/oauth2/test__client_async.py b/packages/google-auth/tests_async/oauth2/test__client_async.py index 5ad9596cf85c..a3abd9067186 100644 --- a/packages/google-auth/tests_async/oauth2/test__client_async.py +++ b/packages/google-auth/tests_async/oauth2/test__client_async.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import datetime import http.client as http_client import json @@ -23,6 +24,7 @@ from google.auth import _helpers from google.auth import _jwt_async as jwt from google.auth import exceptions +from google.auth.aio import transport as aio_transport from google.oauth2 import _client as sync_client from google.oauth2 import _client_async as _client from tests.oauth2 import test__client as test_client @@ -40,6 +42,17 @@ def make_request(response_data, status=http_client.OK, text=False): return request +def make_aio_request(response_data, status_code=http_client.OK, text=False): + """Creates a mock request/response conforming to the google.auth.aio.transport interface (exposing .status_code and .read()).""" + response = mock.AsyncMock(spec=aio_transport.Response) + response.status_code = status_code + data = response_data if text else json.dumps(response_data).encode("utf-8") + response.read = mock.AsyncMock(return_value=data) + request = mock.AsyncMock(spec=aio_transport.Request) + request.return_value = response + return request + + @pytest.mark.asyncio async def test__token_endpoint_request(): request = make_request({"test": "response"}) @@ -473,7 +486,8 @@ async def test_refresh_grant_retry_with_retry( @pytest.mark.asyncio @pytest.mark.parametrize("can_retry", [True, False]) -async def test__token_endpoint_request_no_throw_with_retry(can_retry): +@mock.patch("time.sleep", return_value=None) +async def test__token_endpoint_request_no_throw_with_retry(mock_sleep, can_retry): mock_request = make_request( {"error": "help", "error_description": "I'm alive"}, http_client.INTERNAL_SERVER_ERROR, @@ -490,5 +504,168 @@ async def test__token_endpoint_request_no_throw_with_retry(can_retry): if can_retry: assert mock_request.call_count == 3 + assert mock_sleep.call_count == 2 else: assert mock_request.call_count == 1 + mock_sleep.assert_not_called() + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_success(): + request = make_aio_request( + {"encodedLocations": "0xA30", "locations": ["us-central1"]} + ) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result == {"encodedLocations": "0xA30", "locations": ["us-central1"]} + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_legacy_transport(): + # Create a legacy mock response that has .status and .content() + response = mock.AsyncMock(spec=["transport.Response"]) + response.status = http_client.OK + + data = json.dumps( + {"encodedLocations": "0xA30", "locations": ["us-central1"]} + ).encode("utf-8") + response.content = mock.AsyncMock(return_value=data) + + request = mock.AsyncMock(spec=["transport.Request"]) + request.return_value = response + + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result == {"encodedLocations": "0xA30", "locations": ["us-central1"]} + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_malformed(): + request = make_aio_request({"locations": ["us-central1"]}) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result is None + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_invalid_json(): + request = make_aio_request("Service Unavailable", text=True) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result is None + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_non_dict_response(): + request = make_aio_request(123) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result is None + + +@pytest.mark.asyncio +@mock.patch("asyncio.wait_for", side_effect=asyncio.TimeoutError) +async def test__lookup_regional_access_boundary_request_no_throw_timeout(mock_wait_for): + request = mock.AsyncMock(spec=["transport.Request"]) + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com", fail_fast=True + ) + + assert success is False + assert data == {} + assert retryable is True + + +@pytest.mark.asyncio +@mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) +async def test__lookup_regional_access_boundary_request_no_throw_bad_gateway_retry( + mock_sleep, +): + bad_gateway_response = mock.AsyncMock(spec=["transport.Response"]) + bad_gateway_response.status = http_client.BAD_GATEWAY + bad_gateway_response.content = mock.AsyncMock(return_value=b"{}") + + ok_response = mock.AsyncMock(spec=["transport.Response"]) + ok_response.status = http_client.OK + ok_response.content = mock.AsyncMock(return_value=b'{"encodedLocations": "0xA30"}') + + request = mock.AsyncMock(spec=["transport.Request"]) + request.side_effect = [bad_gateway_response, ok_response] + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com" + ) + + assert success is True + assert data == {"encodedLocations": "0xA30"} + assert request.call_count == 2 + + +@pytest.mark.asyncio +@mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) +async def test__lookup_regional_access_boundary_request_no_throw_transport_error( + mock_sleep, +): + request = mock.AsyncMock(spec=["transport.Request"]) + request.side_effect = exceptions.TransportError("Socket connection failed") + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com" + ) + + assert success is False + assert data == {} + assert retryable is True + assert request.call_count == 6 + assert mock_sleep.call_count == 5 + + +@pytest.mark.asyncio +@mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) +async def test__lookup_regional_access_boundary_request_no_throw_non_json_bad_gateway_retry( + mock_sleep, +): + bad_gateway_response = mock.AsyncMock(spec=["status", "content"]) + bad_gateway_response.status = http_client.BAD_GATEWAY + bad_gateway_response.content = mock.AsyncMock( + return_value=b"Bad Gateway" + ) + + ok_response = mock.AsyncMock(spec=["status", "content"]) + ok_response.status = http_client.OK + ok_response.content = mock.AsyncMock(return_value=b'{"encodedLocations": "0xA30"}') + + request = mock.AsyncMock(spec=["__call__"]) + request.side_effect = [bad_gateway_response, ok_response] + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com" + ) + + assert success is True + assert data == {"encodedLocations": "0xA30"} + assert retryable is None + assert request.call_count == 2 + mock_sleep.assert_called_once() diff --git a/packages/google-auth/tests_async/oauth2/test_service_account_async.py b/packages/google-auth/tests_async/oauth2/test_service_account_async.py index 5a9a89fcaac2..e0c2e0d60a60 100644 --- a/packages/google-auth/tests_async/oauth2/test_service_account_async.py +++ b/packages/google-auth/tests_async/oauth2/test_service_account_async.py @@ -229,6 +229,143 @@ async def test_before_request_refreshes(self, jwt_grant): # Credentials should now be valid. assert credentials.valid + @pytest.mark.asyncio + async def test_before_request_triggers_rab_refresh(self): + credentials = self.make_credentials() + credentials.token = "tok" + + request = mock.AsyncMock(spec=["transport.Request"]) + headers1 = {} + + with mock.patch.object( + credentials, + "_lookup_regional_access_boundary", + new_callable=mock.AsyncMock, + ) as mock_lookup, mock.patch.object( + credentials, + "_is_regional_access_boundary_lookup_required", + return_value=True, + ): + mock_lookup.return_value = { + "locations": ["us-central1", "europe-west1"], + "encodedLocations": "0xA30", + } + + # The first request triggers a background refresh and returns immediately. + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers1 + ) + assert "x-allowed-locations" not in headers1 + + # Wait for the background task to finish and update the cache. + await credentials._rab_manager.refresh_manager._worker_task + mock_lookup.assert_called_once_with(request) + + # The second request retrieves the locations from the cache. + headers2 = {} + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers2 + ) + assert headers2["x-allowed-locations"] == "0xA30" + + @pytest.mark.asyncio + async def test_before_request_rab_refresh_failure_ignored(self): + credentials = self.make_credentials() + credentials.token = "tok" + + request = mock.AsyncMock(spec=["transport.Request"]) + headers = {} + + with mock.patch.object( + credentials, + "_lookup_regional_access_boundary", + new_callable=mock.AsyncMock, + side_effect=Exception("Transport failed"), + ) as mock_lookup, mock.patch.object( + credentials, + "_is_regional_access_boundary_lookup_required", + return_value=True, + ): + # Any transport/lookup failure must be caught gracefully during refresh. + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers + ) + + # Wait for the background task to finish. + await credentials._rab_manager.refresh_manager._worker_task + + mock_lookup.assert_called_once_with(request) + assert "x-allowed-locations" not in headers + + @pytest.mark.asyncio + async def test_before_request_triggers_blocking_rab_refresh(self): + credentials = self.make_credentials() + credentials.token = "tok" + credentials._set_blocking_regional_access_boundary_lookup() + + request = mock.AsyncMock(spec=["transport.Request"]) + headers = {} + + with mock.patch.object( + credentials, + "_lookup_regional_access_boundary", + new_callable=mock.AsyncMock, + ) as mock_lookup, mock.patch.object( + credentials, + "_is_regional_access_boundary_lookup_required", + return_value=True, + ): + mock_lookup.return_value = { + "locations": ["us-central1", "europe-west1"], + "encodedLocations": "0xA30", + } + + # When blocking lookup is enabled, the first request awaits the lookup sequentially. + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers + ) + + mock_lookup.assert_called_once_with(request, fail_fast=True) + assert headers["x-allowed-locations"] == "0xA30" + + @pytest.mark.asyncio + async def test_maybe_start_regional_access_boundary_refresh_async_invalid_url(self): + credentials = self.make_credentials() + request = mock.create_autospec(transport.Request) + + # Verifies that passing invalid/non-string URLs asynchronously fails safe without crashing. + await credentials._maybe_start_regional_access_boundary_refresh_async( + request, url=None + ) + await credentials._maybe_start_regional_access_boundary_refresh_async( + request, url=123 + ) + await credentials._maybe_start_regional_access_boundary_refresh_async( + request, url=object() + ) + + def test_unpickle_old_credentials_without_rab(self): + from google.auth import _regional_access_boundary_utils + + credentials = self.make_credentials() + old_state = credentials.__dict__.copy() + if "_rab_manager" in old_state: + del old_state["_rab_manager"] + if "_use_non_blocking_refresh" in old_state: + del old_state["_use_non_blocking_refresh"] + if "_refresh_worker" in old_state: + del old_state["_refresh_worker"] + + new_instance = type(credentials).__new__(type(credentials)) + new_instance.__setstate__(old_state) + + # Verify the manager was correctly restored with the async refresh manager! + assert hasattr(new_instance, "_rab_manager") + assert isinstance( + new_instance._rab_manager.refresh_manager, + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager, + ) + class TestIDTokenCredentials(object): SERVICE_ACCOUNT_EMAIL = "service-account@example.com" diff --git a/packages/google-auth/tests_async/test__regional_access_boundary_utils.py b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py new file mode 100644 index 000000000000..268ee37261c8 --- /dev/null +++ b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py @@ -0,0 +1,84 @@ +# Copyright 2026 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest import mock + +import pytest # type: ignore + +from google.auth import _regional_access_boundary_utils + + +@pytest.mark.asyncio +async def test_async_refresh_manager_start_refresh(): + credentials = mock.AsyncMock() + credentials._lookup_regional_access_boundary.return_value = { + "encodedLocations": "0xA30" + } + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + manager.start_refresh(credentials, request, rab_manager) + + # Wait for the background task to finish + await manager._worker_task + + credentials._lookup_regional_access_boundary.assert_called_once_with(request) + rab_manager.process_regional_access_boundary_info.assert_called_once_with( + {"encodedLocations": "0xA30"} + ) + + +@pytest.mark.asyncio +async def test_async_refresh_manager_duplicate_refresh_prevented(): + credentials = mock.AsyncMock() + + # Use events to control the concurrency timing + lookup_started = asyncio.Event() + lookup_finish = asyncio.Event() + + async def controlled_lookup(*args, **kwargs): + lookup_started.set() # Signal that the background lookup has started. + await lookup_finish.wait() # Block until the test allows the lookup to complete. + return {"encodedLocations": "0xA30"} + + credentials._lookup_regional_access_boundary.side_effect = controlled_lookup + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + # Start the initial refresh task in the background. + manager.start_refresh(credentials, request, rab_manager) + + # Wait until the background task has begun executing the lookup. + await lookup_started.wait() + + # Attempt a second refresh while the initial task is still in progress. + manager.start_refresh(credentials, request, rab_manager) + + # Unblock the initial task and wait for it to complete. + lookup_finish.set() + await manager._worker_task + + # Verify that the second refresh request was ignored and only one lookup occurred. + assert credentials._lookup_regional_access_boundary.call_count == 1 diff --git a/packages/google-auth/tests_async/test_jwt_async.py b/packages/google-auth/tests_async/test_jwt_async.py index 9d9eca4e2852..9e6054fa93ef 100644 --- a/packages/google-auth/tests_async/test_jwt_async.py +++ b/packages/google-auth/tests_async/test_jwt_async.py @@ -143,6 +143,47 @@ def test_with_quota_project(self): assert new_credentials._additional_claims == self.credentials._additional_claims assert new_credentials._quota_project_id == quota_project_id + def test_build_regional_access_boundary_lookup_url_standard(self, monkeypatch): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return False to simulate standard TLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: False) + + url = self.credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + assert url == expected_url + + def test_build_regional_access_boundary_lookup_url_mtls(self, monkeypatch): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return True to simulate mTLS + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + url = self.credentials._build_regional_access_boundary_lookup_url() + expected_url = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + assert url == expected_url + + def test_unpickle_old_credentials_without_rab(self): + from google.auth import _regional_access_boundary_utils + + credentials = self.credentials + old_state = credentials.__dict__.copy() + if "_rab_manager" in old_state: + del old_state["_rab_manager"] + + new_instance = type(credentials).__new__(type(credentials)) + new_instance.__setstate__(old_state) + + assert hasattr(new_instance, "_rab_manager") + assert isinstance( + new_instance._rab_manager.refresh_manager, + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager, + ) + def test_sign_bytes(self): to_sign = b"123" signature = self.credentials.sign_bytes(to_sign) @@ -326,10 +367,11 @@ def test_refresh(self): with pytest.raises(exceptions.RefreshError): self.credentials.refresh(None) - def test_before_request(self): + @pytest.mark.asyncio + async def test_before_request(self): headers = {} - self.credentials.before_request( + await self.credentials.before_request( None, "GET", "http://example.com?a=1#3", headers ) @@ -339,7 +381,9 @@ def test_before_request(self): assert payload["aud"] == "http://example.com" # Making another request should re-use the same token. - self.credentials.before_request(None, "GET", "http://example.com?b=2", headers) + await self.credentials.before_request( + None, "GET", "http://example.com?b=2", headers + ) _, new_token = headers["authorization"].split(" ")