diff --git a/packages/toolbox-adk/src/toolbox_adk/tool.py b/packages/toolbox-adk/src/toolbox_adk/tool.py index 1b0476f0..dde37ce4 100644 --- a/packages/toolbox-adk/src/toolbox_adk/tool.py +++ b/packages/toolbox-adk/src/toolbox_adk/tool.py @@ -36,6 +36,21 @@ from .client import USER_TOKEN_CONTEXT_VAR from .credentials import CredentialConfig, CredentialType +# --- Monkey Patch ADK OAuth2 Exchange to Retain ID Tokens --- +# Google's ID Token is required by MCP Toolbox but ADK's `update_credential_with_tokens` natively drops the `id_token`. +# TODO(id_token): Remove this monkey patch once the PR https://github.com/google/adk-python/pull/4402 is merged. +import google.adk.auth.oauth2_credential_util as oauth2_credential_util +import google.adk.auth.exchanger.oauth2_credential_exchanger as oauth2_credential_exchanger +_orig_update_cred = oauth2_credential_util.update_credential_with_tokens + +def _patched_update_credential_with_tokens(auth_credential, tokens): + _orig_update_cred(auth_credential, tokens) + if tokens and "id_token" in tokens and auth_credential and auth_credential.oauth2: + setattr(auth_credential.oauth2, "id_token", tokens["id_token"]) + +oauth2_credential_util.update_credential_with_tokens = _patched_update_credential_with_tokens +oauth2_credential_exchanger.update_credential_with_tokens = _patched_update_credential_with_tokens +# ------------------------------------------------------------- class ToolboxTool(BaseTool): """ @@ -135,56 +150,98 @@ async def run_async( reset_token = None if self._auth_config and self._auth_config.type == CredentialType.USER_IDENTITY: - if not self._auth_config.client_id or not self._auth_config.client_secret: - raise ValueError("USER_IDENTITY requires client_id and client_secret") - - # Construct ADK AuthConfig - scopes = self._auth_config.scopes or [ - "https://www.googleapis.com/auth/cloud-platform" - ] - scope_dict = {s: "" for s in scopes} - - auth_config_adk = AuthConfig( - auth_scheme=OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://accounts.google.com/o/oauth2/auth", - tokenUrl="https://oauth2.googleapis.com/token", - scopes=scope_dict, + requires_auth = ( + len(self._core_tool._required_authn_params) > 0 + or len(self._core_tool._required_authz_tokens) > 0 + ) + + if requires_auth: + if not self._auth_config.client_id or not self._auth_config.client_secret: + raise ValueError("USER_IDENTITY requires client_id and client_secret") + + # Construct ADK AuthConfig + scopes = self._auth_config.scopes or ["openid", "profile", "email"] + scope_dict = {s: "" for s in scopes} + + auth_config_adk = AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://accounts.google.com/o/oauth2/auth", + tokenUrl="https://oauth2.googleapis.com/token", + scopes=scope_dict, + ) ) - ) - ), - raw_auth_credential=AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=self._auth_config.client_id, - client_secret=self._auth_config.client_secret, ), - ), - ) + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=self._auth_config.client_id, + client_secret=self._auth_config.client_secret, + ), + ), + ) - # Check if we already have credentials from a previous exchange - try: - # get_auth_response returns AuthCredential if found - creds = tool_context.get_auth_response(auth_config_adk) - if creds and creds.oauth2 and creds.oauth2.access_token: - reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token) - else: - # Request credentials and pause execution + # Check if we already have credentials from a previous exchange + try: + # Try to load credential from credential service first (persists across sessions) + creds = None + try: + if tool_context._invocation_context.credential_service: + creds = await tool_context._invocation_context.credential_service.load_credential( + auth_config=auth_config_adk, + callback_context=tool_context + ) + except ValueError: + # Credential service might not be initialized + pass + + if not creds: + # Fallback to session state (get_auth_response returns AuthCredential if found) + creds = tool_context.get_auth_response(auth_config_adk) + + if creds and creds.oauth2 and creds.oauth2.access_token: + reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token) + + # Bind the token to the underlying core_tool so it constructs headers properly + needed_services = set() + for requested_service in (list(self._core_tool._required_authn_params.values()) + list(self._core_tool._required_authz_tokens)): + if isinstance(requested_service, list): + needed_services.update(requested_service) + else: + needed_services.add(requested_service) + + for s in needed_services: + # Only add if not already registered (prevents ValueError on duplicate params or subsequent runs) + if not hasattr(self._core_tool, '_auth_token_getters') or s not in self._core_tool._auth_token_getters: + # TODO(id_token): Uncomment this line and remove the `getattr` fallback below once PR https://github.com/google/adk-python/pull/4402 is merged. + # self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=creds.oauth2.id_token or creds.oauth2.access_token: t) + self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=getattr(creds.oauth2, "id_token", creds.oauth2.access_token): t) + # Once we use it from get_auth_response, save it to the auth service for future use + try: + if tool_context._invocation_context.credential_service: + auth_config_adk.exchanged_auth_credential = creds + await tool_context._invocation_context.credential_service.save_credential( + auth_config=auth_config_adk, + callback_context=tool_context + ) + except Exception as e: + logging.debug(f"Failed to save credential to service: {e}") + else: + tool_context.request_credential(auth_config_adk) + return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."} + except Exception as e: + if "credential" in str(e).lower() or isinstance(e, ValueError): + raise e + + logging.warning( + f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. " + "Falling back to request_credential.", + exc_info=True + ) + # Fallback to request logic tool_context.request_credential(auth_config_adk) - return None - except Exception as e: - if "credential" in str(e).lower() or isinstance(e, ValueError): - raise e - - logging.warning( - f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. " - "Falling back to request_credential.", - exc_info=True - ) - # Fallback to request logic - tool_context.request_credential(auth_config_adk) - return None + return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."} result: Optional[Any] = None error: Optional[Exception] = None diff --git a/packages/toolbox-adk/tests/integration/test_integration.py b/packages/toolbox-adk/tests/integration/test_integration.py index c7088e08..fa4fdc1a 100644 --- a/packages/toolbox-adk/tests/integration/test_integration.py +++ b/packages/toolbox-adk/tests/integration/test_integration.py @@ -188,6 +188,10 @@ async def test_3lo_flow_simulation(self): assert declaration.name == "get-n-rows" assert "num_rows" in declaration.parameters.properties + # Force the proxy tool to require auth to properly simulate the 3LO flow branches + tool._core_tool._ToolboxTool__required_authn_params = {} + tool._core_tool._ToolboxTool__required_authz_tokens = ["mock_service"] + # Create a mock context that behaves like ADK's ReadonlyContext mock_ctx_first = MagicMock() # Simulate "No Auth Response Found" @@ -201,32 +205,44 @@ async def test_3lo_flow_simulation(self): result_first = await tool.run_async({"num_rows": "1"}, mock_ctx_first) # The wrapper should catch the missing creds and request them. - assert result_first is None, "Tool should return None sig for auth requirement" + assert isinstance(result_first, dict) and "error" in result_first, "Tool should return error sig for auth requirement" mock_ctx_first.request_credential.assert_called_once() - + # Inspect the requested config auth_config = mock_ctx_first.request_credential.call_args[0][0] assert auth_config.raw_auth_credential.oauth2.client_id == "test-client-id" + # Verify the default fallback scopes were assigned correctly to avoid upstream crashes + assert auth_config.auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""} mock_ctx_second = MagicMock() # Simulate "Auth Response Found" mock_creds = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="fake-access-token"), + oauth2=OAuth2Auth(access_token="fake-access-token", id_token="fake-id-token"), ) mock_ctx_second.get_auth_response.return_value = mock_creds + # Setup the credential service mock to verify credential persistence across sessions + mock_cred_service = AsyncMock() + mock_cred_service.load_credential.return_value = None + mock_ctx_second._invocation_context = MagicMock() + mock_ctx_second._invocation_context.credential_service = mock_cred_service + print("Running tool second time (expecting success or server error)...") try: result_second = await tool.run_async({"num_rows": "1"}, mock_ctx_second) assert result_second is not None + # Verify that the tool saved the credentials to the storage service backends locally + mock_cred_service.save_credential.assert_called_once() except Exception as e: mock_ctx_second.request_credential.assert_not_called() err_msg = str(e).lower() assert any(x in err_msg for x in ["401", "403", "unauthorized", "forbidden"]), f"Caught UNEXPECTED exception: {type(e).__name__}: {e}" print(f"Caught expected server exception with fake token: {e}") + # Verify that the tool AT LEAST triggered save_credential before failing via core_tool inner HTTP req + mock_cred_service.save_credential.assert_called_once() finally: await toolset.close() diff --git a/packages/toolbox-adk/tests/unit/test_tool.py b/packages/toolbox-adk/tests/unit/test_tool.py index 6c7c7a14..3b3ce19a 100644 --- a/packages/toolbox-adk/tests/unit/test_tool.py +++ b/packages/toolbox-adk/tests/unit/test_tool.py @@ -155,6 +155,8 @@ async def test_3lo_missing_client_secret(self): core_tool = AsyncMock() core_tool.__name__ = "mock_tool" core_tool.__doc__ = "mock doc" + core_tool._required_authn_params = {"mock_param": "mock_service"} + core_tool._required_authz_tokens = [] auth_config = CredentialConfig(type=CredentialType.USER_IDENTITY) # Missing client_id/secret @@ -174,6 +176,8 @@ async def test_3lo_request_credential_when_missing(self): core_tool.__doc__ = "mock" core_tool.__name__ = "mock_tool" core_tool.__doc__ = "mock doc" + core_tool._required_authn_params = {"mock_param": "mock_service"} + core_tool._required_authz_tokens = [] auth_config = CredentialConfig( type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec" @@ -187,8 +191,8 @@ async def test_3lo_request_credential_when_missing(self): result = await tool.run_async({}, ctx) - # Verify result is None (signal pause) - assert result is None + # Verify result is error/stop + assert isinstance(result, dict) and "error" in result # Verify request_credential was called ctx.request_credential.assert_called_once() # Verify core tool was NOT called @@ -198,10 +202,12 @@ async def test_3lo_request_credential_when_missing(self): async def test_3lo_uses_existing_credential(self): # Test that if creds exist, they are used and injected core_tool = AsyncMock(return_value="success") - core_tool.__name__ = "mock" - core_tool.__doc__ = "mock" core_tool.__name__ = "mock_tool" core_tool.__doc__ = "mock doc" + # Setup overlapping needed services to test deduplication + core_tool._required_authn_params = {"mock_param": "mock_service", "another_param": "mock_service"} + core_tool._required_authz_tokens = ["mock_service"] + core_tool.add_auth_token_getter = MagicMock(return_value=core_tool) auth_config = CredentialConfig( type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec" @@ -210,11 +216,19 @@ async def test_3lo_uses_existing_credential(self): tool = ToolboxTool(core_tool, auth_config=auth_config) ctx = MagicMock() - # Mock get_auth_response returning valid creds + # Mock get_auth_response returning valid creds with both access & id tokens mock_creds = MagicMock() - mock_creds.oauth2.access_token = "valid_token" + mock_creds.oauth2.access_token = "valid_access_token" + mock_creds.oauth2.id_token = "valid_id_token" ctx.get_auth_response.return_value = mock_creds + # Set up invocation context and credential service mock to verify saving and avoid await errors + mock_cred_service = MagicMock() + mock_cred_service.load_credential = AsyncMock(return_value=None) + mock_cred_service.save_credential = AsyncMock(return_value=None) + ctx._invocation_context = MagicMock() + ctx._invocation_context.credential_service = mock_cred_service + result = await tool.run_async({}, ctx) # Verify result is success @@ -223,15 +237,31 @@ async def test_3lo_uses_existing_credential(self): ctx.request_credential.assert_not_called() # Verify core tool WAS called core_tool.assert_called_once() + + # Verify deduplication: add_auth_token_getter should only be called ONCE for "mock_service" + core_tool.add_auth_token_getter.assert_called_once() + call_args_getter = core_tool.add_auth_token_getter.call_args[0] + assert call_args_getter[0] == "mock_service" + # Evaluate the getter lambda to ensure it prefers id_token + token_getter_lambda = call_args_getter[1] + assert token_getter_lambda() == "valid_id_token" + + # Verify save_credential was called with the exchanged credential + mock_cred_service.save_credential.assert_called_once() + call_args = mock_cred_service.save_credential.call_args[1] + assert call_args["auth_config"].exchanged_auth_credential == mock_creds + + # Verify safe scope fallback to ["openid", "profile", "email"] when scopes is None + assert call_args["auth_config"].auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""} @pytest.mark.asyncio async def test_3lo_exception_reraise(self): # Test that specific credential errors are re-raised core_tool = AsyncMock() - core_tool.__name__ = "mock" - core_tool.__doc__ = "mock" core_tool.__name__ = "mock_tool" core_tool.__doc__ = "mock doc" + core_tool._required_authn_params = {"mock_param": "mock_service"} + core_tool._required_authz_tokens = [] auth_config = CredentialConfig( type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec" @@ -239,6 +269,11 @@ async def test_3lo_exception_reraise(self): tool = ToolboxTool(core_tool, auth_config=auth_config) ctx = MagicMock() + mock_cred_service = MagicMock() + mock_cred_service.load_credential = AsyncMock(return_value=None) + ctx._invocation_context = MagicMock() + ctx._invocation_context.credential_service = mock_cred_service + # Mock get_auth_response raising ValueError ctx.get_auth_response.side_effect = ValueError("Invalid Credential") @@ -253,6 +288,8 @@ async def test_3lo_exception_fallback(self): core_tool.__doc__ = "mock" core_tool.__name__ = "mock_tool" core_tool.__doc__ = "mock doc" + core_tool._required_authn_params = {"mock_param": "mock_service"} + core_tool._required_authz_tokens = [] auth_config = CredentialConfig( type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec" @@ -265,8 +302,8 @@ async def test_3lo_exception_fallback(self): result = await tool.run_async({}, ctx) - # Should catch RuntimeError, call request_credential, and return None - assert result is None + # Should catch RuntimeError, call request_credential, and return error map + assert isinstance(result, dict) and "error" in result ctx.request_credential.assert_called_once() def test_param_type_to_schema_type(self):