-
Notifications
You must be signed in to change notification settings - Fork 22
feat: add configurable token cache for customer flow #133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,8 @@ | |
| IntegrationDependency, | ||
| MCPTool, | ||
| ) | ||
| from sap_cloud_sdk.agentgateway._token_cache import _TokenCache, compute_expires_at | ||
| from sap_cloud_sdk.agentgateway.config import ClientConfig | ||
| from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -211,19 +213,24 @@ def _request_token_mtls( | |
| credentials: CustomerCredentials, | ||
| grant_type: str, | ||
| timeout: float, | ||
| config: ClientConfig, | ||
| app_tid: str | None = None, | ||
| extra_data: dict | None = None, | ||
| ) -> str: | ||
| ) -> tuple[str, float]: | ||
| """Make mTLS token request to IAS. | ||
|
|
||
| Args: | ||
| credentials: Customer credentials with certificate and private key. | ||
| grant_type: OAuth2 grant type. | ||
| timeout: HTTP timeout in seconds. | ||
| config: Client configuration (used to compute cache expiry). | ||
| app_tid: BTP Application Tenant ID of subscriber (optional). | ||
| extra_data: Additional form data for the token request. | ||
|
|
||
| Returns: | ||
| Access token string. | ||
| Tuple of (access_token, expires_at) where expires_at is a | ||
| time.monotonic() value indicating when the cached token should | ||
| be refreshed (already includes the configured buffer). | ||
|
|
||
| Raises: | ||
| AgentGatewaySDKError: If token request fails. | ||
|
|
@@ -282,8 +289,10 @@ def _request_token_mtls( | |
| f"Token response missing 'access_token'. Keys: {list(token_data.keys())}" | ||
| ) | ||
|
|
||
| expires_at = compute_expires_at(token_data, config) | ||
|
|
||
| logger.debug("Token acquired successfully (length: %d)", len(access_token)) | ||
| return access_token | ||
| return access_token, expires_at | ||
|
|
||
| except httpx.RequestError as e: | ||
| raise AgentGatewaySDKError(f"Token request failed: {e}") | ||
|
|
@@ -292,61 +301,87 @@ def _request_token_mtls( | |
| def get_system_token_mtls( | ||
| credentials: CustomerCredentials, | ||
| timeout: float, | ||
| config: ClientConfig, | ||
| cache: _TokenCache, | ||
| app_tid: str | None = None, | ||
| ) -> str: | ||
| """Get system-scoped token using mTLS client credentials flow. | ||
|
|
||
| Used for tool discovery where user identity is not needed. | ||
| Used for tool discovery where user identity is not needed. Returns | ||
| a cached token if still valid; otherwise acquires a fresh one. | ||
|
|
||
| Args: | ||
| credentials: Customer credentials. | ||
| timeout: HTTP timeout in seconds. | ||
| config: Client configuration. | ||
| cache: Token cache to consult and update. | ||
| app_tid: BTP Application Tenant ID of subscriber (optional). | ||
|
|
||
| Returns: | ||
| System-scoped access token. | ||
| """ | ||
| cached = cache.get_system_token(app_tid) | ||
| if cached: | ||
| logger.debug("Using cached system token (app_tid=%s)", app_tid) | ||
| return cached | ||
|
|
||
| logger.info("Acquiring system token via mTLS client credentials") | ||
| return _request_token_mtls( | ||
| token, expires_at = _request_token_mtls( | ||
| credentials, | ||
| grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS, | ||
| timeout=timeout, | ||
| config=config, | ||
| app_tid=app_tid, | ||
| extra_data={"response_type": "token"}, | ||
| ) | ||
| cache.set_system_token(token, expires_at, app_tid) | ||
| return token | ||
|
|
||
|
|
||
| def exchange_user_token( | ||
| credentials: CustomerCredentials, | ||
| user_token: str, | ||
| timeout: float, | ||
| config: ClientConfig, | ||
| cache: _TokenCache, | ||
| app_tid: str | None = None, | ||
| ) -> str: | ||
| """Exchange user token for AGW-scoped token using jwt-bearer grant. | ||
|
|
||
| Used for tool invocation where user identity must be preserved | ||
| for principal propagation. | ||
| for principal propagation. Returns a cached exchanged token if | ||
| still valid; otherwise acquires a fresh one. | ||
|
|
||
| Args: | ||
| credentials: Customer credentials. | ||
| user_token: User's JWT token to exchange. | ||
| timeout: HTTP timeout in seconds. | ||
| config: Client configuration. | ||
| cache: Token cache to consult and update. | ||
| app_tid: BTP Application Tenant ID of subscriber (optional). | ||
|
|
||
| Returns: | ||
| AGW-scoped access token with user identity. | ||
| """ | ||
| cached = cache.get_user_token(user_token, app_tid) | ||
| if cached: | ||
| logger.debug("Using cached user token (app_tid=%s)", app_tid) | ||
| return cached | ||
|
|
||
| logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant") | ||
| return _request_token_mtls( | ||
| token, expires_at = _request_token_mtls( | ||
| credentials, | ||
| grant_type=_GRANT_TYPE_JWT_BEARER, | ||
| timeout=timeout, | ||
| config=config, | ||
| app_tid=app_tid, | ||
| extra_data={ | ||
| "assertion": user_token, | ||
| "token_format": "jwt", | ||
| }, | ||
| ) | ||
| cache.set_user_token(user_token, token, expires_at, app_tid) | ||
| return token | ||
|
|
||
|
|
||
| def _build_mcp_url(gateway_url: str, ord_id: str, gt_id: str) -> str: | ||
|
|
@@ -433,6 +468,8 @@ async def _list_server_tools( | |
| async def get_mcp_tools_customer( | ||
| credentials: CustomerCredentials, | ||
| timeout: float, | ||
| config: ClientConfig, | ||
| cache: _TokenCache, | ||
| app_tid: str | None = None, | ||
| ) -> list[MCPTool]: | ||
| """List all MCP tools from servers defined in credentials. | ||
|
|
@@ -442,6 +479,9 @@ async def get_mcp_tools_customer( | |
|
|
||
| Args: | ||
| credentials: Customer credentials with integrationDependencies. | ||
| timeout: HTTP timeout in seconds. | ||
| config: Client configuration. | ||
| cache: Token cache shared across calls. | ||
| app_tid: BTP Application Tenant ID of subscriber (optional). | ||
|
|
||
| Returns: | ||
|
|
@@ -462,7 +502,7 @@ async def get_mcp_tools_customer( | |
| # Get system token for discovery | ||
| loop = asyncio.get_running_loop() | ||
| system_token = await loop.run_in_executor( | ||
| None, get_system_token_mtls, credentials, timeout, app_tid | ||
| None, get_system_token_mtls, credentials, timeout, config, cache, app_tid | ||
| ) | ||
|
|
||
| tools: list[MCPTool] = [] | ||
|
|
@@ -480,7 +520,42 @@ async def get_mcp_tools_customer( | |
| server_tools = await _list_server_tools(url, system_token, dep, timeout) | ||
| tools.extend(server_tools) | ||
| logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id) | ||
| except Exception: | ||
| except Exception as exc: | ||
| unwrapped = _unwrap_exception_group(exc) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do u feel the exception block is unusually big? can we put this stuff inside a method? |
||
| if _is_unauthorized(unwrapped): | ||
| logger.info( | ||
| "401 from %s — invalidating cached system token and retrying", | ||
| dep.ord_id, | ||
| ) | ||
| cache.invalidate_system_token(app_tid) | ||
| try: | ||
| fresh_token = await loop.run_in_executor( | ||
| None, | ||
| get_system_token_mtls, | ||
| credentials, | ||
| timeout, | ||
| config, | ||
| cache, | ||
| app_tid, | ||
| ) | ||
| server_tools = await _list_server_tools( | ||
| url, fresh_token, dep, timeout | ||
| ) | ||
| tools.extend(server_tools) | ||
| # Replace stale token for remaining iterations | ||
| system_token = fresh_token | ||
| logger.debug( | ||
| "Loaded %d tool(s) from %s after retry", | ||
| len(server_tools), | ||
| dep.ord_id, | ||
| ) | ||
| continue | ||
| except Exception: | ||
| logger.exception( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i feel we should not swallow the exception here, if there is a failure we should fail-fast, what's ur take? |
||
| "Failed to load tools from %s after retry — skipping", | ||
| dep.ord_id, | ||
| ) | ||
| continue | ||
| logger.exception("Failed to load tools from %s — skipping", dep.ord_id) | ||
|
|
||
| logger.info( | ||
|
|
@@ -494,6 +569,8 @@ async def call_mcp_tool_customer( | |
| tool: MCPTool, | ||
| user_token: str | None, | ||
| timeout: float, | ||
| config: ClientConfig, | ||
| cache: _TokenCache, | ||
| app_tid: str | None = None, | ||
| **kwargs, | ||
| ) -> str: | ||
|
|
@@ -502,11 +579,16 @@ async def call_mcp_tool_customer( | |
| If user_token is provided, exchanges it for an AGW-scoped token to preserve | ||
| user identity for principal propagation. Otherwise, falls back to system token. | ||
|
|
||
| On a 401 from the MCP server, drops the cached token and retries once. | ||
|
|
||
| Args: | ||
| credentials: Customer credentials. | ||
| tool: MCPTool to invoke. | ||
| user_token: User's JWT token for principal propagation (optional). | ||
| If None, system token is used instead (no principal propagation). | ||
| timeout: HTTP timeout in seconds. | ||
| config: Client configuration. | ||
| cache: Token cache shared across calls. | ||
| app_tid: BTP Application Tenant ID of subscriber (optional). | ||
| **kwargs: Tool input parameters. | ||
|
|
||
|
|
@@ -517,26 +599,74 @@ async def call_mcp_tool_customer( | |
|
|
||
| loop = asyncio.get_running_loop() | ||
|
|
||
| if user_token: | ||
| # Exchange user token for AGW-scoped token (with principal propagation) | ||
| agw_token = await loop.run_in_executor( | ||
| None, exchange_user_token, credentials, user_token, timeout, app_tid | ||
| ) | ||
| else: | ||
| async def _acquire_token() -> str: | ||
| if user_token: | ||
| return await loop.run_in_executor( | ||
| None, | ||
| exchange_user_token, | ||
| credentials, | ||
| user_token, | ||
| timeout, | ||
| config, | ||
| cache, | ||
| app_tid, | ||
| ) | ||
| # TODO: IBD workaround - use system token when user_token is not available. | ||
| # This bypasses principal propagation. Remove this fallback once IBD | ||
| # supports proper user token flow. | ||
| logger.warning( | ||
| "No user_token provided - using system token for tool invocation. " | ||
| "Principal propagation will NOT work." | ||
| ) | ||
| agw_token = await loop.run_in_executor( | ||
| None, get_system_token_mtls, credentials, timeout, app_tid | ||
| return await loop.run_in_executor( | ||
| None, get_system_token_mtls, credentials, timeout, config, cache, app_tid | ||
| ) | ||
|
|
||
| def _invalidate_token() -> None: | ||
| if user_token: | ||
| cache.invalidate_user_token(user_token, app_tid) | ||
| else: | ||
| cache.invalidate_system_token(app_tid) | ||
|
|
||
| last_exc: Exception | None = None | ||
| for attempt in (1, 2): | ||
| agw_token = await _acquire_token() | ||
| try: | ||
| return await _invoke_tool(tool, agw_token, timeout, **kwargs) | ||
| except Exception as exc: | ||
| unwrapped = _unwrap_exception_group(exc) | ||
| if _is_unauthorized(unwrapped) and attempt == 1: | ||
| logger.info( | ||
| "401 from MCP server for tool '%s' — invalidating cached token and retrying", | ||
| tool.name, | ||
| ) | ||
| _invalidate_token() | ||
| last_exc = exc | ||
| continue | ||
| raise | ||
|
|
||
| # Defensive — should not be reachable; second attempt either returns or raises. | ||
| raise AgentGatewaySDKError( | ||
| f"Tool invocation for '{tool.name}' failed after 401 retry: {last_exc}" | ||
| ) | ||
|
|
||
|
|
||
| async def _invoke_tool( | ||
| tool: MCPTool, | ||
| auth_token: str, | ||
| timeout: float, | ||
| **kwargs, | ||
| ) -> str: | ||
| """Open an MCP session to `tool.url` and invoke `tool.name` with `kwargs`. | ||
|
|
||
| Returns the first content block's text, or empty string when content is | ||
| empty. Raises whatever the MCP transport / session raises (notably | ||
| `httpx.HTTPStatusError` on 401, which the caller uses to drive cache | ||
| invalidation and retry). | ||
| """ | ||
| async with httpx.AsyncClient( | ||
| headers={ | ||
| "Authorization": f"Bearer {agw_token}", | ||
| "Authorization": f"Bearer {auth_token}", | ||
| "x-correlation-id": str(uuid.uuid4()), | ||
| }, | ||
| timeout=timeout, | ||
|
|
@@ -556,3 +686,17 @@ async def call_mcp_tool_customer( | |
|
|
||
| first = result.content[0] | ||
| return str(getattr(first, "text", "")) | ||
|
|
||
|
|
||
| def _unwrap_exception_group(exc: BaseException) -> BaseException: | ||
| """Unwrap nested ExceptionGroups to find the underlying cause.""" | ||
| while isinstance(exc, BaseExceptionGroup) and exc.exceptions: | ||
| exc = exc.exceptions[0] | ||
| return exc | ||
|
|
||
|
|
||
| def _is_unauthorized(exc: BaseException) -> bool: | ||
| """Detect a 401 response from the MCP server (httpx-based).""" | ||
| if isinstance(exc, httpx.HTTPStatusError): | ||
| return exc.response is not None and exc.response.status_code == 401 | ||
| return False | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am not sure if app_tid would be needed for fetching the tokens, is there any other way to do this?