|
16 | 16 | elicitation-sep1034-client-defaults - Elicitation with default accept callback |
17 | 17 | auth/client-credentials-jwt - Client credentials with private_key_jwt |
18 | 18 | auth/client-credentials-basic - Client credentials with client_secret_basic |
| 19 | + auth/enterprise-token-exchange - Enterprise auth with OIDC ID token (SEP-990) |
| 20 | + auth/enterprise-saml-exchange - Enterprise auth with SAML assertion (SEP-990) |
| 21 | + auth/enterprise-id-jag-validation - Validate ID-JAG token structure (SEP-990) |
19 | 22 | auth/* - Authorization code flow (default for auth scenarios) |
20 | 23 | """ |
21 | 24 |
|
@@ -293,6 +296,255 @@ async def run_auth_code_client(server_url: str) -> None: |
293 | 296 | await _run_auth_session(server_url, oauth_auth) |
294 | 297 |
|
295 | 298 |
|
| 299 | +@register("auth/enterprise-token-exchange") |
| 300 | +async def run_enterprise_token_exchange(server_url: str) -> None: |
| 301 | + """Enterprise managed auth: Token exchange flow (RFC 8693).""" |
| 302 | + from mcp.client.auth.extensions.enterprise_managed_auth import ( |
| 303 | + EnterpriseAuthOAuthClientProvider, |
| 304 | + TokenExchangeParameters, |
| 305 | + ) |
| 306 | + |
| 307 | + context = get_conformance_context() |
| 308 | + id_token = context.get("id_token") |
| 309 | + idp_token_endpoint = context.get("idp_token_endpoint") |
| 310 | + mcp_server_auth_issuer = context.get("mcp_server_auth_issuer") |
| 311 | + mcp_server_resource_id = context.get("mcp_server_resource_id") |
| 312 | + scope = context.get("scope") |
| 313 | + |
| 314 | + if not id_token: |
| 315 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'id_token'") |
| 316 | + if not idp_token_endpoint: |
| 317 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'") |
| 318 | + if not mcp_server_auth_issuer: |
| 319 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'") |
| 320 | + if not mcp_server_resource_id: |
| 321 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'") |
| 322 | + |
| 323 | + # Create token exchange parameters |
| 324 | + token_exchange_params = TokenExchangeParameters.from_id_token( |
| 325 | + id_token=id_token, |
| 326 | + mcp_server_auth_issuer=mcp_server_auth_issuer, |
| 327 | + mcp_server_resource_id=mcp_server_resource_id, |
| 328 | + scope=scope, |
| 329 | + ) |
| 330 | + |
| 331 | + # Create enterprise auth provider |
| 332 | + enterprise_auth = EnterpriseAuthOAuthClientProvider( |
| 333 | + server_url=server_url, |
| 334 | + client_metadata=OAuthClientMetadata( |
| 335 | + client_name="conformance-enterprise-client", |
| 336 | + redirect_uris=[AnyUrl("http://localhost:3000/callback")], |
| 337 | + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], |
| 338 | + response_types=["token"], |
| 339 | + ), |
| 340 | + storage=InMemoryTokenStorage(), |
| 341 | + idp_token_endpoint=idp_token_endpoint, |
| 342 | + token_exchange_params=token_exchange_params, |
| 343 | + ) |
| 344 | + |
| 345 | + # Perform token exchange flow |
| 346 | + async with httpx.AsyncClient() as client: |
| 347 | + # Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow) |
| 348 | + logger.debug(f"Setting OAuth metadata for {server_url}") |
| 349 | + from pydantic import AnyUrl as PydanticAnyUrl |
| 350 | + |
| 351 | + from mcp.shared.auth import OAuthMetadata |
| 352 | + |
| 353 | + # Extract base URL from server_url |
| 354 | + base_url = server_url.replace("/mcp", "") |
| 355 | + token_endpoint_url = f"{base_url}/oauth/token" |
| 356 | + auth_endpoint_url = f"{base_url}/oauth/authorize" |
| 357 | + |
| 358 | + enterprise_auth.context.oauth_metadata = OAuthMetadata( |
| 359 | + issuer=mcp_server_auth_issuer, |
| 360 | + authorization_endpoint=PydanticAnyUrl(auth_endpoint_url), |
| 361 | + token_endpoint=PydanticAnyUrl(token_endpoint_url), |
| 362 | + ) |
| 363 | + logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}") |
| 364 | + |
| 365 | + # Step 2: Exchange ID token for ID-JAG |
| 366 | + logger.debug("Exchanging ID token for ID-JAG") |
| 367 | + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) |
| 368 | + logger.debug(f"Obtained ID-JAG: {id_jag[:50]}...") |
| 369 | + |
| 370 | + # Step 3: Exchange ID-JAG for access token |
| 371 | + logger.debug("Exchanging ID-JAG for access token") |
| 372 | + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) |
| 373 | + logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s") |
| 374 | + |
| 375 | + # Step 4: Verify we can make authenticated requests |
| 376 | + logger.debug("Verifying access token with MCP endpoint") |
| 377 | + auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"}) |
| 378 | + response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp") |
| 379 | + if response.status_code == 200: |
| 380 | + logger.debug(f"Successfully authenticated with MCP server: {response.json()}") |
| 381 | + else: |
| 382 | + logger.warning(f"MCP server returned {response.status_code}") |
| 383 | + |
| 384 | + logger.debug("Enterprise auth flow completed successfully") |
| 385 | + |
| 386 | + |
| 387 | +@register("auth/enterprise-saml-exchange") |
| 388 | +async def run_enterprise_saml_exchange(server_url: str) -> None: |
| 389 | + """Enterprise managed auth: SAML assertion exchange flow.""" |
| 390 | + from mcp.client.auth.extensions.enterprise_managed_auth import ( |
| 391 | + EnterpriseAuthOAuthClientProvider, |
| 392 | + TokenExchangeParameters, |
| 393 | + ) |
| 394 | + |
| 395 | + context = get_conformance_context() |
| 396 | + saml_assertion = context.get("saml_assertion") |
| 397 | + idp_token_endpoint = context.get("idp_token_endpoint") |
| 398 | + mcp_server_auth_issuer = context.get("mcp_server_auth_issuer") |
| 399 | + mcp_server_resource_id = context.get("mcp_server_resource_id") |
| 400 | + scope = context.get("scope") |
| 401 | + |
| 402 | + if not saml_assertion: |
| 403 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'saml_assertion'") |
| 404 | + if not idp_token_endpoint: |
| 405 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'") |
| 406 | + if not mcp_server_auth_issuer: |
| 407 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'") |
| 408 | + if not mcp_server_resource_id: |
| 409 | + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'") |
| 410 | + |
| 411 | + # Create token exchange parameters for SAML |
| 412 | + token_exchange_params = TokenExchangeParameters.from_saml_assertion( |
| 413 | + saml_assertion=saml_assertion, |
| 414 | + mcp_server_auth_issuer=mcp_server_auth_issuer, |
| 415 | + mcp_server_resource_id=mcp_server_resource_id, |
| 416 | + scope=scope, |
| 417 | + ) |
| 418 | + |
| 419 | + # Create enterprise auth provider |
| 420 | + enterprise_auth = EnterpriseAuthOAuthClientProvider( |
| 421 | + server_url=server_url, |
| 422 | + client_metadata=OAuthClientMetadata( |
| 423 | + client_name="conformance-enterprise-saml-client", |
| 424 | + redirect_uris=[AnyUrl("http://localhost:3000/callback")], |
| 425 | + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], |
| 426 | + response_types=["token"], |
| 427 | + ), |
| 428 | + storage=InMemoryTokenStorage(), |
| 429 | + idp_token_endpoint=idp_token_endpoint, |
| 430 | + token_exchange_params=token_exchange_params, |
| 431 | + ) |
| 432 | + |
| 433 | + # Perform token exchange flow |
| 434 | + async with httpx.AsyncClient() as client: |
| 435 | + # Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow) |
| 436 | + logger.debug(f"Setting OAuth metadata for {server_url}") |
| 437 | + from pydantic import AnyUrl as PydanticAnyUrl |
| 438 | + |
| 439 | + from mcp.shared.auth import OAuthMetadata |
| 440 | + |
| 441 | + # Extract base URL from server_url |
| 442 | + base_url = server_url.replace("/mcp", "") |
| 443 | + token_endpoint_url = f"{base_url}/oauth/token" |
| 444 | + auth_endpoint_url = f"{base_url}/oauth/authorize" |
| 445 | + |
| 446 | + enterprise_auth.context.oauth_metadata = OAuthMetadata( |
| 447 | + issuer=mcp_server_auth_issuer, |
| 448 | + authorization_endpoint=PydanticAnyUrl(auth_endpoint_url), |
| 449 | + token_endpoint=PydanticAnyUrl(token_endpoint_url), |
| 450 | + ) |
| 451 | + logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}") |
| 452 | + |
| 453 | + # Step 2: Exchange SAML assertion for ID-JAG |
| 454 | + logger.debug("Exchanging SAML assertion for ID-JAG") |
| 455 | + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) |
| 456 | + logger.debug(f"Obtained ID-JAG from SAML: {id_jag[:50]}...") |
| 457 | + |
| 458 | + # Step 3: Exchange ID-JAG for access token |
| 459 | + logger.debug("Exchanging ID-JAG for access token") |
| 460 | + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) |
| 461 | + logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s") |
| 462 | + |
| 463 | + # Step 4: Verify we can make authenticated requests |
| 464 | + logger.debug("Verifying access token with MCP endpoint") |
| 465 | + auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"}) |
| 466 | + response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp") |
| 467 | + if response.status_code == 200: |
| 468 | + logger.debug(f"Successfully authenticated with MCP server: {response.json()}") |
| 469 | + else: |
| 470 | + logger.warning(f"MCP server returned {response.status_code}") |
| 471 | + |
| 472 | + logger.debug("SAML enterprise auth flow completed successfully") |
| 473 | + |
| 474 | + |
| 475 | +@register("auth/enterprise-id-jag-validation") |
| 476 | +async def run_id_jag_validation(server_url: str) -> None: |
| 477 | + """Validate ID-JAG token structure and claims.""" |
| 478 | + from mcp.client.auth.extensions.enterprise_managed_auth import ( |
| 479 | + EnterpriseAuthOAuthClientProvider, |
| 480 | + TokenExchangeParameters, |
| 481 | + decode_id_jag, |
| 482 | + validate_token_exchange_params, |
| 483 | + ) |
| 484 | + |
| 485 | + context = get_conformance_context() |
| 486 | + id_token = context.get("id_token") |
| 487 | + idp_token_endpoint = context.get("idp_token_endpoint") |
| 488 | + mcp_server_auth_issuer = context.get("mcp_server_auth_issuer") |
| 489 | + mcp_server_resource_id = context.get("mcp_server_resource_id") |
| 490 | + |
| 491 | + if not all([id_token, idp_token_endpoint, mcp_server_auth_issuer, mcp_server_resource_id]): |
| 492 | + raise RuntimeError("Missing required context parameters for ID-JAG validation") |
| 493 | + |
| 494 | + # Create and validate token exchange parameters |
| 495 | + token_exchange_params = TokenExchangeParameters.from_id_token( |
| 496 | + id_token=id_token, |
| 497 | + mcp_server_auth_issuer=mcp_server_auth_issuer, |
| 498 | + mcp_server_resource_id=mcp_server_resource_id, |
| 499 | + ) |
| 500 | + |
| 501 | + logger.debug("Validating token exchange parameters") |
| 502 | + validate_token_exchange_params(token_exchange_params) |
| 503 | + logger.debug("Token exchange parameters validated successfully") |
| 504 | + |
| 505 | + # Create enterprise auth provider |
| 506 | + enterprise_auth = EnterpriseAuthOAuthClientProvider( |
| 507 | + server_url=server_url, |
| 508 | + client_metadata=OAuthClientMetadata( |
| 509 | + client_name="conformance-validation-client", |
| 510 | + redirect_uris=[AnyUrl("http://localhost:3000/callback")], |
| 511 | + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], |
| 512 | + response_types=["token"], |
| 513 | + ), |
| 514 | + storage=InMemoryTokenStorage(), |
| 515 | + idp_token_endpoint=idp_token_endpoint, |
| 516 | + token_exchange_params=token_exchange_params, |
| 517 | + ) |
| 518 | + |
| 519 | + async with httpx.AsyncClient() as client: |
| 520 | + # Get ID-JAG |
| 521 | + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) |
| 522 | + logger.debug(f"Obtained ID-JAG for validation: {id_jag[:50]}...") |
| 523 | + |
| 524 | + # Decode and validate ID-JAG claims |
| 525 | + logger.debug("Decoding ID-JAG token") |
| 526 | + claims = decode_id_jag(id_jag) |
| 527 | + |
| 528 | + # Validate required claims |
| 529 | + assert claims.typ == "oauth-id-jag+jwt", f"Invalid typ: {claims.typ}" |
| 530 | + assert claims.jti, "Missing jti claim" |
| 531 | + assert claims.iss == mcp_server_auth_issuer or claims.iss, "Missing or invalid iss claim" |
| 532 | + assert claims.sub, "Missing sub claim" |
| 533 | + assert claims.aud, "Missing aud claim" |
| 534 | + assert claims.resource == mcp_server_resource_id, f"Invalid resource: {claims.resource}" |
| 535 | + assert claims.client_id, "Missing client_id claim" |
| 536 | + assert claims.exp > claims.iat, "Invalid expiration" |
| 537 | + |
| 538 | + logger.debug("ID-JAG validated successfully:") |
| 539 | + logger.debug(f" Subject: {claims.sub}") |
| 540 | + logger.debug(f" Issuer: {claims.iss}") |
| 541 | + logger.debug(f" Audience: {claims.aud}") |
| 542 | + logger.debug(f" Resource: {claims.resource}") |
| 543 | + logger.debug(f" Client ID: {claims.client_id}") |
| 544 | + |
| 545 | + logger.debug("ID-JAG validation completed successfully") |
| 546 | + |
| 547 | + |
296 | 548 | async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None: |
297 | 549 | """Common session logic for all OAuth flows.""" |
298 | 550 | client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0) |
|
0 commit comments