diff --git a/README.md b/README.md index 0cdd70c..6f8e319 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,35 @@ This approach automatically uses the latest version without requiring global ins MCP_HTTP_MODE=true SOCKET_API_KEY=your-api-key npx @socketsecurity/mcp@latest --http ``` + HTTP mode supports these environment variables: + + | Variable | Required | Default | Description | + |---|---|---|---| + | `SOCKET_API_KEY` | Required unless OAuth is enabled | None | Socket API key used for outbound API calls. If unset in OAuth-enabled HTTP mode, the validated incoming bearer token is forwarded upstream instead. | + | `SOCKET_OAUTH_ISSUER` | Set together with the two introspection vars to enable OAuth | None | OAuth issuer URL used for metadata discovery and incoming bearer-token validation. | + | `SOCKET_OAUTH_INTROSPECTION_CLIENT_ID` | With OAuth | None | Client ID used for token introspection. | + | `SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET` | With OAuth | None | Client secret used for token introspection. | + | `SOCKET_OAUTH_REQUIRED_SCOPES` | No | `packages:list` | Space-delimited scopes required on incoming access tokens. | + | `SOCKET_API_URL` | No | Production Socket API URL, or localhost when `SOCKET_DEBUG=true` | Override the upstream Socket API endpoint. Useful for local development and testing. | + | `SOCKET_DEBUG` | No | `false` | Switches the default upstream Socket API endpoint to localhost when `SOCKET_API_URL` is unset. | + | `TRUST_PROXY` | No | `false` | When `true`, trust `X-Forwarded-Host` and `X-Forwarded-Proto` when building OAuth metadata URLs. Enable only behind a trusted reverse proxy that rewrites these headers. | + | `MCP_PORT` | HTTP mode only | `3000` | Port to bind the HTTP server to. | + + `SOCKET_API_URL` and `SOCKET_DEBUG` also apply in stdio mode. + In OAuth-enabled HTTP mode, if `SOCKET_API_KEY` is unset, the authenticated client's bearer token is forwarded to the Socket API. That token therefore must also be accepted by the configured upstream Socket API. + + To enable OAuth-backed auth for incoming MCP requests: + + ```bash + MCP_HTTP_MODE=true \ + SOCKET_OAUTH_ISSUER=https://issuer.example.com \ + SOCKET_OAUTH_INTROSPECTION_CLIENT_ID=your-client-id \ + SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET=your-client-secret \ + npx @socketsecurity/mcp@latest --http + ``` + + Add `TRUST_PROXY=true` only when the server is deployed behind a trusted reverse proxy or load balancer that normalizes the forwarded host and protocol headers. + 2. Configure your MCP client to connect to the HTTP server: ```json { diff --git a/index.ts b/index.ts index f979f50..50b6bba 100755 --- a/index.ts +++ b/index.ts @@ -1,5 +1,6 @@ #!/usr/bin/env -S node --experimental-strip-types import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import type { AuthInfo } from '@modelcontextprotocol/sdk/server/auth/types.js' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js' import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' @@ -11,7 +12,7 @@ import readline from 'readline' import { join } from 'path' import { readFileSync } from 'fs' import { tmpdir } from 'os' -import { createServer } from 'http' +import { createServer, type IncomingMessage, type ServerResponse } from 'http' const __dirname = import.meta.dirname @@ -38,10 +39,36 @@ const logger = pino({ } }) +interface OAuthAuthorizationServerMetadata { + issuer: string + authorization_endpoint: string + token_endpoint: string + introspection_endpoint: string + [key: string]: unknown +} + +type AuthenticatedRequest = IncomingMessage & { auth?: AuthInfo } + // Socket API URL - use localhost when debugging is enabled, otherwise use production -const SOCKET_API_URL = process.env['SOCKET_DEBUG'] === 'true' +const DEFAULT_SOCKET_API_URL = process.env['SOCKET_DEBUG'] === 'true' ? 'http://localhost:8866/v0/purl?alerts=false&compact=false&fixable=false&licenseattrib=false&licensedetails=false' : 'https://api.socket.dev/v0/purl?alerts=false&compact=false&fixable=false&licenseattrib=false&licensedetails=false' +const SOCKET_API_URL = process.env['SOCKET_API_URL'] || DEFAULT_SOCKET_API_URL +const SOCKET_OAUTH_ISSUER = process.env['SOCKET_OAUTH_ISSUER'] || '' +const SOCKET_OAUTH_INTROSPECTION_CLIENT_ID = + process.env['SOCKET_OAUTH_INTROSPECTION_CLIENT_ID'] || '' +const SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET = + process.env['SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET'] || '' +const SOCKET_OAUTH_REQUIRED_SCOPES = ( + process.env['SOCKET_OAUTH_REQUIRED_SCOPES'] || 'packages:list' +) + .split(/\s+/u) + .map(scope => scope.trim()) + .filter(Boolean) +const TRUST_PROXY = process.env['TRUST_PROXY'] === 'true' +const OAUTH_WELL_KNOWN_PATH = '/.well-known/oauth-authorization-server' +const OAUTH_PROTECTED_RESOURCE_METADATA_PATH = + '/.well-known/oauth-protected-resource' // Function to get API key interactively (only for HTTP mode) async function getApiKeyInteractively (): Promise { @@ -68,13 +95,288 @@ async function getApiKeyInteractively (): Promise { // Initialize API key let SOCKET_API_KEY = process.env['SOCKET_API_KEY'] || '' -// Build headers dynamically to reflect current API key -function buildSocketHeaders (): Record { +// Build Socket API request headers with the provided access token. +function buildSocketHeaders (accessToken?: string): Record { return { 'user-agent': `socket-mcp/${VERSION}`, accept: 'application/x-ndjson', 'content-type': 'application/json', - authorization: `Bearer ${SOCKET_API_KEY}` + ...(accessToken ? { authorization: `Bearer ${accessToken}` } : {}) + } +} + +function splitScopes (scope: unknown): string[] { + if (typeof scope !== 'string') { + return [] + } + + return scope + .split(/\s+/u) + .map(value => value.trim()) + .filter(Boolean) +} + +function getRequestHeaderValue (header: string | string[] | undefined): string { + if (Array.isArray(header)) { + return header[0] || '' + } + + return header || '' +} + +function getForwardedHeaderValue (header: string | string[] | undefined): string { + return getRequestHeaderValue(header) + .split(',', 1)[0] + ?.trim() || '' +} + +function getRequestBaseUrl (req: IncomingMessage, fallbackPort: number): URL { + const forwardedProto = TRUST_PROXY + ? getForwardedHeaderValue(req.headers['x-forwarded-proto']).toLowerCase() + : '' + const forwardedHost = TRUST_PROXY + ? getForwardedHeaderValue(req.headers['x-forwarded-host']) + : '' + const host = forwardedHost || getRequestHeaderValue(req.headers.host).trim() || `localhost:${fallbackPort}` + const socketWithTls = req.socket as { encrypted?: boolean } + const protocol = forwardedProto === 'https' || forwardedProto === 'http' + ? forwardedProto + : (socketWithTls.encrypted ? 'https' : 'http') + + return new URL(`${protocol}://${host}/`) +} + +function parseJsonObject ( + responseText: string, + context: string +): Record { + try { + const parsed = JSON.parse(responseText) + + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error('expected a JSON object') + } + + return parsed as Record + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + throw new Error(`${context} returned invalid JSON: ${message}`) + } +} + +function getProtectedResourceMetadataUrl (baseUrl: URL): string { + return new URL(OAUTH_PROTECTED_RESOURCE_METADATA_PATH, baseUrl).href +} + +function buildProtectedResourceMetadata ( + baseUrl: URL, + oauthMetadata: OAuthAuthorizationServerMetadata +): Record { + return { + resource: new URL('/', baseUrl).href, + authorization_servers: [oauthMetadata.issuer], + scopes_supported: SOCKET_OAUTH_REQUIRED_SCOPES, + resource_name: 'Socket MCP Server' + } +} + +function writeJson ( + res: ServerResponse, + statusCode: number, + body: unknown, + headers: Record = {} +): void { + res.writeHead(statusCode, { + 'Content-Type': 'application/json', + ...headers + }) + res.end(JSON.stringify(body)) +} + +function writeOAuthError ( + res: ServerResponse, + statusCode: number, + errorCode: string, + message: string, + resourceMetadataUrl?: string +): void { + const authenticateValue = resourceMetadataUrl + ? `Bearer error="${errorCode}", error_description="${message}", resource_metadata="${resourceMetadataUrl}"` + : `Bearer error="${errorCode}", error_description="${message}"` + + writeJson( + res, + statusCode, + { + error: errorCode, + error_description: message + }, + { 'WWW-Authenticate': authenticateValue } + ) +} + +const useHttp = process.env['MCP_HTTP_MODE'] === 'true' || process.argv.includes('--http') +const port = parseInt(process.env['MCP_PORT'] || '3000', 10) +const hasAnyOAuthConfig = Boolean( + SOCKET_OAUTH_ISSUER || + SOCKET_OAUTH_INTROSPECTION_CLIENT_ID || + SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET +) +const oauthEnabled = useHttp && Boolean( + SOCKET_OAUTH_ISSUER && + SOCKET_OAUTH_INTROSPECTION_CLIENT_ID && + SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET +) + +let oauthMetadataPromise: Promise | undefined + +async function loadOAuthMetadata (): Promise { + if (!oauthEnabled) { + return null + } + + if (!oauthMetadataPromise) { + const metadataPromise = (async () => { + const issuerUrl = new URL(SOCKET_OAUTH_ISSUER) + const response = await fetch(new URL(OAUTH_WELL_KNOWN_PATH, issuerUrl)) + const responseText = await response.text() + + if (!response.ok) { + throw new Error(`OAuth metadata discovery failed with status ${response.status}: ${responseText}`) + } + + const metadata = parseJsonObject(responseText, 'OAuth metadata discovery') + + for (const field of [ + 'issuer', + 'authorization_endpoint', + 'token_endpoint', + 'introspection_endpoint' + ] as const) { + if (typeof metadata[field] !== 'string' || !metadata[field]) { + throw new Error(`OAuth metadata missing required field: ${field}`) + } + } + + return metadata as OAuthAuthorizationServerMetadata + })() + + const retryableMetadataPromise = metadataPromise.catch((error) => { + if (oauthMetadataPromise === retryableMetadataPromise) { + oauthMetadataPromise = undefined + } + + throw error + }) + + oauthMetadataPromise = retryableMetadataPromise + } + + return await oauthMetadataPromise +} + +async function verifyAccessToken (token: string): Promise { + const oauthMetadata = await loadOAuthMetadata() + if (!oauthMetadata) { + throw new Error('OAuth is not configured for this server') + } + + const response = await fetch(oauthMetadata.introspection_endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + authorization: `Basic ${Buffer.from(`${SOCKET_OAUTH_INTROSPECTION_CLIENT_ID}:${SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET}`).toString('base64')}` + }, + body: new URLSearchParams({ token }).toString() + }) + const responseText = await response.text() + + if (!response.ok) { + throw new Error(`Token introspection failed with status ${response.status}: ${responseText}`) + } + + const introspection = parseJsonObject(responseText, 'Token introspection') + if (!introspection['active']) { + return null + } + + const expiresAt = typeof introspection['exp'] === 'number' + ? introspection['exp'] + : Number(introspection['exp']) + + return { + token, + clientId: typeof introspection['client_id'] === 'string' + ? introspection['client_id'] + : 'unknown', + scopes: splitScopes(introspection['scope']), + ...(Number.isFinite(expiresAt) ? { expiresAt } : {}), + extra: introspection + } +} + +async function authenticateRequest ( + req: AuthenticatedRequest, + res: ServerResponse, + resourceMetadataUrl: string +): Promise<{ ok: false } | { ok: true, authInfo: AuthInfo }> { + const authHeader = getRequestHeaderValue(req.headers.authorization).trim() + if (!authHeader) { + writeOAuthError(res, 401, 'invalid_request', 'Missing Authorization header', resourceMetadataUrl) + return { ok: false } + } + + const [type, token] = authHeader.split(/\s+/u) + if ((type || '').toLowerCase() !== 'bearer' || !token) { + writeOAuthError( + res, + 401, + 'invalid_request', + "Invalid Authorization header format, expected 'Bearer TOKEN'", + resourceMetadataUrl + ) + return { ok: false } + } + + let authInfo: AuthInfo | null + try { + authInfo = await verifyAccessToken(token) + } catch (error) { + logger.error(`Token verification failed: ${error instanceof Error ? error.message : String(error)}`) + writeJson(res, 500, { + error: 'server_error', + error_description: 'Token verification failed' + }) + return { ok: false } + } + + if (!authInfo) { + writeOAuthError(res, 401, 'invalid_token', 'Invalid or expired token', resourceMetadataUrl) + return { ok: false } + } + + if (typeof authInfo.expiresAt === 'number' && + authInfo.expiresAt < Date.now() / 1000) { + writeOAuthError(res, 401, 'invalid_token', 'Token has expired', resourceMetadataUrl) + return { ok: false } + } + + const missingScopes = SOCKET_OAUTH_REQUIRED_SCOPES.filter(scope => !authInfo.scopes.includes(scope)) + if (missingScopes.length > 0) { + writeOAuthError( + res, + 403, + 'insufficient_scope', + `Missing required scopes: ${missingScopes.join(', ')}`, + resourceMetadataUrl + ) + return { ok: false } + } + + req.auth = authInfo + return { + ok: true, + authInfo } } @@ -97,8 +399,17 @@ function createConfiguredServer (): McpServer { readOnlyHint: true, }, }, - async ({ packages }) => { + async ({ packages }, extra) => { logger.info(`Received request for ${packages.length} packages`) + const accessToken = extra.authInfo?.token || SOCKET_API_KEY + if (!accessToken) { + const errorMsg = 'Authentication is required. Configure SOCKET_API_KEY for stdio mode or connect through OAuth-enabled HTTP mode.' + logger.error(errorMsg) + return { + content: [{ type: 'text', text: errorMsg }], + isError: true + } + } // Build components array for the API request const components = packages.map((pkg: { ecosystem?: string; depname: string; version?: string }) => { @@ -118,12 +429,30 @@ function createConfiguredServer (): McpServer { // Make a POST request to the Socket API with all packages const response = await fetch(SOCKET_API_URL, { method: 'POST', - headers: buildSocketHeaders(), + headers: buildSocketHeaders(accessToken), body: JSON.stringify({ components }) }) const responseText = await response.text() + if (response.status === 401) { + const errorMsg = `Socket authentication failed [401]. Re-authenticate and retry. ${responseText}` + logger.error(errorMsg) + return { + content: [{ type: 'text', text: errorMsg }], + isError: true + } + } + + if (response.status === 403) { + const errorMsg = `Socket denied access [403]. Re-authenticate with the correct organization or repository permissions and retry. ${responseText}` + logger.error(errorMsg) + return { + content: [{ type: 'text', text: errorMsg }], + isError: true + } + } + if (response.status !== 200) { const errorMsg = `Error processing packages: [${response.status}] ${responseText}` logger.error(errorMsg) @@ -227,12 +556,13 @@ function createConfiguredServer (): McpServer { return srv } -// Determine transport mode from environment or arguments -const useHttp = process.env['MCP_HTTP_MODE'] === 'true' || process.argv.includes('--http') -const port = parseInt(process.env['MCP_PORT'] || '3000', 10) +if (useHttp && hasAnyOAuthConfig && !oauthEnabled) { + logger.error('Incomplete OAuth configuration for HTTP mode. Set SOCKET_OAUTH_ISSUER, SOCKET_OAUTH_INTROSPECTION_CLIENT_ID, and SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET together.') + process.exit(1) +} // Validate API key - in stdio mode, we can't prompt interactively -if (!SOCKET_API_KEY) { +if (!SOCKET_API_KEY && !(useHttp && oauthEnabled)) { if (useHttp) { // In HTTP mode, we can prompt for the API key logger.error('SOCKET_API_KEY environment variable is not set') @@ -245,6 +575,16 @@ if (!SOCKET_API_KEY) { } } +if (oauthEnabled) { + try { + await loadOAuthMetadata() + logger.info(`Enabled OAuth-backed MCP auth with issuer ${SOCKET_OAUTH_ISSUER}`) + } catch (error) { + logger.error(`Failed to initialize OAuth metadata: ${error instanceof Error ? error.message : String(error)}`) + process.exit(1) + } +} + if (useHttp) { // HTTP mode with Server-Sent Events logger.info(`Starting HTTP server on port ${port}`) @@ -279,41 +619,41 @@ if (useHttp) { reapInterval.unref() // don't keep the process alive just for the reaper const httpServer = createServer(async (req, res) => { + const authenticatedReq = req as AuthenticatedRequest + // Parse URL first to check for health endpoint let url: URL try { url = new URL(req.url!, `http://localhost:${port}`) } catch (error) { logger.warn(`Invalid URL in request: ${req.url} - ${error}`) - res.writeHead(400, { 'Content-Type': 'application/json' }) - res.end(JSON.stringify({ + writeJson(res, 400, { jsonrpc: '2.0', error: { code: -32000, message: 'Bad Request: Invalid URL' }, id: null - })) + }) return } // Health check endpoint for K8s/Docker - bypass origin validation if (url.pathname === '/health') { - res.writeHead(200, { 'Content-Type': 'application/json' }) - res.end(JSON.stringify({ + writeJson(res, 200, { status: 'healthy', service: 'socket-mcp', version: VERSION, timestamp: new Date().toISOString() - })) + }) return } // Validate Origin header as required by MCP spec (for non-health endpoints) - const origin = req.headers.origin + const origin = getRequestHeaderValue(req.headers.origin).trim() // Check if origin is from localhost (any port) - safe for local development const isLocalhostOrigin = (originUrl: string): boolean => { try { - const url = new URL(originUrl) - return url.hostname === 'localhost' || url.hostname === '127.0.0.1' + const originValue = new URL(originUrl) + return originValue.hostname === 'localhost' || originValue.hostname === '127.0.0.1' } catch { return false } @@ -326,7 +666,7 @@ if (useHttp) { // Check if request is from localhost (for same-origin requests that don't send Origin header) // Use strict matching to prevent spoofing via subdomains like "malicious-localhost.evil.com" - const host = req.headers.host || '' + const host = getRequestHeaderValue(req.headers.host).trim() // Extract hostnames from allowedOrigins for Host header validation const allowedHosts = allowedOrigins.map(o => new URL(o).hostname) @@ -346,12 +686,11 @@ if (useHttp) { if (!isValidOrigin) { logger.warn(`Rejected request from invalid origin: ${origin || 'missing'} (host: ${host})`) - res.writeHead(403, { 'Content-Type': 'application/json' }) - res.end(JSON.stringify({ + writeJson(res, 403, { jsonrpc: '2.0', error: { code: -32000, message: 'Forbidden: Invalid origin' }, id: null - })) + }) return } @@ -360,8 +699,8 @@ if (useHttp) { if (origin) { res.setHeader('Access-Control-Allow-Origin', origin) res.setHeader('Access-Control-Allow-Methods', 'GET, POST, DELETE, OPTIONS') - res.setHeader('Access-Control-Allow-Headers', 'Content-Type, Accept, Mcp-Session-Id') - res.setHeader('Access-Control-Expose-Headers', 'Mcp-Session-Id') + res.setHeader('Access-Control-Allow-Headers', 'Authorization, Content-Type, Accept, Mcp-Session-Id') + res.setHeader('Access-Control-Expose-Headers', 'Mcp-Session-Id, WWW-Authenticate') } if (req.method === 'OPTIONS') { @@ -370,6 +709,21 @@ if (useHttp) { return } + const baseUrl = getRequestBaseUrl(req, port) + if (oauthEnabled && url.pathname === OAUTH_PROTECTED_RESOURCE_METADATA_PATH) { + const oauthMetadata = await loadOAuthMetadata() + if (!oauthMetadata) { + writeJson(res, 500, { + error: 'server_error', + error_description: 'OAuth metadata is unavailable' + }) + return + } + + writeJson(res, 200, buildProtectedResourceMetadata(baseUrl, oauthMetadata)) + return + } + if (url.pathname === '/') { // Ensure Accept header includes required MIME types for MCP Streamable HTTP spec. // Some clients (e.g. Cursor) may not send these, causing the SDK to reject with 406. @@ -386,6 +740,18 @@ if (useHttp) { } } + if (oauthEnabled) { + const authResult = await authenticateRequest( + authenticatedReq, + res, + getProtectedResourceMetadataUrl(baseUrl) + ) + + if (!authResult.ok) { + return + } + } + if (req.method === 'POST') { // Buffer the body, then pass it as parsedBody so hono doesn't re-read the consumed stream. let body = '' @@ -393,7 +759,7 @@ if (useHttp) { req.on('end', async () => { try { const jsonData = JSON.parse(body) - const sessionId = (req.headers['mcp-session-id'] as string) || undefined + const sessionId = getRequestHeaderValue(req.headers['mcp-session-id']) || undefined const session = sessionId ? sessions.get(sessionId) : undefined let transport = session?.transport @@ -419,12 +785,11 @@ if (useHttp) { } if (!transport) { - res.writeHead(400, { 'Content-Type': 'application/json' }) - res.end(JSON.stringify({ + writeJson(res, 400, { jsonrpc: '2.0', error: { code: -32000, message: 'Bad Request: No valid session. Send initialize first.' }, id: null - })) + }) return } @@ -434,68 +799,63 @@ if (useHttp) { if (activeSession) activeSession.lastActivity = Date.now() } - await transport.handleRequest(req, res, jsonData) + await transport.handleRequest(authenticatedReq, res, jsonData) } catch (error) { logger.error(`Error processing POST request: ${error}`) if (!res.headersSent) { - res.writeHead(500) - res.end(JSON.stringify({ + writeJson(res, 500, { jsonrpc: '2.0', error: { code: -32603, message: 'Internal server error' }, id: null - })) + }) } } }) } else if (req.method === 'GET') { - const sessionId = (req.headers['mcp-session-id'] as string) || undefined + const sessionId = getRequestHeaderValue(req.headers['mcp-session-id']) || undefined const session = sessionId ? sessions.get(sessionId) : undefined if (!session) { - res.writeHead(404, { 'Content-Type': 'application/json' }) - res.end(JSON.stringify({ + writeJson(res, 404, { jsonrpc: '2.0', error: { code: -32000, message: 'Not Found: Invalid or expired session. Re-initialize.' }, id: null - })) + }) return } try { session.lastActivity = Date.now() - await session.transport.handleRequest(req, res) + await session.transport.handleRequest(authenticatedReq, res) } catch (error) { logger.error(`Error processing GET request: ${error}`) if (!res.headersSent) { - res.writeHead(500) - res.end(JSON.stringify({ + writeJson(res, 500, { jsonrpc: '2.0', error: { code: -32603, message: 'Internal server error' }, id: null - })) + }) } } } else if (req.method === 'DELETE') { - const sessionId = (req.headers['mcp-session-id'] as string) || undefined + const sessionId = getRequestHeaderValue(req.headers['mcp-session-id']) || undefined const transport = sessionId ? sessions.get(sessionId)?.transport : undefined if (!transport) { - res.writeHead(404, { 'Content-Type': 'application/json' }) - res.end(JSON.stringify({ + writeJson(res, 404, { jsonrpc: '2.0', error: { code: -32000, message: 'Not Found: Invalid or expired session.' }, id: null - })) + }) return } try { - await transport.handleRequest(req, res) + await transport.handleRequest(authenticatedReq, res) } catch (error) { logger.error(`Error processing DELETE request: ${error}`) if (!res.headersSent) { - res.writeHead(500) - res.end(JSON.stringify({ + writeJson(res, 500, { jsonrpc: '2.0', error: { code: -32603, message: 'Internal server error' }, id: null - })) + }) } } } else { diff --git a/oauth.test.ts b/oauth.test.ts new file mode 100644 index 0000000..774fb72 --- /dev/null +++ b/oauth.test.ts @@ -0,0 +1,396 @@ +#!/usr/bin/env node +import { test } from 'node:test' +import assert from 'node:assert/strict' +import { once } from 'node:events' +import { createServer, type IncomingMessage, type ServerResponse } from 'node:http' +import type { AddressInfo } from 'node:net' +import { spawn } from 'node:child_process' +import { setTimeout as delay } from 'node:timers/promises' +import { join } from 'node:path' +import { Client } from '@modelcontextprotocol/sdk/client/index.js' +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' + +const serverPath = join(import.meta.dirname, 'index.ts') +const inheritedEnv = Object.fromEntries( + Object.entries(process.env).filter(([, value]) => value !== undefined) +) as Record +const oauthWellKnownPath = '/.well-known/oauth-authorization-server' +const protectedResourceMetadataPath = '/.well-known/oauth-protected-resource' +const mockIntrospectionResponses: Record> = { + 'token-without-exp': { + active: true, + client_id: 'oauth-test-client', + scope: 'packages:list' + }, + 'token-with-wrong-scope': { + active: true, + client_id: 'oauth-test-client', + scope: 'packages:write' + } +} + +function closeHttpServer (server: ReturnType): Promise { + return new Promise((resolve, reject) => { + server.close((error) => { + if (error) { + reject(error) + return + } + + resolve() + }) + }) +} + +async function getFreePort (): Promise { + const server = createServer() + server.listen(0, '127.0.0.1') + await once(server, 'listening') + const address = server.address() as AddressInfo + await closeHttpServer(server) + return address.port +} + +async function readRequestBody (req: IncomingMessage): Promise { + let body = '' + for await (const chunk of req) { + body += chunk + } + + return body +} + +async function assertOAuthErrorResponse ( + response: Response, + serverBaseUrl: string, + expected: { + status: number + error: string + errorDescription: string + } +): Promise { + const body = await response.json() as { + error?: string + error_description?: string + } + + assert.equal(response.status, expected.status) + assert.equal(body.error, expected.error) + assert.equal(body.error_description, expected.errorDescription) + assert.equal( + response.headers.get('www-authenticate'), + `Bearer error="${expected.error}", error_description="${expected.errorDescription}", resource_metadata="${serverBaseUrl}${protectedResourceMetadataPath}"` + ) +} + +async function startMockIssuer (): Promise<{ + baseUrl: string + close: () => Promise +}> { + const server = createServer(async (req: IncomingMessage, res: ServerResponse) => { + const baseUrl = `http://${req.headers.host}` + const url = new URL(req.url || '/', baseUrl) + + if (req.method === 'GET' && url.pathname === oauthWellKnownPath) { + res.writeHead(200, { 'Content-Type': 'application/json' }) + res.end(JSON.stringify({ + issuer: baseUrl, + authorization_endpoint: `${baseUrl}/authorize`, + token_endpoint: `${baseUrl}/token`, + introspection_endpoint: `${baseUrl}/introspect` + })) + return + } + + if (req.method === 'POST' && url.pathname === '/introspect') { + const body = await readRequestBody(req) + const token = new URLSearchParams(body).get('token') + const introspectionResponse = token ? mockIntrospectionResponses[token] : undefined + + res.writeHead(200, { 'Content-Type': 'application/json' }) + res.end(JSON.stringify(introspectionResponse || { active: false })) + return + } + + res.writeHead(404) + res.end('Not found') + }) + + server.listen(0, '127.0.0.1') + await once(server, 'listening') + const address = server.address() as AddressInfo + + return { + baseUrl: `http://127.0.0.1:${address.port}`, + close: async () => { await closeHttpServer(server) } + } +} + +async function stopChildProcess (child: ReturnType): Promise { + if (child.exitCode !== null) { + return + } + + let exited = false + const onExit = once(child, 'exit').then(() => { + exited = true + }) + + child.kill('SIGTERM') + await Promise.race([ + onExit, + delay(3000).then(() => { + if (exited || child.exitCode !== null) { + return + } + + child.kill('SIGKILL') + return onExit + }) + ]) +} + +async function waitForHealth ( + baseUrl: string, + child: ReturnType, + getOutput: () => string +): Promise { + const timeoutAt = Date.now() + 5000 + + while (Date.now() < timeoutAt) { + if (child.exitCode !== null) { + throw new Error(`HTTP server exited before becoming ready:\n${getOutput()}`) + } + + try { + const response = await fetch(`${baseUrl}/health`) + if (response.ok) { + return + } + } catch {} + + await delay(100) + } + + throw new Error(`Timed out waiting for HTTP server readiness:\n${getOutput()}`) +} + +async function startOAuthHttpServer ( + issuerBaseUrl: string, + extraEnv: Record = {} +): Promise<{ + baseUrl: string + close: () => Promise +}> { + const port = await getFreePort() + let output = '' + + const child = spawn('node', ['--experimental-strip-types', serverPath], { + cwd: import.meta.dirname, + env: { + ...inheritedEnv, + MCP_HTTP_MODE: 'true', + MCP_PORT: String(port), + SOCKET_OAUTH_ISSUER: issuerBaseUrl, + SOCKET_OAUTH_INTROSPECTION_CLIENT_ID: 'oauth-test-client-id', + SOCKET_OAUTH_INTROSPECTION_CLIENT_SECRET: 'oauth-test-client-secret', + ...extraEnv + }, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + child.stdout.on('data', (chunk: Buffer) => { output += chunk.toString() }) + child.stderr.on('data', (chunk: Buffer) => { output += chunk.toString() }) + + const baseUrl = `http://127.0.0.1:${port}` + await waitForHealth(baseUrl, child, () => output) + + return { + baseUrl, + close: async () => { await stopChildProcess(child) } + } +} + +test('stdio mode ignores partial OAuth config', async (t) => { + const transport = new StdioClientTransport({ + command: 'node', + args: ['--experimental-strip-types', serverPath], + env: { + ...inheritedEnv, + SOCKET_API_KEY: 'test-api-key', + SOCKET_OAUTH_ISSUER: 'https://issuer.example.test' + } + }) + + const client = new Client( + { name: 'oauth-stdio-test-client', version: '1.0.0' }, + { capabilities: {} } + ) + + t.after(async () => { + await client.close().catch(() => {}) + }) + + await client.connect(transport) + const tools = await client.listTools() + assert.ok(tools.tools.some(tool => tool.name === 'depscore')) +}) + +test('HTTP OAuth metadata and auth semantics', async (t) => { + const issuer = await startMockIssuer() + const server = await startOAuthHttpServer(issuer.baseUrl) + + t.after(async () => { + await server.close() + await issuer.close() + }) + + await t.test('does not expose upstream authorization server metadata', async () => { + const response = await fetch(`${server.baseUrl}${oauthWellKnownPath}`) + + assert.equal(response.status, 404) + assert.match(await response.text(), /not found/i) + }) + + await t.test('serves protected resource metadata pointing to the issuer', async () => { + const response = await fetch(`${server.baseUrl}${protectedResourceMetadataPath}`) + const metadata = await response.json() as { + authorization_servers?: string[] + resource?: string + } + + assert.equal(response.status, 200) + assert.deepEqual(metadata.authorization_servers, [issuer.baseUrl]) + assert.equal(metadata.resource, `${server.baseUrl}/`) + }) + + await t.test('ignores forwarded host and proto headers unless TRUST_PROXY is enabled', async () => { + const response = await fetch(`${server.baseUrl}${protectedResourceMetadataPath}`, { + headers: { + 'X-Forwarded-Host': 'attacker.example.com', + 'X-Forwarded-Proto': 'https' + } + }) + const metadata = await response.json() as { resource?: string } + + assert.equal(response.status, 200) + assert.equal(metadata.resource, `${server.baseUrl}/`) + + const unauthenticatedResponse = await fetch(`${server.baseUrl}/`, { + method: 'POST', + headers: { + 'X-Forwarded-Host': 'attacker.example.com', + 'X-Forwarded-Proto': 'https' + } + }) + + await assertOAuthErrorResponse(unauthenticatedResponse, server.baseUrl, { + status: 401, + error: 'invalid_request', + errorDescription: 'Missing Authorization header' + }) + }) + + await t.test('returns invalid_request when the Authorization header is missing', async () => { + const response = await fetch(`${server.baseUrl}/`, { method: 'POST' }) + await assertOAuthErrorResponse(response, server.baseUrl, { + status: 401, + error: 'invalid_request', + errorDescription: 'Missing Authorization header' + }) + }) + + await t.test('returns invalid_token when introspection reports an inactive token', async () => { + const response = await fetch(`${server.baseUrl}/`, { + method: 'POST', + headers: { + Authorization: 'Bearer inactive-token' + } + }) + + await assertOAuthErrorResponse(response, server.baseUrl, { + status: 401, + error: 'invalid_token', + errorDescription: 'Invalid or expired token' + }) + }) + + await t.test('returns insufficient_scope when the token is missing required scopes', async () => { + const response = await fetch(`${server.baseUrl}/`, { + method: 'POST', + headers: { + Authorization: 'Bearer token-with-wrong-scope' + } + }) + + await assertOAuthErrorResponse(response, server.baseUrl, { + status: 403, + error: 'insufficient_scope', + errorDescription: 'Missing required scopes: packages:list' + }) + }) + + await t.test('accepts an active token even when introspection omits exp', async () => { + const response = await fetch(`${server.baseUrl}/`, { + method: 'POST', + headers: { + Authorization: 'Bearer token-without-exp', + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '0.1.0', + capabilities: {}, + clientInfo: { + name: 'oauth-http-test-client', + version: '1.0.0' + } + } + }) + }) + + const body = await response.json() as { result?: { serverInfo?: { name?: string } } } + + assert.equal(response.status, 200) + assert.equal(body.result?.serverInfo?.name, 'socket') + }) +}) + +test('TRUST_PROXY enables forwarded host and proto for OAuth metadata URLs', async (t) => { + const issuer = await startMockIssuer() + const server = await startOAuthHttpServer(issuer.baseUrl, { TRUST_PROXY: 'true' }) + + t.after(async () => { + await server.close() + await issuer.close() + }) + + const response = await fetch(`${server.baseUrl}${protectedResourceMetadataPath}`, { + headers: { + 'X-Forwarded-Host': 'proxy.example.com', + 'X-Forwarded-Proto': 'https' + } + }) + const metadata = await response.json() as { resource?: string } + + assert.equal(response.status, 200) + assert.equal(metadata.resource, 'https://proxy.example.com/') + + const unauthenticatedResponse = await fetch(`${server.baseUrl}/`, { + method: 'POST', + headers: { + 'X-Forwarded-Host': 'proxy.example.com', + 'X-Forwarded-Proto': 'https' + } + }) + + await assertOAuthErrorResponse(unauthenticatedResponse, 'https://proxy.example.com', { + status: 401, + error: 'invalid_request', + errorDescription: 'Missing Authorization header' + }) +})