diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 34d04dde93..bc91d48f79 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -79,6 +79,7 @@ class OAuth2Auth(BaseModelWithConfig): refresh_token: Optional[str] = None expires_at: Optional[int] = None expires_in: Optional[int] = None + audience: Optional[str] = None class ServiceAccountCredential(BaseModelWithConfig): diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 2e2a9a074f..7a51a71e29 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -188,9 +188,16 @@ def generate_auth_uri( scope=" ".join(scopes), redirect_uri=auth_credential.oauth2.redirect_uri, ) + params = { + "access_type": "offline", + "prompt": "consent", + } + if auth_credential.oauth2.audience: + params["audience"] = auth_credential.oauth2.audience uri, state = client.create_authorization_url( - url=authorization_endpoint, access_type="offline", prompt="consent" + url=authorization_endpoint, **params ) + exchanged_auth_credential = auth_credential.model_copy(deep=True) exchanged_auth_credential.oauth2.auth_uri = uri exchanged_auth_credential.oauth2.state = state diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 2a65f7795f..0a2a2f7802 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -61,7 +61,10 @@ def __init__( self.state = state def create_authorization_url(self, url, **kwargs): - return f"{url}?client_id={self.client_id}&scope={self.scope}", "mock_state" + params = f"client_id={self.client_id}&scope={self.scope}" + if kwargs.get("audience"): + params += f"&audience={kwargs.get('audience')}" + return f"{url}?{params}", "mock_state" def fetch_token( self, @@ -225,8 +228,27 @@ def test_generate_auth_uri_oauth2(self, auth_config): "https://example.com/oauth2/authorize" ) assert "client_id=mock_client_id" in result.oauth2.auth_uri + assert "audience" not in result.oauth2.auth_uri assert result.oauth2.state == "mock_state" + @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) + def test_generate_auth_uri_with_audience_and_prompt( + self, openid_auth_scheme, oauth2_credentials + ): + """Test generating an auth URI with audience and prompt.""" + oauth2_credentials.oauth2.audience = "test_audience" + exchanged = oauth2_credentials.model_copy(deep=True) + + config = AuthConfig( + auth_scheme=openid_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged, + ) + handler = AuthHandler(config) + result = handler.generate_auth_uri() + + assert "audience=test_audience" in result.oauth2.auth_uri + @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) def test_generate_auth_uri_openid( self, openid_auth_scheme, oauth2_credentials