From 089078edf2c0c0cbfdc6324e336c1cfa72db0949 Mon Sep 17 00:00:00 2001 From: David Sullivan Date: Sun, 7 Sep 2025 13:46:16 +0100 Subject: [PATCH 1/7] feat: add token_endpoint_auth_method support to OAuth2 credentials Add token_endpoint_auth_method field to OAuth2Auth class to allow configuring OAuth2 token endpoint authentication methods. Supports standard methods: client_secret_basic (default), client_secret_post, client_secret_jwt, and private_key_jwt. Changes: - Add token_endpoint_auth_method field to OAuth2Auth with default 'client_secret_basic' - Update create_oauth2_session to pass auth method to OAuth2Session - Add comprehensive test coverage for all authentication methods #non-breaking --- src/google/adk/auth/auth_credential.py | 1 + src/google/adk/auth/oauth2_credential_util.py | 1 + .../auth/test_oauth2_credential_util.py | 87 +++++++++++++++++++ 3 files changed, 89 insertions(+) diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index bc91d48f79..9694e8b1ad 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -80,6 +80,7 @@ class OAuth2Auth(BaseModelWithConfig): expires_at: Optional[int] = None expires_in: Optional[int] = None audience: Optional[str] = None + token_endpoint_auth_method: Optional[str] = "client_secret_basic" class ServiceAccountCredential(BaseModelWithConfig): diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py index cc315bd29e..843f1152b6 100644 --- a/src/google/adk/auth/oauth2_credential_util.py +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -82,6 +82,7 @@ def create_oauth2_session( scope=" ".join(scopes), redirect_uri=auth_credential.oauth2.redirect_uri, state=auth_credential.oauth2.state, + token_endpoint_auth_method=auth_credential.oauth2.token_endpoint_auth_method, ), token_endpoint, ) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index f1fd607ff5..b78083ede7 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -122,6 +122,93 @@ def test_create_oauth2_session_missing_credentials(self): assert client is None assert token_endpoint is None + def test_create_oauth2_session_with_token_endpoint_auth_method(self): + """Test create_oauth2_session with token_endpoint_auth_method specified.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + token_endpoint_auth_method="client_secret_post", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + assert client.token_endpoint_auth_method == "client_secret_post" + + def test_create_oauth2_session_with_default_token_endpoint_auth_method(self): + """Test create_oauth2_session with default token_endpoint_auth_method (None).""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + assert client.token_endpoint_auth_method == "client_secret_basic" + + def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method( + self, + ): + """Test create_oauth2_session with OAuth2 scheme and token_endpoint_auth_method.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + token_endpoint_auth_method="client_secret_jwt", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.token_endpoint_auth_method == "client_secret_jwt" + def test_update_credential_with_tokens(self): """Test update_credential_with_tokens function.""" credential = AuthCredential( From d7ae7732b1a0831e2f39bec11b57a5d4ad8d3059 Mon Sep 17 00:00:00 2001 From: David Sullivan Date: Sun, 7 Sep 2025 14:26:20 +0100 Subject: [PATCH 2/7] feat: use Literal type for token_endpoint_auth_method Improve type safety by using Literal type instead of Optional[str] for token_endpoint_auth_method field. This provides compile-time validation of allowed authentication methods and better IDE support. Supported methods: - client_secret_basic (default) - client_secret_post - client_secret_jwt - private_key_jwt --- src/google/adk/auth/auth_credential.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 9694e8b1ad..f707d6a0bc 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -18,6 +18,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Literal from typing import Optional from pydantic import alias_generators @@ -80,7 +81,14 @@ class OAuth2Auth(BaseModelWithConfig): expires_at: Optional[int] = None expires_in: Optional[int] = None audience: Optional[str] = None - token_endpoint_auth_method: Optional[str] = "client_secret_basic" + token_endpoint_auth_method: Optional[ + Literal[ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ] + ] = "client_secret_basic" class ServiceAccountCredential(BaseModelWithConfig): From ce7129d09aee15c8cc06c770238e8c574323cd14 Mon Sep 17 00:00:00 2001 From: David Sullivan Date: Sun, 7 Sep 2025 14:27:32 +0100 Subject: [PATCH 3/7] Update tests/unittests/auth/test_oauth2_credential_util.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/unittests/auth/test_oauth2_credential_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index b78083ede7..2ec41467fc 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -153,7 +153,7 @@ def test_create_oauth2_session_with_token_endpoint_auth_method(self): assert client.token_endpoint_auth_method == "client_secret_post" def test_create_oauth2_session_with_default_token_endpoint_auth_method(self): - """Test create_oauth2_session with default token_endpoint_auth_method (None).""" + """Test create_oauth2_session with default token_endpoint_auth_method.""" scheme = OpenIdConnectWithConfig( type_="openIdConnect", openId_connect_url=( From c0e99c66c10f430c3dea09325c546754939a1afa Mon Sep 17 00:00:00 2001 From: David Sullivan Date: Sun, 7 Sep 2025 16:45:16 +0100 Subject: [PATCH 4/7] refactor: eliminate code duplication in OAuth2 credential tests Create pytest fixture and helper function to reduce redundancy between test_create_oauth2_session_with_token_endpoint_auth_method and test_create_oauth2_session_with_default_token_endpoint_auth_method tests. --- .../auth/test_oauth2_credential_util.py | 79 +++++++++---------- 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index 2ec41467fc..92d125fc52 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -15,6 +15,7 @@ import time from unittest.mock import Mock +import pytest from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlowAuthorizationCode @@ -27,6 +28,37 @@ from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +@pytest.fixture +def openid_connect_scheme(): + """Fixture providing a standard OpenIdConnectWithConfig scheme.""" + return OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + + +def create_oauth2_auth_credential(token_endpoint_auth_method=None): + """Helper function to create OAuth2Auth credential with optional token_endpoint_auth_method.""" + oauth2_auth = OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ) + if token_endpoint_auth_method is not None: + oauth2_auth.token_endpoint_auth_method = token_endpoint_auth_method + + return AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=oauth2_auth, + ) + + class TestOAuth2CredentialUtil: """Test suite for OAuth2 credential utility functions.""" @@ -122,29 +154,11 @@ def test_create_oauth2_session_missing_credentials(self): assert client is None assert token_endpoint is None - def test_create_oauth2_session_with_token_endpoint_auth_method(self): + def test_create_oauth2_session_with_token_endpoint_auth_method(self, openid_connect_scheme): """Test create_oauth2_session with token_endpoint_auth_method specified.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - state="test_state", - token_endpoint_auth_method="client_secret_post", - ), - ) + credential = create_oauth2_auth_credential(token_endpoint_auth_method="client_secret_post") - client, token_endpoint = create_oauth2_session(scheme, credential) + client, token_endpoint = create_oauth2_session(openid_connect_scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" @@ -152,28 +166,11 @@ def test_create_oauth2_session_with_token_endpoint_auth_method(self): assert client.client_secret == "test_client_secret" assert client.token_endpoint_auth_method == "client_secret_post" - def test_create_oauth2_session_with_default_token_endpoint_auth_method(self): + def test_create_oauth2_session_with_default_token_endpoint_auth_method(self, openid_connect_scheme): """Test create_oauth2_session with default token_endpoint_auth_method.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - state="test_state", - ), - ) + credential = create_oauth2_auth_credential() - client, token_endpoint = create_oauth2_session(scheme, credential) + client, token_endpoint = create_oauth2_session(openid_connect_scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" From 4c2ff81fcfcdd22e345dbbcb0eb255223962ed85 Mon Sep 17 00:00:00 2001 From: David Sullivan Date: Mon, 15 Sep 2025 21:52:02 +0100 Subject: [PATCH 5/7] run pyink and isort --- .../auth/test_oauth2_credential_util.py | 68 +++++++++++-------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index 92d125fc52..952ede61da 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -15,7 +15,6 @@ import time from unittest.mock import Mock -import pytest from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlowAuthorizationCode @@ -26,37 +25,36 @@ from google.adk.auth.auth_schemes import OpenIdConnectWithConfig from google.adk.auth.oauth2_credential_util import create_oauth2_session from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +import pytest @pytest.fixture def openid_connect_scheme(): - """Fixture providing a standard OpenIdConnectWithConfig scheme.""" - return OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) + """Fixture providing a standard OpenIdConnectWithConfig scheme.""" + return OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url="https://example.com/.well-known/openid_configuration", + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) def create_oauth2_auth_credential(token_endpoint_auth_method=None): - """Helper function to create OAuth2Auth credential with optional token_endpoint_auth_method.""" - oauth2_auth = OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - state="test_state", - ) - if token_endpoint_auth_method is not None: - oauth2_auth.token_endpoint_auth_method = token_endpoint_auth_method - - return AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=oauth2_auth, - ) + """Helper function to create OAuth2Auth credential with optional token_endpoint_auth_method.""" + oauth2_auth = OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ) + if token_endpoint_auth_method is not None: + oauth2_auth.token_endpoint_auth_method = token_endpoint_auth_method + + return AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=oauth2_auth, + ) class TestOAuth2CredentialUtil: @@ -154,11 +152,17 @@ def test_create_oauth2_session_missing_credentials(self): assert client is None assert token_endpoint is None - def test_create_oauth2_session_with_token_endpoint_auth_method(self, openid_connect_scheme): + def test_create_oauth2_session_with_token_endpoint_auth_method( + self, openid_connect_scheme + ): """Test create_oauth2_session with token_endpoint_auth_method specified.""" - credential = create_oauth2_auth_credential(token_endpoint_auth_method="client_secret_post") + credential = create_oauth2_auth_credential( + token_endpoint_auth_method="client_secret_post" + ) - client, token_endpoint = create_oauth2_session(openid_connect_scheme, credential) + client, token_endpoint = create_oauth2_session( + openid_connect_scheme, credential + ) assert client is not None assert token_endpoint == "https://example.com/token" @@ -166,11 +170,15 @@ def test_create_oauth2_session_with_token_endpoint_auth_method(self, openid_conn assert client.client_secret == "test_client_secret" assert client.token_endpoint_auth_method == "client_secret_post" - def test_create_oauth2_session_with_default_token_endpoint_auth_method(self, openid_connect_scheme): + def test_create_oauth2_session_with_default_token_endpoint_auth_method( + self, openid_connect_scheme + ): """Test create_oauth2_session with default token_endpoint_auth_method.""" credential = create_oauth2_auth_credential() - client, token_endpoint = create_oauth2_session(openid_connect_scheme, credential) + client, token_endpoint = create_oauth2_session( + openid_connect_scheme, credential + ) assert client is not None assert token_endpoint == "https://example.com/token" From 3b99b76c3df0b61d730ea0464b1ebcec43531812 Mon Sep 17 00:00:00 2001 From: Mark Scannell Date: Tue, 9 Dec 2025 17:52:33 +0000 Subject: [PATCH 6/7] - Added typing to some of the tests - Consolidated two tests into one --- .../auth/test_oauth2_credential_util.py | 57 +++++++++---------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index 952ede61da..0882cad80e 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -13,6 +13,7 @@ # limitations under the License. import time +from typing import Optional from unittest.mock import Mock from authlib.oauth2.rfc6749 import OAuth2Token @@ -29,7 +30,7 @@ @pytest.fixture -def openid_connect_scheme(): +def openid_connect_scheme() -> OpenIdConnectWithConfig: """Fixture providing a standard OpenIdConnectWithConfig scheme.""" return OpenIdConnectWithConfig( type_="openIdConnect", @@ -40,7 +41,10 @@ def openid_connect_scheme(): ) -def create_oauth2_auth_credential(token_endpoint_auth_method=None): +def create_oauth2_auth_credential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + token_endpoint_auth_method: Optional[str] = None, +): """Helper function to create OAuth2Auth credential with optional token_endpoint_auth_method.""" oauth2_auth = OAuth2Auth( client_id="test_client_id", @@ -52,7 +56,7 @@ def create_oauth2_auth_credential(token_endpoint_auth_method=None): oauth2_auth.token_endpoint_auth_method = token_endpoint_auth_method return AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + auth_type=auth_type, oauth2=oauth2_auth, ) @@ -71,14 +75,9 @@ def test_create_oauth2_session_openid_connect(self): token_endpoint="https://example.com/token", scopes=["openid", "profile"], ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - state="test_state", - ), + credential = create_oauth2_auth_credential( + auth_type=AuthCredentialTypes.OAUTH2, + token_endpoint_auth_method="client_secret_jwt", ) client, token_endpoint = create_oauth2_session(scheme, credential) @@ -152,12 +151,24 @@ def test_create_oauth2_session_missing_credentials(self): assert client is None assert token_endpoint is None + # def test_create_oauth2_session_with_token_endpoint_auth_method( + # self, openid_connect_scheme + @pytest.mark.parametrize( + "token_endpoint_auth_method, expected_auth_method", + [ + ("client_secret_post", "client_secret_post"), + (None, "client_secret_basic"), + ], + ) def test_create_oauth2_session_with_token_endpoint_auth_method( - self, openid_connect_scheme + self, + openid_connect_scheme, + token_endpoint_auth_method, + expected_auth_method, ): - """Test create_oauth2_session with token_endpoint_auth_method specified.""" + """Test create_oauth2_session with various token_endpoint_auth_method settings.""" credential = create_oauth2_auth_credential( - token_endpoint_auth_method="client_secret_post" + token_endpoint_auth_method=token_endpoint_auth_method ) client, token_endpoint = create_oauth2_session( @@ -168,23 +179,7 @@ def test_create_oauth2_session_with_token_endpoint_auth_method( assert token_endpoint == "https://example.com/token" assert client.client_id == "test_client_id" assert client.client_secret == "test_client_secret" - assert client.token_endpoint_auth_method == "client_secret_post" - - def test_create_oauth2_session_with_default_token_endpoint_auth_method( - self, openid_connect_scheme - ): - """Test create_oauth2_session with default token_endpoint_auth_method.""" - credential = create_oauth2_auth_credential() - - client, token_endpoint = create_oauth2_session( - openid_connect_scheme, credential - ) - - assert client is not None - assert token_endpoint == "https://example.com/token" - assert client.client_id == "test_client_id" - assert client.client_secret == "test_client_secret" - assert client.token_endpoint_auth_method == "client_secret_basic" + assert client.token_endpoint_auth_method == expected_auth_method def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method( self, From 00142d3286ca0c7997998e4ac960d182f3cf848d Mon Sep 17 00:00:00 2001 From: Mark Scannell Date: Thu, 11 Dec 2025 11:39:29 +0000 Subject: [PATCH 7/7] Removed extraneous comments --- tests/unittests/auth/test_oauth2_credential_util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index 0882cad80e..cab4c49374 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -151,8 +151,6 @@ def test_create_oauth2_session_missing_credentials(self): assert client is None assert token_endpoint is None - # def test_create_oauth2_session_with_token_endpoint_auth_method( - # self, openid_connect_scheme @pytest.mark.parametrize( "token_endpoint_auth_method, expected_auth_method", [