From 9aea20fb9c50f7215513aec7bf0ba32c127041c5 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 16:01:11 +0000 Subject: [PATCH 1/5] feat: add TokenProvider for composable bearer-token auth (non-breaking) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a minimal `() => Promise` function type as a lightweight alternative to OAuthClientProvider, for scenarios where bearer tokens are managed externally (gateway/proxy patterns, service accounts, API keys). - New TokenProvider type + withBearerAuth(getToken, fetchFn?) helper - New tokenProvider option on StreamableHTTPClientTransport and SSEClientTransport, used as fallback after authProvider in _commonHeaders(). authProvider takes precedence when both set. - On 401 with tokenProvider (no authProvider), transports throw UnauthorizedError — no retry, since tokenProvider() is already called before every request and would likely return the same rejected token. Callers catch UnauthorizedError, invalidate external cache, reconnect. - Exported previously-internal auth helpers for building custom flows: applyBasicAuth, applyPostAuth, applyPublicAuth, executeTokenRequest. - Tests, example, docs, changeset. Zero breakage. Bughunter fleet review: 28 findings submitted, 2 confirmed, both addressed. --- .changeset/token-provider-composable-auth.md | 10 + docs/client.md | 16 +- examples/client/src/clientGuide.examples.ts | 13 +- examples/client/src/simpleTokenProvider.ts | 69 ++++++ packages/client/src/client/auth.ts | 8 +- packages/client/src/client/sse.ts | 67 ++++-- packages/client/src/client/streamableHttp.ts | 85 ++++--- packages/client/src/client/tokenProvider.ts | 53 +++++ packages/client/src/index.ts | 1 + .../client/test/client/tokenProvider.test.ts | 208 ++++++++++++++++++ 10 files changed, 478 insertions(+), 52 deletions(-) create mode 100644 .changeset/token-provider-composable-auth.md create mode 100644 examples/client/src/simpleTokenProvider.ts create mode 100644 packages/client/src/client/tokenProvider.ts create mode 100644 packages/client/test/client/tokenProvider.test.ts diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md new file mode 100644 index 000000000..50b296298 --- /dev/null +++ b/.changeset/token-provider-composable-auth.md @@ -0,0 +1,10 @@ +--- +'@modelcontextprotocol/client': minor +--- + +Add `TokenProvider` for simple bearer-token authentication and export composable auth primitives + +- New `TokenProvider` type — a minimal `() => Promise` function interface for supplying bearer tokens. Use this instead of `OAuthClientProvider` when tokens are managed externally (gateway/proxy patterns, service accounts, upfront API tokens, or any scenario where the full OAuth redirect flow is not needed). +- New `tokenProvider` option on `StreamableHTTPClientTransport` and `SSEClientTransport`. Called before every request to obtain a fresh token. If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. +- New `withBearerAuth(getToken, fetchFn?)` helper that wraps a fetch function to inject `Authorization: Bearer` headers — useful for composing with other fetch middleware. +- Exported previously-internal auth helpers for building custom auth flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. diff --git a/docs/client.md b/docs/client.md index 782ab885b..467df2789 100644 --- a/docs/client.md +++ b/docs/client.md @@ -13,7 +13,7 @@ A client connects to a server, discovers what it offers — tools, resources, pr The examples below use these imports. Adjust based on which features and transport you need: ```ts source="../examples/client/src/clientGuide.examples.ts#imports" -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -113,7 +113,19 @@ console.log(systemPrompt); ## Authentication -MCP servers can require OAuth 2.0 authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an `authProvider` to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport} to enable this — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. +MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). For servers that accept plain bearer tokens, pass a `tokenProvider` function to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. For servers that require OAuth 2.0, pass an `authProvider` — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. + +### Token provider + +For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials, or tokens obtained through a separate auth flow — pass a {@linkcode @modelcontextprotocol/client!client/tokenProvider.TokenProvider | TokenProvider} function. It is called before every request, so it can handle expiry and refresh internally. If the server rejects the token with 401, the transport throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} without retrying — catch it to invalidate any external cache and reconnect: + +```ts source="../examples/client/src/clientGuide.examples.ts#auth_tokenProvider" +const tokenProvider: TokenProvider = async () => getStoredToken(); + +const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); +``` + +See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. For finer control, {@linkcode @modelcontextprotocol/client!client/tokenProvider.withBearerAuth | withBearerAuth} wraps a fetch function directly. ### Client credentials diff --git a/examples/client/src/clientGuide.examples.ts b/examples/client/src/clientGuide.examples.ts index 389059024..c34a3a574 100644 --- a/examples/client/src/clientGuide.examples.ts +++ b/examples/client/src/clientGuide.examples.ts @@ -8,7 +8,7 @@ */ //#region imports -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -107,6 +107,16 @@ async function serverInstructions_basic(client: Client) { // Authentication // --------------------------------------------------------------------------- +/** Example: TokenProvider for bearer auth with externally-managed tokens. */ +async function auth_tokenProvider(getStoredToken: () => Promise) { + //#region auth_tokenProvider + const tokenProvider: TokenProvider = async () => getStoredToken(); + + const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); + //#endregion auth_tokenProvider + return transport; +} + /** Example: Client credentials auth for service-to-service communication. */ async function auth_clientCredentials() { //#region auth_clientCredentials @@ -540,6 +550,7 @@ void connect_stdio; void connect_sseFallback; void disconnect_streamableHttp; void serverInstructions_basic; +void auth_tokenProvider; void auth_clientCredentials; void auth_privateKeyJwt; void auth_crossAppAccess; diff --git a/examples/client/src/simpleTokenProvider.ts b/examples/client/src/simpleTokenProvider.ts new file mode 100644 index 000000000..f6829f556 --- /dev/null +++ b/examples/client/src/simpleTokenProvider.ts @@ -0,0 +1,69 @@ +#!/usr/bin/env node + +/** + * Example demonstrating TokenProvider for simple bearer token authentication. + * + * TokenProvider is a lightweight alternative to OAuthClientProvider for cases + * where tokens are managed externally — e.g., pre-configured API tokens, + * gateway/proxy patterns, or tokens obtained through a separate auth flow. + * + * Environment variables: + * MCP_SERVER_URL - Server URL (default: http://localhost:3000/mcp) + * MCP_TOKEN - Bearer token to use for authentication (required) + * + * Two approaches are demonstrated: + * 1. Using `tokenProvider` option on the transport (simplest) + * 2. Using `withBearerAuth` to wrap a custom fetch function (more flexible) + */ + +import type { TokenProvider } from '@modelcontextprotocol/client'; +import { Client, StreamableHTTPClientTransport, withBearerAuth } from '@modelcontextprotocol/client'; + +const DEFAULT_SERVER_URL = process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'; + +async function main() { + const token = process.env.MCP_TOKEN; + if (!token) { + console.error('MCP_TOKEN environment variable is required'); + process.exit(1); + } + + // A TokenProvider is just an async function that returns a token string. + // It is called before every request, so it can handle refresh logic internally. + const tokenProvider: TokenProvider = async () => token; + + const client = new Client({ name: 'token-provider-example', version: '1.0.0' }, { capabilities: {} }); + + // Approach 1: Pass tokenProvider directly to the transport. + // This is the simplest way to add bearer auth. + const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { + tokenProvider + }); + + // Approach 2 (alternative): Use withBearerAuth to wrap fetch. + // This is useful when you need more control over the fetch behavior, + // or when composing with other fetch wrappers. + // + // const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { + // fetch: withBearerAuth(tokenProvider), + // }); + + await client.connect(transport); + console.log('Connected successfully.'); + + const tools = await client.listTools(); + console.log('Available tools:', tools.tools.map(t => t.name).join(', ') || '(none)'); + + await transport.close(); +} + +try { + await main(); +} catch (error) { + console.error('Error running client:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); +} + +// Referenced in the commented-out Approach 2 above; kept so uncommenting it type-checks. +void withBearerAuth; diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index 58ec23ddd..c47a57f27 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -381,7 +381,7 @@ export function applyClientAuthentication( /** * Applies HTTP Basic authentication (RFC 6749 Section 2.3.1) */ -function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { +export function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { if (!clientSecret) { throw new Error('client_secret_basic authentication requires a client_secret'); } @@ -393,7 +393,7 @@ function applyBasicAuth(clientId: string, clientSecret: string | undefined, head /** * Applies POST body authentication (RFC 6749 Section 2.3.1) */ -function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { +export function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { params.set('client_id', clientId); if (clientSecret) { params.set('client_secret', clientSecret); @@ -403,7 +403,7 @@ function applyPostAuth(clientId: string, clientSecret: string | undefined, param /** * Applies public client authentication (RFC 6749 Section 2.1) */ -function applyPublicAuth(clientId: string, params: URLSearchParams): void { +export function applyPublicAuth(clientId: string, params: URLSearchParams): void { params.set('client_id', clientId); } @@ -1265,7 +1265,7 @@ export function prepareAuthorizationCodeRequest( * Internal helper to execute a token request with the given parameters. * Used by {@linkcode exchangeAuthorization}, {@linkcode refreshAuthorization}, and {@linkcode fetchToken}. */ -async function executeTokenRequest( +export async function executeTokenRequest( authorizationServerUrl: string | URL, { metadata, diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 133aa0004..e5b04a258 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -5,6 +5,7 @@ import { EventSource } from 'eventsource'; import type { AuthResult, OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { TokenProvider } from './tokenProvider.js'; export class SseError extends Error { constructor( @@ -36,6 +37,16 @@ export type SSEClientTransportOptions = { */ authProvider?: OAuthClientProvider; + /** + * A simple token provider for bearer authentication. + * + * Use this instead of `authProvider` when tokens are managed externally + * (e.g., upfront auth, gateway/proxy patterns, service accounts). + * + * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. + */ + tokenProvider?: TokenProvider; + /** * Customizes the initial SSE request to the server (the request that begins the stream). * @@ -72,6 +83,7 @@ export class SSEClientTransport implements Transport { private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _tokenProvider?: TokenProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -87,6 +99,7 @@ export class SSEClientTransport implements Transport { this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } @@ -123,6 +136,11 @@ export class SSEClientTransport implements Transport { if (tokens) { headers['Authorization'] = `Bearer ${tokens.access_token}`; } + } else if (this._tokenProvider) { + const token = await this._tokenProvider(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } } if (this._protocolVersion) { headers['mcp-protocol-version'] = this._protocolVersion; @@ -161,9 +179,17 @@ export class SSEClientTransport implements Transport { this._abortController = new AbortController(); this._eventSource.onerror = event => { - if (event.code === 401 && this._authProvider) { - this._authThenStart().then(resolve, reject); - return; + if (event.code === 401) { + if (this._authProvider) { + this._authThenStart().then(resolve, reject); + return; + } + if (this._tokenProvider) { + const error = new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + reject(error); + this.onerror?.(error); + return; + } } const error = new SseError(event.code, event.message, event); @@ -263,23 +289,28 @@ export class SSEClientTransport implements Transport { if (!response.ok) { const text = await response.text?.().catch(() => null); - if (response.status === 401 && this._authProvider) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + if (response.status === 401) { + if (this._authProvider) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); + const result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + scope: this._scope, + fetchFn: this._fetchWithInit + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } + + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + if (this._tokenProvider) { + throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); } - - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); } throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index dab9b37ab..180893472 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -15,6 +15,7 @@ import { EventSourceParserStream } from 'eventsource-parser/stream'; import type { AuthResult, OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { TokenProvider } from './tokenProvider.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -98,6 +99,16 @@ export type StreamableHTTPClientTransportOptions = { */ authProvider?: OAuthClientProvider; + /** + * A simple token provider for bearer authentication. + * + * Use this instead of `authProvider` when tokens are managed externally + * (e.g., upfront auth, gateway/proxy patterns, service accounts). + * + * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. + */ + tokenProvider?: TokenProvider; + /** * Customizes HTTP requests to the server. */ @@ -132,6 +143,7 @@ export class StreamableHTTPClientTransport implements Transport { private _scope?: string; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _tokenProvider?: TokenProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; @@ -152,6 +164,7 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = undefined; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; @@ -190,6 +203,11 @@ export class StreamableHTTPClientTransport implements Transport { if (tokens) { headers['Authorization'] = `Bearer ${tokens.access_token}`; } + } else if (this._tokenProvider) { + const token = await this._tokenProvider(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } } if (this._sessionId) { @@ -231,9 +249,13 @@ export class StreamableHTTPClientTransport implements Transport { if (!response.ok) { await response.text?.().catch(() => {}); - if (response.status === 401 && this._authProvider) { - // Need to authenticate - return await this._authThenStart(); + if (response.status === 401) { + if (this._authProvider) { + return await this._authThenStart(); + } + if (this._tokenProvider) { + throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + } } // 405 indicates that the server does not offer an SSE stream at GET endpoint @@ -494,33 +516,42 @@ export class StreamableHTTPClientTransport implements Transport { if (!response.ok) { const text = await response.text?.().catch(() => null); - if (response.status === 401 && this._authProvider) { - // Prevent infinite recursion when server returns 401 after successful auth - if (this._hasCompletedAuthFlow) { - throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after successful authentication', { - status: 401, - text - }); - } + if (response.status === 401) { + if (this._authProvider) { + // Prevent infinite recursion when server returns 401 after successful auth + if (this._hasCompletedAuthFlow) { + throw new SdkError( + SdkErrorCode.ClientHttpAuthentication, + 'Server returned 401 after successful authentication', + { + status: 401, + text + } + ); + } - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } + const result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + scope: this._scope, + fetchFn: this._fetchWithInit + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } - // Mark that we completed auth flow - this._hasCompletedAuthFlow = true; - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + // Mark that we completed auth flow + this._hasCompletedAuthFlow = true; + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + if (this._tokenProvider) { + throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + } } if (response.status === 403 && this._authProvider) { diff --git a/packages/client/src/client/tokenProvider.ts b/packages/client/src/client/tokenProvider.ts new file mode 100644 index 000000000..ab8f2bbc9 --- /dev/null +++ b/packages/client/src/client/tokenProvider.ts @@ -0,0 +1,53 @@ +/** + * Minimal interface for providing bearer tokens to MCP transports. + * + * Unlike `OAuthClientProvider` which assumes interactive browser-redirect OAuth, + * `TokenProvider` is a simple function that returns a token string. + * Use this for upfront auth, gateway/proxy patterns, service accounts, + * or any scenario where tokens are managed externally. + * + * The provider is called before every request. If the server responds with 401, + * the transport throws `UnauthorizedError` without retrying — the provider is + * assumed to have already returned its freshest token. Catch `UnauthorizedError` + * to invalidate any external cache and reconnect. + * + * @example + * ```typescript + * // Static token + * const provider: TokenProvider = async () => "my-api-token"; + * + * // Token from secure storage with refresh + * const provider: TokenProvider = async () => { + * const token = await storage.getToken(); + * if (isExpiringSoon(token)) { + * return (await refreshToken(token)).accessToken; + * } + * return token.accessToken; + * }; + * ``` + */ +export type TokenProvider = () => Promise; + +/** + * Wraps a fetch function to automatically inject Bearer authentication headers. + * + * @example + * ```typescript + * const authedFetch = withBearerAuth(async () => getStoredToken()); + * const transport = new StreamableHTTPClientTransport(url, { fetch: authedFetch }); + * ``` + */ +export function withBearerAuth( + getToken: TokenProvider, + fetchFn: (url: string | URL, init?: RequestInit) => Promise = globalThis.fetch +): (url: string | URL, init?: RequestInit) => Promise { + return async (url, init) => { + const token = await getToken(); + if (token) { + const headers = new Headers(init?.headers); + headers.set('Authorization', `Bearer ${token}`); + return fetchFn(url, { ...init, headers }); + } + return fetchFn(url, init); + }; +} diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index c37d9fe28..b72b3e2d9 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -6,6 +6,7 @@ export * from './client/middleware.js'; export * from './client/sse.js'; export * from './client/stdio.js'; export * from './client/streamableHttp.js'; +export * from './client/tokenProvider.js'; export * from './client/websocket.js'; // experimental exports diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts new file mode 100644 index 000000000..111a7b6a5 --- /dev/null +++ b/packages/client/test/client/tokenProvider.test.ts @@ -0,0 +1,208 @@ +import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import type { Mock } from 'vitest'; + +import type { TokenProvider } from '../../src/client/tokenProvider.js'; +import { withBearerAuth } from '../../src/client/tokenProvider.js'; +import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; + +describe('withBearerAuth', () => { + it('should inject Authorization header when token is available', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + const getToken: TokenProvider = async () => 'test-token-123'; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/api', { method: 'POST' }); + + expect(mockFetch).toHaveBeenCalledOnce(); + const [url, init] = mockFetch.mock.calls[0]!; + expect(url).toBe('https://example.com/api'); + expect(new Headers(init.headers).get('Authorization')).toBe('Bearer test-token-123'); + }); + + it('should not inject Authorization header when token is undefined', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + const getToken: TokenProvider = async () => undefined; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/api', { method: 'POST' }); + + expect(mockFetch).toHaveBeenCalledOnce(); + const [, init] = mockFetch.mock.calls[0]!; + expect(new Headers(init?.headers).has('Authorization')).toBe(false); + }); + + it('should preserve existing headers', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + const getToken: TokenProvider = async () => 'my-token'; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/api', { + headers: { 'Content-Type': 'application/json', 'X-Custom': 'value' } + }); + + const [, init] = mockFetch.mock.calls[0]!; + const headers = new Headers(init.headers); + expect(headers.get('Authorization')).toBe('Bearer my-token'); + expect(headers.get('Content-Type')).toBe('application/json'); + expect(headers.get('X-Custom')).toBe('value'); + }); + + it('should call getToken on every request', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + let callCount = 0; + const getToken: TokenProvider = async () => `token-${++callCount}`; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/1'); + await authedFetch('https://example.com/2'); + + expect(new Headers(mockFetch.mock.calls[0]![1]!.headers).get('Authorization')).toBe('Bearer token-1'); + expect(new Headers(mockFetch.mock.calls[1]![1]!.headers).get('Authorization')).toBe('Bearer token-2'); + }); +}); + +describe('StreamableHTTPClientTransport with tokenProvider', () => { + let transport: StreamableHTTPClientTransport; + + afterEach(async () => { + await transport?.close().catch(() => {}); + vi.clearAllMocks(); + }); + + it('should set Authorization header from tokenProvider', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => 'my-bearer-token'); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + expect(tokenProvider).toHaveBeenCalled(); + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.get('Authorization')).toBe('Bearer my-bearer-token'); + }); + + it('should not set Authorization header when tokenProvider returns undefined', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => undefined); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); + + it('should throw UnauthorizedError on 401 when using tokenProvider', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => 'rejected-token'); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(tokenProvider).toHaveBeenCalledTimes(1); + }); + + it('should prefer authProvider over tokenProvider when both are set', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => 'token-provider-value'); + const authProvider = { + get redirectUrl() { + return 'http://localhost/callback'; + }, + get clientMetadata() { + return { redirect_uris: ['http://localhost/callback'] }; + }, + clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-secret' })), + tokens: vi.fn(() => ({ access_token: 'auth-provider-value', token_type: 'bearer' })), + saveTokens: vi.fn(), + redirectToAuthorization: vi.fn(), + saveCodeVerifier: vi.fn(), + codeVerifier: vi.fn() + }; + + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider, tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + // authProvider should be used, not tokenProvider + expect(tokenProvider).not.toHaveBeenCalled(); + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.get('Authorization')).toBe('Bearer auth-provider-value'); + }); + + it('should work with no auth at all', async () => { + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); +}); From 2961101713146be8919d62058dc5d1507d69ef47 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 16:01:11 +0000 Subject: [PATCH 2/5] BREAKING: unify client auth around minimal AuthProvider interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Transports now accept AuthProvider { token(), onUnauthorized() } instead of being typed as OAuthClientProvider. OAuthClientProvider extends AuthProvider, so built-in providers work unchanged — custom implementations add two methods (both TypeScript-enforced). Core changes: - New AuthProvider interface — transports only need token() + onUnauthorized(), not the full 21-member OAuth interface - OAuthClientProvider extends AuthProvider; onUnauthorized() is required (not optional) on OAuthClientProvider since OAuth providers that omit it lose all 401 recovery. The 4 built-in providers implement both methods, delegating to new handleOAuthUnauthorized helper. - Transports call authProvider.token() in _commonHeaders() — one code path, no precedence rules - Transports call authProvider.onUnauthorized() on 401, retry once — ~50 lines of inline OAuth orchestration removed per transport. Circuit breaker via _authRetryInFlight (reset in outer catch so transient onUnauthorized failures don't permanently disable retries). - Response body consumption deferred until after the onUnauthorized branch so custom implementations can read ctx.response.text() - WWW-Authenticate extraction guarded with headers.has() check (pre-existing inconsistency; the SSE connect path already did this) - finishAuth() and 403 upscoping gated on isOAuthClientProvider() - TokenProvider type + tokenProvider option deleted — subsumed by { token: async () => ... } as authProvider Simple case: { authProvider: { token: async () => apiKey } } — no class needed, TypeScript structural typing. auth() and authInternal() (227 LOC of OAuth orchestration) untouched. They still take OAuthClientProvider. Only the transport/provider boundary moved. See docs/migration.md and docs/migration-SKILL.md for before/after. --- .changeset/token-provider-composable-auth.md | 19 +- docs/client.md | 14 +- docs/migration-SKILL.md | 132 +++++---- docs/migration.md | 155 +++++++---- examples/client/src/clientGuide.examples.ts | 8 +- .../client/src/simpleOAuthClientProvider.ts | 17 +- examples/client/src/simpleTokenProvider.ts | 47 ++-- packages/client/src/client/auth.examples.ts | 10 +- packages/client/src/client/auth.ts | 91 +++++- packages/client/src/client/authExtensions.ts | 35 ++- packages/client/src/client/sse.ts | 141 ++++------ packages/client/src/client/streamableHttp.ts | 155 ++++------- packages/client/src/client/tokenProvider.ts | 53 ---- packages/client/src/index.ts | 1 - packages/client/test/client/auth.test.ts | 10 + .../client/test/client/middleware.test.ts | 4 + packages/client/test/client/sse.test.ts | 124 ++++++++- .../client/test/client/streamableHttp.test.ts | 10 +- .../client/test/client/tokenProvider.test.ts | 263 ++++++++---------- 19 files changed, 736 insertions(+), 553 deletions(-) delete mode 100644 packages/client/src/client/tokenProvider.ts diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md index 50b296298..c4ea7f5e3 100644 --- a/.changeset/token-provider-composable-auth.md +++ b/.changeset/token-provider-composable-auth.md @@ -1,10 +1,17 @@ --- -'@modelcontextprotocol/client': minor +'@modelcontextprotocol/client': major --- -Add `TokenProvider` for simple bearer-token authentication and export composable auth primitives +Unify client auth around a minimal `AuthProvider` interface -- New `TokenProvider` type — a minimal `() => Promise` function interface for supplying bearer tokens. Use this instead of `OAuthClientProvider` when tokens are managed externally (gateway/proxy patterns, service accounts, upfront API tokens, or any scenario where the full OAuth redirect flow is not needed). -- New `tokenProvider` option on `StreamableHTTPClientTransport` and `SSEClientTransport`. Called before every request to obtain a fresh token. If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. -- New `withBearerAuth(getToken, fetchFn?)` helper that wraps a fetch function to inject `Authorization: Bearer` headers — useful for composing with other fetch middleware. -- Exported previously-internal auth helpers for building custom auth flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. +**Breaking:** Transport `authProvider` option now accepts the new minimal `AuthProvider` interface instead of being typed as `OAuthClientProvider`. `OAuthClientProvider` now extends `AuthProvider`, so most existing code continues to work — but custom implementations must add a `token()` method. + +- New `AuthProvider` interface: `{ token(): Promise; onUnauthorized?(ctx): Promise }`. Transports call `token()` before every request and `onUnauthorized()` on 401 (then retry once). +- `OAuthClientProvider` extends `AuthProvider`. Custom implementations must add `token()` (typically `return (await this.tokens())?.access_token`) and optionally `onUnauthorized()` (typically `return handleOAuthUnauthorized(this, ctx)`). +- Built-in providers (`ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, `CrossAppAccessProvider`) implement both methods — existing user code is unchanged. +- New `handleOAuthUnauthorized(provider, ctx)` helper runs the standard OAuth flow from `onUnauthorized`. +- New `isOAuthClientProvider()` type guard for gating OAuth-specific transport features like `finishAuth()`. +- Transports no longer inline OAuth orchestration — ~50 lines of `auth()` calls, WWW-Authenticate parsing, and circuit-breaker state moved into `onUnauthorized()` implementations. +- Exported previously-internal auth helpers for building custom flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. + +See `docs/migration.md` for before/after examples. diff --git a/docs/client.md b/docs/client.md index 467df2789..b5086f531 100644 --- a/docs/client.md +++ b/docs/client.md @@ -13,7 +13,7 @@ A client connects to a server, discovers what it offers — tools, resources, pr The examples below use these imports. Adjust based on which features and transport you need: ```ts source="../examples/client/src/clientGuide.examples.ts#imports" -import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -113,19 +113,19 @@ console.log(systemPrompt); ## Authentication -MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). For servers that accept plain bearer tokens, pass a `tokenProvider` function to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. For servers that require OAuth 2.0, pass an `authProvider` — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. +MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an {@linkcode @modelcontextprotocol/client!client/auth.AuthProvider | AuthProvider} to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. The transport calls `token()` before every request and `onUnauthorized()` (if provided) on 401, then retries once. -### Token provider +### Bearer tokens -For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials, or tokens obtained through a separate auth flow — pass a {@linkcode @modelcontextprotocol/client!client/tokenProvider.TokenProvider | TokenProvider} function. It is called before every request, so it can handle expiry and refresh internally. If the server rejects the token with 401, the transport throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} without retrying — catch it to invalidate any external cache and reconnect: +For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials — implement only `token()`. With no `onUnauthorized()`, a 401 throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} immediately: ```ts source="../examples/client/src/clientGuide.examples.ts#auth_tokenProvider" -const tokenProvider: TokenProvider = async () => getStoredToken(); +const authProvider: AuthProvider = { token: async () => getStoredToken() }; -const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); +const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); ``` -See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. For finer control, {@linkcode @modelcontextprotocol/client!client/tokenProvider.withBearerAuth | withBearerAuth} wraps a fetch function directly. +See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. ### Client credentials diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 9dffe4418..cdec2b9a9 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -47,15 +47,15 @@ Replace all `@modelcontextprotocol/sdk/...` imports using this table. ### Server imports -| v1 import path | v2 package | -| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | +| v1 import path | v2 package | +| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | | `@modelcontextprotocol/sdk/server/streamableHttp.js` | `@modelcontextprotocol/node` (class renamed to `NodeStreamableHTTPServerTransport`) OR `@modelcontextprotocol/server` (web-standard `WebStandardStreamableHTTPServerTransport` for Cloudflare Workers, Deno, etc.) | -| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | -| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | -| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | +| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | +| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | +| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | ### Types / shared imports @@ -203,7 +203,41 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; if (error instanceof OAuthError && error.code === OAuthErrorCode.InvalidClient) { ... } ``` -**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). +### Client `OAuthClientProvider` now extends `AuthProvider` + +Transport `authProvider` options now accept the minimal `AuthProvider` interface. `OAuthClientProvider` extends it, so built-in providers work unchanged — custom implementations must add `token()`. + +| v1 pattern | v2 equivalent | +| ----------------------------------------------------- | --------------------------------------------------------------------------- | +| `authProvider?: OAuthClientProvider` (option type) | `authProvider?: AuthProvider` (accepts `OAuthClientProvider` via extension) | +| Transport reads `authProvider.tokens()?.access_token` | Transport calls `authProvider.token()` | +| Transport inlines `auth()` on 401 | Transport calls `authProvider.onUnauthorized()` then retries once | +| `_hasCompletedAuthFlow` circuit breaker | `_authRetryInFlight` circuit breaker | +| N/A | `handleOAuthUnauthorized(provider, ctx)` — standard `onUnauthorized` impl | +| N/A | `isOAuthClientProvider(provider)` — type guard | +| N/A | `UnauthorizedContext` — `{ response, serverUrl, fetchFn }` | + +**For custom `OAuthClientProvider` implementations**, add both methods (both required — TypeScript enforces this): + +```typescript +async token(): Promise { + return (await this.tokens())?.access_token; +} + +async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); +} +``` + +**For simple bearer tokens** (previously required stubbing 8 `OAuthClientProvider` members): + +```typescript +// v2: one-liner +const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; +``` + +**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all +Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). ## 6. McpServer API Changes @@ -279,12 +313,12 @@ Note: the third argument (`metadata`) is required — pass `{}` if no metadata. ### Schema Migration Quick Reference -| v1 (raw shape) | v2 (Zod schema) | -|----------------|-----------------| -| `{ name: z.string() }` | `z.object({ name: z.string() })` | +| v1 (raw shape) | v2 (Zod schema) | +| ---------------------------------- | -------------------------------------------- | +| `{ name: z.string() }` | `z.object({ name: z.string() })` | | `{ count: z.number().optional() }` | `z.object({ count: z.number().optional() })` | -| `{}` (empty) | `z.object({})` | -| `undefined` (no schema) | `undefined` or omit the field | +| `{}` (empty) | `z.object({})` | +| `undefined` (no schema) | `undefined` or omit the field | ## 7. Headers API @@ -370,31 +404,31 @@ Request/notification params remain fully typed. Remove unused schema imports aft `RequestHandlerExtra` → structured context types with nested groups. Rename `extra` → `ctx` in all handler callbacks. -| v1 | v2 | -|----|-----| -| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | -| `extra` (param name) | `ctx` | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| v1 | v2 | +| -------------------------------- | -------------------------------------------------------------------------- | +| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | +| `extra` (param name) | `ctx` | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | `ServerContext` convenience methods (new in v2, no v1 equivalent): -| Method | Description | Replaces | -|--------|-------------|----------| -| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | -| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | -| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | +| Method | Description | Replaces | +| ---------------------------------------------- | ------------------------------------------------------ | ---------------------------------------------------- | +| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | +| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | +| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | ## 11. Schema parameter removed from `request()`, `send()`, and `callTool()` @@ -413,14 +447,14 @@ const elicit = await ctx.mcpReq.send({ method: 'elicitation/create', params: { . const tool = await client.callTool({ name: 'my-tool', arguments: {} }); ``` -| v1 call | v2 call | -|---------|---------| -| `client.request(req, ResultSchema)` | `client.request(req)` | -| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | -| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | -| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | -| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | -| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | +| v1 call | v2 call | +| ------------------------------------------------------------ | ---------------------------------- | +| `client.request(req, ResultSchema)` | `client.request(req)` | +| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | +| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | +| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | +| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | +| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResultSchema`, `ElicitResultSchema`, `CreateMessageResultSchema`, etc., when they were only used in `request()`/`send()`/`callTool()` calls. @@ -431,6 +465,7 @@ Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResu ## 13. Runtime-Specific JSON Schema Validators (Enhancement) The SDK now auto-selects the appropriate JSON Schema validator based on runtime: + - Node.js → `AjvJsonSchemaValidator` (no change from v1) - Cloudflare Workers (workerd) → `CfWorkerJsonSchemaValidator` (previously required manual config) @@ -438,9 +473,12 @@ The SDK now auto-selects the appropriate JSON Schema validator based on runtime: ```typescript // v1 (Cloudflare Workers): Required explicit validator -new McpServer({ name: 'server', version: '1.0.0' }, { - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() -}); +new McpServer( + { name: 'server', version: '1.0.0' }, + { + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() + } +); // v2 (Cloudflare Workers): Auto-selected, explicit config optional new McpServer({ name: 'server', version: '1.0.0' }, {}); diff --git a/docs/migration.md b/docs/migration.md index 59b2b50ed..e541bb01b 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -252,6 +252,7 @@ server.registerTool('ping', { ``` This applies to: + - `inputSchema` in `registerTool()` - `outputSchema` in `registerTool()` - `argsSchema` in `registerPrompt()` @@ -339,25 +340,21 @@ Common method string replacements: ### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer take a schema parameter -The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas like `CallToolResultSchema` or `ElicitResultSchema` when making requests. +The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas +like `CallToolResultSchema` or `ElicitResultSchema` when making requests. **`client.request()` — Before (v1):** ```typescript import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, - CallToolResultSchema -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, CallToolResultSchema); ``` **After (v2):** ```typescript -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } } -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }); ``` **`ctx.mcpReq.send()` — Before (v1):** @@ -390,10 +387,7 @@ server.setRequestHandler('tools/call', async (request, ctx) => { ```typescript import { CompatibilityCallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.callTool( - { name: 'my-tool', arguments: {} }, - CompatibilityCallToolResultSchema -); +const result = await client.callTool({ name: 'my-tool', arguments: {} }, CompatibilityCallToolResultSchema); ``` **After (v2):** @@ -452,32 +446,32 @@ import { JSONRPCErrorResponse, ResourceTemplateReference, isJSONRPCErrorResponse The `RequestHandlerExtra` type has been replaced with a structured context type hierarchy using nested groups: -| v1 | v2 | -|----|-----| +| v1 | v2 | +| ---------------------------------------- | ---------------------------------------------------------------------- | | `RequestHandlerExtra` (flat, all fields) | `ServerContext` (server handlers) or `ClientContext` (client handlers) | -| `extra` parameter name | `ctx` parameter name | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| `extra` parameter name | `ctx` parameter name | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | **Before (v1):** ```typescript server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - const headers = extra.requestInfo?.headers; - const taskStore = extra.taskStore; - await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = extra.requestInfo?.headers; + const taskStore = extra.taskStore; + await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -485,10 +479,10 @@ server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - const headers = ctx.http?.req?.headers; - const taskStore = ctx.task?.store; - await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = ctx.http?.req?.headers; + const taskStore = ctx.task?.store; + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -504,22 +498,22 @@ Context fields are organized into 4 groups: ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - // Send a log message (respects client's log level filter) - await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); - - // Request client to sample an LLM - const samplingResult = await ctx.mcpReq.requestSampling({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100, - }); - - // Elicit user input via a form - const elicitResult = await ctx.mcpReq.elicitInput({ - message: 'Please provide details', - requestedSchema: { type: 'object', properties: { name: { type: 'string' } } }, - }); - - return { content: [{ type: 'text', text: 'done' }] }; + // Send a log message (respects client's log level filter) + await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); + + // Request client to sample an LLM + const samplingResult = await ctx.mcpReq.requestSampling({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); + + // Elicit user input via a form + const elicitResult = await ctx.mcpReq.elicitInput({ + message: 'Please provide details', + requestedSchema: { type: 'object', properties: { name: { type: 'string' } } } + }); + + return { content: [{ type: 'text', text: 'done' }] }; }); ``` @@ -646,13 +640,52 @@ try { #### Why this change? -Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was semantically inconsistent. +Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was +semantically inconsistent. The new design: - `ProtocolError` with `ProtocolErrorCode`: For errors that are serialized and sent as JSON-RPC error responses - `SdkError` with `SdkErrorCode`: For local errors that are thrown/rejected locally and never leave the SDK +### Client `authProvider` unified around `AuthProvider` + +Transport `authProvider` options now accept the minimal `AuthProvider` interface rather than being typed as `OAuthClientProvider`. `OAuthClientProvider` extends `AuthProvider`, so built-in providers and most existing code continue to work unchanged — but custom +`OAuthClientProvider` implementations must add a `token()` method. + +**What changed:** transports now call `authProvider.token()` before every request (instead of `authProvider.tokens()?.access_token`), and call `authProvider.onUnauthorized()` on 401 (instead of inlining OAuth orchestration). One code path handles both simple bearer tokens and +full OAuth. + +**If you implement `OAuthClientProvider` directly** (the interactive browser-redirect pattern), add: + +```ts +class MyProvider implements OAuthClientProvider { + // ...existing 8 required members... + + // Required: return the current access token + async token(): Promise { + return (await this.tokens())?.access_token; + } + + // Required: runs the OAuth flow on 401 — without this, 401 throws with no recovery + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } +} +``` + +**If you use `ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, or `CrossAppAccessProvider`** — no change. These already implement both methods. + +**If you have simple bearer tokens** (API keys, gateway tokens, externally-managed tokens), you can now skip `OAuthClientProvider` entirely: + +```ts +// Before: had to implement 8 OAuthClientProvider members with no-op stubs +// After: +const transport = new StreamableHTTPClientTransport(url, { + authProvider: { token: async () => process.env.API_KEY } +}); +``` + ### OAuth error refactoring The OAuth error classes have been consolidated into a single `OAuthError` class with an `OAuthErrorCode` enum. @@ -743,11 +776,11 @@ This means Cloudflare Workers users no longer need to explicitly pass the valida import { McpServer, CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { - capabilities: { tools: {} }, - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 - } + { name: 'my-server', version: '1.0.0' }, + { + capabilities: { tools: {} }, + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 + } ); ``` @@ -757,9 +790,9 @@ const server = new McpServer( import { McpServer } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { capabilities: { tools: {} } } - // Validator auto-selected based on runtime + { name: 'my-server', version: '1.0.0' }, + { capabilities: { tools: {} } } + // Validator auto-selected based on runtime ); ``` diff --git a/examples/client/src/clientGuide.examples.ts b/examples/client/src/clientGuide.examples.ts index c34a3a574..f07d272db 100644 --- a/examples/client/src/clientGuide.examples.ts +++ b/examples/client/src/clientGuide.examples.ts @@ -8,7 +8,7 @@ */ //#region imports -import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -107,12 +107,12 @@ async function serverInstructions_basic(client: Client) { // Authentication // --------------------------------------------------------------------------- -/** Example: TokenProvider for bearer auth with externally-managed tokens. */ +/** Example: Minimal AuthProvider for bearer auth with externally-managed tokens. */ async function auth_tokenProvider(getStoredToken: () => Promise) { //#region auth_tokenProvider - const tokenProvider: TokenProvider = async () => getStoredToken(); + const authProvider: AuthProvider = { token: async () => getStoredToken() }; - const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); + const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); //#endregion auth_tokenProvider return transport; } diff --git a/examples/client/src/simpleOAuthClientProvider.ts b/examples/client/src/simpleOAuthClientProvider.ts index 96655c9f6..6248d1f90 100644 --- a/examples/client/src/simpleOAuthClientProvider.ts +++ b/examples/client/src/simpleOAuthClientProvider.ts @@ -1,4 +1,11 @@ -import type { OAuthClientInformationMixed, OAuthClientMetadata, OAuthClientProvider, OAuthTokens } from '@modelcontextprotocol/client'; +import type { + OAuthClientInformationMixed, + OAuthClientMetadata, + OAuthClientProvider, + OAuthTokens, + UnauthorizedContext +} from '@modelcontextprotocol/client'; +import { handleOAuthUnauthorized } from '@modelcontextprotocol/client'; /** * In-memory OAuth client provider for demonstration purposes @@ -24,6 +31,14 @@ export class InMemoryOAuthClientProvider implements OAuthClientProvider { private _onRedirect: (url: URL) => void; + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): string | URL { return this._redirectUrl; } diff --git a/examples/client/src/simpleTokenProvider.ts b/examples/client/src/simpleTokenProvider.ts index f6829f556..7b5f1a4c1 100644 --- a/examples/client/src/simpleTokenProvider.ts +++ b/examples/client/src/simpleTokenProvider.ts @@ -1,23 +1,22 @@ #!/usr/bin/env node /** - * Example demonstrating TokenProvider for simple bearer token authentication. + * Example demonstrating the minimal AuthProvider for bearer token authentication. * - * TokenProvider is a lightweight alternative to OAuthClientProvider for cases - * where tokens are managed externally — e.g., pre-configured API tokens, - * gateway/proxy patterns, or tokens obtained through a separate auth flow. + * AuthProvider is the base interface for all client auth. For simple cases where + * tokens are managed externally — pre-configured API tokens, gateway/proxy patterns, + * or tokens obtained through a separate auth flow — implement only `token()`. + * + * For OAuth flows (client_credentials, private_key_jwt, etc.), use the built-in + * providers which implement both `token()` and `onUnauthorized()`. * * Environment variables: * MCP_SERVER_URL - Server URL (default: http://localhost:3000/mcp) * MCP_TOKEN - Bearer token to use for authentication (required) - * - * Two approaches are demonstrated: - * 1. Using `tokenProvider` option on the transport (simplest) - * 2. Using `withBearerAuth` to wrap a custom fetch function (more flexible) */ -import type { TokenProvider } from '@modelcontextprotocol/client'; -import { Client, StreamableHTTPClientTransport, withBearerAuth } from '@modelcontextprotocol/client'; +import type { AuthProvider } from '@modelcontextprotocol/client'; +import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; const DEFAULT_SERVER_URL = process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'; @@ -28,25 +27,16 @@ async function main() { process.exit(1); } - // A TokenProvider is just an async function that returns a token string. - // It is called before every request, so it can handle refresh logic internally. - const tokenProvider: TokenProvider = async () => token; + // AuthProvider with just token() — the simplest possible auth. + // token() is called before every request, so it can handle refresh internally. + // With no onUnauthorized(), a 401 throws UnauthorizedError immediately. + const authProvider: AuthProvider = { + token: async () => token + }; - const client = new Client({ name: 'token-provider-example', version: '1.0.0' }, { capabilities: {} }); + const client = new Client({ name: 'auth-provider-example', version: '1.0.0' }, { capabilities: {} }); - // Approach 1: Pass tokenProvider directly to the transport. - // This is the simplest way to add bearer auth. - const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { - tokenProvider - }); - - // Approach 2 (alternative): Use withBearerAuth to wrap fetch. - // This is useful when you need more control over the fetch behavior, - // or when composing with other fetch wrappers. - // - // const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { - // fetch: withBearerAuth(tokenProvider), - // }); + const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { authProvider }); await client.connect(transport); console.log('Connected successfully.'); @@ -64,6 +54,3 @@ try { // eslint-disable-next-line unicorn/no-process-exit process.exit(1); } - -// Referenced in the commented-out Approach 2 above; kept so uncommenting it type-checks. -void withBearerAuth; diff --git a/packages/client/src/client/auth.examples.ts b/packages/client/src/client/auth.examples.ts index 17c04e6a0..15b6487a7 100644 --- a/packages/client/src/client/auth.examples.ts +++ b/packages/client/src/client/auth.examples.ts @@ -9,8 +9,8 @@ import type { AuthorizationServerMetadata } from '@modelcontextprotocol/core'; -import type { OAuthClientProvider } from './auth.js'; -import { fetchToken } from './auth.js'; +import type { OAuthClientProvider, UnauthorizedContext } from './auth.js'; +import { fetchToken, handleOAuthUnauthorized } from './auth.js'; /** * Base class providing no-op implementations of required OAuthClientProvider methods. @@ -29,6 +29,12 @@ abstract class MyProviderBase implements OAuthClientProvider { tokens(): undefined { return; } + async token(): Promise { + return undefined; + } + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } saveTokens() { return Promise.resolve(); } diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index c47a57f27..bca45ad66 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -34,14 +34,103 @@ export type AddClientAuthentication = ( metadata?: AuthorizationServerMetadata ) => void | Promise; +/** + * Context passed to {@linkcode AuthProvider.onUnauthorized} when the server + * responds with 401. Provides everything needed to refresh credentials. + */ +export interface UnauthorizedContext { + /** The 401 response — inspect `WWW-Authenticate` for resource metadata, scope, etc. */ + response: Response; + /** The MCP server URL, for passing to {@linkcode auth} or discovery helpers. */ + serverUrl: URL; + /** Fetch function configured with the transport's `requestInit`, for making auth requests. */ + fetchFn: FetchLike; +} + +/** + * Minimal interface for authenticating MCP client transports with bearer tokens. + * + * Transports call {@linkcode AuthProvider.token | token()} before every request + * to obtain the current token, and {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * (if provided) when the server responds with 401, giving the provider a chance + * to refresh credentials before the transport retries once. + * + * For simple cases (API keys, gateway-managed tokens), implement only `token()`: + * ```typescript + * const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; + * ``` + * + * For OAuth flows, use {@linkcode OAuthClientProvider} which extends this interface, + * or one of the built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.). + */ +export interface AuthProvider { + /** + * Returns the current bearer token, or `undefined` if no token is available. + * Called before every request. + */ + token(): Promise; + + /** + * Called when the server responds with 401. If provided, the transport will + * await this, then retry the request once. If the retry also gets 401, or if + * this method is not provided, the transport throws {@linkcode UnauthorizedError}. + * + * Implementations should refresh tokens, re-authenticate, etc. — whatever is + * needed so the next `token()` call returns a valid token. + */ + onUnauthorized?(ctx: UnauthorizedContext): Promise; +} + +/** + * Type guard: checks whether an `AuthProvider` is a full `OAuthClientProvider`. + * Use this to gate OAuth-specific transport features like `finishAuth()` and + * 403 scope upscoping. + */ +export function isOAuthClientProvider(provider: AuthProvider | undefined): provider is OAuthClientProvider { + return provider !== undefined && 'tokens' in provider && 'clientMetadata' in provider; +} + +/** + * Default `onUnauthorized` implementation for OAuth providers: extracts + * `WWW-Authenticate` parameters from the 401 response and runs {@linkcode auth}. + * Built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.) + * delegate to this. Custom `OAuthClientProvider` implementations can do the same. + */ +export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx: UnauthorizedContext): Promise { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(ctx.response); + const result = await auth(provider, { + serverUrl: ctx.serverUrl, + resourceMetadataUrl, + scope, + fetchFn: ctx.fetchFn + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } +} + /** * Implements an end-to-end OAuth client to be used with one MCP server. * * This client relies upon a concept of an authorized "session," the exact * meaning of which is application-defined. Tokens, authorization codes, and * code verifiers should not cross different sessions. + * + * Extends {@linkcode AuthProvider} — implementations must provide `token()` + * (typically `return (await this.tokens())?.access_token`) and `onUnauthorized()` + * (typically `return handleOAuthUnauthorized(this, ctx)`). Without `onUnauthorized()`, + * 401 responses throw immediately with no token refresh or reauth. */ -export interface OAuthClientProvider { +export interface OAuthClientProvider extends AuthProvider { + /** + * Runs the OAuth re-authentication flow on 401. Required on `OAuthClientProvider` + * (optional on the base `AuthProvider`) because OAuth providers that omit this lose + * all 401 recovery — no token refresh, no redirect to authorization. + * + * Most implementations should delegate: `return handleOAuthUnauthorized(this, ctx)`. + */ + onUnauthorized(ctx: UnauthorizedContext): Promise; + /** * The URL to redirect the user agent to after authorization. * Return `undefined` for non-interactive flows that don't require user interaction diff --git a/packages/client/src/client/authExtensions.ts b/packages/client/src/client/authExtensions.ts index ae614f7ba..7508298b7 100644 --- a/packages/client/src/client/authExtensions.ts +++ b/packages/client/src/client/authExtensions.ts @@ -8,7 +8,8 @@ import type { FetchLike, OAuthClientInformation, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/core'; import type { CryptoKey, JWK } from 'jose'; -import type { AddClientAuthentication, OAuthClientProvider } from './auth.js'; +import type { AddClientAuthentication, OAuthClientProvider, UnauthorizedContext } from './auth.js'; +import { handleOAuthUnauthorized } from './auth.js'; /** * Helper to produce a `private_key_jwt` client authentication function. @@ -150,6 +151,14 @@ export class ClientCredentialsProvider implements OAuthClientProvider { }; } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } @@ -269,6 +278,14 @@ export class PrivateKeyJwtProvider implements OAuthClientProvider { }); } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } @@ -366,6 +383,14 @@ export class StaticPrivateKeyJwtProvider implements OAuthClientProvider { }; } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } @@ -564,6 +589,14 @@ export class CrossAppAccessProvider implements OAuthClientProvider { this._fetchFn = options.fetchFn ?? fetch; } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index e5b04a258..025c785ea 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -3,9 +3,8 @@ import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, SdkError, import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; -import type { TokenProvider } from './tokenProvider.js'; +import type { AuthProvider } from './auth.js'; +import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; export class SseError extends Error { constructor( @@ -24,28 +23,19 @@ export type SSEClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the SSE connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode SSEClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode SSEClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. - */ - authProvider?: OAuthClientProvider; - - /** - * A simple token provider for bearer authentication. - * - * Use this instead of `authProvider` when tokens are managed externally - * (e.g., upfront auth, gateway/proxy patterns, service accounts). - * - * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation. + * Interactive flows: after {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode SSEClientTransport.finishAuth | finishAuth} with the authorization code before reconnecting. */ - tokenProvider?: TokenProvider; + authProvider?: AuthProvider; /** * Customizes the initial SSE request to the server (the request that begins the stream). @@ -82,8 +72,7 @@ export class SSEClientTransport implements Transport { private _scope?: string; private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; - private _tokenProvider?: TokenProvider; + private _authProvider?: AuthProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -99,48 +88,18 @@ export class SSEClientTransport implements Transport { this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; - this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuth(); - } + private _authRetryInFlight = false; + private _last401Response?: Response; private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } - } else if (this._tokenProvider) { - const token = await this._tokenProvider(); - if (token) { - headers['Authorization'] = `Bearer ${token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._protocolVersion) { headers['mcp-protocol-version'] = this._protocolVersion; @@ -167,10 +126,13 @@ export class SSEClientTransport implements Transport { headers }); - if (response.status === 401 && response.headers.has('www-authenticate')) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + if (response.status === 401) { + this._last401Response = response; + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } } return response; @@ -179,17 +141,23 @@ export class SSEClientTransport implements Transport { this._abortController = new AbortController(); this._eventSource.onerror = event => { - if (event.code === 401) { - if (this._authProvider) { - this._authThenStart().then(resolve, reject); - return; - } - if (this._tokenProvider) { - const error = new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); - reject(error); - this.onerror?.(error); + if (event.code === 401 && this._authProvider) { + if (this._authProvider.onUnauthorized && this._last401Response && !this._authRetryInFlight) { + this._authRetryInFlight = true; + const response = this._last401Response; + this._authProvider + .onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }) + .then(() => this._startOrAuth()) + .then(resolve, reject) + .finally(() => { + this._authRetryInFlight = false; + }); return; } + const error = new UnauthorizedError(); + reject(error); + this.onerror?.(error); + return; } const error = new SseError(event.code, event.message, event); @@ -247,8 +215,8 @@ export class SSEClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!isOAuthClientProvider(this._authProvider)) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } const result = await auth(this._authProvider, { @@ -287,38 +255,39 @@ export class SSEClientTransport implements Transport { const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { - const text = await response.text?.().catch(() => null); - if (response.status === 401) { - if (this._authProvider) { + if (response.headers.has('www-authenticate')) { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + } - const result = await auth(this._authProvider, { + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + await this._authProvider.onUnauthorized({ + response, serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, fetchFn: this._fetchWithInit }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } - if (this._tokenProvider) { - throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + if (this._authProvider) { + await response.text?.().catch(() => {}); + throw new UnauthorizedError(); } } + const text = await response.text?.().catch(() => null); throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); } + this._authRetryInFlight = false; + // Release connection - POST responses don't have content we need await response.text?.().catch(() => {}); } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 180893472..6a3b6dd00 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -13,9 +13,8 @@ import { } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; -import type { TokenProvider } from './tokenProvider.js'; +import type { AuthProvider } from './auth.js'; +import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -86,28 +85,20 @@ export type StreamableHTTPClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode StreamableHTTPClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode StreamableHTTPClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation + * (which extends `AuthProvider`). Interactive flows: after {@linkcode UnauthorizedError}, redirect the + * user, then call {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization + * code before reconnecting. */ - authProvider?: OAuthClientProvider; - - /** - * A simple token provider for bearer authentication. - * - * Use this instead of `authProvider` when tokens are managed externally - * (e.g., upfront auth, gateway/proxy patterns, service accounts). - * - * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. - */ - tokenProvider?: TokenProvider; + authProvider?: AuthProvider; /** * Customizes HTTP requests to the server. @@ -142,14 +133,13 @@ export class StreamableHTTPClientTransport implements Transport { private _resourceMetadataUrl?: URL; private _scope?: string; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; - private _tokenProvider?: TokenProvider; + private _authProvider?: AuthProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; - private _hasCompletedAuthFlow = false; // Circuit breaker: detect auth success followed by immediate 401 + private _authRetryInFlight = false; // Circuit breaker: single retry per operation on 401 private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; @@ -164,50 +154,17 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = undefined; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; - this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuthSse({ resumptionToken: undefined }); - } - private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } - } else if (this._tokenProvider) { - const token = await this._tokenProvider(); - if (token) { - headers['Authorization'] = `Bearer ${token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._sessionId) { @@ -247,17 +204,34 @@ export class StreamableHTTPClientTransport implements Transport { }); if (!response.ok) { - await response.text?.().catch(() => {}); - if (response.status === 401) { - if (this._authProvider) { - return await this._authThenStart(); + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; } - if (this._tokenProvider) { - throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + try { + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + return await this._startOrAuthSse(options); + } finally { + this._authRetryInFlight = false; + } + } + if (this._authProvider) { + await response.text?.().catch(() => {}); + throw new UnauthorizedError(); } } + await response.text?.().catch(() => {}); + // 405 indicates that the server does not offer an SSE stream at GET endpoint // This is an expected case that should not trigger an error if (response.status === 405) { @@ -453,8 +427,8 @@ export class StreamableHTTPClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!isOAuthClientProvider(this._authProvider)) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } const result = await auth(this._authProvider, { @@ -514,47 +488,33 @@ export class StreamableHTTPClientTransport implements Transport { } if (!response.ok) { - const text = await response.text?.().catch(() => null); - if (response.status === 401) { - if (this._authProvider) { - // Prevent infinite recursion when server returns 401 after successful auth - if (this._hasCompletedAuthFlow) { - throw new SdkError( - SdkErrorCode.ClientHttpAuthentication, - 'Server returned 401 after successful authentication', - { - status: 401, - text - } - ); - } - + // Store WWW-Authenticate params for interactive finishAuth() path + if (response.headers.has('www-authenticate')) { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + } - const result = await auth(this._authProvider, { + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + await this._authProvider.onUnauthorized({ + response, serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, fetchFn: this._fetchWithInit }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - // Mark that we completed auth flow - this._hasCompletedAuthFlow = true; // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } - if (this._tokenProvider) { - throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + if (this._authProvider) { + await response.text?.().catch(() => {}); + throw new UnauthorizedError(); } } - if (response.status === 403 && this._authProvider) { + const text = await response.text?.().catch(() => null); + + if (response.status === 403 && isOAuthClientProvider(this._authProvider)) { const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); if (error === 'insufficient_scope') { @@ -600,7 +560,7 @@ export class StreamableHTTPClientTransport implements Transport { } // Reset auth loop flag on successful response - this._hasCompletedAuthFlow = false; + this._authRetryInFlight = false; this._lastUpscopingHeader = undefined; // If the response is 202 Accepted, there's no body to process @@ -650,6 +610,7 @@ export class StreamableHTTPClientTransport implements Transport { await response.text?.().catch(() => {}); } } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } diff --git a/packages/client/src/client/tokenProvider.ts b/packages/client/src/client/tokenProvider.ts deleted file mode 100644 index ab8f2bbc9..000000000 --- a/packages/client/src/client/tokenProvider.ts +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Minimal interface for providing bearer tokens to MCP transports. - * - * Unlike `OAuthClientProvider` which assumes interactive browser-redirect OAuth, - * `TokenProvider` is a simple function that returns a token string. - * Use this for upfront auth, gateway/proxy patterns, service accounts, - * or any scenario where tokens are managed externally. - * - * The provider is called before every request. If the server responds with 401, - * the transport throws `UnauthorizedError` without retrying — the provider is - * assumed to have already returned its freshest token. Catch `UnauthorizedError` - * to invalidate any external cache and reconnect. - * - * @example - * ```typescript - * // Static token - * const provider: TokenProvider = async () => "my-api-token"; - * - * // Token from secure storage with refresh - * const provider: TokenProvider = async () => { - * const token = await storage.getToken(); - * if (isExpiringSoon(token)) { - * return (await refreshToken(token)).accessToken; - * } - * return token.accessToken; - * }; - * ``` - */ -export type TokenProvider = () => Promise; - -/** - * Wraps a fetch function to automatically inject Bearer authentication headers. - * - * @example - * ```typescript - * const authedFetch = withBearerAuth(async () => getStoredToken()); - * const transport = new StreamableHTTPClientTransport(url, { fetch: authedFetch }); - * ``` - */ -export function withBearerAuth( - getToken: TokenProvider, - fetchFn: (url: string | URL, init?: RequestInit) => Promise = globalThis.fetch -): (url: string | URL, init?: RequestInit) => Promise { - return async (url, init) => { - const token = await getToken(); - if (token) { - const headers = new Headers(init?.headers); - headers.set('Authorization', `Bearer ${token}`); - return fetchFn(url, { ...init, headers }); - } - return fetchFn(url, init); - }; -} diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index b72b3e2d9..c37d9fe28 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -6,7 +6,6 @@ export * from './client/middleware.js'; export * from './client/sse.js'; export * from './client/stdio.js'; export * from './client/streamableHttp.js'; -export * from './client/tokenProvider.js'; export * from './client/websocket.js'; // experimental exports diff --git a/packages/client/test/client/auth.test.ts b/packages/client/test/client/auth.test.ts index 9d8f5cf6b..12d6793af 100644 --- a/packages/client/test/client/auth.test.ts +++ b/packages/client/test/client/auth.test.ts @@ -1038,6 +1038,8 @@ describe('OAuth Authorization', () => { client_secret: 'test-client-secret' }), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -1983,6 +1985,8 @@ describe('OAuth Authorization', () => { }, clientInformation: vi.fn(), tokens: vi.fn(), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2056,6 +2060,8 @@ describe('OAuth Authorization', () => { client_id: 'client-id' }), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2971,6 +2977,8 @@ describe('OAuth Authorization', () => { client_secret: 'secret123' }), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -3424,6 +3432,8 @@ describe('OAuth Authorization', () => { clientInformation: vi.fn().mockResolvedValue(undefined), saveClientInformation: vi.fn().mockResolvedValue(undefined), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn().mockResolvedValue(undefined), saveCodeVerifier: vi.fn().mockResolvedValue(undefined), diff --git a/packages/client/test/client/middleware.test.ts b/packages/client/test/client/middleware.test.ts index 64bbfa673..d2084af99 100644 --- a/packages/client/test/client/middleware.test.ts +++ b/packages/client/test/client/middleware.test.ts @@ -33,6 +33,8 @@ describe('withOAuth', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), @@ -759,6 +761,8 @@ describe('Integration Tests', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 0b0aff67b..3e1a3f895 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -7,8 +7,8 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; -import type { OAuthClientProvider } from '../../src/client/auth.js'; -import { UnauthorizedError } from '../../src/client/auth.js'; +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; +import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; import { SSEClientTransport } from '../../src/client/sse.js'; /** @@ -430,11 +430,15 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), + token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn() + invalidateCredentials: vi.fn(), + onUnauthorized: vi.fn(async ctx => { + await handleOAuthUnauthorized(mockAuthProvider, ctx); + }) }; }); @@ -443,6 +447,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -451,7 +456,7 @@ describe('SSEClientTransport', () => { await transport.start(); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); + expect(mockAuthProvider.token).toHaveBeenCalled(); }); it('attaches custom header from provider on initial SSE connection', async () => { @@ -459,6 +464,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' }; @@ -474,7 +480,7 @@ describe('SSEClientTransport', () => { expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); expect(lastServerRequest.headers['x-custom-header']).toBe('custom-value'); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); + expect(mockAuthProvider.token).toHaveBeenCalled(); }); it('attaches auth header from provider on POST requests', async () => { @@ -482,6 +488,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -499,7 +506,7 @@ describe('SSEClientTransport', () => { await transport.send(message); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); + expect(mockAuthProvider.token).toHaveBeenCalled(); }); it('attempts auth flow on 401 during SSE connection', async () => { @@ -631,6 +638,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' @@ -666,6 +674,7 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -795,6 +804,7 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -948,6 +958,7 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -1218,11 +1229,15 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn().mockResolvedValue(clientInfo), tokens: vi.fn().mockResolvedValue(tokens), + token: vi.fn(async () => tokens?.access_token), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn().mockResolvedValue('test-verifier'), - invalidateCredentials: vi.fn() + invalidateCredentials: vi.fn(), + onUnauthorized: vi.fn(async ctx => { + await handleOAuthUnauthorized(mockAuthProvider, ctx); + }) }; }; @@ -1528,4 +1543,99 @@ describe('SSEClientTransport', () => { expect(globalFetchSpy).not.toHaveBeenCalled(); }); }); + + describe('minimal AuthProvider (non-OAuth)', () => { + let postResponses: number[]; + let postCount: number; + + async function setupServer(): Promise { + await resourceServer.close(); + + postCount = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.method === 'GET') { + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}post\n\n`); + return; + } + + if (req.method === 'POST') { + const status = postResponses[postCount] ?? 200; + postCount++; + res.writeHead(status).end(); + return; + } + }); + + resourceBaseUrl = await listenOnRandomPort(resourceServer); + } + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: '1' }; + + it('throws UnauthorizedError on POST 401 when onUnauthorized is not provided', async () => { + postResponses = [401]; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + }); + + it('enforces circuit breaker on double-401: onUnauthorized called once, then throws', async () => { + postResponses = [401, 401]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(postCount).toBe(2); + }); + + it('resets retry guard when onUnauthorized throws, allowing retry on next send', async () => { + postResponses = [401, 401, 200]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + // First send: 401 → onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + + // Second send: flag should be reset, so 401 → onUnauthorized (succeeds) → retry → 200 + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(postCount).toBe(3); + }); + + it('throws when finishAuth is called with a non-OAuth AuthProvider', async () => { + postResponses = []; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + }); }); diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 0398964d3..abebaed33 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -3,7 +3,7 @@ import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontex import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; -import { UnauthorizedError } from '../../src/client/auth.js'; +import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; import type { StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -21,11 +21,15 @@ describe('StreamableHTTPClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), + token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn() + invalidateCredentials: vi.fn(), + onUnauthorized: vi.fn(async ctx => { + await handleOAuthUnauthorized(mockAuthProvider, ctx); + }) }; transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); vi.spyOn(globalThis, 'fetch'); @@ -1678,7 +1682,7 @@ describe('StreamableHTTPClientTransport', () => { // Retry the original request - still 401 (broken server) .mockResolvedValueOnce(unauthedResponse); - await expect(transport.send(message)).rejects.toThrow('Server returned 401 after successful authentication'); + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ access_token: 'new-access-token', token_type: 'Bearer', diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts index 111a7b6a5..c683a4012 100644 --- a/packages/client/test/client/tokenProvider.test.ts +++ b/packages/client/test/client/tokenProvider.test.ts @@ -1,68 +1,11 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/core'; import type { Mock } from 'vitest'; -import type { TokenProvider } from '../../src/client/tokenProvider.js'; -import { withBearerAuth } from '../../src/client/tokenProvider.js'; -import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; +import type { AuthProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; +import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; -describe('withBearerAuth', () => { - it('should inject Authorization header when token is available', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - const getToken: TokenProvider = async () => 'test-token-123'; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/api', { method: 'POST' }); - - expect(mockFetch).toHaveBeenCalledOnce(); - const [url, init] = mockFetch.mock.calls[0]!; - expect(url).toBe('https://example.com/api'); - expect(new Headers(init.headers).get('Authorization')).toBe('Bearer test-token-123'); - }); - - it('should not inject Authorization header when token is undefined', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - const getToken: TokenProvider = async () => undefined; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/api', { method: 'POST' }); - - expect(mockFetch).toHaveBeenCalledOnce(); - const [, init] = mockFetch.mock.calls[0]!; - expect(new Headers(init?.headers).has('Authorization')).toBe(false); - }); - - it('should preserve existing headers', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - const getToken: TokenProvider = async () => 'my-token'; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/api', { - headers: { 'Content-Type': 'application/json', 'X-Custom': 'value' } - }); - - const [, init] = mockFetch.mock.calls[0]!; - const headers = new Headers(init.headers); - expect(headers.get('Authorization')).toBe('Bearer my-token'); - expect(headers.get('Content-Type')).toBe('application/json'); - expect(headers.get('X-Custom')).toBe('value'); - }); - - it('should call getToken on every request', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - let callCount = 0; - const getToken: TokenProvider = async () => `token-${++callCount}`; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/1'); - await authedFetch('https://example.com/2'); - - expect(new Headers(mockFetch.mock.calls[0]![1]!.headers).get('Authorization')).toBe('Bearer token-1'); - expect(new Headers(mockFetch.mock.calls[1]![1]!.headers).get('Authorization')).toBe('Bearer token-2'); - }); -}); - -describe('StreamableHTTPClientTransport with tokenProvider', () => { +describe('StreamableHTTPClientTransport with AuthProvider', () => { let transport: StreamableHTTPClientTransport; afterEach(async () => { @@ -70,48 +13,28 @@ describe('StreamableHTTPClientTransport with tokenProvider', () => { vi.clearAllMocks(); }); - it('should set Authorization header from tokenProvider', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => 'my-bearer-token'); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: 'test-id' }; + + it('should set Authorization header from AuthProvider.token()', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'my-bearer-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); await transport.send(message); - expect(tokenProvider).toHaveBeenCalled(); + expect(authProvider.token).toHaveBeenCalled(); const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; expect(init.headers.get('Authorization')).toBe('Bearer my-bearer-token'); }); - it('should not set Authorization header when tokenProvider returns undefined', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => undefined); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + it('should not set Authorization header when token() returns undefined', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => undefined) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); await transport.send(message); @@ -119,18 +42,11 @@ describe('StreamableHTTPClientTransport with tokenProvider', () => { expect(init.headers.has('Authorization')).toBe(false); }); - it('should throw UnauthorizedError on 401 when using tokenProvider', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => 'rejected-token'); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + it('should throw UnauthorizedError on 401 when onUnauthorized is not provided', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'rejected-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: false, status: 401, @@ -139,70 +55,125 @@ describe('StreamableHTTPClientTransport with tokenProvider', () => { }); await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(tokenProvider).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(1); }); - it('should prefer authProvider over tokenProvider when both are set', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => 'token-provider-value'); - const authProvider = { - get redirectUrl() { - return 'http://localhost/callback'; - }, - get clientMetadata() { - return { redirect_uris: ['http://localhost/callback'] }; - }, - clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-secret' })), - tokens: vi.fn(() => ({ access_token: 'auth-provider-value', token_type: 'bearer' })), - saveTokens: vi.fn(), - redirectToAuthorization: vi.fn(), - saveCodeVerifier: vi.fn(), - codeVerifier: vi.fn() + it('should call onUnauthorized and retry once on 401', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider, tokenProvider }); + it('should throw UnauthorizedError if retry after onUnauthorized also gets 401', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + }); + + it('should reset retry guard when onUnauthorized throws, allowing retry on next send', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); - await transport.send(message); + // First send: onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); - // authProvider should be used, not tokenProvider - expect(tokenProvider).not.toHaveBeenCalled(); - const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; - expect(init.headers.get('Authorization')).toBe('Bearer auth-provider-value'); + // Second send: flag should be reset, so onUnauthorized gets a second chance + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); }); - it('should work with no auth at all', async () => { + it('should work with no authProvider at all', async () => { transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); await transport.send(message); const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; expect(init.headers.has('Authorization')).toBe(false); }); + + it('should throw when finishAuth is called with a non-OAuth AuthProvider', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + + it('should throw UnauthorizedError on GET-SSE 401 with no onUnauthorized (via resumeStream)', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.resumeStream('last-event-id')).rejects.toThrow(UnauthorizedError); + }); + + it('should call onUnauthorized and retry on GET-SSE 401 (via resumeStream)', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + // First GET: 401. Second GET (retry): 405 (server doesn't offer SSE — clean exit) + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 405, headers: new Headers(), text: async () => '' }); + + await transport.resumeStream('last-event-id'); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); }); From 65b5099d1064d4489c1c097337fd2254f8dab0c1 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 21:44:00 +0000 Subject: [PATCH 3/5] refactor: adapt OAuthClientProvider at transport boundary (non-breaking) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Alternative to the breaking 'extends AuthProvider' approach. Instead of requiring OAuthClientProvider implementations to add token() + onUnauthorized(), the transport constructor classifies the authProvider option once and adapts OAuth providers via adaptOAuthProvider(). - OAuthClientProvider interface is unchanged from v1 - Transport option: authProvider?: AuthProvider | OAuthClientProvider - Constructor: if OAuth, store both original (for finishAuth/403) and adapted (for _commonHeaders/401) — classification happens once, no runtime type guards in the hot path - 4 built-in providers no longer need token()/onUnauthorized() - migration.md/migration-SKILL.md entries removed — nothing to migrate - Changeset downgraded to minor Net -142 lines vs the breaking approach. Same transport simplification, zero migration burden. Duck-typing via isOAuthClientProvider() ('tokens' + 'clientMetadata' in provider) at construction only. --- .changeset/token-provider-composable-auth.md | 19 ++++--- docs/migration-SKILL.md | 33 ------------ docs/migration.md | 38 -------------- .../client/src/simpleOAuthClientProvider.ts | 17 +------ packages/client/src/client/auth.examples.ts | 10 +--- packages/client/src/client/auth.ts | 50 ++++++++++--------- packages/client/src/client/authExtensions.ts | 35 +------------ packages/client/src/client/sse.ts | 18 ++++--- packages/client/src/client/streamableHttp.ts | 29 +++++++---- packages/client/test/client/auth.test.ts | 10 ---- .../client/test/client/middleware.test.ts | 4 -- packages/client/test/client/sse.test.ts | 27 +++------- .../client/test/client/streamableHttp.test.ts | 8 +-- 13 files changed, 78 insertions(+), 220 deletions(-) diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md index c4ea7f5e3..f5c064e7f 100644 --- a/.changeset/token-provider-composable-auth.md +++ b/.changeset/token-provider-composable-auth.md @@ -1,17 +1,16 @@ --- -'@modelcontextprotocol/client': major +'@modelcontextprotocol/client': minor --- -Unify client auth around a minimal `AuthProvider` interface - -**Breaking:** Transport `authProvider` option now accepts the new minimal `AuthProvider` interface instead of being typed as `OAuthClientProvider`. `OAuthClientProvider` now extends `AuthProvider`, so most existing code continues to work — but custom implementations must add a `token()` method. +Add `AuthProvider` for composable bearer-token auth; transports adapt `OAuthClientProvider` automatically - New `AuthProvider` interface: `{ token(): Promise; onUnauthorized?(ctx): Promise }`. Transports call `token()` before every request and `onUnauthorized()` on 401 (then retry once). -- `OAuthClientProvider` extends `AuthProvider`. Custom implementations must add `token()` (typically `return (await this.tokens())?.access_token`) and optionally `onUnauthorized()` (typically `return handleOAuthUnauthorized(this, ctx)`). -- Built-in providers (`ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, `CrossAppAccessProvider`) implement both methods — existing user code is unchanged. -- New `handleOAuthUnauthorized(provider, ctx)` helper runs the standard OAuth flow from `onUnauthorized`. -- New `isOAuthClientProvider()` type guard for gating OAuth-specific transport features like `finishAuth()`. -- Transports no longer inline OAuth orchestration — ~50 lines of `auth()` calls, WWW-Authenticate parsing, and circuit-breaker state moved into `onUnauthorized()` implementations. +- Transport `authProvider` option now accepts `AuthProvider | OAuthClientProvider`. OAuth providers are adapted internally via `adaptOAuthProvider()` — no changes needed to existing `OAuthClientProvider` implementations. +- For simple bearer tokens (API keys, gateway-managed tokens, service accounts): `{ authProvider: { token: async () => myKey } }` — one-line object literal, no class. +- New `adaptOAuthProvider(provider)` export for explicit adaptation. +- New `handleOAuthUnauthorized(provider, ctx)` helper — the standard OAuth `onUnauthorized` behavior. +- New `isOAuthClientProvider()` type guard. +- New `UnauthorizedContext` type. - Exported previously-internal auth helpers for building custom flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. -See `docs/migration.md` for before/after examples. +Transports are simplified internally — ~50 lines of inline OAuth orchestration (auth() calls, WWW-Authenticate parsing, circuit-breaker state) moved into the adapter's `onUnauthorized()` implementation. `OAuthClientProvider` itself is unchanged. diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index cdec2b9a9..cfd540ae3 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -203,39 +203,6 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; if (error instanceof OAuthError && error.code === OAuthErrorCode.InvalidClient) { ... } ``` -### Client `OAuthClientProvider` now extends `AuthProvider` - -Transport `authProvider` options now accept the minimal `AuthProvider` interface. `OAuthClientProvider` extends it, so built-in providers work unchanged — custom implementations must add `token()`. - -| v1 pattern | v2 equivalent | -| ----------------------------------------------------- | --------------------------------------------------------------------------- | -| `authProvider?: OAuthClientProvider` (option type) | `authProvider?: AuthProvider` (accepts `OAuthClientProvider` via extension) | -| Transport reads `authProvider.tokens()?.access_token` | Transport calls `authProvider.token()` | -| Transport inlines `auth()` on 401 | Transport calls `authProvider.onUnauthorized()` then retries once | -| `_hasCompletedAuthFlow` circuit breaker | `_authRetryInFlight` circuit breaker | -| N/A | `handleOAuthUnauthorized(provider, ctx)` — standard `onUnauthorized` impl | -| N/A | `isOAuthClientProvider(provider)` — type guard | -| N/A | `UnauthorizedContext` — `{ response, serverUrl, fetchFn }` | - -**For custom `OAuthClientProvider` implementations**, add both methods (both required — TypeScript enforces this): - -```typescript -async token(): Promise { - return (await this.tokens())?.access_token; -} - -async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); -} -``` - -**For simple bearer tokens** (previously required stubbing 8 `OAuthClientProvider` members): - -```typescript -// v2: one-liner -const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; -``` - **Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). diff --git a/docs/migration.md b/docs/migration.md index e541bb01b..14f888314 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -648,44 +648,6 @@ The new design: - `ProtocolError` with `ProtocolErrorCode`: For errors that are serialized and sent as JSON-RPC error responses - `SdkError` with `SdkErrorCode`: For local errors that are thrown/rejected locally and never leave the SDK -### Client `authProvider` unified around `AuthProvider` - -Transport `authProvider` options now accept the minimal `AuthProvider` interface rather than being typed as `OAuthClientProvider`. `OAuthClientProvider` extends `AuthProvider`, so built-in providers and most existing code continue to work unchanged — but custom -`OAuthClientProvider` implementations must add a `token()` method. - -**What changed:** transports now call `authProvider.token()` before every request (instead of `authProvider.tokens()?.access_token`), and call `authProvider.onUnauthorized()` on 401 (instead of inlining OAuth orchestration). One code path handles both simple bearer tokens and -full OAuth. - -**If you implement `OAuthClientProvider` directly** (the interactive browser-redirect pattern), add: - -```ts -class MyProvider implements OAuthClientProvider { - // ...existing 8 required members... - - // Required: return the current access token - async token(): Promise { - return (await this.tokens())?.access_token; - } - - // Required: runs the OAuth flow on 401 — without this, 401 throws with no recovery - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } -} -``` - -**If you use `ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, or `CrossAppAccessProvider`** — no change. These already implement both methods. - -**If you have simple bearer tokens** (API keys, gateway tokens, externally-managed tokens), you can now skip `OAuthClientProvider` entirely: - -```ts -// Before: had to implement 8 OAuthClientProvider members with no-op stubs -// After: -const transport = new StreamableHTTPClientTransport(url, { - authProvider: { token: async () => process.env.API_KEY } -}); -``` - ### OAuth error refactoring The OAuth error classes have been consolidated into a single `OAuthError` class with an `OAuthErrorCode` enum. diff --git a/examples/client/src/simpleOAuthClientProvider.ts b/examples/client/src/simpleOAuthClientProvider.ts index 6248d1f90..96655c9f6 100644 --- a/examples/client/src/simpleOAuthClientProvider.ts +++ b/examples/client/src/simpleOAuthClientProvider.ts @@ -1,11 +1,4 @@ -import type { - OAuthClientInformationMixed, - OAuthClientMetadata, - OAuthClientProvider, - OAuthTokens, - UnauthorizedContext -} from '@modelcontextprotocol/client'; -import { handleOAuthUnauthorized } from '@modelcontextprotocol/client'; +import type { OAuthClientInformationMixed, OAuthClientMetadata, OAuthClientProvider, OAuthTokens } from '@modelcontextprotocol/client'; /** * In-memory OAuth client provider for demonstration purposes @@ -31,14 +24,6 @@ export class InMemoryOAuthClientProvider implements OAuthClientProvider { private _onRedirect: (url: URL) => void; - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): string | URL { return this._redirectUrl; } diff --git a/packages/client/src/client/auth.examples.ts b/packages/client/src/client/auth.examples.ts index 15b6487a7..17c04e6a0 100644 --- a/packages/client/src/client/auth.examples.ts +++ b/packages/client/src/client/auth.examples.ts @@ -9,8 +9,8 @@ import type { AuthorizationServerMetadata } from '@modelcontextprotocol/core'; -import type { OAuthClientProvider, UnauthorizedContext } from './auth.js'; -import { fetchToken, handleOAuthUnauthorized } from './auth.js'; +import type { OAuthClientProvider } from './auth.js'; +import { fetchToken } from './auth.js'; /** * Base class providing no-op implementations of required OAuthClientProvider methods. @@ -29,12 +29,6 @@ abstract class MyProviderBase implements OAuthClientProvider { tokens(): undefined { return; } - async token(): Promise { - return undefined; - } - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } saveTokens() { return Promise.resolve(); } diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index bca45ad66..d26bb5727 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -60,8 +60,8 @@ export interface UnauthorizedContext { * const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; * ``` * - * For OAuth flows, use {@linkcode OAuthClientProvider} which extends this interface, - * or one of the built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.). + * For OAuth flows, pass an {@linkcode OAuthClientProvider} directly — transports + * accept either shape and adapt OAuth providers automatically via {@linkcode adaptOAuthProvider}. */ export interface AuthProvider { /** @@ -82,19 +82,17 @@ export interface AuthProvider { } /** - * Type guard: checks whether an `AuthProvider` is a full `OAuthClientProvider`. - * Use this to gate OAuth-specific transport features like `finishAuth()` and - * 403 scope upscoping. + * Type guard distinguishing `OAuthClientProvider` from a minimal `AuthProvider`. + * Transports use this at construction time to classify the `authProvider` option. */ -export function isOAuthClientProvider(provider: AuthProvider | undefined): provider is OAuthClientProvider { +export function isOAuthClientProvider(provider: AuthProvider | OAuthClientProvider | undefined): provider is OAuthClientProvider { return provider !== undefined && 'tokens' in provider && 'clientMetadata' in provider; } /** - * Default `onUnauthorized` implementation for OAuth providers: extracts + * Standard `onUnauthorized` behavior for OAuth providers: extracts * `WWW-Authenticate` parameters from the 401 response and runs {@linkcode auth}. - * Built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.) - * delegate to this. Custom `OAuthClientProvider` implementations can do the same. + * Used by {@linkcode adaptOAuthProvider} to bridge `OAuthClientProvider` to `AuthProvider`. */ export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx: UnauthorizedContext): Promise { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(ctx.response); @@ -109,6 +107,22 @@ export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx } } +/** + * Adapts an `OAuthClientProvider` to the minimal `AuthProvider` interface that + * transports consume. Called once at transport construction — the transport stores + * the adapted provider for `_commonHeaders()` and 401 handling, while keeping the + * original `OAuthClientProvider` for OAuth-specific paths (`finishAuth()`, 403 upscoping). + */ +export function adaptOAuthProvider(provider: OAuthClientProvider): AuthProvider { + return { + token: async () => { + const tokens = await provider.tokens(); + return tokens?.access_token; + }, + onUnauthorized: async ctx => handleOAuthUnauthorized(provider, ctx) + }; +} + /** * Implements an end-to-end OAuth client to be used with one MCP server. * @@ -116,21 +130,11 @@ export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx * meaning of which is application-defined. Tokens, authorization codes, and * code verifiers should not cross different sessions. * - * Extends {@linkcode AuthProvider} — implementations must provide `token()` - * (typically `return (await this.tokens())?.access_token`) and `onUnauthorized()` - * (typically `return handleOAuthUnauthorized(this, ctx)`). Without `onUnauthorized()`, - * 401 responses throw immediately with no token refresh or reauth. + * Transports accept `OAuthClientProvider` directly via the `authProvider` option — + * they adapt it to {@linkcode AuthProvider} internally via {@linkcode adaptOAuthProvider}. + * No changes are needed to existing implementations. */ -export interface OAuthClientProvider extends AuthProvider { - /** - * Runs the OAuth re-authentication flow on 401. Required on `OAuthClientProvider` - * (optional on the base `AuthProvider`) because OAuth providers that omit this lose - * all 401 recovery — no token refresh, no redirect to authorization. - * - * Most implementations should delegate: `return handleOAuthUnauthorized(this, ctx)`. - */ - onUnauthorized(ctx: UnauthorizedContext): Promise; - +export interface OAuthClientProvider { /** * The URL to redirect the user agent to after authorization. * Return `undefined` for non-interactive flows that don't require user interaction diff --git a/packages/client/src/client/authExtensions.ts b/packages/client/src/client/authExtensions.ts index 7508298b7..ae614f7ba 100644 --- a/packages/client/src/client/authExtensions.ts +++ b/packages/client/src/client/authExtensions.ts @@ -8,8 +8,7 @@ import type { FetchLike, OAuthClientInformation, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/core'; import type { CryptoKey, JWK } from 'jose'; -import type { AddClientAuthentication, OAuthClientProvider, UnauthorizedContext } from './auth.js'; -import { handleOAuthUnauthorized } from './auth.js'; +import type { AddClientAuthentication, OAuthClientProvider } from './auth.js'; /** * Helper to produce a `private_key_jwt` client authentication function. @@ -151,14 +150,6 @@ export class ClientCredentialsProvider implements OAuthClientProvider { }; } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } @@ -278,14 +269,6 @@ export class PrivateKeyJwtProvider implements OAuthClientProvider { }); } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } @@ -383,14 +366,6 @@ export class StaticPrivateKeyJwtProvider implements OAuthClientProvider { }; } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } @@ -589,14 +564,6 @@ export class CrossAppAccessProvider implements OAuthClientProvider { this._fetchFn = options.fetchFn ?? fetch; } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 025c785ea..c613bd2b1 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -3,8 +3,8 @@ import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, SdkError, import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; -import type { AuthProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; export class SseError extends Error { constructor( @@ -35,7 +35,7 @@ export type SSEClientTransportOptions = { * Interactive flows: after {@linkcode UnauthorizedError}, redirect the user, then call * {@linkcode SSEClientTransport.finishAuth | finishAuth} with the authorization code before reconnecting. */ - authProvider?: AuthProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes the initial SSE request to the server (the request that begins the stream). @@ -73,6 +73,7 @@ export class SSEClientTransport implements Transport { private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -87,7 +88,12 @@ export class SSEClientTransport implements Transport { this._scope = undefined; this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } @@ -215,11 +221,11 @@ export class SSEClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!isOAuthClientProvider(this._authProvider)) { + if (!this._oauthProvider) { throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 6a3b6dd00..5d1d4612c 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -13,8 +13,8 @@ import { } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; -import type { AuthProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -94,11 +94,12 @@ export type StreamableHTTPClientTransportOptions = { * For simple bearer tokens: `{ token: async () => myApiKey }`. * * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation - * (which extends `AuthProvider`). Interactive flows: after {@linkcode UnauthorizedError}, redirect the - * user, then call {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization - * code before reconnecting. + * directly — the transport adapts it to `AuthProvider` internally. Interactive flows: after + * {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization code before + * reconnecting. */ - authProvider?: AuthProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes HTTP requests to the server. @@ -134,6 +135,7 @@ export class StreamableHTTPClientTransport implements Transport { private _scope?: string; private _requestInit?: RequestInit; private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; @@ -153,7 +155,12 @@ export class StreamableHTTPClientTransport implements Transport { this._resourceMetadataUrl = undefined; this._scope = undefined; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; @@ -427,11 +434,11 @@ export class StreamableHTTPClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!isOAuthClientProvider(this._authProvider)) { + if (!this._oauthProvider) { throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, @@ -514,7 +521,7 @@ export class StreamableHTTPClientTransport implements Transport { const text = await response.text?.().catch(() => null); - if (response.status === 403 && isOAuthClientProvider(this._authProvider)) { + if (response.status === 403 && this._oauthProvider) { const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); if (error === 'insufficient_scope') { @@ -538,7 +545,7 @@ export class StreamableHTTPClientTransport implements Transport { // Mark that upscoping was tried. this._lastUpscopingHeader = wwwAuthHeader ?? undefined; - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, scope: this._scope, diff --git a/packages/client/test/client/auth.test.ts b/packages/client/test/client/auth.test.ts index 12d6793af..9d8f5cf6b 100644 --- a/packages/client/test/client/auth.test.ts +++ b/packages/client/test/client/auth.test.ts @@ -1038,8 +1038,6 @@ describe('OAuth Authorization', () => { client_secret: 'test-client-secret' }), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -1985,8 +1983,6 @@ describe('OAuth Authorization', () => { }, clientInformation: vi.fn(), tokens: vi.fn(), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2060,8 +2056,6 @@ describe('OAuth Authorization', () => { client_id: 'client-id' }), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2977,8 +2971,6 @@ describe('OAuth Authorization', () => { client_secret: 'secret123' }), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -3432,8 +3424,6 @@ describe('OAuth Authorization', () => { clientInformation: vi.fn().mockResolvedValue(undefined), saveClientInformation: vi.fn().mockResolvedValue(undefined), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn().mockResolvedValue(undefined), saveCodeVerifier: vi.fn().mockResolvedValue(undefined), diff --git a/packages/client/test/client/middleware.test.ts b/packages/client/test/client/middleware.test.ts index d2084af99..64bbfa673 100644 --- a/packages/client/test/client/middleware.test.ts +++ b/packages/client/test/client/middleware.test.ts @@ -33,8 +33,6 @@ describe('withOAuth', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), @@ -761,8 +759,6 @@ describe('Integration Tests', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 3e1a3f895..10fcd76bc 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -8,7 +8,7 @@ import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; -import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; import { SSEClientTransport } from '../../src/client/sse.js'; /** @@ -430,15 +430,11 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), - token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn(), - onUnauthorized: vi.fn(async ctx => { - await handleOAuthUnauthorized(mockAuthProvider, ctx); - }) + invalidateCredentials: vi.fn() }; }); @@ -447,7 +443,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -456,7 +451,7 @@ describe('SSEClientTransport', () => { await transport.start(); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.token).toHaveBeenCalled(); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); it('attaches custom header from provider on initial SSE connection', async () => { @@ -464,7 +459,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' }; @@ -480,7 +474,7 @@ describe('SSEClientTransport', () => { expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); expect(lastServerRequest.headers['x-custom-header']).toBe('custom-value'); - expect(mockAuthProvider.token).toHaveBeenCalled(); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); it('attaches auth header from provider on POST requests', async () => { @@ -488,7 +482,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -506,7 +499,7 @@ describe('SSEClientTransport', () => { await transport.send(message); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.token).toHaveBeenCalled(); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); it('attempts auth flow on 401 during SSE connection', async () => { @@ -638,7 +631,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' @@ -674,7 +666,6 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -804,7 +795,6 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -958,7 +948,6 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -1229,15 +1218,11 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn().mockResolvedValue(clientInfo), tokens: vi.fn().mockResolvedValue(tokens), - token: vi.fn(async () => tokens?.access_token), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn().mockResolvedValue('test-verifier'), - invalidateCredentials: vi.fn(), - onUnauthorized: vi.fn(async ctx => { - await handleOAuthUnauthorized(mockAuthProvider, ctx); - }) + invalidateCredentials: vi.fn() }; }; diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index abebaed33..ddd24608f 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -3,7 +3,7 @@ import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontex import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; -import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; import type { StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -21,15 +21,11 @@ describe('StreamableHTTPClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), - token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn(), - onUnauthorized: vi.fn(async ctx => { - await handleOAuthUnauthorized(mockAuthProvider, ctx); - }) + invalidateCredentials: vi.fn() }; transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); vi.spyOn(globalThis, 'fetch'); From f2c32e8b0c743e32ea378a9a1abd69d6fc3ea1e5 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 22:05:49 +0000 Subject: [PATCH 4/5] fix: address round-4 review comments on 401 handling Four fixes from claude[bot] review on the AuthProvider approach: 1. Drain 401 response body after onUnauthorized() succeeds, before the retry. Unconsumed bodies block socket recycling in undici. All three 401 sites now drain before return. 2. _startOrAuthSse() 401 retry was return await, causing onerror to fire twice (recursive call's catch + outer catch both fire). Changed to return (not awaited) matching the send() pattern. Removed the try/finally, added flag reset to success path + outer catch instead. 3. Migration docs still referenced SdkErrorCode.ClientHttpAuthentication for the 401-after-auth case, but that throw site was replaced by _authRetryInFlight which throws UnauthorizedError. Updated both migration.md and migration-SKILL.md. 4. Pre-existing: 403 upscoping auth() call passed this._fetch instead of this._fetchWithInit, dropping custom requestInit options during token requests. All other auth() calls in this transport already used _fetchWithInit. --- docs/migration-SKILL.md | 3 +-- docs/migration.md | 9 ++++---- packages/client/src/client/sse.ts | 1 + packages/client/src/client/streamableHttp.ts | 23 ++++++++++---------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index cfd540ae3..6add1bc02 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -116,7 +116,7 @@ Two error classes now exist: | Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | | HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | | Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | -| 401 after auth flow | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | +| 401 after auth flow | `StreamableHTTPError` | `UnauthorizedError` | | 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | | Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | | Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | @@ -131,7 +131,6 @@ New `SdkErrorCode` enum values: - `SdkErrorCode.ConnectionClosed` = `'CONNECTION_CLOSED'` - `SdkErrorCode.SendFailed` = `'SEND_FAILED'` - `SdkErrorCode.ClientHttpNotImplemented` = `'CLIENT_HTTP_NOT_IMPLEMENTED'` -- `SdkErrorCode.ClientHttpAuthentication` = `'CLIENT_HTTP_AUTHENTICATION'` - `SdkErrorCode.ClientHttpForbidden` = `'CLIENT_HTTP_FORBIDDEN'` - `SdkErrorCode.ClientHttpUnexpectedContent` = `'CLIENT_HTTP_UNEXPECTED_CONTENT'` - `SdkErrorCode.ClientHttpFailedToOpenStream` = `'CLIENT_HTTP_FAILED_TO_OPEN_STREAM'` diff --git a/docs/migration.md b/docs/migration.md index 14f888314..6ec9616dc 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -585,7 +585,7 @@ The new `SdkErrorCode` enum contains string-valued codes for local SDK errors: | `SdkErrorCode.ConnectionClosed` | Connection was closed | | `SdkErrorCode.SendFailed` | Failed to send message | | `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | -| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after successful auth | +| `UnauthorizedError` (thrown, not `SdkError`) | Server returned 401 after re-auth attempt | | `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | | `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | | `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | @@ -617,11 +617,10 @@ import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; try { await transport.send(message); } catch (error) { - if (error instanceof SdkError) { + if (error instanceof UnauthorizedError) { + console.log('Token rejected — reconnect with fresh credentials'); + } else if (error instanceof SdkError) { switch (error.code) { - case SdkErrorCode.ClientHttpAuthentication: - console.log('Auth failed after completing auth flow'); - break; case SdkErrorCode.ClientHttpForbidden: console.log('Forbidden after upscoping attempt'); break; diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index c613bd2b1..a8646adea 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -275,6 +275,7 @@ export class SSEClientTransport implements Transport { serverUrl: this._url, fetchFn: this._fetchWithInit }); + await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 5d1d4612c..aad670c97 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -220,16 +220,14 @@ export class StreamableHTTPClientTransport implements Transport { if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { this._authRetryInFlight = true; - try { - await this._authProvider.onUnauthorized({ - response, - serverUrl: this._url, - fetchFn: this._fetchWithInit - }); - return await this._startOrAuthSse(options); - } finally { - this._authRetryInFlight = false; - } + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this._startOrAuthSse(options); } if (this._authProvider) { await response.text?.().catch(() => {}); @@ -251,8 +249,10 @@ export class StreamableHTTPClientTransport implements Transport { }); } + this._authRetryInFlight = false; this._handleSseStream(response.body, options, true); } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } @@ -510,6 +510,7 @@ export class StreamableHTTPClientTransport implements Transport { serverUrl: this._url, fetchFn: this._fetchWithInit }); + await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } @@ -549,7 +550,7 @@ export class StreamableHTTPClientTransport implements Transport { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, scope: this._scope, - fetchFn: this._fetch + fetchFn: this._fetchWithInit }); if (result !== 'AUTHORIZED') { From 3aa0cd6917c90a956cdcf7c37da51e6180b59387 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 22:11:34 +0000 Subject: [PATCH 5/5] fix: restore SdkError(ClientHttpAuthentication) for circuit-breaker case MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 401-after-re-auth case (circuit breaker trips) should throw a distinct error from the normal 'token rejected' case: - First 401 with no onUnauthorized → UnauthorizedError — caller re-auths externally and reconnects - Second 401 after onUnauthorized succeeded → SdkError with ClientHttpAuthentication — server is misbehaving, don't blindly retry, escalate The previous commit collapsed these into UnauthorizedError, which risks callers catching it, re-authing, and looping. Restored the SdkError throw at all three 401 sites when _authRetryInFlight is already set. Reverted migration doc changes — ClientHttpAuthentication is not dead code. --- docs/migration-SKILL.md | 27 +++++++------- docs/migration.md | 37 ++++++++++--------- packages/client/src/client/sse.ts | 7 +++- packages/client/src/client/streamableHttp.ts | 14 ++++++- packages/client/test/client/sse.test.ts | 8 ++-- .../client/test/client/streamableHttp.test.ts | 4 +- .../client/test/client/tokenProvider.test.ts | 7 +++- 7 files changed, 64 insertions(+), 40 deletions(-) diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 6add1bc02..957a583df 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -107,19 +107,19 @@ Two error classes now exist: - **`ProtocolError`** (renamed from `McpError`): Protocol errors that cross the wire as JSON-RPC responses - **`SdkError`** (new): Local SDK errors that never cross the wire -| Error scenario | v1 type | v2 type | -| -------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | -| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | -| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | -| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | -| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | -| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | -| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | -| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | -| 401 after auth flow | `StreamableHTTPError` | `UnauthorizedError` | -| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | -| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | -| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | +| Error scenario | v1 type | v2 type | +| --------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | +| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | +| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | +| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | +| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | +| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | +| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | +| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | +| 401 after re-auth (circuit break) | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | +| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | +| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | +| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | New `SdkErrorCode` enum values: @@ -131,6 +131,7 @@ New `SdkErrorCode` enum values: - `SdkErrorCode.ConnectionClosed` = `'CONNECTION_CLOSED'` - `SdkErrorCode.SendFailed` = `'SEND_FAILED'` - `SdkErrorCode.ClientHttpNotImplemented` = `'CLIENT_HTTP_NOT_IMPLEMENTED'` +- `SdkErrorCode.ClientHttpAuthentication` = `'CLIENT_HTTP_AUTHENTICATION'` - `SdkErrorCode.ClientHttpForbidden` = `'CLIENT_HTTP_FORBIDDEN'` - `SdkErrorCode.ClientHttpUnexpectedContent` = `'CLIENT_HTTP_UNEXPECTED_CONTENT'` - `SdkErrorCode.ClientHttpFailedToOpenStream` = `'CLIENT_HTTP_FAILED_TO_OPEN_STREAM'` diff --git a/docs/migration.md b/docs/migration.md index 6ec9616dc..0614b572e 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -575,21 +575,21 @@ try { The new `SdkErrorCode` enum contains string-valued codes for local SDK errors: -| Code | Description | -| ------------------------------------------------- | ------------------------------------------ | -| `SdkErrorCode.NotConnected` | Transport is not connected | -| `SdkErrorCode.AlreadyConnected` | Transport is already connected | -| `SdkErrorCode.NotInitialized` | Protocol is not initialized | -| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | -| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | -| `SdkErrorCode.ConnectionClosed` | Connection was closed | -| `SdkErrorCode.SendFailed` | Failed to send message | -| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | -| `UnauthorizedError` (thrown, not `SdkError`) | Server returned 401 after re-auth attempt | -| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | -| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | -| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | -| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | +| Code | Description | +| ------------------------------------------------- | ------------------------------------------- | +| `SdkErrorCode.NotConnected` | Transport is not connected | +| `SdkErrorCode.AlreadyConnected` | Transport is already connected | +| `SdkErrorCode.NotInitialized` | Protocol is not initialized | +| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | +| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | +| `SdkErrorCode.ConnectionClosed` | Connection was closed | +| `SdkErrorCode.SendFailed` | Failed to send message | +| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | +| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after re-authentication | +| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | +| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | +| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | +| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | #### `StreamableHTTPError` removed @@ -617,10 +617,11 @@ import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; try { await transport.send(message); } catch (error) { - if (error instanceof UnauthorizedError) { - console.log('Token rejected — reconnect with fresh credentials'); - } else if (error instanceof SdkError) { + if (error instanceof SdkError) { switch (error.code) { + case SdkErrorCode.ClientHttpAuthentication: + console.log('Auth failed — server rejected token after re-auth'); + break; case SdkErrorCode.ClientHttpForbidden: console.log('Forbidden after upscoping attempt'); break; diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index a8646adea..e714ad4d2 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -279,8 +279,13 @@ export class SSEClientTransport implements Transport { // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } if (this._authProvider) { - await response.text?.().catch(() => {}); throw new UnauthorizedError(); } } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index aad670c97..1213e418e 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -229,8 +229,13 @@ export class StreamableHTTPClientTransport implements Transport { // Purposely _not_ awaited, so we don't call onerror twice return this._startOrAuthSse(options); } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } if (this._authProvider) { - await response.text?.().catch(() => {}); throw new UnauthorizedError(); } } @@ -514,8 +519,13 @@ export class StreamableHTTPClientTransport implements Transport { // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } if (this._authProvider) { - await response.text?.().catch(() => {}); throw new UnauthorizedError(); } } diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 10fcd76bc..fd2b184cc 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -3,7 +3,7 @@ import { createServer } from 'node:http'; import type { AddressInfo } from 'node:net'; import type { JSONRPCMessage, OAuthTokens } from '@modelcontextprotocol/core'; -import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; +import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; @@ -1575,7 +1575,7 @@ describe('SSEClientTransport', () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); }); - it('enforces circuit breaker on double-401: onUnauthorized called once, then throws', async () => { + it('enforces circuit breaker on double-401: onUnauthorized called once, then throws SdkError', async () => { postResponses = [401, 401]; await setupServer(); @@ -1586,7 +1586,9 @@ describe('SSEClientTransport', () => { transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); await transport.start(); - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); expect(postCount).toBe(2); }); diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index ddd24608f..a0ae90b73 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -1678,7 +1678,9 @@ describe('StreamableHTTPClientTransport', () => { // Retry the original request - still 401 (broken server) .mockResolvedValueOnce(unauthedResponse); - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ access_token: 'new-access-token', token_type: 'Bearer', diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts index c683a4012..3ab6c9623 100644 --- a/packages/client/test/client/tokenProvider.test.ts +++ b/packages/client/test/client/tokenProvider.test.ts @@ -1,4 +1,5 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; import type { Mock } from 'vitest'; import type { AuthProvider } from '../../src/client/auth.js'; @@ -81,7 +82,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); }); - it('should throw UnauthorizedError if retry after onUnauthorized also gets 401', async () => { + it('should throw SdkError(ClientHttpAuthentication) if retry after onUnauthorized also gets 401', async () => { const authProvider: AuthProvider = { token: vi.fn(async () => 'still-bad'), onUnauthorized: vi.fn(async () => {}) @@ -93,7 +94,9 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }); - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); });