diff --git a/.env_integration_tests.example b/.env_integration_tests.example index 1089a5e..2cdac6f 100644 --- a/.env_integration_tests.example +++ b/.env_integration_tests.example @@ -19,3 +19,7 @@ CLOUD_SDK_CFG_SDM_DEFAULT_UAA='{"url":"https://your-auth-url","clientid":"your-c CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_APPLICATION_URL=https://your-agent-memory-api-url-here CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA='{"url":"https://your-auth-url","clientid":"your-client-id","clientsecret":"your-client-secret"}' + +APPFND_CONHOS_LANDSCAPE=your-landscape-here +TENANT_SUBDOMAIN=your-tenant-subdomain-here +AGW_USER_TOKEN=your-user-jwt-here diff --git a/docs/INTEGRATION_TESTS.md b/docs/INTEGRATION_TESTS.md index a056627..37fac02 100644 --- a/docs/INTEGRATION_TESTS.md +++ b/docs/INTEGRATION_TESTS.md @@ -74,6 +74,29 @@ CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_APPLICATION_URL=https://your-agent-memor CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA='{"url":"https://your-auth-url","clientid":"your-client-id","clientsecret":"your-client-secret"}' ``` +### Agent Gateway Integration Tests + +Agent Gateway integration tests use the LoB agent flow via the Destination Service. Configure the following variables in `.env_integration_tests`: + +```bash +# Destination Service (required by the LoB agent flow) +CLOUD_SDK_CFG_DESTINATION_DEFAULT_CLIENTID=your-destination-client-id-here +CLOUD_SDK_CFG_DESTINATION_DEFAULT_CLIENTSECRET=your-destination-client-secret-here +CLOUD_SDK_CFG_DESTINATION_DEFAULT_URL=https://your-destination-auth-url-here +CLOUD_SDK_CFG_DESTINATION_DEFAULT_URI=https://your-destination-configuration-uri-here +CLOUD_SDK_CFG_DESTINATION_DEFAULT_IDENTITYZONE=your-identity-zone-here + +# Landscape suffix used to resolve the IAS destination name +APPFND_CONHOS_LANDSCAPE=your-landscape-here + +# Tenant subdomain for multi-tenant lookup +TENANT_SUBDOMAIN=your-tenant-subdomain-here + +# User JWT for token exchange scenarios (get_user_auth) +# If not set, user auth scenarios are automatically skipped +AGW_USER_TOKEN=your-user-jwt-here +``` + ## Running Integration Tests ```bash @@ -85,6 +108,7 @@ uv run pytest tests/core/integration/auditlog -v uv run pytest tests/objectstore/integration/ -v uv run pytest tests/destination/integration/ -v uv run pytest tests/agent_memory/integration/ -v +uv run pytest tests/agentgateway/integration/ -v ``` ### BDD Scenarios diff --git a/pyproject.toml b/pyproject.toml index e6cae92..c57ff76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sap-cloud-sdk" -version = "0.19.3" +version = "0.20.0" description = "SAP Cloud SDK for Python" readme = "README.md" license = "Apache-2.0" diff --git a/src/sap_cloud_sdk/agentgateway/__init__.py b/src/sap_cloud_sdk/agentgateway/__init__.py index 0a7b735..216f301 100644 --- a/src/sap_cloud_sdk/agentgateway/__init__.py +++ b/src/sap_cloud_sdk/agentgateway/__init__.py @@ -52,7 +52,7 @@ ] """ -from sap_cloud_sdk.agentgateway._models import MCPTool +from sap_cloud_sdk.agentgateway._models import AuthResult, MCPTool from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway.agw_client import create_client, AgentGatewayClient from sap_cloud_sdk.agentgateway.exceptions import ( @@ -69,6 +69,7 @@ # Configuration "ClientConfig", # Data models + "AuthResult", "MCPTool", # Exceptions "AgentGatewaySDKError", diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index 0f6ffb4..eb8a364 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -8,7 +8,6 @@ - Tool invocation: mTLS + jwt-bearer grant → user-scoped token (principal propagation) """ -import asyncio import json import logging import os @@ -432,17 +431,18 @@ async def _list_server_tools( async def get_mcp_tools_customer( credentials: CustomerCredentials, + system_token: str, timeout: float, - app_tid: str | None = None, ) -> list[MCPTool]: """List all MCP tools from servers defined in credentials. Iterates over all integrationDependencies in the credentials file and - discovers tools from each MCP server using mTLS client credentials. + discovers tools from each MCP server using a pre-fetched system token. Args: credentials: Customer credentials with integrationDependencies. - app_tid: BTP Application Tenant ID of subscriber (optional). + system_token: Pre-fetched raw system token for authentication. + timeout: HTTP timeout in seconds for MCP server calls. Returns: List of MCPTool objects from all servers. @@ -459,12 +459,6 @@ async def get_mcp_tools_customer( logger.info("Discovering tools from %d MCP server(s)", len(dependencies)) - # Get system token for discovery - loop = asyncio.get_running_loop() - system_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, app_tid - ) - tools: list[MCPTool] = [] for dep in dependencies: @@ -490,24 +484,20 @@ async def get_mcp_tools_customer( async def call_mcp_tool_customer( - credentials: CustomerCredentials, tool: MCPTool, - user_token: str | None, + auth_token: str, timeout: float, - app_tid: str | None = None, **kwargs, ) -> str: """Invoke an MCP tool using customer flow. - If user_token is provided, exchanges it for an AGW-scoped token to preserve - user identity for principal propagation. Otherwise, falls back to system token. + Uses a pre-fetched token (either user-scoped or system-scoped) for + authentication against the MCP server. Args: - credentials: Customer credentials. tool: MCPTool to invoke. - user_token: User's JWT token for principal propagation (optional). - If None, system token is used instead (no principal propagation). - app_tid: BTP Application Tenant ID of subscriber (optional). + auth_token: Pre-fetched raw access token for authentication. + timeout: HTTP timeout in seconds for the MCP server call. **kwargs: Tool input parameters. Returns: @@ -515,28 +505,9 @@ async def call_mcp_tool_customer( """ logger.info("Calling tool '%s' on server '%s'", tool.name, tool.server_name) - loop = asyncio.get_running_loop() - - if user_token: - # Exchange user token for AGW-scoped token (with principal propagation) - agw_token = await loop.run_in_executor( - None, exchange_user_token, credentials, user_token, timeout, app_tid - ) - else: - # TODO: IBD workaround - use system token when user_token is not available. - # This bypasses principal propagation. Remove this fallback once IBD - # supports proper user token flow. - logger.warning( - "No user_token provided - using system token for tool invocation. " - "Principal propagation will NOT work." - ) - agw_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, app_tid - ) - async with httpx.AsyncClient( headers={ - "Authorization": f"Bearer {agw_token}", + "Authorization": f"Bearer {auth_token}", "x-correlation-id": str(uuid.uuid4()), }, timeout=timeout, diff --git a/src/sap_cloud_sdk/agentgateway/_lob.py b/src/sap_cloud_sdk/agentgateway/_lob.py index 15c9707..c3cbba0 100644 --- a/src/sap_cloud_sdk/agentgateway/_lob.py +++ b/src/sap_cloud_sdk/agentgateway/_lob.py @@ -1,8 +1,8 @@ """LoB agent flow - BTP Destination Service based. LoB agents use BTP Destination Service for credential management: -- Phase 1 (discovery): Client credentials from destination -- Phase 2 (execution): Token exchange with user_token for principal propagation +- Phase 1 (discovery): Client credentials from destination (subscriber.ias fragment) +- Phase 2 (execution): Token exchange with user_token (subscriber.ias.user fragment) """ import asyncio @@ -33,6 +33,7 @@ # Label values for fragment discovery _MCP_LABEL_VALUE = "agw.mcp.server" _IAS_LABEL_VALUE = "subscriber.ias" +_IAS_USER_LABEL_VALUE = "subscriber.ias.user" _DESTINATION_INSTANCE = "default" @@ -58,8 +59,12 @@ def _fetch_auth_token( dest_name: str, tenant_subdomain: str, options: ConsumptionOptions | None = None, -) -> str: - """Fetch auth token from destination service. +) -> tuple[str, str]: + """Fetch auth token and gateway URL from destination service. + + Extracts the raw JWT from the Authorization header value returned by the + destination service (e.g. strips the "Bearer " prefix from "Bearer "), + and the gateway URL from the destination's URL property. Args: dest_name: Destination name. @@ -67,7 +72,7 @@ def _fetch_auth_token( options: Consumption options (fragment_name, user_token). Returns: - Authorization header value. + Tuple of (raw_jwt, gateway_url). Raises: MCPServerNotFoundError: If no auth token is returned. @@ -85,13 +90,17 @@ def _fetch_auth_token( f"No auth token returned for destination '{dest_name}'" ) - auth = dest.auth_tokens[0].http_header.get("value", "") - if not auth: - raise MCPServerNotFoundError( - f"Empty Authorization header for destination '{dest_name}'" - ) + auth_token = dest.auth_tokens[0] + header_value = auth_token.http_header.get("value") or "" + if not header_value: + raise MCPServerNotFoundError(f"Empty auth header for destination '{dest_name}'") + + # Strip "Bearer " prefix — AuthResult.access_token is always a raw JWT + raw_token = header_value.removeprefix("Bearer ").strip() - return auth + gateway_url = (dest.url or "").rstrip("/") + + return raw_token, gateway_url def list_mcp_fragments(tenant_subdomain: str) -> list: @@ -143,10 +152,40 @@ def get_ias_fragment_name(tenant_subdomain: str) -> str: return fragments[0].name -async def get_system_auth( +def get_ias_user_fragment_name(tenant_subdomain: str) -> str: + """Get the IAS user fragment name for token exchange (principal propagation). + + Looks up the IAS user fragment created during subscription by the + sap-managed-runtime-type=subscriber.ias.user label. + + Args: + tenant_subdomain: Tenant subdomain for multi-tenant lookup. + + Returns: + IAS user fragment name. + + Raises: + MCPServerNotFoundError: If no IAS user fragment is found. + """ + client = create_fragment_client(instance=_DESTINATION_INSTANCE) + fragments = client.list_instance_fragments( + filter=ListOptions( + filter_labels=[Label(key=_LABEL_KEY, values=[_IAS_USER_LABEL_VALUE])] + ), + tenant=tenant_subdomain, + ) + if not fragments: + raise MCPServerNotFoundError( + f"No IAS user fragment found (label {_LABEL_KEY}={_IAS_USER_LABEL_VALUE}) " + f"for tenant '{tenant_subdomain}'" + ) + return fragments[0].name + + +async def fetch_system_auth( tenant_subdomain: str, -) -> str: - """Get system-scoped auth (Phase 1 - client credentials). +) -> tuple[str, str]: + """Fetch system-scoped auth (Phase 1 - client credentials). Looks up the IAS fragment (subscriber.ias label) and uses it to acquire a client-credentials token via BTP Destination Service. @@ -155,7 +194,7 @@ async def get_system_auth( tenant_subdomain: Tenant subdomain for multi-tenant lookup. Returns: - Authorization header value (e.g., "Bearer xxx"). + Tuple of (raw_access_token, gateway_url). Raises: MCPServerNotFoundError: If no IAS fragment or auth token is found. @@ -182,39 +221,42 @@ def _fetch_system_auth_sync(): return await loop.run_in_executor(None, _fetch_system_auth_sync) -async def get_user_auth( - mcp_fragment_name: str, +async def fetch_user_auth( user_token: str, tenant_subdomain: str, -) -> str: - """Get user-scoped auth (Phase 2 - token exchange). +) -> tuple[str, str]: + """Fetch user-scoped auth (Phase 2 - token exchange). + + Looks up the IAS user fragment (subscriber.ias.user label) and uses it + together with the user_token to perform a token exchange via BTP + Destination Service. Args: - mcp_fragment_name: MCP fragment name for token exchange. user_token: User's JWT for principal propagation. tenant_subdomain: Tenant subdomain for multi-tenant lookup. Returns: - Authorization header value with user identity embedded. + Tuple of (raw_access_token, gateway_url). Raises: - MCPServerNotFoundError: If no auth token is returned. + MCPServerNotFoundError: If no IAS user fragment or auth token is found. """ loop = asyncio.get_running_loop() def _fetch_user_auth_sync(): + ias_user_fragment_name = get_ias_user_fragment_name(tenant_subdomain) dest_name = _ias_dest_name() logger.info( "Exchanging user auth — destination: '%s', fragment: '%s', tenant: '%s'", dest_name, - mcp_fragment_name, + ias_user_fragment_name, tenant_subdomain, ) options = ConsumptionOptions( user_token=user_token, - fragment_name=mcp_fragment_name, + fragment_name=ias_user_fragment_name, fragment_level=ConsumptionLevel.INSTANCE, ) @@ -224,20 +266,23 @@ def _fetch_user_auth_sync(): async def list_server_tools( - dest_url: str, system_auth: str, fragment_name: str, timeout: float + dest_url: str, auth_token: str, fragment_name: str, timeout: float ) -> list[MCPTool]: """List tools from a single MCP server. Args: dest_url: MCP endpoint URL. - system_auth: Authorization header for the request. + auth_token: Raw access token for the request. fragment_name: Fragment name for reference. Returns: List of MCPTool objects from this server. """ async with httpx.AsyncClient( - headers={"Authorization": system_auth, "x-correlation-id": str(uuid.uuid4())}, + headers={ + "Authorization": f"Bearer {auth_token}", + "x-correlation-id": str(uuid.uuid4()), + }, timeout=timeout, ) as http_client: async with streamable_http_client(dest_url, http_client=http_client) as ( @@ -270,14 +315,17 @@ async def list_server_tools( async def get_mcp_tools_lob( tenant_subdomain: str, + system_token: str, timeout: float, ) -> list[MCPTool]: """List all MCP tools using LoB flow (destination-based). - Uses Phase 1 auth (client-scoped) via BTP Destination Service. + Uses a pre-fetched system token for authentication against MCP servers. Args: tenant_subdomain: Tenant subdomain for multi-tenant lookup. + system_token: Pre-fetched raw system token (from get_system_auth). + timeout: HTTP timeout in seconds for MCP server calls. Returns: List of MCPTool objects from all MCP servers. @@ -306,9 +354,8 @@ async def get_mcp_tools_lob( continue try: - system_auth = await get_system_auth(tenant_subdomain) server_tools = await list_server_tools( - mcp_url, system_auth, fragment_name, timeout + mcp_url, system_token, fragment_name, timeout ) tools.extend(server_tools) logger.debug( @@ -328,36 +375,33 @@ async def get_mcp_tools_lob( async def call_mcp_tool_lob( tool: MCPTool, - user_token: str, - tenant_subdomain: str, + user_auth_token: str, timeout: float, **kwargs, ) -> str: """Invoke an MCP tool using LoB flow (destination-based). - Uses Phase 2 auth (user-scoped) via token exchange. - Principal propagation ensures LoB systems see user identity. + Uses a pre-fetched user token for principal propagation. Args: tool: MCPTool object (from list_mcp_tools). - user_token: User's JWT for principal propagation. - tenant_subdomain: Tenant subdomain for token exchange. + user_auth_token: Pre-fetched raw user token (from get_user_auth). + timeout: HTTP timeout in seconds for the MCP server call. **kwargs: Tool input parameters. Returns: Tool execution result as string. - - Raises: - MCPServerNotFoundError: If destination/auth fails. """ if not tool.fragment_name: raise MCPServerNotFoundError( f"Tool '{tool.name}' missing fragment_name for LoB invocation" ) - user_auth = await get_user_auth(tool.fragment_name, user_token, tenant_subdomain) async with httpx.AsyncClient( - headers={"Authorization": user_auth, "x-correlation-id": str(uuid.uuid4())}, + headers={ + "Authorization": f"Bearer {user_auth_token}", + "x-correlation-id": str(uuid.uuid4()), + }, timeout=timeout, ) as http_client: async with streamable_http_client(tool.url, http_client=http_client) as ( diff --git a/src/sap_cloud_sdk/agentgateway/_models.py b/src/sap_cloud_sdk/agentgateway/_models.py index c48ee7a..e680bc2 100644 --- a/src/sap_cloud_sdk/agentgateway/_models.py +++ b/src/sap_cloud_sdk/agentgateway/_models.py @@ -4,6 +4,32 @@ from typing import Any +@dataclass +class AuthResult: + """Authentication result from Agent Gateway. + + Contains the access token and the Agent Gateway URL. + + Attributes: + access_token: Raw JWT access token (no "Bearer " prefix). + gateway_url: Agent Gateway base URL (no trailing slash). + + Example: + ```python + from sap_cloud_sdk.agentgateway import create_client + + agw_client = create_client(tenant_subdomain="my-tenant") + + auth = await agw_client.get_system_auth() + print(auth.access_token) # raw JWT + print(auth.gateway_url) # "https://agw.example.com" + ``` + """ + + access_token: str + gateway_url: str + + @dataclass class MCPTool: """MCP tool discovered from Agent Gateway. diff --git a/src/sap_cloud_sdk/agentgateway/agw_client.py b/src/sap_cloud_sdk/agentgateway/agw_client.py index a601d88..f385118 100644 --- a/src/sap_cloud_sdk/agentgateway/agw_client.py +++ b/src/sap_cloud_sdk/agentgateway/agw_client.py @@ -7,18 +7,26 @@ - Customer agents: Use file-based credentials mounted on pod with mTLS auth """ +import asyncio import logging from typing import Callable -from sap_cloud_sdk.agentgateway._models import MCPTool from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway._customer import ( + call_mcp_tool_customer, detect_customer_agent_credentials, - load_customer_credentials, + exchange_user_token, get_mcp_tools_customer, - call_mcp_tool_customer, + get_system_token_mtls, + load_customer_credentials, ) -from sap_cloud_sdk.agentgateway._lob import get_mcp_tools_lob, call_mcp_tool_lob +from sap_cloud_sdk.agentgateway._lob import ( + call_mcp_tool_lob, + fetch_system_auth, + fetch_user_auth, + get_mcp_tools_lob, +) +from sap_cloud_sdk.agentgateway._models import AuthResult, MCPTool from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError from sap_cloud_sdk.core.telemetry import Module, Operation, record_metrics @@ -68,6 +76,23 @@ class AgentGatewayClient: cost_center="1000", ) ``` + + Example (auth for external use): + ```python + from sap_cloud_sdk.agentgateway import create_client + + agw_client = create_client(tenant_subdomain="my-tenant") + + # Get system-scoped auth (token + gateway URL) + auth = await agw_client.get_system_auth() + print(auth.access_token) # raw JWT + print(auth.gateway_url) # "https://agw.example.com" + + # Get user-scoped auth (token exchange + gateway URL) + auth = await agw_client.get_user_auth(user_token="user-jwt") + print(auth.access_token) # exchanged JWT with user identity + print(auth.gateway_url) # "https://agw.example.com" + ``` """ def __init__( @@ -117,6 +142,139 @@ def _resolve_tenant_subdomain(self) -> str: "tenant_subdomain is required for LoB agent flow.", ) + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_GET_SYSTEM_AUTH) + async def get_system_auth(self, app_tid: str | None = None) -> AuthResult: + """Get system-scoped authentication (client_credentials flow). + + Automatically detects agent type (LoB vs Customer) based on + credential file presence. + + Args: + app_tid: BTP Application Tenant ID of the subscriber. + Only used for customer agents. This is passed to the token + service for tenant-scoped token requests. + + Returns: + AuthResult with raw access token (JWT) and Agent Gateway URL. + + Raises: + AgentGatewaySDKError: If tenant_subdomain is required but not + provided (LoB), or if token acquisition fails. + + Example: + ```python + auth = await agw_client.get_system_auth() + headers = {"Authorization": f"Bearer {auth.access_token}"} + # auth.gateway_url is the Agent Gateway base URL + ``` + """ + try: + credentials_path = detect_customer_agent_credentials() + if credentials_path: + logger.info( + "Customer agent credentials detected at '%s'", credentials_path + ) + credentials = load_customer_credentials(credentials_path) + loop = asyncio.get_running_loop() + token = await loop.run_in_executor( + None, + get_system_token_mtls, + credentials, + self._config.timeout, + app_tid, + ) + return AuthResult( + access_token=token, + gateway_url=credentials.gateway_url, + ) + + # LoB flow + if app_tid: + logger.warning("app_tid parameter ignored for LoB agent flow") + + tenant = self._resolve_tenant_subdomain() + token, gateway_url = await fetch_system_auth(tenant) + return AuthResult(access_token=token, gateway_url=gateway_url) + + except AgentGatewaySDKError: + raise + except Exception as e: + logger.exception("Unexpected error during system auth acquisition") + raise AgentGatewaySDKError(f"System auth acquisition failed: {e}") from e + + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_GET_USER_AUTH) + async def get_user_auth( + self, + user_token: str | Callable[[], str] | None, + app_tid: str | None = None, + ) -> AuthResult: + """Exchange a user token for AGW-scoped authentication (token exchange). + + Automatically detects agent type (LoB vs Customer) based on + credential file presence. + + Args: + user_token: User's JWT for principal propagation. + Can be a string or a callable returning a string. + app_tid: BTP Application Tenant ID of the subscriber. + Only used for customer agents. This is passed to the token + service for tenant-scoped token exchange. + + Returns: + AuthResult with raw access token (JWT, user identity embedded) + and Agent Gateway URL. + + Raises: + AgentGatewaySDKError: If user_token is empty, or tenant_subdomain + is required but not provided (LoB), or if token exchange fails. + + Example: + ```python + auth = await agw_client.get_user_auth(user_token="user-jwt") + headers = {"Authorization": f"Bearer {auth.access_token}"} + # auth.gateway_url is the Agent Gateway base URL + ``` + """ + try: + resolved_user_token = self._resolve_value( + user_token, + "user_token is required for token exchange.", + ) + + credentials_path = detect_customer_agent_credentials() + if credentials_path: + logger.info( + "Customer agent credentials detected at '%s'", credentials_path + ) + credentials = load_customer_credentials(credentials_path) + loop = asyncio.get_running_loop() + token = await loop.run_in_executor( + None, + exchange_user_token, + credentials, + resolved_user_token, + self._config.timeout, + app_tid, + ) + return AuthResult( + access_token=token, + gateway_url=credentials.gateway_url, + ) + + # LoB flow + if app_tid: + logger.warning("app_tid parameter ignored for LoB agent flow") + + tenant = self._resolve_tenant_subdomain() + token, gateway_url = await fetch_user_auth(resolved_user_token, tenant) + return AuthResult(access_token=token, gateway_url=gateway_url) + + except AgentGatewaySDKError: + raise + except Exception as e: + logger.exception("Unexpected error during user auth exchange") + raise AgentGatewaySDKError(f"User auth exchange failed: {e}") from e + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_LIST_MCP_TOOLS) async def list_mcp_tools( self, @@ -157,8 +315,9 @@ async def list_mcp_tools( "Customer agent credentials detected at '%s'", credentials_path ) credentials = load_customer_credentials(credentials_path) + auth = await self.get_system_auth(app_tid=app_tid) return await get_mcp_tools_customer( - credentials, self._config.timeout, app_tid + credentials, auth.access_token, self._config.timeout ) # LoB flow - requires tenant_subdomain @@ -166,7 +325,10 @@ async def list_mcp_tools( logger.warning("app_tid parameter ignored for LoB agent flow") tenant = self._resolve_tenant_subdomain() - return await get_mcp_tools_lob(tenant, self._config.timeout) + auth = await self.get_system_auth() + return await get_mcp_tools_lob( + tenant, auth.access_token, self._config.timeout + ) except AgentGatewaySDKError: # Re-raise SDK errors as-is @@ -235,38 +397,29 @@ async def call_mcp_tool( ) # Resolve user_token if provided (optional for customer flow) - resolved_user_token = None if user_token: - resolved_user_token = ( - user_token() - if not isinstance(user_token, str) and callable(user_token) - else user_token + auth = await self.get_user_auth(user_token, app_tid) + else: + # TODO: IBD workaround - use system token when user_token + # is not available. This bypasses principal propagation. + # Remove this fallback once IBD supports proper user token flow. + logger.warning( + "No user_token provided - using system token for tool " + "invocation. Principal propagation will NOT work." ) - if resolved_user_token: - resolved_user_token = resolved_user_token.strip() or None + auth = await self.get_system_auth(app_tid) - credentials = load_customer_credentials(credentials_path) return await call_mcp_tool_customer( - credentials, - tool, - resolved_user_token, - self._config.timeout, - app_tid, - **kwargs, + tool, auth.access_token, self._config.timeout, **kwargs ) # LoB flow - requires user_token and tenant_subdomain - resolved_user_token = self._resolve_value( - user_token, - "user_token is required for LoB agent tool invocation.", - ) - if app_tid: logger.warning("app_tid parameter ignored for LoB agent flow") - tenant = self._resolve_tenant_subdomain() + auth = await self.get_user_auth(user_token, app_tid) return await call_mcp_tool_lob( - tool, resolved_user_token, tenant, self._config.timeout, **kwargs + tool, auth.access_token, self._config.timeout, **kwargs ) except AgentGatewaySDKError: @@ -342,5 +495,16 @@ def create_client( cost_center="1000", # example tool-specific parameter ) ``` + + Example (auth fetching): + ```python + from sap_cloud_sdk.agentgateway import create_client + + agw_client = create_client(tenant_subdomain="my-tenant") + + # Get auth for external use + auth = await agw_client.get_system_auth() + user_auth = await agw_client.get_user_auth(user_token="user-jwt") + ``` """ return AgentGatewayClient(tenant_subdomain=tenant_subdomain, config=config) diff --git a/src/sap_cloud_sdk/core/telemetry/operation.py b/src/sap_cloud_sdk/core/telemetry/operation.py index 8619145..f56e702 100644 --- a/src/sap_cloud_sdk/core/telemetry/operation.py +++ b/src/sap_cloud_sdk/core/telemetry/operation.py @@ -107,6 +107,8 @@ class Operation(str, Enum): # Agent Gateway Operations AGENTGATEWAY_LIST_MCP_TOOLS = "list_mcp_tools" AGENTGATEWAY_CALL_MCP_TOOL = "call_mcp_tool" + AGENTGATEWAY_GET_SYSTEM_AUTH = "get_system_auth" + AGENTGATEWAY_GET_USER_AUTH = "get_user_auth" # Agent Memory Operations AGENT_MEMORY_ADD_MEMORY = "add_memory" diff --git a/src/sap_cloud_sdk/extensibility/_ums_transport.py b/src/sap_cloud_sdk/extensibility/_ums_transport.py index 76f69cf..49bcfdf 100644 --- a/src/sap_cloud_sdk/extensibility/_ums_transport.py +++ b/src/sap_cloud_sdk/extensibility/_ums_transport.py @@ -436,7 +436,8 @@ class UmsTransport: 1. ``config.destination_name`` (explicit config override). 2. ``APPFND_UMS_DESTINATION_NAME`` environment variable. 3. ``sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}`` (constructed). - 4. ``EXTENSIBILITY_SERVICE`` (fallback with warning). + + If none of the above are available, resolution fails with a warning. Args: agent_ord_id: ORD ID of the agent. diff --git a/src/sap_cloud_sdk/extensibility/config.py b/src/sap_cloud_sdk/extensibility/config.py index 649a700..c881492 100644 --- a/src/sap_cloud_sdk/extensibility/config.py +++ b/src/sap_cloud_sdk/extensibility/config.py @@ -23,8 +23,8 @@ class ExtensibilityConfig: When ``None`` (the default), the destination name is resolved automatically in order: (1) ``APPFND_UMS_DESTINATION_NAME`` environment variable, - (2) ``sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}``, - (3) fallback to ``"EXTENSIBILITY_SERVICE"`` with a warning. + (2) ``sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}``. + If neither is available, resolution fails with a warning. Set this only when the destination follows a non-standard naming convention that cannot be expressed via environment variables. diff --git a/src/sap_cloud_sdk/extensibility/user-guide.md b/src/sap_cloud_sdk/extensibility/user-guide.md index 043e539..ed7590d 100644 --- a/src/sap_cloud_sdk/extensibility/user-guide.md +++ b/src/sap_cloud_sdk/extensibility/user-guide.md @@ -714,7 +714,7 @@ Validation issues produce log warnings but never prevent output generation. The module resolves the extensibility service URL and credentials through the SAP BTP Destination Service. The destination is looked up at the subaccount level. -- **Default destination name resolution**: (1) `APPFND_UMS_DESTINATION_NAME` env var, (2) `sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}`, (3) `EXTENSIBILITY_SERVICE` fallback. +- **Default destination name resolution**: (1) `APPFND_UMS_DESTINATION_NAME` env var, (2) `sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}`. If neither is available, resolution fails with a warning. - **Default destination instance**: `default` - Override via `ExtensibilityConfig(destination_name=...)` when the destination uses a non-standard name. diff --git a/tests/agentgateway/integration/__init__.py b/tests/agentgateway/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agentgateway/integration/agw_auth.feature b/tests/agentgateway/integration/agw_auth.feature new file mode 100644 index 0000000..aae2a03 --- /dev/null +++ b/tests/agentgateway/integration/agw_auth.feature @@ -0,0 +1,54 @@ +Feature: Agent Gateway Auth Integration + As a developer using the SDK + I want to fetch auth credentials from the Agent Gateway + So that I can make authenticated requests to MCP servers + + Background: + Given the Agent Gateway client is available + + Scenario: Get system auth returns a valid AuthResult + When I call get_system_auth + Then the result should be an AuthResult + And the access_token should be a non-empty string + And the gateway_url should be a non-empty string + And the gateway_url should have no trailing slash + And the access_token should not start with "Bearer " + + Scenario: Get user auth returns a valid AuthResult + Given I have a valid user token + When I call get_user_auth with the user token + Then the result should be an AuthResult + And the access_token should be a non-empty string + And the gateway_url should be a non-empty string + And the gateway_url should have no trailing slash + And the access_token should not start with "Bearer " + + Scenario: Get user auth accepts a callable user token + Given I have a valid user token + When I call get_user_auth with a callable returning the user token + Then the result should be an AuthResult + And the access_token should be a non-empty string + + Scenario: System auth and user auth return the same gateway URL + Given I have a valid user token + When I call get_system_auth + And I call get_user_auth with the user token + Then both gateway URLs should match + + Scenario: Get user auth fails when user token is empty + When I call get_user_auth with an empty user token + Then the operation should fail with AgentGatewaySDKError + And the error message should mention "user_token is required" + + Scenario: List MCP tools returns a non-empty list of tools + When I call list_mcp_tools + Then the result should be a list of MCPTool + And the list should be non-empty + And each tool should have a non-empty name + And each tool should have a non-empty url + + Scenario: Call search_workflows tool returns a non-empty result + Given I have a valid user token + When I call list_mcp_tools + And I call call_mcp_tool with "search_workflows" and the user token + Then the tool result should be a non-empty string diff --git a/tests/agentgateway/integration/conftest.py b/tests/agentgateway/integration/conftest.py new file mode 100644 index 0000000..aa96899 --- /dev/null +++ b/tests/agentgateway/integration/conftest.py @@ -0,0 +1,47 @@ +"""Pytest configuration and fixtures for Agent Gateway integration tests.""" + +import os +from pathlib import Path + +import pytest +from dotenv import load_dotenv + +from sap_cloud_sdk.agentgateway import create_client, AgentGatewayClient + + +def _setup_cloud_mode(): + """Load environment variables from .env_integration_tests if present.""" + env_file = Path(__file__).parents[3] / ".env_integration_tests" + if env_file.exists(): + load_dotenv(env_file) + + +@pytest.fixture(scope="session") +def agw_client() -> AgentGatewayClient: + """Create an AgentGatewayClient from environment variables.""" + _setup_cloud_mode() + + tenant_subdomain = os.environ.get("TENANT_SUBDOMAIN") + if not tenant_subdomain: + pytest.fail("TENANT_SUBDOMAIN environment variable is not set") + + try: + return create_client(tenant_subdomain=tenant_subdomain) + except Exception as e: + pytest.fail(f"Failed to create Agent Gateway client for integration tests: {e}") + + +# Configure pytest markers for integration tests +def pytest_configure(config): + """Configure pytest markers.""" + config.addinivalue_line( + "markers", + "integration: mark test as integration test" + ) + + +def pytest_collection_modifyitems(config, items): + """Automatically mark integration tests.""" + for item in items: + if "integration" in str(item.fspath): + item.add_marker(pytest.mark.integration) diff --git a/tests/agentgateway/integration/test_agw_bdd.py b/tests/agentgateway/integration/test_agw_bdd.py new file mode 100644 index 0000000..37704b0 --- /dev/null +++ b/tests/agentgateway/integration/test_agw_bdd.py @@ -0,0 +1,244 @@ +"""BDD step definitions for Agent Gateway auth integration tests. + +Run against a live BTP tenant: + + APPFND_CONHOS_LANDSCAPE= \\ + TENANT_SUBDOMAIN= \\ + CLOUD_SDK_CFG_DESTINATION_DEFAULT_CLIENTID=... \\ + CLOUD_SDK_CFG_DESTINATION_DEFAULT_CLIENTSECRET=... \\ + CLOUD_SDK_CFG_DESTINATION_DEFAULT_URL=... \\ + CLOUD_SDK_CFG_DESTINATION_DEFAULT_URI=... \\ + CLOUD_SDK_CFG_DESTINATION_DEFAULT_IDENTITYZONE=... \\ + AGW_USER_TOKEN= \\ + pytest tests/agentgateway/integration/ -v +""" + +import asyncio +import os +from typing import Optional + +import pytest +from pytest_bdd import scenarios, given, when, then, parsers + +from sap_cloud_sdk.agentgateway import AgentGatewayClient, AuthResult, AgentGatewaySDKError +from sap_cloud_sdk.agentgateway._models import MCPTool + +scenarios("agw_auth.feature") + + +# ==================== HELPERS ==================== + + +def run(coro): + """Run a coroutine in a new event loop.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +# ==================== CONTEXT ==================== + + +class ScenarioContext: + """Context to store test state between steps.""" + + def __init__(self): + self.system_auth_result: Optional[AuthResult] = None + self.user_auth_result: Optional[AuthResult] = None + self.last_result: Optional[AuthResult] = None + self.operation_error: Optional[Exception] = None + self.user_token: Optional[str] = None + self.tools: Optional[list[MCPTool]] = None + self.tool_result: Optional[str] = None + + +@pytest.fixture +def context(): + """Provide a fresh context for each scenario.""" + return ScenarioContext() + + +# ==================== GIVEN ==================== + + +@given("the Agent Gateway client is available") +def agent_gateway_client_available(agw_client: AgentGatewayClient): + """Verify that the Agent Gateway client is available.""" + assert agw_client is not None + + +@given("I have a valid user token") +def have_valid_user_token(context: ScenarioContext): + """Load user token from environment variable.""" + token = os.environ.get("AGW_USER_TOKEN", "") + if not token: + pytest.skip("AGW_USER_TOKEN is not set — skipping user auth scenario") + context.user_token = token + + +# ==================== WHEN ==================== + + +@when("I call get_system_auth") +def call_get_system_auth(context: ScenarioContext, agw_client: AgentGatewayClient): + """Call get_system_auth and store the result.""" + context.system_auth_result = run(agw_client.get_system_auth()) + context.last_result = context.system_auth_result + + +@when("I call get_user_auth with the user token") +def call_get_user_auth(context: ScenarioContext, agw_client: AgentGatewayClient): + """Call get_user_auth with the user token and store the result.""" + context.user_auth_result = run( + agw_client.get_user_auth(user_token=context.user_token) + ) + context.last_result = context.user_auth_result + + +@when("I call get_user_auth with a callable returning the user token") +def call_get_user_auth_callable(context: ScenarioContext, agw_client: AgentGatewayClient): + """Call get_user_auth with a callable and store the result.""" + token = context.user_token + context.last_result = run( + agw_client.get_user_auth(user_token=lambda: token) + ) + + +@when("I call get_user_auth with an empty user token") +def call_get_user_auth_empty_token(context: ScenarioContext, agw_client: AgentGatewayClient): + """Call get_user_auth with an empty token and capture the error.""" + try: + run(agw_client.get_user_auth(user_token="")) + except AgentGatewaySDKError as e: + context.operation_error = e + + +@when("I call list_mcp_tools") +def call_list_mcp_tools(context: ScenarioContext, agw_client: AgentGatewayClient): + """Call list_mcp_tools and store the result.""" + context.tools = run(agw_client.list_mcp_tools()) + + +@when(parsers.parse('I call call_mcp_tool with "{tool_name}" and the user token')) +def call_call_mcp_tool( + context: ScenarioContext, agw_client: AgentGatewayClient, tool_name: str +): + """Find tool by name from list_mcp_tools result and call it.""" + assert context.tools is not None, "call list_mcp_tools before calling a tool" + tool = next((t for t in context.tools if t.name == tool_name), None) + if tool is None: + pytest.fail(f"Tool '{tool_name}' not found in list_mcp_tools result") + context.tool_result = run( + agw_client.call_mcp_tool(tool, user_token=context.user_token) + ) + + +# ==================== THEN ==================== + + +@then("the result should be an AuthResult") +def result_is_auth_result(context: ScenarioContext): + """Verify the result is an AuthResult instance.""" + assert isinstance(context.last_result, AuthResult) + + +@then("the access_token should be a non-empty string") +def access_token_non_empty(context: ScenarioContext): + """Verify access_token is a non-empty string.""" + assert context.last_result is not None + assert isinstance(context.last_result.access_token, str) + assert context.last_result.access_token.strip() + + +@then("the gateway_url should be a non-empty string") +def gateway_url_non_empty(context: ScenarioContext): + """Verify gateway_url is a non-empty string.""" + assert context.last_result is not None + assert isinstance(context.last_result.gateway_url, str) + assert context.last_result.gateway_url.strip() + + +@then("the gateway_url should have no trailing slash") +def gateway_url_no_trailing_slash(context: ScenarioContext): + """Verify gateway_url does not end with a slash.""" + assert context.last_result is not None + assert not context.last_result.gateway_url.endswith("/") + + +@then(parsers.parse('the access_token should not start with "{prefix}"')) +def access_token_not_starts_with(context: ScenarioContext, prefix: str): + """Verify access_token does not start with the given prefix.""" + assert context.last_result is not None + assert not context.last_result.access_token.startswith(prefix), ( + f"Expected access_token NOT to start with '{prefix}', " + f"got: {context.last_result.access_token[:40]}..." + ) + + +@then("both gateway URLs should match") +def gateway_urls_match(context: ScenarioContext): + """Verify system auth and user auth return the same gateway URL.""" + assert context.system_auth_result is not None + assert context.user_auth_result is not None + assert context.system_auth_result.gateway_url == context.user_auth_result.gateway_url + + +@then("the operation should fail with AgentGatewaySDKError") +def operation_fails_with_sdk_error(context: ScenarioContext): + """Verify the operation raised an AgentGatewaySDKError.""" + assert isinstance(context.operation_error, AgentGatewaySDKError), ( + f"Expected AgentGatewaySDKError, got: {context.operation_error}" + ) + + +@then(parsers.parse('the error message should mention "{expected}"')) +def error_message_mentions(context: ScenarioContext, expected: str): + """Verify the error message contains the expected text.""" + assert expected in str(context.operation_error), ( + f"Expected '{expected}' in error: {context.operation_error}" + ) + + +@then("the result should be a list of MCPTool") +def result_is_list_of_mcp_tool(context: ScenarioContext): + """Verify the result is a list of MCPTool instances.""" + assert isinstance(context.tools, list) + for tool in context.tools: + assert isinstance(tool, MCPTool), f"Expected MCPTool, got {type(tool)}" + + +@then("the list should be non-empty") +def list_is_non_empty(context: ScenarioContext): + """Verify the tools list is not empty.""" + assert context.tools is not None + assert len(context.tools) > 0, "Expected at least one MCP tool" + + +@then("each tool should have a non-empty name") +def each_tool_has_non_empty_name(context: ScenarioContext): + """Verify every tool has a non-empty name.""" + assert context.tools is not None + for tool in context.tools: + assert isinstance(tool.name, str) and tool.name.strip(), ( + f"Tool has empty name: {tool}" + ) + + +@then("each tool should have a non-empty url") +def each_tool_has_non_empty_url(context: ScenarioContext): + """Verify every tool has a non-empty url.""" + assert context.tools is not None + for tool in context.tools: + assert isinstance(tool.url, str) and tool.url.strip(), ( + f"Tool '{tool.name}' has empty url" + ) + + +@then("the tool result should be a non-empty string") +def tool_result_is_non_empty_string(context: ScenarioContext): + """Verify the tool invocation returned a non-empty string.""" + assert context.tool_result is not None + assert isinstance(context.tool_result, str) + assert context.tool_result.strip(), "Expected a non-empty tool result" diff --git a/tests/agentgateway/unit/test_agw_client.py b/tests/agentgateway/unit/test_agw_client.py index c60e8bb..3521aae 100644 --- a/tests/agentgateway/unit/test_agw_client.py +++ b/tests/agentgateway/unit/test_agw_client.py @@ -1,12 +1,13 @@ """Unit tests for Agent Gateway client.""" -from unittest.mock import patch, AsyncMock +from unittest.mock import patch, AsyncMock, MagicMock import pytest from sap_cloud_sdk.agentgateway import ( create_client, AgentGatewayClient, + AuthResult, MCPTool, AgentGatewaySDKError, ) @@ -96,6 +97,223 @@ def test_raises_on_callable_returning_empty(self): AgentGatewayClient._resolve_value(get_empty, "test error") +# ============================================================ +# Test: get_system_auth +# ============================================================ + + +class TestGetSystemAuth: + """Tests for get_system_auth async method.""" + + @pytest.mark.asyncio + async def test_lob_flow_returns_auth_result(self): + """Return AuthResult from LoB flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("raw-system-jwt-token", "https://agw.example.com"), + ) as mock_auth: + agw_client = create_client(tenant_subdomain="my-tenant") + + result = await agw_client.get_system_auth() + + assert isinstance(result, AuthResult) + assert result.access_token == "raw-system-jwt-token" + assert result.gateway_url == "https://agw.example.com" + mock_auth.assert_called_once_with("my-tenant") + + @pytest.mark.asyncio + async def test_customer_flow_returns_auth_result(self): + """Return AuthResult from customer flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, patch( + "sap_cloud_sdk.agentgateway.agw_client.get_system_token_mtls", + return_value="customer-system-token", + ) as mock_mtls: + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + result = await agw_client.get_system_auth(app_tid="test-tid") + + assert isinstance(result, AuthResult) + assert result.access_token == "customer-system-token" + assert result.gateway_url == "https://agw.customer.com" + mock_load.assert_called_once_with("/path/to/credentials") + mock_mtls.assert_called_once_with(mock_creds, 60.0, "test-tid") + + @pytest.mark.asyncio + async def test_missing_tenant_raises_for_lob(self): + """Raise AgentGatewaySDKError when tenant_subdomain is missing for LoB.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ): + agw_client = create_client() + + with pytest.raises(AgentGatewaySDKError, match="tenant_subdomain is required"): + await agw_client.get_system_auth() + + @pytest.mark.asyncio + async def test_callable_tenant_subdomain(self): + """Accept callable for tenant_subdomain in LoB flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ) as mock_auth: + get_tenant = lambda: "dynamic-tenant" + agw_client = create_client(tenant_subdomain=get_tenant) + + await agw_client.get_system_auth() + + mock_auth.assert_called_once_with("dynamic-tenant") + + @pytest.mark.asyncio + async def test_wraps_unexpected_errors(self): + """Wrap unexpected errors in AgentGatewaySDKError.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + side_effect=RuntimeError("unexpected"), + ): + agw_client = create_client(tenant_subdomain="my-tenant") + + with pytest.raises(AgentGatewaySDKError, match="System auth acquisition failed"): + await agw_client.get_system_auth() + + +# ============================================================ +# Test: get_user_auth +# ============================================================ + + +class TestGetUserAuth: + """Tests for get_user_auth async method.""" + + @pytest.mark.asyncio + async def test_lob_flow_returns_auth_result(self): + """Return AuthResult from LoB flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("raw-user-jwt-token", "https://agw.example.com"), + ) as mock_auth: + agw_client = create_client(tenant_subdomain="my-tenant") + + result = await agw_client.get_user_auth(user_token="user-jwt") + + assert isinstance(result, AuthResult) + assert result.access_token == "raw-user-jwt-token" + assert result.gateway_url == "https://agw.example.com" + mock_auth.assert_called_once_with("user-jwt", "my-tenant") + + @pytest.mark.asyncio + async def test_customer_flow_exchanges_token(self): + """Exchange token via customer flow and return AuthResult.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, patch( + "sap_cloud_sdk.agentgateway.agw_client.exchange_user_token", + return_value="exchanged-token", + ) as mock_exchange: + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + result = await agw_client.get_user_auth( + user_token="user-jwt", app_tid="test-tid" + ) + + assert isinstance(result, AuthResult) + assert result.access_token == "exchanged-token" + assert result.gateway_url == "https://agw.customer.com" + mock_exchange.assert_called_once_with( + mock_creds, "user-jwt", 60.0, "test-tid" + ) + + @pytest.mark.asyncio + async def test_missing_user_token_raises(self): + """Raise AgentGatewaySDKError when user_token is empty.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ): + agw_client = create_client(tenant_subdomain="my-tenant") + + with pytest.raises(AgentGatewaySDKError, match="user_token is required"): + await agw_client.get_user_auth(user_token="") + + @pytest.mark.asyncio + async def test_callable_user_token(self): + """Accept callable for user_token.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ) as mock_auth: + agw_client = create_client(tenant_subdomain="my-tenant") + get_token = lambda: "dynamic-user-jwt" + + await agw_client.get_user_auth(user_token=get_token) + + mock_auth.assert_called_once_with("dynamic-user-jwt", "my-tenant") + + @pytest.mark.asyncio + async def test_missing_tenant_raises_for_lob(self): + """Raise AgentGatewaySDKError when tenant_subdomain is missing for LoB.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ): + agw_client = create_client() + + with pytest.raises(AgentGatewaySDKError, match="tenant_subdomain is required"): + await agw_client.get_user_auth(user_token="user-jwt") + + @pytest.mark.asyncio + async def test_wraps_unexpected_errors(self): + """Wrap unexpected errors in AgentGatewaySDKError.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + side_effect=RuntimeError("unexpected"), + ): + agw_client = create_client(tenant_subdomain="my-tenant") + + with pytest.raises(AgentGatewaySDKError, match="User auth exchange failed"): + await agw_client.get_user_auth(user_token="user-jwt") + + # ============================================================ # Test: list_mcp_tools # ============================================================ @@ -154,6 +372,11 @@ async def test_with_callable_tenant(self): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("system-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_lob", new_callable=AsyncMock, @@ -165,16 +388,21 @@ async def test_with_callable_tenant(self): await agw_client.list_mcp_tools() - mock_lob.assert_called_once_with("my-tenant", 60.0) + mock_lob.assert_called_once_with("my-tenant", "system-token", 60.0) @pytest.mark.asyncio - async def test_calls_lob_flow(self): - """list_mcp_tools should call LoB flow with correct parameters.""" + async def test_calls_lob_flow_with_system_token(self): + """list_mcp_tools should call LoB flow with system token.""" with ( patch( "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("system-token-xyz", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_lob", new_callable=AsyncMock, @@ -185,7 +413,7 @@ async def test_calls_lob_flow(self): await agw_client.list_mcp_tools() - mock_lob.assert_called_once_with("my-tenant", 60.0) + mock_lob.assert_called_once_with("my-tenant", "system-token-xyz", 60.0) @pytest.mark.asyncio async def test_returns_tools_from_lob_flow(self): @@ -206,6 +434,11 @@ async def test_returns_tools_from_lob_flow(self): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_lob", new_callable=AsyncMock, @@ -220,6 +453,34 @@ async def test_returns_tools_from_lob_flow(self): assert len(result) == 1 assert result[0].name == "tool1" + @pytest.mark.asyncio + async def test_customer_flow_passes_system_token(self): + """Customer flow passes pre-fetched system token to get_mcp_tools_customer.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, patch( + "sap_cloud_sdk.agentgateway.agw_client.get_system_token_mtls", + return_value="customer-system-token", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_customer", + new_callable=AsyncMock, + return_value=[], + ) as mock_customer: + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + await agw_client.list_mcp_tools(app_tid="tid") + + mock_customer.assert_called_once_with( + mock_creds, "customer-system-token", 60.0 + ) + # ============================================================ # Test: call_mcp_tool @@ -289,6 +550,11 @@ async def test_with_callable_user_token(self, mock_tool): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("exchanged-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, @@ -306,7 +572,7 @@ async def test_with_callable_user_token(self, mock_tool): assert result == "result" mock_lob.assert_called_once_with( - mock_tool, "my-jwt", "my-tenant", 60.0, param1="value1" + mock_tool, "exchanged-token", 60.0, param1="value1" ) @pytest.mark.asyncio @@ -317,6 +583,11 @@ async def test_with_callable_tenant_subdomain(self, mock_tool): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("exchanged-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, @@ -332,7 +603,7 @@ async def test_with_callable_tenant_subdomain(self, mock_tool): ) assert result == "result" - mock_lob.assert_called_once_with(mock_tool, "my-jwt", "my-tenant", 60.0) + mock_lob.assert_called_once_with(mock_tool, "exchanged-token", 60.0) @pytest.mark.asyncio async def test_customer_credentials_calls_customer_flow(self, mock_tool): @@ -345,12 +616,20 @@ async def test_customer_credentials_calls_customer_flow(self, mock_tool): patch( "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", ) as mock_load, + patch( + "sap_cloud_sdk.agentgateway.agw_client.exchange_user_token", + return_value="exchanged-token", + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_customer", new_callable=AsyncMock, return_value="customer result", ) as mock_customer, ): + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + agw_client = create_client(tenant_subdomain="my-tenant") result = await agw_client.call_mcp_tool( @@ -359,17 +638,62 @@ async def test_customer_credentials_calls_customer_flow(self, mock_tool): ) assert result == "customer result" + # load_customer_credentials is called once in get_user_auth() mock_load.assert_called_once_with("/path/to/credentials") - mock_customer.assert_called_once() + mock_customer.assert_called_once_with( + mock_tool, "exchanged-token", 60.0 + ) @pytest.mark.asyncio - async def test_calls_lob_flow(self, mock_tool): - """call_mcp_tool should call LoB flow with correct parameters.""" + async def test_customer_flow_falls_back_to_system_token(self, mock_tool): + """Customer flow falls back to system token when user_token is None.""" + with ( + patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, + patch( + "sap_cloud_sdk.agentgateway.agw_client.get_system_token_mtls", + return_value="system-token", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_customer", + new_callable=AsyncMock, + return_value="result with system token", + ) as mock_customer, + ): + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + result = await agw_client.call_mcp_tool( + tool=mock_tool, + user_token=None, + ) + + assert result == "result with system token" + mock_customer.assert_called_once_with( + mock_tool, "system-token", 60.0 + ) + + @pytest.mark.asyncio + async def test_calls_lob_flow_with_exchanged_token(self, mock_tool): + """call_mcp_tool should exchange user token and pass to LoB flow.""" with ( patch( "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("exchanged-user-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, @@ -386,7 +710,7 @@ async def test_calls_lob_flow(self, mock_tool): assert result == "tool result" mock_lob.assert_called_once_with( - mock_tool, "jwt-token", "my-tenant", 60.0, order_id="12345" + mock_tool, "exchanged-user-token", 60.0, order_id="12345" ) @pytest.mark.asyncio @@ -397,6 +721,11 @@ async def test_returns_result_from_lob_flow(self, mock_tool): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, diff --git a/tests/agentgateway/unit/test_customer.py b/tests/agentgateway/unit/test_customer.py index 4ed170b..f8db5fe 100644 --- a/tests/agentgateway/unit/test_customer.py +++ b/tests/agentgateway/unit/test_customer.py @@ -451,11 +451,11 @@ async def test_raises_when_empty_dependencies(self): with pytest.raises( AgentGatewaySDKError, match="integrationDependencies is empty" ): - await get_mcp_tools_customer(credentials, timeout=60.0) + await get_mcp_tools_customer(credentials, "system-token", 60.0) @pytest.mark.asyncio async def test_discovers_tools_from_credentials(self, credentials): - """Discover tools from integrationDependencies in credentials.""" + """Discover tools from integrationDependencies using pre-fetched token.""" mock_tools = [ MCPTool( name="list_cost_centers", @@ -467,21 +467,22 @@ async def test_discovers_tools_from_credentials(self, credentials): ] with ( - patch( - "sap_cloud_sdk.agentgateway._customer.get_system_token_mtls", - return_value="system-token", - ), patch( "sap_cloud_sdk.agentgateway._customer._list_server_tools", new_callable=AsyncMock, return_value=mock_tools, ) as mock_list, ): - result = await get_mcp_tools_customer(credentials, timeout=60.0) + result = await get_mcp_tools_customer( + credentials, "pre-fetched-system-token", 60.0 + ) assert len(result) == 1 assert result[0].name == "list_cost_centers" mock_list.assert_called_once() + # Verify the pre-fetched token was passed + call_args = mock_list.call_args[0] + assert call_args[1] == "pre-fetched-system-token" @pytest.mark.asyncio async def test_handles_server_error_gracefully(self): @@ -516,16 +517,14 @@ async def mock_list_tools(*args, **kwargs): return [mock_tool] with ( - patch( - "sap_cloud_sdk.agentgateway._customer.get_system_token_mtls", - return_value="system-token", - ), patch( "sap_cloud_sdk.agentgateway._customer._list_server_tools", side_effect=mock_list_tools, ), ): - result = await get_mcp_tools_customer(credentials, timeout=60.0) + result = await get_mcp_tools_customer( + credentials, "system-token", 60.0 + ) # Should still return tools from server2 assert len(result) == 1 @@ -572,13 +571,9 @@ def mock_tool(self): ) @pytest.mark.asyncio - async def test_exchanges_user_token_before_call(self, credentials, mock_tool): - """Exchange user token before making tool call.""" + async def test_calls_tool_with_pre_fetched_token(self, credentials, mock_tool): + """Call tool using pre-fetched auth token.""" with ( - patch( - "sap_cloud_sdk.agentgateway._customer.exchange_user_token", - return_value="exchanged-token", - ) as mock_exchange, patch( "httpx.AsyncClient", ) as mock_client_class, @@ -615,25 +610,19 @@ async def test_exchanges_user_token_before_call(self, credentials, mock_tool): mock_session_class.return_value = mock_session_ctx result = await call_mcp_tool_customer( - credentials, mock_tool, "user-jwt", 60.0, order_id="12345" + mock_tool, "pre-fetched-token", 60.0, order_id="12345" ) assert result == "Order created successfully" - mock_exchange.assert_called_once_with(credentials, "user-jwt", 60.0, None) + # Verify the token was used in the Authorization header + mock_client_class.assert_called_once() + call_kwargs = mock_client_class.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == "Bearer pre-fetched-token" @pytest.mark.asyncio - async def test_uses_system_token_when_user_token_not_provided( - self, credentials, mock_tool - ): - """Fall back to system token when user_token is None (IBD workaround).""" + async def test_returns_empty_string_when_no_content(self, credentials, mock_tool): + """Return empty string when tool returns no content.""" with ( - patch( - "sap_cloud_sdk.agentgateway._customer.get_system_token_mtls", - return_value="system-token", - ) as mock_system_token, - patch( - "sap_cloud_sdk.agentgateway._customer.exchange_user_token", - ) as mock_exchange, patch( "httpx.AsyncClient", ) as mock_client_class, @@ -644,7 +633,6 @@ async def test_uses_system_token_when_user_token_not_provided( "sap_cloud_sdk.agentgateway._customer.ClientSession", ) as mock_session_class, ): - # Set up mock chain mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) @@ -660,21 +648,15 @@ async def test_uses_system_token_when_user_token_not_provided( mock_session = AsyncMock() mock_session.initialize = AsyncMock() mock_result = MagicMock() - mock_content = MagicMock() - mock_content.text = "Result with system token" - mock_result.content = [mock_content] + mock_result.content = [] mock_session.call_tool = AsyncMock(return_value=mock_result) mock_session_ctx = AsyncMock() mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) mock_session_ctx.__aexit__ = AsyncMock(return_value=None) mock_session_class.return_value = mock_session_ctx - # Call without user_token (None) result = await call_mcp_tool_customer( - credentials, mock_tool, None, 60.0, order_id="12345" + mock_tool, "auth-token", 60.0 ) - assert result == "Result with system token" - # Should use system token, not exchange - mock_system_token.assert_called_once_with(credentials, 60.0, None) - mock_exchange.assert_not_called() + assert result == "" diff --git a/tests/agentgateway/unit/test_lob.py b/tests/agentgateway/unit/test_lob.py index 35b5e05..8088a4c 100644 --- a/tests/agentgateway/unit/test_lob.py +++ b/tests/agentgateway/unit/test_lob.py @@ -10,13 +10,15 @@ _fetch_auth_token, list_mcp_fragments, get_ias_fragment_name, - get_system_auth, - get_user_auth, + get_ias_user_fragment_name, + fetch_system_auth, + fetch_user_auth, get_mcp_tools_lob, call_mcp_tool_lob, _LABEL_KEY, _MCP_LABEL_VALUE, _IAS_LABEL_VALUE, + _IAS_USER_LABEL_VALUE, ) from sap_cloud_sdk.agentgateway._models import MCPTool from sap_cloud_sdk.agentgateway.exceptions import MCPServerNotFoundError @@ -61,11 +63,13 @@ def test_raises_when_env_not_set(self): class TestFetchAuthToken: """Tests for _fetch_auth_token function.""" - def test_fetches_token_successfully(self): - """Fetch auth token from destination service.""" + def test_fetches_and_decodes_token_and_url(self): + """Strip Bearer prefix from auth header and return raw JWT with gateway URL.""" + header_value = "Bearer my-raw-jwt-token-123" mock_dest = MagicMock() mock_dest.auth_tokens = [MagicMock()] - mock_dest.auth_tokens[0].http_header = {"value": "Bearer test-token"} + mock_dest.auth_tokens[0].http_header = {"value": header_value} + mock_dest.url = "https://agw.example.com/" with patch( "sap_cloud_sdk.agentgateway._lob.create_destination_client" @@ -74,7 +78,7 @@ def test_fetches_token_successfully(self): result = _fetch_auth_token("dest-name", "tenant-sub") - assert result == "Bearer test-token" + assert result == ("my-raw-jwt-token-123", "https://agw.example.com") mock_client.return_value.get_destination.assert_called_once_with( "dest-name", level=ConsumptionLevel.PROVIDER_SUBACCOUNT, @@ -82,6 +86,21 @@ def test_fetches_token_successfully(self): tenant="tenant-sub", ) + def test_strips_trailing_slashes_from_url(self): + """Strip trailing slashes from gateway URL.""" + header_value = "Bearer token" + mock_dest = MagicMock() + mock_dest.auth_tokens = [MagicMock()] + mock_dest.auth_tokens[0].http_header = {"value": header_value} + mock_dest.url = "https://agw.example.com/v1/mcp///" + + with patch("sap_cloud_sdk.agentgateway._lob.create_destination_client") as mock_client: + mock_client.return_value.get_destination.return_value = mock_dest + + result = _fetch_auth_token("dest-name", "tenant-sub") + + assert result == ("token", "https://agw.example.com/v1/mcp") + def test_raises_when_no_destination(self): """Raise MCPServerNotFoundError when destination is None.""" with patch( @@ -105,8 +124,8 @@ def test_raises_when_no_auth_tokens(self): with pytest.raises(MCPServerNotFoundError, match="No auth token"): _fetch_auth_token("dest-name", "tenant-sub") - def test_raises_when_empty_auth_header(self): - """Raise MCPServerNotFoundError when auth header is empty.""" + def test_raises_when_empty_token_value(self): + """Raise MCPServerNotFoundError when http_header value is empty.""" mock_dest = MagicMock() mock_dest.auth_tokens = [MagicMock()] mock_dest.auth_tokens[0].http_header = {"value": ""} @@ -116,7 +135,7 @@ def test_raises_when_empty_auth_header(self): ) as mock_client: mock_client.return_value.get_destination.return_value = mock_dest - with pytest.raises(MCPServerNotFoundError, match="Empty Authorization"): + with pytest.raises(MCPServerNotFoundError, match="Empty auth header"): _fetch_auth_token("dest-name", "tenant-sub") def test_passes_options_to_destination(self): @@ -124,6 +143,7 @@ def test_passes_options_to_destination(self): mock_dest = MagicMock() mock_dest.auth_tokens = [MagicMock()] mock_dest.auth_tokens[0].http_header = {"value": "Bearer token"} + mock_dest.url = "https://agw.example.com" mock_options = MagicMock() with patch( @@ -242,16 +262,65 @@ def test_raises_when_no_fragment_found(self): # ============================================================ -# Test: get_system_auth +# Test: get_ias_user_fragment_name # ============================================================ -class TestGetSystemAuth: - """Tests for get_system_auth async function.""" +class TestGetIasUserFragmentName: + """Tests for get_ias_user_fragment_name function.""" + + def test_returns_fragment_name(self): + """Return name of first IAS user fragment found.""" + fragment = MagicMock() + fragment.name = "sap-managed-runtime-agw-subscriber-ias-user-abc123" + + with patch("sap_cloud_sdk.agentgateway._lob.create_fragment_client") as mock_client: + mock_client.return_value.list_instance_fragments.return_value = [fragment] + + result = get_ias_user_fragment_name("tenant-sub") + + assert result == "sap-managed-runtime-agw-subscriber-ias-user-abc123" + + def test_uses_correct_filter_labels(self): + """Use correct label filter for IAS user fragments.""" + fragment = MagicMock() + fragment.name = "ias-user-fragment" + + with patch("sap_cloud_sdk.agentgateway._lob.create_fragment_client") as mock_client: + mock_client.return_value.list_instance_fragments.return_value = [fragment] + + get_ias_user_fragment_name("tenant-sub") + + call_args = mock_client.return_value.list_instance_fragments.call_args + filter_opt = call_args.kwargs.get("filter") + assert filter_opt is not None + assert len(filter_opt.filter_labels) == 1 + assert filter_opt.filter_labels[0].key == _LABEL_KEY + assert filter_opt.filter_labels[0].values == [_IAS_USER_LABEL_VALUE] + + def test_raises_when_no_fragment_found(self): + """Raise MCPServerNotFoundError when no IAS user fragment exists.""" + with patch("sap_cloud_sdk.agentgateway._lob.create_fragment_client") as mock_client: + mock_client.return_value.list_instance_fragments.return_value = [] + + with pytest.raises(MCPServerNotFoundError, match="No IAS user fragment found"): + get_ias_user_fragment_name("tenant-sub") + + +# ============================================================ +# Test: fetch_system_auth +# ============================================================ + + +class TestFetchSystemAuth: + """Tests for fetch_system_auth async function.""" @pytest.mark.asyncio async def test_fetches_system_auth(self): - """Fetch system auth using IAS fragment looked up by label.""" + """Fetch system auth using IAS fragment and return tuple (token, url).""" + raw_token = "system-jwt-token-xyz" + gateway_url = "https://agw.example.com" + with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): with ( patch( @@ -262,11 +331,11 @@ async def test_fetches_system_auth(self): ) as mock_fetch, ): mock_ias.return_value = "sap-managed-runtime-agw-subscriber-ias-abc" - mock_fetch.return_value = "Bearer system-token" + mock_fetch.return_value = (raw_token, gateway_url) - result = await get_system_auth("tenant-sub") + result = await fetch_system_auth("tenant-sub") - assert result == "Bearer system-token" + assert result == (raw_token, gateway_url) mock_ias.assert_called_once_with("tenant-sub") mock_fetch.assert_called_once() call_args = mock_fetch.call_args @@ -280,32 +349,38 @@ async def test_fetches_system_auth(self): # ============================================================ -# Test: get_user_auth +# Test: fetch_user_auth # ============================================================ -class TestGetUserAuth: - """Tests for get_user_auth async function.""" +class TestFetchUserAuth: + """Tests for fetch_user_auth async function.""" @pytest.mark.asyncio - async def test_fetches_user_auth_with_token_exchange(self): - """Fetch user auth with token exchange.""" + async def test_fetches_user_auth_with_ias_user_fragment(self): + """Fetch user auth using IAS user fragment and user_token, return tuple.""" + raw_token = "exchanged-user-jwt-token" + gateway_url = "https://agw.example.com" + with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): - with patch( - "sap_cloud_sdk.agentgateway._lob._fetch_auth_token" - ) as mock_fetch: - mock_fetch.return_value = "Bearer user-token" + with ( + patch("sap_cloud_sdk.agentgateway._lob.get_ias_user_fragment_name") as mock_ias_user, + patch("sap_cloud_sdk.agentgateway._lob._fetch_auth_token") as mock_fetch, + ): + mock_ias_user.return_value = "sap-managed-runtime-agw-subscriber-ias-user-abc" + mock_fetch.return_value = (raw_token, gateway_url) - result = await get_user_auth("mcp-fragment", "user-jwt", "tenant-sub") + result = await fetch_user_auth("user-jwt", "tenant-sub") - assert result == "Bearer user-token" + assert result == (raw_token, gateway_url) + mock_ias_user.assert_called_once_with("tenant-sub") mock_fetch.assert_called_once() call_args = mock_fetch.call_args assert call_args[0][0] == "sap-managed-runtime-ias-eu10" assert call_args[0][1] == "tenant-sub" options = call_args[0][2] assert options.user_token == "user-jwt" - assert options.fragment_name == "mcp-fragment" + assert options.fragment_name == "sap-managed-runtime-agw-subscriber-ias-user-abc" assert options.fragment_level == ConsumptionLevel.INSTANCE @@ -323,7 +398,7 @@ async def test_returns_empty_when_no_fragments(self): with patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list: mock_list.return_value = [] - result = await get_mcp_tools_lob("tenant-sub", 60.0) + result = await get_mcp_tools_lob("tenant-sub", "system-token", 60.0) assert result == [] @@ -337,13 +412,13 @@ async def test_skips_fragments_without_url(self): with patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list: mock_list.return_value = [fragment] - result = await get_mcp_tools_lob("tenant-sub", 60.0) + result = await get_mcp_tools_lob("tenant-sub", "system-token", 60.0) assert result == [] @pytest.mark.asyncio - async def test_uses_fragment_name_directly(self): - """Use fragment name as-is (no -technical stripping).""" + async def test_uses_pre_fetched_system_token(self): + """Use the pre-fetched system token for MCP server calls.""" fragment = MagicMock() fragment.name = "mcp-server-a" fragment.properties = {"URL": "https://example.com/mcp"} @@ -359,27 +434,20 @@ async def test_uses_fragment_name_directly(self): with ( patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list, - patch( - "sap_cloud_sdk.agentgateway._lob.get_system_auth", - new_callable=AsyncMock, - ) as mock_auth, patch( "sap_cloud_sdk.agentgateway._lob.list_server_tools", new_callable=AsyncMock, ) as mock_tools, ): mock_list.return_value = [fragment] - mock_auth.return_value = "Bearer token" mock_tools.return_value = [mock_tool] - await get_mcp_tools_lob("tenant-sub", 60.0) + await get_mcp_tools_lob("tenant-sub", "pre-fetched-token", 60.0) - # Verify get_system_auth called with just tenant_subdomain - mock_auth.assert_called_once_with("tenant-sub") - # Verify list_server_tools called with the unchanged fragment name - mock_tools.assert_called_once() - call_args = mock_tools.call_args[0] - assert call_args[2] == "mcp-server-a" + # Verify list_server_tools called with the pre-fetched token + mock_tools.assert_called_once_with( + "https://example.com/mcp", "pre-fetched-token", "mcp-server-a", 60.0 + ) @pytest.mark.asyncio async def test_handles_exception_for_single_fragment(self): @@ -401,24 +469,25 @@ async def test_handles_exception_for_single_fragment(self): fragment_name="mcp-server2", ) + call_count = 0 + + async def mock_list_tools_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("Server connection failed") + return [mock_tool] + with ( patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list, - patch( - "sap_cloud_sdk.agentgateway._lob.get_system_auth", - new_callable=AsyncMock, - ) as mock_auth, patch( "sap_cloud_sdk.agentgateway._lob.list_server_tools", - new_callable=AsyncMock, - ) as mock_tools, + side_effect=mock_list_tools_fn, + ), ): mock_list.return_value = [fragment1, fragment2] - # First fragment fails, second succeeds - mock_auth.side_effect = [Exception("Auth failed"), "Bearer token"] - mock_tools.return_value = [mock_tool] - - result = await get_mcp_tools_lob("tenant-sub", 60.0) + result = await get_mcp_tools_lob("tenant-sub", "system-token", 60.0) # Should still get tools from second fragment assert len(result) == 1 @@ -434,8 +503,8 @@ class TestCallMcpToolLob: """Tests for call_mcp_tool_lob async function.""" @pytest.mark.asyncio - async def test_calls_tool_with_user_auth(self): - """Call tool using user authentication.""" + async def test_calls_tool_with_pre_fetched_token(self): + """Call tool using pre-fetched user auth token.""" tool = MCPTool( name="test-tool", server_name="test-server", @@ -450,17 +519,12 @@ async def test_calls_tool_with_user_auth(self): mock_result.content[0].text = "Tool result" with ( - patch( - "sap_cloud_sdk.agentgateway._lob.get_user_auth", new_callable=AsyncMock - ) as mock_auth, patch("sap_cloud_sdk.agentgateway._lob.httpx.AsyncClient") as mock_http, patch( "sap_cloud_sdk.agentgateway._lob.streamable_http_client" ) as mock_stream, patch("sap_cloud_sdk.agentgateway._lob.ClientSession") as mock_session, ): - mock_auth.return_value = "Bearer user-token" - # Setup async context managers mock_http_instance = AsyncMock() mock_http.return_value.__aenter__.return_value = mock_http_instance @@ -477,15 +541,19 @@ async def test_calls_tool_with_user_auth(self): mock_session.return_value.__aenter__.return_value = mock_session_instance result = await call_mcp_tool_lob( - tool, "user-jwt", "tenant-sub", 60.0, param1="value1" + tool, "user-auth-token", 60.0, param1="value1" ) assert result == "Tool result" - mock_auth.assert_called_once_with("test-fragment", "user-jwt", "tenant-sub") mock_session_instance.call_tool.assert_called_once_with( "test-tool", {"param1": "value1"} ) + # Verify the Authorization header uses Bearer + raw token + mock_http.assert_called_once() + call_kwargs = mock_http.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == "Bearer user-auth-token" + @pytest.mark.asyncio async def test_returns_empty_string_when_no_content(self): """Return empty string when tool returns no content.""" @@ -502,17 +570,12 @@ async def test_returns_empty_string_when_no_content(self): mock_result.content = [] with ( - patch( - "sap_cloud_sdk.agentgateway._lob.get_user_auth", new_callable=AsyncMock - ) as mock_auth, patch("sap_cloud_sdk.agentgateway._lob.httpx.AsyncClient") as mock_http, patch( "sap_cloud_sdk.agentgateway._lob.streamable_http_client" ) as mock_stream, patch("sap_cloud_sdk.agentgateway._lob.ClientSession") as mock_session, ): - mock_auth.return_value = "Bearer user-token" - mock_http_instance = AsyncMock() mock_http.return_value.__aenter__.return_value = mock_http_instance @@ -527,6 +590,6 @@ async def test_returns_empty_string_when_no_content(self): mock_session_instance.call_tool = AsyncMock(return_value=mock_result) mock_session.return_value.__aenter__.return_value = mock_session_instance - result = await call_mcp_tool_lob(tool, "user-jwt", "tenant-sub", 60.0) + result = await call_mcp_tool_lob(tool, "user-auth-token", 60.0) assert result == "" diff --git a/tests/core/unit/telemetry/test_operation.py b/tests/core/unit/telemetry/test_operation.py index 1205626..651b3ec 100644 --- a/tests/core/unit/telemetry/test_operation.py +++ b/tests/core/unit/telemetry/test_operation.py @@ -180,5 +180,5 @@ def test_operation_count(self): """Test that we have the expected number of operations.""" all_operations = list(Operation) # 3 auditlog + 11 destination + 10 certificate + 10 fragment + 8 objectstore - # + 2 extensibility + 2 aicore + 23 dms + 2 agentgateway + 13 agent_memory = 84 - assert len(all_operations) == 84 + # + 2 extensibility + 2 aicore + 23 dms + 4 agentgateway + 13 agent_memory = 86 + assert len(all_operations) == 86 diff --git a/uv.lock b/uv.lock index 0e5cebd..1e99357 100644 --- a/uv.lock +++ b/uv.lock @@ -2924,7 +2924,7 @@ wheels = [ [[package]] name = "sap-cloud-sdk" -version = "0.19.3" +version = "0.20.0" source = { editable = "." } dependencies = [ { name = "grpcio" },