diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md new file mode 100644 index 000000000..f5c064e7f --- /dev/null +++ b/.changeset/token-provider-composable-auth.md @@ -0,0 +1,16 @@ +--- +'@modelcontextprotocol/client': minor +--- + +Add `AuthProvider` for composable bearer-token auth; transports adapt `OAuthClientProvider` automatically + +- New `AuthProvider` interface: `{ token(): Promise; onUnauthorized?(ctx): Promise }`. Transports call `token()` before every request and `onUnauthorized()` on 401 (then retry once). +- Transport `authProvider` option now accepts `AuthProvider | OAuthClientProvider`. OAuth providers are adapted internally via `adaptOAuthProvider()` — no changes needed to existing `OAuthClientProvider` implementations. +- For simple bearer tokens (API keys, gateway-managed tokens, service accounts): `{ authProvider: { token: async () => myKey } }` — one-line object literal, no class. +- New `adaptOAuthProvider(provider)` export for explicit adaptation. +- New `handleOAuthUnauthorized(provider, ctx)` helper — the standard OAuth `onUnauthorized` behavior. +- New `isOAuthClientProvider()` type guard. +- New `UnauthorizedContext` type. +- Exported previously-internal auth helpers for building custom flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. + +Transports are simplified internally — ~50 lines of inline OAuth orchestration (auth() calls, WWW-Authenticate parsing, circuit-breaker state) moved into the adapter's `onUnauthorized()` implementation. `OAuthClientProvider` itself is unchanged. diff --git a/docs/client.md b/docs/client.md index 782ab885b..b5086f531 100644 --- a/docs/client.md +++ b/docs/client.md @@ -13,7 +13,7 @@ A client connects to a server, discovers what it offers — tools, resources, pr The examples below use these imports. Adjust based on which features and transport you need: ```ts source="../examples/client/src/clientGuide.examples.ts#imports" -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -113,7 +113,19 @@ console.log(systemPrompt); ## Authentication -MCP servers can require OAuth 2.0 authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an `authProvider` to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport} to enable this — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. +MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an {@linkcode @modelcontextprotocol/client!client/auth.AuthProvider | AuthProvider} to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. The transport calls `token()` before every request and `onUnauthorized()` (if provided) on 401, then retries once. + +### Bearer tokens + +For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials — implement only `token()`. With no `onUnauthorized()`, a 401 throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} immediately: + +```ts source="../examples/client/src/clientGuide.examples.ts#auth_tokenProvider" +const authProvider: AuthProvider = { token: async () => getStoredToken() }; + +const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); +``` + +See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. ### Client credentials diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 9dffe4418..957a583df 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -47,15 +47,15 @@ Replace all `@modelcontextprotocol/sdk/...` imports using this table. ### Server imports -| v1 import path | v2 package | -| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | +| v1 import path | v2 package | +| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | | `@modelcontextprotocol/sdk/server/streamableHttp.js` | `@modelcontextprotocol/node` (class renamed to `NodeStreamableHTTPServerTransport`) OR `@modelcontextprotocol/server` (web-standard `WebStandardStreamableHTTPServerTransport` for Cloudflare Workers, Deno, etc.) | -| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | -| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | -| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | +| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | +| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | +| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | ### Types / shared imports @@ -107,19 +107,19 @@ Two error classes now exist: - **`ProtocolError`** (renamed from `McpError`): Protocol errors that cross the wire as JSON-RPC responses - **`SdkError`** (new): Local SDK errors that never cross the wire -| Error scenario | v1 type | v2 type | -| -------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | -| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | -| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | -| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | -| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | -| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | -| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | -| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | -| 401 after auth flow | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | -| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | -| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | -| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | +| Error scenario | v1 type | v2 type | +| --------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | +| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | +| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | +| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | +| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | +| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | +| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | +| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | +| 401 after re-auth (circuit break) | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | +| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | +| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | +| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | New `SdkErrorCode` enum values: @@ -203,7 +203,8 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; if (error instanceof OAuthError && error.code === OAuthErrorCode.InvalidClient) { ... } ``` -**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). +**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all +Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). ## 6. McpServer API Changes @@ -279,12 +280,12 @@ Note: the third argument (`metadata`) is required — pass `{}` if no metadata. ### Schema Migration Quick Reference -| v1 (raw shape) | v2 (Zod schema) | -|----------------|-----------------| -| `{ name: z.string() }` | `z.object({ name: z.string() })` | +| v1 (raw shape) | v2 (Zod schema) | +| ---------------------------------- | -------------------------------------------- | +| `{ name: z.string() }` | `z.object({ name: z.string() })` | | `{ count: z.number().optional() }` | `z.object({ count: z.number().optional() })` | -| `{}` (empty) | `z.object({})` | -| `undefined` (no schema) | `undefined` or omit the field | +| `{}` (empty) | `z.object({})` | +| `undefined` (no schema) | `undefined` or omit the field | ## 7. Headers API @@ -370,31 +371,31 @@ Request/notification params remain fully typed. Remove unused schema imports aft `RequestHandlerExtra` → structured context types with nested groups. Rename `extra` → `ctx` in all handler callbacks. -| v1 | v2 | -|----|-----| -| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | -| `extra` (param name) | `ctx` | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| v1 | v2 | +| -------------------------------- | -------------------------------------------------------------------------- | +| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | +| `extra` (param name) | `ctx` | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | `ServerContext` convenience methods (new in v2, no v1 equivalent): -| Method | Description | Replaces | -|--------|-------------|----------| -| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | -| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | -| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | +| Method | Description | Replaces | +| ---------------------------------------------- | ------------------------------------------------------ | ---------------------------------------------------- | +| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | +| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | +| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | ## 11. Schema parameter removed from `request()`, `send()`, and `callTool()` @@ -413,14 +414,14 @@ const elicit = await ctx.mcpReq.send({ method: 'elicitation/create', params: { . const tool = await client.callTool({ name: 'my-tool', arguments: {} }); ``` -| v1 call | v2 call | -|---------|---------| -| `client.request(req, ResultSchema)` | `client.request(req)` | -| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | -| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | -| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | -| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | -| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | +| v1 call | v2 call | +| ------------------------------------------------------------ | ---------------------------------- | +| `client.request(req, ResultSchema)` | `client.request(req)` | +| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | +| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | +| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | +| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | +| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResultSchema`, `ElicitResultSchema`, `CreateMessageResultSchema`, etc., when they were only used in `request()`/`send()`/`callTool()` calls. @@ -431,6 +432,7 @@ Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResu ## 13. Runtime-Specific JSON Schema Validators (Enhancement) The SDK now auto-selects the appropriate JSON Schema validator based on runtime: + - Node.js → `AjvJsonSchemaValidator` (no change from v1) - Cloudflare Workers (workerd) → `CfWorkerJsonSchemaValidator` (previously required manual config) @@ -438,9 +440,12 @@ The SDK now auto-selects the appropriate JSON Schema validator based on runtime: ```typescript // v1 (Cloudflare Workers): Required explicit validator -new McpServer({ name: 'server', version: '1.0.0' }, { - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() -}); +new McpServer( + { name: 'server', version: '1.0.0' }, + { + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() + } +); // v2 (Cloudflare Workers): Auto-selected, explicit config optional new McpServer({ name: 'server', version: '1.0.0' }, {}); diff --git a/docs/migration.md b/docs/migration.md index 59b2b50ed..0614b572e 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -252,6 +252,7 @@ server.registerTool('ping', { ``` This applies to: + - `inputSchema` in `registerTool()` - `outputSchema` in `registerTool()` - `argsSchema` in `registerPrompt()` @@ -339,25 +340,21 @@ Common method string replacements: ### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer take a schema parameter -The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas like `CallToolResultSchema` or `ElicitResultSchema` when making requests. +The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas +like `CallToolResultSchema` or `ElicitResultSchema` when making requests. **`client.request()` — Before (v1):** ```typescript import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, - CallToolResultSchema -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, CallToolResultSchema); ``` **After (v2):** ```typescript -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } } -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }); ``` **`ctx.mcpReq.send()` — Before (v1):** @@ -390,10 +387,7 @@ server.setRequestHandler('tools/call', async (request, ctx) => { ```typescript import { CompatibilityCallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.callTool( - { name: 'my-tool', arguments: {} }, - CompatibilityCallToolResultSchema -); +const result = await client.callTool({ name: 'my-tool', arguments: {} }, CompatibilityCallToolResultSchema); ``` **After (v2):** @@ -452,32 +446,32 @@ import { JSONRPCErrorResponse, ResourceTemplateReference, isJSONRPCErrorResponse The `RequestHandlerExtra` type has been replaced with a structured context type hierarchy using nested groups: -| v1 | v2 | -|----|-----| +| v1 | v2 | +| ---------------------------------------- | ---------------------------------------------------------------------- | | `RequestHandlerExtra` (flat, all fields) | `ServerContext` (server handlers) or `ClientContext` (client handlers) | -| `extra` parameter name | `ctx` parameter name | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| `extra` parameter name | `ctx` parameter name | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | **Before (v1):** ```typescript server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - const headers = extra.requestInfo?.headers; - const taskStore = extra.taskStore; - await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = extra.requestInfo?.headers; + const taskStore = extra.taskStore; + await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -485,10 +479,10 @@ server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - const headers = ctx.http?.req?.headers; - const taskStore = ctx.task?.store; - await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = ctx.http?.req?.headers; + const taskStore = ctx.task?.store; + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -504,22 +498,22 @@ Context fields are organized into 4 groups: ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - // Send a log message (respects client's log level filter) - await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); - - // Request client to sample an LLM - const samplingResult = await ctx.mcpReq.requestSampling({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100, - }); - - // Elicit user input via a form - const elicitResult = await ctx.mcpReq.elicitInput({ - message: 'Please provide details', - requestedSchema: { type: 'object', properties: { name: { type: 'string' } } }, - }); - - return { content: [{ type: 'text', text: 'done' }] }; + // Send a log message (respects client's log level filter) + await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); + + // Request client to sample an LLM + const samplingResult = await ctx.mcpReq.requestSampling({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); + + // Elicit user input via a form + const elicitResult = await ctx.mcpReq.elicitInput({ + message: 'Please provide details', + requestedSchema: { type: 'object', properties: { name: { type: 'string' } } } + }); + + return { content: [{ type: 'text', text: 'done' }] }; }); ``` @@ -581,21 +575,21 @@ try { The new `SdkErrorCode` enum contains string-valued codes for local SDK errors: -| Code | Description | -| ------------------------------------------------- | ------------------------------------------ | -| `SdkErrorCode.NotConnected` | Transport is not connected | -| `SdkErrorCode.AlreadyConnected` | Transport is already connected | -| `SdkErrorCode.NotInitialized` | Protocol is not initialized | -| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | -| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | -| `SdkErrorCode.ConnectionClosed` | Connection was closed | -| `SdkErrorCode.SendFailed` | Failed to send message | -| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | -| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after successful auth | -| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | -| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | -| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | -| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | +| Code | Description | +| ------------------------------------------------- | ------------------------------------------- | +| `SdkErrorCode.NotConnected` | Transport is not connected | +| `SdkErrorCode.AlreadyConnected` | Transport is already connected | +| `SdkErrorCode.NotInitialized` | Protocol is not initialized | +| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | +| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | +| `SdkErrorCode.ConnectionClosed` | Connection was closed | +| `SdkErrorCode.SendFailed` | Failed to send message | +| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | +| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after re-authentication | +| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | +| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | +| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | +| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | #### `StreamableHTTPError` removed @@ -626,7 +620,7 @@ try { if (error instanceof SdkError) { switch (error.code) { case SdkErrorCode.ClientHttpAuthentication: - console.log('Auth failed after completing auth flow'); + console.log('Auth failed — server rejected token after re-auth'); break; case SdkErrorCode.ClientHttpForbidden: console.log('Forbidden after upscoping attempt'); @@ -646,7 +640,8 @@ try { #### Why this change? -Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was semantically inconsistent. +Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was +semantically inconsistent. The new design: @@ -743,11 +738,11 @@ This means Cloudflare Workers users no longer need to explicitly pass the valida import { McpServer, CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { - capabilities: { tools: {} }, - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 - } + { name: 'my-server', version: '1.0.0' }, + { + capabilities: { tools: {} }, + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 + } ); ``` @@ -757,9 +752,9 @@ const server = new McpServer( import { McpServer } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { capabilities: { tools: {} } } - // Validator auto-selected based on runtime + { name: 'my-server', version: '1.0.0' }, + { capabilities: { tools: {} } } + // Validator auto-selected based on runtime ); ``` diff --git a/examples/client/src/clientGuide.examples.ts b/examples/client/src/clientGuide.examples.ts index 389059024..f07d272db 100644 --- a/examples/client/src/clientGuide.examples.ts +++ b/examples/client/src/clientGuide.examples.ts @@ -8,7 +8,7 @@ */ //#region imports -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -107,6 +107,16 @@ async function serverInstructions_basic(client: Client) { // Authentication // --------------------------------------------------------------------------- +/** Example: Minimal AuthProvider for bearer auth with externally-managed tokens. */ +async function auth_tokenProvider(getStoredToken: () => Promise) { + //#region auth_tokenProvider + const authProvider: AuthProvider = { token: async () => getStoredToken() }; + + const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); + //#endregion auth_tokenProvider + return transport; +} + /** Example: Client credentials auth for service-to-service communication. */ async function auth_clientCredentials() { //#region auth_clientCredentials @@ -540,6 +550,7 @@ void connect_stdio; void connect_sseFallback; void disconnect_streamableHttp; void serverInstructions_basic; +void auth_tokenProvider; void auth_clientCredentials; void auth_privateKeyJwt; void auth_crossAppAccess; diff --git a/examples/client/src/simpleTokenProvider.ts b/examples/client/src/simpleTokenProvider.ts new file mode 100644 index 000000000..7b5f1a4c1 --- /dev/null +++ b/examples/client/src/simpleTokenProvider.ts @@ -0,0 +1,56 @@ +#!/usr/bin/env node + +/** + * Example demonstrating the minimal AuthProvider for bearer token authentication. + * + * AuthProvider is the base interface for all client auth. For simple cases where + * tokens are managed externally — pre-configured API tokens, gateway/proxy patterns, + * or tokens obtained through a separate auth flow — implement only `token()`. + * + * For OAuth flows (client_credentials, private_key_jwt, etc.), use the built-in + * providers which implement both `token()` and `onUnauthorized()`. + * + * Environment variables: + * MCP_SERVER_URL - Server URL (default: http://localhost:3000/mcp) + * MCP_TOKEN - Bearer token to use for authentication (required) + */ + +import type { AuthProvider } from '@modelcontextprotocol/client'; +import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; + +const DEFAULT_SERVER_URL = process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'; + +async function main() { + const token = process.env.MCP_TOKEN; + if (!token) { + console.error('MCP_TOKEN environment variable is required'); + process.exit(1); + } + + // AuthProvider with just token() — the simplest possible auth. + // token() is called before every request, so it can handle refresh internally. + // With no onUnauthorized(), a 401 throws UnauthorizedError immediately. + const authProvider: AuthProvider = { + token: async () => token + }; + + const client = new Client({ name: 'auth-provider-example', version: '1.0.0' }, { capabilities: {} }); + + const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { authProvider }); + + await client.connect(transport); + console.log('Connected successfully.'); + + const tools = await client.listTools(); + console.log('Available tools:', tools.tools.map(t => t.name).join(', ') || '(none)'); + + await transport.close(); +} + +try { + await main(); +} catch (error) { + console.error('Error running client:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); +} diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index 58ec23ddd..d26bb5727 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -34,12 +34,105 @@ export type AddClientAuthentication = ( metadata?: AuthorizationServerMetadata ) => void | Promise; +/** + * Context passed to {@linkcode AuthProvider.onUnauthorized} when the server + * responds with 401. Provides everything needed to refresh credentials. + */ +export interface UnauthorizedContext { + /** The 401 response — inspect `WWW-Authenticate` for resource metadata, scope, etc. */ + response: Response; + /** The MCP server URL, for passing to {@linkcode auth} or discovery helpers. */ + serverUrl: URL; + /** Fetch function configured with the transport's `requestInit`, for making auth requests. */ + fetchFn: FetchLike; +} + +/** + * Minimal interface for authenticating MCP client transports with bearer tokens. + * + * Transports call {@linkcode AuthProvider.token | token()} before every request + * to obtain the current token, and {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * (if provided) when the server responds with 401, giving the provider a chance + * to refresh credentials before the transport retries once. + * + * For simple cases (API keys, gateway-managed tokens), implement only `token()`: + * ```typescript + * const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; + * ``` + * + * For OAuth flows, pass an {@linkcode OAuthClientProvider} directly — transports + * accept either shape and adapt OAuth providers automatically via {@linkcode adaptOAuthProvider}. + */ +export interface AuthProvider { + /** + * Returns the current bearer token, or `undefined` if no token is available. + * Called before every request. + */ + token(): Promise; + + /** + * Called when the server responds with 401. If provided, the transport will + * await this, then retry the request once. If the retry also gets 401, or if + * this method is not provided, the transport throws {@linkcode UnauthorizedError}. + * + * Implementations should refresh tokens, re-authenticate, etc. — whatever is + * needed so the next `token()` call returns a valid token. + */ + onUnauthorized?(ctx: UnauthorizedContext): Promise; +} + +/** + * Type guard distinguishing `OAuthClientProvider` from a minimal `AuthProvider`. + * Transports use this at construction time to classify the `authProvider` option. + */ +export function isOAuthClientProvider(provider: AuthProvider | OAuthClientProvider | undefined): provider is OAuthClientProvider { + return provider !== undefined && 'tokens' in provider && 'clientMetadata' in provider; +} + +/** + * Standard `onUnauthorized` behavior for OAuth providers: extracts + * `WWW-Authenticate` parameters from the 401 response and runs {@linkcode auth}. + * Used by {@linkcode adaptOAuthProvider} to bridge `OAuthClientProvider` to `AuthProvider`. + */ +export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx: UnauthorizedContext): Promise { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(ctx.response); + const result = await auth(provider, { + serverUrl: ctx.serverUrl, + resourceMetadataUrl, + scope, + fetchFn: ctx.fetchFn + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } +} + +/** + * Adapts an `OAuthClientProvider` to the minimal `AuthProvider` interface that + * transports consume. Called once at transport construction — the transport stores + * the adapted provider for `_commonHeaders()` and 401 handling, while keeping the + * original `OAuthClientProvider` for OAuth-specific paths (`finishAuth()`, 403 upscoping). + */ +export function adaptOAuthProvider(provider: OAuthClientProvider): AuthProvider { + return { + token: async () => { + const tokens = await provider.tokens(); + return tokens?.access_token; + }, + onUnauthorized: async ctx => handleOAuthUnauthorized(provider, ctx) + }; +} + /** * Implements an end-to-end OAuth client to be used with one MCP server. * * This client relies upon a concept of an authorized "session," the exact * meaning of which is application-defined. Tokens, authorization codes, and * code verifiers should not cross different sessions. + * + * Transports accept `OAuthClientProvider` directly via the `authProvider` option — + * they adapt it to {@linkcode AuthProvider} internally via {@linkcode adaptOAuthProvider}. + * No changes are needed to existing implementations. */ export interface OAuthClientProvider { /** @@ -381,7 +474,7 @@ export function applyClientAuthentication( /** * Applies HTTP Basic authentication (RFC 6749 Section 2.3.1) */ -function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { +export function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { if (!clientSecret) { throw new Error('client_secret_basic authentication requires a client_secret'); } @@ -393,7 +486,7 @@ function applyBasicAuth(clientId: string, clientSecret: string | undefined, head /** * Applies POST body authentication (RFC 6749 Section 2.3.1) */ -function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { +export function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { params.set('client_id', clientId); if (clientSecret) { params.set('client_secret', clientSecret); @@ -403,7 +496,7 @@ function applyPostAuth(clientId: string, clientSecret: string | undefined, param /** * Applies public client authentication (RFC 6749 Section 2.1) */ -function applyPublicAuth(clientId: string, params: URLSearchParams): void { +export function applyPublicAuth(clientId: string, params: URLSearchParams): void { params.set('client_id', clientId); } @@ -1265,7 +1358,7 @@ export function prepareAuthorizationCodeRequest( * Internal helper to execute a token request with the given parameters. * Used by {@linkcode exchangeAuthorization}, {@linkcode refreshAuthorization}, and {@linkcode fetchToken}. */ -async function executeTokenRequest( +export async function executeTokenRequest( authorizationServerUrl: string | URL, { metadata, diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 133aa0004..e714ad4d2 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -3,8 +3,8 @@ import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, SdkError, import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; export class SseError extends Error { constructor( @@ -23,18 +23,19 @@ export type SSEClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the SSE connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode SSEClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode SSEClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation. + * Interactive flows: after {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode SSEClientTransport.finishAuth | finishAuth} with the authorization code before reconnecting. */ - authProvider?: OAuthClientProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes the initial SSE request to the server (the request that begins the stream). @@ -71,7 +72,8 @@ export class SSEClientTransport implements Transport { private _scope?: string; private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; + private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -86,43 +88,24 @@ export class SSEClientTransport implements Transport { this._scope = undefined; this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuth(); - } + private _authRetryInFlight = false; + private _last401Response?: Response; private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._protocolVersion) { headers['mcp-protocol-version'] = this._protocolVersion; @@ -149,10 +132,13 @@ export class SSEClientTransport implements Transport { headers }); - if (response.status === 401 && response.headers.has('www-authenticate')) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + if (response.status === 401) { + this._last401Response = response; + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } } return response; @@ -162,7 +148,21 @@ export class SSEClientTransport implements Transport { this._eventSource.onerror = event => { if (event.code === 401 && this._authProvider) { - this._authThenStart().then(resolve, reject); + if (this._authProvider.onUnauthorized && this._last401Response && !this._authRetryInFlight) { + this._authRetryInFlight = true; + const response = this._last401Response; + this._authProvider + .onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }) + .then(() => this._startOrAuth()) + .then(resolve, reject) + .finally(() => { + this._authRetryInFlight = false; + }); + return; + } + const error = new UnauthorizedError(); + reject(error); + this.onerror?.(error); return; } @@ -221,11 +221,11 @@ export class SSEClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!this._oauthProvider) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, @@ -261,33 +261,45 @@ export class SSEClientTransport implements Transport { const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { - const text = await response.text?.().catch(() => null); - - if (response.status === 401 && this._authProvider) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + if (response.status === 401) { + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } + if (this._authProvider) { throw new UnauthorizedError(); } - - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); } + const text = await response.text?.().catch(() => null); throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); } + this._authRetryInFlight = false; + // Release connection - POST responses don't have content we need await response.text?.().catch(() => {}); } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index dab9b37ab..1213e418e 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -13,8 +13,8 @@ import { } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -85,18 +85,21 @@ export type StreamableHTTPClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode StreamableHTTPClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode StreamableHTTPClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation + * directly — the transport adapts it to `AuthProvider` internally. Interactive flows: after + * {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization code before + * reconnecting. */ - authProvider?: OAuthClientProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes HTTP requests to the server. @@ -131,13 +134,14 @@ export class StreamableHTTPClientTransport implements Transport { private _resourceMetadataUrl?: URL; private _scope?: string; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; + private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; - private _hasCompletedAuthFlow = false; // Circuit breaker: detect auth success followed by immediate 401 + private _authRetryInFlight = false; // Circuit breaker: single retry per operation on 401 private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; @@ -151,45 +155,23 @@ export class StreamableHTTPClientTransport implements Transport { this._resourceMetadataUrl = undefined; this._scope = undefined; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuthSse({ resumptionToken: undefined }); - } - private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._sessionId) { @@ -229,13 +211,37 @@ export class StreamableHTTPClientTransport implements Transport { }); if (!response.ok) { - await response.text?.().catch(() => {}); + if (response.status === 401) { + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } - if (response.status === 401 && this._authProvider) { - // Need to authenticate - return await this._authThenStart(); + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this._startOrAuthSse(options); + } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } + if (this._authProvider) { + throw new UnauthorizedError(); + } } + await response.text?.().catch(() => {}); + // 405 indicates that the server does not offer an SSE stream at GET endpoint // This is an expected case that should not trigger an error if (response.status === 405) { @@ -248,8 +254,10 @@ export class StreamableHTTPClientTransport implements Transport { }); } + this._authRetryInFlight = false; this._handleSseStream(response.body, options, true); } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } @@ -431,11 +439,11 @@ export class StreamableHTTPClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!this._oauthProvider) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, @@ -492,38 +500,39 @@ export class StreamableHTTPClientTransport implements Transport { } if (!response.ok) { - const text = await response.text?.().catch(() => null); + if (response.status === 401) { + // Store WWW-Authenticate params for interactive finishAuth() path + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } - if (response.status === 401 && this._authProvider) { - // Prevent infinite recursion when server returns 401 after successful auth - if (this._hasCompletedAuthFlow) { - throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after successful authentication', { - status: 401, - text + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); } - - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; - - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } + if (this._authProvider) { throw new UnauthorizedError(); } - - // Mark that we completed auth flow - this._hasCompletedAuthFlow = true; - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); } - if (response.status === 403 && this._authProvider) { + const text = await response.text?.().catch(() => null); + + if (response.status === 403 && this._oauthProvider) { const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); if (error === 'insufficient_scope') { @@ -547,11 +556,11 @@ export class StreamableHTTPClientTransport implements Transport { // Mark that upscoping was tried. this._lastUpscopingHeader = wwwAuthHeader ?? undefined; - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, scope: this._scope, - fetchFn: this._fetch + fetchFn: this._fetchWithInit }); if (result !== 'AUTHORIZED') { @@ -569,7 +578,7 @@ export class StreamableHTTPClientTransport implements Transport { } // Reset auth loop flag on successful response - this._hasCompletedAuthFlow = false; + this._authRetryInFlight = false; this._lastUpscopingHeader = undefined; // If the response is 202 Accepted, there's no body to process @@ -619,6 +628,7 @@ export class StreamableHTTPClientTransport implements Transport { await response.text?.().catch(() => {}); } } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 0b0aff67b..fd2b184cc 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -3,11 +3,11 @@ import { createServer } from 'node:http'; import type { AddressInfo } from 'node:net'; import type { JSONRPCMessage, OAuthTokens } from '@modelcontextprotocol/core'; -import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; +import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; -import type { OAuthClientProvider } from '../../src/client/auth.js'; +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; import { SSEClientTransport } from '../../src/client/sse.js'; @@ -1528,4 +1528,101 @@ describe('SSEClientTransport', () => { expect(globalFetchSpy).not.toHaveBeenCalled(); }); }); + + describe('minimal AuthProvider (non-OAuth)', () => { + let postResponses: number[]; + let postCount: number; + + async function setupServer(): Promise { + await resourceServer.close(); + + postCount = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.method === 'GET') { + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}post\n\n`); + return; + } + + if (req.method === 'POST') { + const status = postResponses[postCount] ?? 200; + postCount++; + res.writeHead(status).end(); + return; + } + }); + + resourceBaseUrl = await listenOnRandomPort(resourceServer); + } + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: '1' }; + + it('throws UnauthorizedError on POST 401 when onUnauthorized is not provided', async () => { + postResponses = [401]; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + }); + + it('enforces circuit breaker on double-401: onUnauthorized called once, then throws SdkError', async () => { + postResponses = [401, 401]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(postCount).toBe(2); + }); + + it('resets retry guard when onUnauthorized throws, allowing retry on next send', async () => { + postResponses = [401, 401, 200]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + // First send: 401 → onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + + // Second send: flag should be reset, so 401 → onUnauthorized (succeeds) → retry → 200 + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(postCount).toBe(3); + }); + + it('throws when finishAuth is called with a non-OAuth AuthProvider', async () => { + postResponses = []; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + }); }); diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 0398964d3..a0ae90b73 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -1678,7 +1678,9 @@ describe('StreamableHTTPClientTransport', () => { // Retry the original request - still 401 (broken server) .mockResolvedValueOnce(unauthedResponse); - await expect(transport.send(message)).rejects.toThrow('Server returned 401 after successful authentication'); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ access_token: 'new-access-token', token_type: 'Bearer', diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts new file mode 100644 index 000000000..3ab6c9623 --- /dev/null +++ b/packages/client/test/client/tokenProvider.test.ts @@ -0,0 +1,182 @@ +import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; +import type { Mock } from 'vitest'; + +import type { AuthProvider } from '../../src/client/auth.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; +import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; + +describe('StreamableHTTPClientTransport with AuthProvider', () => { + let transport: StreamableHTTPClientTransport; + + afterEach(async () => { + await transport?.close().catch(() => {}); + vi.clearAllMocks(); + }); + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: 'test-id' }; + + it('should set Authorization header from AuthProvider.token()', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'my-bearer-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + expect(authProvider.token).toHaveBeenCalled(); + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.get('Authorization')).toBe('Bearer my-bearer-token'); + }); + + it('should not set Authorization header when token() returns undefined', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => undefined) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); + + it('should throw UnauthorizedError on 401 when onUnauthorized is not provided', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'rejected-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(authProvider.token).toHaveBeenCalledTimes(1); + }); + + it('should call onUnauthorized and retry once on 401', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); + + it('should throw SdkError(ClientHttpAuthentication) if retry after onUnauthorized also gets 401', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }); + + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + }); + + it('should reset retry guard when onUnauthorized throws, allowing retry on next send', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + // First send: onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + + // Second send: flag should be reset, so onUnauthorized gets a second chance + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + }); + + it('should work with no authProvider at all', async () => { + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); + + it('should throw when finishAuth is called with a non-OAuth AuthProvider', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + + it('should throw UnauthorizedError on GET-SSE 401 with no onUnauthorized (via resumeStream)', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.resumeStream('last-event-id')).rejects.toThrow(UnauthorizedError); + }); + + it('should call onUnauthorized and retry on GET-SSE 401 (via resumeStream)', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + // First GET: 401. Second GET (retry): 405 (server doesn't offer SSE — clean exit) + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 405, headers: new Headers(), text: async () => '' }); + + await transport.resumeStream('last-event-id'); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); +});