Skip to content
149 changes: 103 additions & 46 deletions packages/toolbox-adk/src/toolbox_adk/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@
from .client import USER_TOKEN_CONTEXT_VAR
from .credentials import CredentialConfig, CredentialType

# --- Monkey Patch ADK OAuth2 Exchange to Retain ID Tokens ---
# Google's ID Token is required by MCP Toolbox but ADK's `update_credential_with_tokens` natively drops the `id_token`.
# TODO(id_token): Remove this monkey patch once the PR https://github.com/google/adk-python/pull/4402 is merged.
import google.adk.auth.oauth2_credential_util as oauth2_credential_util
import google.adk.auth.exchanger.oauth2_credential_exchanger as oauth2_credential_exchanger
_orig_update_cred = oauth2_credential_util.update_credential_with_tokens

def _patched_update_credential_with_tokens(auth_credential, tokens):
_orig_update_cred(auth_credential, tokens)
if tokens and "id_token" in tokens and auth_credential and auth_credential.oauth2:
setattr(auth_credential.oauth2, "id_token", tokens["id_token"])

oauth2_credential_util.update_credential_with_tokens = _patched_update_credential_with_tokens
oauth2_credential_exchanger.update_credential_with_tokens = _patched_update_credential_with_tokens
# -------------------------------------------------------------

class ToolboxTool(BaseTool):
"""
Expand Down Expand Up @@ -135,56 +150,98 @@ async def run_async(
reset_token = None

if self._auth_config and self._auth_config.type == CredentialType.USER_IDENTITY:
if not self._auth_config.client_id or not self._auth_config.client_secret:
raise ValueError("USER_IDENTITY requires client_id and client_secret")

# Construct ADK AuthConfig
scopes = self._auth_config.scopes or [
"https://www.googleapis.com/auth/cloud-platform"
]
scope_dict = {s: "" for s in scopes}

auth_config_adk = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
tokenUrl="https://oauth2.googleapis.com/token",
scopes=scope_dict,
requires_auth = (
len(self._core_tool._required_authn_params) > 0
or len(self._core_tool._required_authz_tokens) > 0
)

if requires_auth:
if not self._auth_config.client_id or not self._auth_config.client_secret:
raise ValueError("USER_IDENTITY requires client_id and client_secret")

# Construct ADK AuthConfig
scopes = self._auth_config.scopes or ["openid", "profile", "email"]
scope_dict = {s: "" for s in scopes}

auth_config_adk = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
tokenUrl="https://oauth2.googleapis.com/token",
scopes=scope_dict,
)
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id=self._auth_config.client_id,
client_secret=self._auth_config.client_secret,
),
),
)
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id=self._auth_config.client_id,
client_secret=self._auth_config.client_secret,
),
),
)

# Check if we already have credentials from a previous exchange
try:
# get_auth_response returns AuthCredential if found
creds = tool_context.get_auth_response(auth_config_adk)
if creds and creds.oauth2 and creds.oauth2.access_token:
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)
else:
# Request credentials and pause execution
# Check if we already have credentials from a previous exchange
try:
# Try to load credential from credential service first (persists across sessions)
creds = None
try:
if tool_context._invocation_context.credential_service:
creds = await tool_context._invocation_context.credential_service.load_credential(
auth_config=auth_config_adk,
callback_context=tool_context
)
except ValueError:
# Credential service might not be initialized
pass

if not creds:
# Fallback to session state (get_auth_response returns AuthCredential if found)
creds = tool_context.get_auth_response(auth_config_adk)

if creds and creds.oauth2 and creds.oauth2.access_token:
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)

# Bind the token to the underlying core_tool so it constructs headers properly
needed_services = set()
for requested_service in (list(self._core_tool._required_authn_params.values()) + list(self._core_tool._required_authz_tokens)):
if isinstance(requested_service, list):
needed_services.update(requested_service)
else:
needed_services.add(requested_service)

for s in needed_services:
# Only add if not already registered (prevents ValueError on duplicate params or subsequent runs)
if not hasattr(self._core_tool, '_auth_token_getters') or s not in self._core_tool._auth_token_getters:
# TODO(id_token): Uncomment this line and remove the `getattr` fallback below once PR https://github.com/google/adk-python/pull/4402 is merged.
# self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=creds.oauth2.id_token or creds.oauth2.access_token: t)
self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=getattr(creds.oauth2, "id_token", creds.oauth2.access_token): t)
# Once we use it from get_auth_response, save it to the auth service for future use
try:
if tool_context._invocation_context.credential_service:
auth_config_adk.exchanged_auth_credential = creds
await tool_context._invocation_context.credential_service.save_credential(
auth_config=auth_config_adk,
callback_context=tool_context
)
except Exception as e:
logging.debug(f"Failed to save credential to service: {e}")
else:
tool_context.request_credential(auth_config_adk)
return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."}
except Exception as e:
if "credential" in str(e).lower() or isinstance(e, ValueError):
raise e

logging.warning(
f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. "
"Falling back to request_credential.",
exc_info=True
)
# Fallback to request logic
tool_context.request_credential(auth_config_adk)
return None
except Exception as e:
if "credential" in str(e).lower() or isinstance(e, ValueError):
raise e

logging.warning(
f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. "
"Falling back to request_credential.",
exc_info=True
)
# Fallback to request logic
tool_context.request_credential(auth_config_adk)
return None
return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."}

result: Optional[Any] = None
error: Optional[Exception] = None
Expand Down
22 changes: 19 additions & 3 deletions packages/toolbox-adk/tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ async def test_3lo_flow_simulation(self):
assert declaration.name == "get-n-rows"
assert "num_rows" in declaration.parameters.properties

# Force the proxy tool to require auth to properly simulate the 3LO flow branches
tool._core_tool._ToolboxTool__required_authn_params = {}
tool._core_tool._ToolboxTool__required_authz_tokens = ["mock_service"]

# Create a mock context that behaves like ADK's ReadonlyContext
mock_ctx_first = MagicMock()
# Simulate "No Auth Response Found"
Expand All @@ -201,32 +205,44 @@ async def test_3lo_flow_simulation(self):
result_first = await tool.run_async({"num_rows": "1"}, mock_ctx_first)

# The wrapper should catch the missing creds and request them.
assert result_first is None, "Tool should return None sig for auth requirement"
assert isinstance(result_first, dict) and "error" in result_first, "Tool should return error sig for auth requirement"
mock_ctx_first.request_credential.assert_called_once()

# Inspect the requested config
auth_config = mock_ctx_first.request_credential.call_args[0][0]
assert auth_config.raw_auth_credential.oauth2.client_id == "test-client-id"
# Verify the default fallback scopes were assigned correctly to avoid upstream crashes
assert auth_config.auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""}

mock_ctx_second = MagicMock()

# Simulate "Auth Response Found"
mock_creds = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="fake-access-token"),
oauth2=OAuth2Auth(access_token="fake-access-token", id_token="fake-id-token"),
)
mock_ctx_second.get_auth_response.return_value = mock_creds

# Setup the credential service mock to verify credential persistence across sessions
mock_cred_service = AsyncMock()
mock_cred_service.load_credential.return_value = None
mock_ctx_second._invocation_context = MagicMock()
mock_ctx_second._invocation_context.credential_service = mock_cred_service

print("Running tool second time (expecting success or server error)...")

try:
result_second = await tool.run_async({"num_rows": "1"}, mock_ctx_second)
assert result_second is not None
# Verify that the tool saved the credentials to the storage service backends locally
mock_cred_service.save_credential.assert_called_once()
except Exception as e:
mock_ctx_second.request_credential.assert_not_called()
err_msg = str(e).lower()
assert any(x in err_msg for x in ["401", "403", "unauthorized", "forbidden"]), f"Caught UNEXPECTED exception: {type(e).__name__}: {e}"
print(f"Caught expected server exception with fake token: {e}")
# Verify that the tool AT LEAST triggered save_credential before failing via core_tool inner HTTP req
mock_cred_service.save_credential.assert_called_once()

finally:
await toolset.close()
Expand Down
57 changes: 47 additions & 10 deletions packages/toolbox-adk/tests/unit/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ async def test_3lo_missing_client_secret(self):
core_tool = AsyncMock()
core_tool.__name__ = "mock_tool"
core_tool.__doc__ = "mock doc"
core_tool._required_authn_params = {"mock_param": "mock_service"}
core_tool._required_authz_tokens = []
auth_config = CredentialConfig(type=CredentialType.USER_IDENTITY)
# Missing client_id/secret

Expand All @@ -174,6 +176,8 @@ async def test_3lo_request_credential_when_missing(self):
core_tool.__doc__ = "mock"
core_tool.__name__ = "mock_tool"
core_tool.__doc__ = "mock doc"
core_tool._required_authn_params = {"mock_param": "mock_service"}
core_tool._required_authz_tokens = []

auth_config = CredentialConfig(
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
Expand All @@ -187,8 +191,8 @@ async def test_3lo_request_credential_when_missing(self):

result = await tool.run_async({}, ctx)

# Verify result is None (signal pause)
assert result is None
# Verify result is error/stop
assert isinstance(result, dict) and "error" in result
# Verify request_credential was called
ctx.request_credential.assert_called_once()
# Verify core tool was NOT called
Expand All @@ -198,10 +202,12 @@ async def test_3lo_request_credential_when_missing(self):
async def test_3lo_uses_existing_credential(self):
# Test that if creds exist, they are used and injected
core_tool = AsyncMock(return_value="success")
core_tool.__name__ = "mock"
core_tool.__doc__ = "mock"
core_tool.__name__ = "mock_tool"
core_tool.__doc__ = "mock doc"
# Setup overlapping needed services to test deduplication
core_tool._required_authn_params = {"mock_param": "mock_service", "another_param": "mock_service"}
core_tool._required_authz_tokens = ["mock_service"]
core_tool.add_auth_token_getter = MagicMock(return_value=core_tool)

auth_config = CredentialConfig(
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
Expand All @@ -210,11 +216,19 @@ async def test_3lo_uses_existing_credential(self):
tool = ToolboxTool(core_tool, auth_config=auth_config)

ctx = MagicMock()
# Mock get_auth_response returning valid creds
# Mock get_auth_response returning valid creds with both access & id tokens
mock_creds = MagicMock()
mock_creds.oauth2.access_token = "valid_token"
mock_creds.oauth2.access_token = "valid_access_token"
mock_creds.oauth2.id_token = "valid_id_token"
ctx.get_auth_response.return_value = mock_creds

# Set up invocation context and credential service mock to verify saving and avoid await errors
mock_cred_service = MagicMock()
mock_cred_service.load_credential = AsyncMock(return_value=None)
mock_cred_service.save_credential = AsyncMock(return_value=None)
ctx._invocation_context = MagicMock()
ctx._invocation_context.credential_service = mock_cred_service

result = await tool.run_async({}, ctx)

# Verify result is success
Expand All @@ -223,22 +237,43 @@ async def test_3lo_uses_existing_credential(self):
ctx.request_credential.assert_not_called()
# Verify core tool WAS called
core_tool.assert_called_once()

# Verify deduplication: add_auth_token_getter should only be called ONCE for "mock_service"
core_tool.add_auth_token_getter.assert_called_once()
call_args_getter = core_tool.add_auth_token_getter.call_args[0]
assert call_args_getter[0] == "mock_service"
# Evaluate the getter lambda to ensure it prefers id_token
token_getter_lambda = call_args_getter[1]
assert token_getter_lambda() == "valid_id_token"

# Verify save_credential was called with the exchanged credential
mock_cred_service.save_credential.assert_called_once()
call_args = mock_cred_service.save_credential.call_args[1]
assert call_args["auth_config"].exchanged_auth_credential == mock_creds

# Verify safe scope fallback to ["openid", "profile", "email"] when scopes is None
assert call_args["auth_config"].auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""}

@pytest.mark.asyncio
async def test_3lo_exception_reraise(self):
# Test that specific credential errors are re-raised
core_tool = AsyncMock()
core_tool.__name__ = "mock"
core_tool.__doc__ = "mock"
core_tool.__name__ = "mock_tool"
core_tool.__doc__ = "mock doc"
core_tool._required_authn_params = {"mock_param": "mock_service"}
core_tool._required_authz_tokens = []

auth_config = CredentialConfig(
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
)
tool = ToolboxTool(core_tool, auth_config=auth_config)
ctx = MagicMock()

mock_cred_service = MagicMock()
mock_cred_service.load_credential = AsyncMock(return_value=None)
ctx._invocation_context = MagicMock()
ctx._invocation_context.credential_service = mock_cred_service

# Mock get_auth_response raising ValueError
ctx.get_auth_response.side_effect = ValueError("Invalid Credential")

Expand All @@ -253,6 +288,8 @@ async def test_3lo_exception_fallback(self):
core_tool.__doc__ = "mock"
core_tool.__name__ = "mock_tool"
core_tool.__doc__ = "mock doc"
core_tool._required_authn_params = {"mock_param": "mock_service"}
core_tool._required_authz_tokens = []

auth_config = CredentialConfig(
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
Expand All @@ -265,8 +302,8 @@ async def test_3lo_exception_fallback(self):

result = await tool.run_async({}, ctx)

# Should catch RuntimeError, call request_credential, and return None
assert result is None
# Should catch RuntimeError, call request_credential, and return error map
assert isinstance(result, dict) and "error" in result
ctx.request_credential.assert_called_once()

def test_param_type_to_schema_type(self):
Expand Down
Loading