Skip to content

Commit 4598067

Browse files
committed
feat(mcp): OAuth 2.1 + PKCE for outbound MCP servers
Adds spec-compliant OAuth support for MCP servers that require it (Linear, Slack, Notion, Atlassian, etc.) using the SDK's OAuthClientProvider. Tokens are persisted per-user-per-server and refreshed automatically. Also supports pre-registered OAuth clients for servers that don't expose Dynamic Client Registration.
1 parent 9eeb1b2 commit 4598067

26 files changed

Lines changed: 16644 additions & 117 deletions

File tree

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
2+
import { db } from '@sim/db'
3+
import { mcpServers } from '@sim/db/schema'
4+
import { createLogger } from '@sim/logger'
5+
import { toError } from '@sim/utils/errors'
6+
import { eq } from 'drizzle-orm'
7+
import type { NextRequest } from 'next/server'
8+
import { NextResponse } from 'next/server'
9+
import { getSession } from '@/lib/auth'
10+
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
11+
import {
12+
clearState,
13+
clearVerifier,
14+
loadOauthRowByState,
15+
loadPreregisteredClient,
16+
SimMcpOauthProvider,
17+
} from '@/lib/mcp/oauth'
18+
import { mcpService } from '@/lib/mcp/service'
19+
20+
const logger = createLogger('McpOauthCallbackAPI')
21+
22+
export const dynamic = 'force-dynamic'
23+
24+
function escapeHtml(value: string): string {
25+
return value
26+
.replace(/&/g, '&')
27+
.replace(/</g, '&lt;')
28+
.replace(/>/g, '&gt;')
29+
.replace(/"/g, '&quot;')
30+
.replace(/'/g, '&#39;')
31+
}
32+
33+
function htmlClose(message: string, ok: boolean): NextResponse {
34+
const safeMessage = escapeHtml(message)
35+
const title = ok ? 'Connected' : 'Connection failed'
36+
const body = `<!doctype html><html><head><meta charset="utf-8"><title>${title}</title></head><body style="font-family: system-ui; padding: 24px"><p>${safeMessage}</p><script>
37+
try { window.opener && window.opener.postMessage({ type: 'mcp-oauth', ok: ${ok ? 'true' : 'false'} }, window.location.origin) } catch (e) {}
38+
setTimeout(function () { window.close() }, 800)
39+
</script></body></html>`
40+
return new NextResponse(body, {
41+
headers: { 'Content-Type': 'text/html; charset=utf-8' },
42+
})
43+
}
44+
45+
export const GET = withRouteHandler(async (request: NextRequest) => {
46+
const url = new URL(request.url)
47+
const state = url.searchParams.get('state')
48+
const code = url.searchParams.get('code')
49+
const errorParam = url.searchParams.get('error')
50+
51+
if (errorParam) {
52+
logger.warn(`MCP OAuth callback received error: ${errorParam}`)
53+
return htmlClose(`Authorization failed: ${errorParam}`, false)
54+
}
55+
if (!state || !code) {
56+
return htmlClose('Missing state or code in callback URL.', false)
57+
}
58+
59+
try {
60+
const session = await getSession()
61+
if (!session?.user?.id) {
62+
return htmlClose('You must be signed in to complete authorization.', false)
63+
}
64+
65+
const row = await loadOauthRowByState(state)
66+
if (!row) {
67+
return htmlClose('Invalid or expired authorization state.', false)
68+
}
69+
70+
if (session.user.id !== row.userId) {
71+
return htmlClose('You must be signed in as the same user that initiated the flow.', false)
72+
}
73+
74+
const [server] = await db
75+
.select({ id: mcpServers.id, url: mcpServers.url, workspaceId: mcpServers.workspaceId })
76+
.from(mcpServers)
77+
.where(eq(mcpServers.id, row.mcpServerId))
78+
.limit(1)
79+
if (!server || !server.url) {
80+
return htmlClose('Server no longer exists.', false)
81+
}
82+
83+
// Burn state before token exchange so a replayed callback cannot reuse it.
84+
await clearState(row.id)
85+
86+
const preregistered = await loadPreregisteredClient(server.id)
87+
const provider = new SimMcpOauthProvider({ row, preregistered })
88+
const result = await mcpAuth(provider, {
89+
serverUrl: server.url,
90+
authorizationCode: code,
91+
})
92+
93+
await clearVerifier(row.id)
94+
95+
if (result !== 'AUTHORIZED') {
96+
return htmlClose('Authorization did not complete.', false)
97+
}
98+
99+
try {
100+
await mcpService.clearCache(server.workspaceId)
101+
await mcpService.discoverServerTools(row.userId, server.id, server.workspaceId)
102+
} catch (e) {
103+
logger.warn('Post-auth tools refresh failed', toError(e).message)
104+
}
105+
106+
return htmlClose('Connected. You can close this window.', true)
107+
} catch (error) {
108+
logger.error('MCP OAuth callback failed', error)
109+
return htmlClose(`Authorization failed: ${toError(error).message}`, false)
110+
}
111+
})
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
2+
import { db } from '@sim/db'
3+
import { mcpServers } from '@sim/db/schema'
4+
import { createLogger } from '@sim/logger'
5+
import { toError } from '@sim/utils/errors'
6+
import { and, eq, isNull } from 'drizzle-orm'
7+
import type { NextRequest } from 'next/server'
8+
import { NextResponse } from 'next/server'
9+
import { startMcpOauthQuerySchema } from '@/lib/api/contracts/mcp'
10+
import { validationErrorResponse } from '@/lib/api/server'
11+
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
12+
import { withMcpAuth } from '@/lib/mcp/middleware'
13+
import {
14+
getOrCreateOauthRow,
15+
loadPreregisteredClient,
16+
McpOauthRedirectRequired,
17+
SimMcpOauthProvider,
18+
} from '@/lib/mcp/oauth'
19+
import { createMcpErrorResponse } from '@/lib/mcp/utils'
20+
21+
const logger = createLogger('McpOauthStartAPI')
22+
23+
export const dynamic = 'force-dynamic'
24+
25+
export const GET = withRouteHandler(
26+
withMcpAuth('write')(async (request: NextRequest, { userId, workspaceId, requestId }) => {
27+
try {
28+
const queryResult = startMcpOauthQuerySchema.safeParse(
29+
Object.fromEntries(new URL(request.url).searchParams)
30+
)
31+
if (!queryResult.success) {
32+
return validationErrorResponse(queryResult.error)
33+
}
34+
const { serverId } = queryResult.data
35+
36+
const [server] = await db
37+
.select()
38+
.from(mcpServers)
39+
.where(
40+
and(
41+
eq(mcpServers.id, serverId),
42+
eq(mcpServers.workspaceId, workspaceId),
43+
isNull(mcpServers.deletedAt)
44+
)
45+
)
46+
.limit(1)
47+
48+
if (!server) {
49+
return createMcpErrorResponse(new Error('Server not found'), 'Server not found', 404)
50+
}
51+
if (server.authType !== 'oauth') {
52+
return createMcpErrorResponse(
53+
new Error(`Server authType is "${server.authType}", not oauth`),
54+
'Server is not configured for OAuth',
55+
400
56+
)
57+
}
58+
if (!server.url) {
59+
return createMcpErrorResponse(new Error('Server has no URL'), 'Missing server URL', 400)
60+
}
61+
62+
const row = await getOrCreateOauthRow({
63+
mcpServerId: server.id,
64+
userId,
65+
workspaceId,
66+
})
67+
const preregistered = await loadPreregisteredClient(server.id)
68+
const provider = new SimMcpOauthProvider({ row, preregistered })
69+
70+
try {
71+
const result = await mcpAuth(provider, { serverUrl: server.url })
72+
if (result === 'AUTHORIZED') {
73+
return NextResponse.json({ status: 'already_authorized' })
74+
}
75+
return createMcpErrorResponse(
76+
new Error('Provider did not capture redirect URL'),
77+
'Failed to start OAuth flow',
78+
500
79+
)
80+
} catch (e) {
81+
if (e instanceof McpOauthRedirectRequired) {
82+
logger.info(`[${requestId}] OAuth redirect for server ${serverId}`)
83+
return NextResponse.json({
84+
status: 'redirect',
85+
authorizationUrl: e.authorizationUrl,
86+
})
87+
}
88+
throw e
89+
}
90+
} catch (error) {
91+
logger.error(`[${requestId}] Error starting MCP OAuth flow:`, error)
92+
return createMcpErrorResponse(toError(error), 'Failed to start OAuth flow', 500)
93+
}
94+
})
95+
)

apps/sim/app/api/mcp/servers/[id]/route.ts

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit'
22
import { db } from '@sim/db'
3-
import { mcpServers } from '@sim/db/schema'
3+
import { mcpServerOauth, mcpServers } from '@sim/db/schema'
44
import { createLogger } from '@sim/logger'
55
import { toError } from '@sim/utils/errors'
66
import { and, eq, isNull } from 'drizzle-orm'
77
import type { NextRequest } from 'next/server'
88
import { updateMcpServerBodySchema } from '@/lib/api/contracts/mcp'
9+
import { encryptSecret } from '@/lib/core/security/encryption'
910
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
1011
import {
1112
McpDnsResolutionError,
@@ -53,7 +54,13 @@ export const PATCH = withRouteHandler(
5354
)
5455

5556
// Remove workspaceId from body to prevent it from being updated
56-
const { workspaceId: _, ...updateData } = body
57+
const { workspaceId: _, oauthClientSecret, ...updateData } = body
58+
const finalUpdateData: Record<string, unknown> = { ...updateData }
59+
if (oauthClientSecret !== undefined) {
60+
finalUpdateData.oauthClientSecret = oauthClientSecret
61+
? (await encryptSecret(oauthClientSecret)).encrypted
62+
: null
63+
}
5764

5865
if (updateData.url) {
5966
try {
@@ -94,7 +101,7 @@ export const PATCH = withRouteHandler(
94101
const [updatedServer] = await db
95102
.update(mcpServers)
96103
.set({
97-
...updateData,
104+
...finalUpdateData,
98105
updatedAt: new Date(),
99106
})
100107
.where(
@@ -114,8 +121,17 @@ export const PATCH = withRouteHandler(
114121
)
115122
}
116123

124+
const urlChanged = body.url !== undefined && currentServer?.url !== body.url
125+
126+
if (urlChanged) {
127+
await db.delete(mcpServerOauth).where(eq(mcpServerOauth.mcpServerId, serverId))
128+
logger.info(
129+
`[${requestId}] Cleared OAuth credentials for server ${serverId} due to URL change`
130+
)
131+
}
132+
117133
const shouldClearCache =
118-
(body.url !== undefined && currentServer?.url !== body.url) ||
134+
urlChanged ||
119135
body.enabled !== undefined ||
120136
body.headers !== undefined ||
121137
body.timeout !== undefined ||
@@ -149,7 +165,8 @@ export const PATCH = withRouteHandler(
149165
request,
150166
})
151167

152-
return createMcpSuccessResponse({ server: updatedServer })
168+
const { oauthClientSecret: _secret, ...safeServer } = updatedServer
169+
return createMcpSuccessResponse({ server: safeServer })
153170
} catch (error) {
154171
logger.error(`[${requestId}] Error updating MCP server:`, error)
155172
return createMcpErrorResponse(toError(error), 'Failed to update MCP server', 500)

apps/sim/app/api/mcp/servers/route.ts

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { and, eq, isNull } from 'drizzle-orm'
88
import type { NextRequest } from 'next/server'
99
import { createMcpServerBodySchema, deleteMcpServerByQuerySchema } from '@/lib/api/contracts/mcp'
1010
import { validationErrorResponse } from '@/lib/api/server'
11+
import { encryptSecret } from '@/lib/core/security/encryption'
1112
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
1213
import {
1314
McpDnsResolutionError,
@@ -17,6 +18,7 @@ import {
1718
validateMcpServerSsrf,
1819
} from '@/lib/mcp/domain-check'
1920
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
21+
import { detectMcpAuthType } from '@/lib/mcp/oauth'
2022
import { mcpService } from '@/lib/mcp/service'
2123
import {
2224
createMcpErrorResponse,
@@ -37,11 +39,16 @@ export const GET = withRouteHandler(
3739
try {
3840
logger.info(`[${requestId}] Listing MCP servers for workspace ${workspaceId}`)
3941

40-
const servers = await db
42+
const rows = await db
4143
.select()
4244
.from(mcpServers)
4345
.where(and(eq(mcpServers.workspaceId, workspaceId), isNull(mcpServers.deletedAt)))
4446

47+
const servers = rows.map(({ oauthClientSecret: _secret, ...rest }) => ({
48+
...rest,
49+
hasOauthClientSecret: !!_secret,
50+
}))
51+
4552
logger.info(
4653
`[${requestId}] Listed ${servers.length} MCP servers for workspace ${workspaceId}`
4754
)
@@ -105,6 +112,25 @@ export const POST = withRouteHandler(
105112

106113
const serverId = body.url ? generateMcpServerId(workspaceId, body.url) : generateId()
107114

115+
let resolvedAuthType: 'none' | 'headers' | 'oauth' = body.authType ?? 'headers'
116+
if (!body.authType && body.url && !body.headers) {
117+
try {
118+
resolvedAuthType = await detectMcpAuthType(body.url)
119+
logger.info(`[${requestId}] Probed ${body.url}: authType=${resolvedAuthType}`)
120+
} catch (e) {
121+
logger.warn(`[${requestId}] Probe failed for ${body.url}, defaulting to headers`, e)
122+
resolvedAuthType = 'headers'
123+
}
124+
}
125+
126+
// User-supplied client credentials imply OAuth; pin authType regardless of probe.
127+
if (body.oauthClientId) resolvedAuthType = 'oauth'
128+
129+
const oauthClientSecretEncrypted = body.oauthClientSecret
130+
? (await encryptSecret(body.oauthClientSecret)).encrypted
131+
: null
132+
const oauthClientId = body.oauthClientId || null
133+
108134
const [existingServer] = await db
109135
.select({ id: mcpServers.id, deletedAt: mcpServers.deletedAt })
110136
.from(mcpServers)
@@ -123,12 +149,15 @@ export const POST = withRouteHandler(
123149
description: body.description,
124150
transport: body.transport,
125151
url: body.url,
152+
authType: resolvedAuthType,
153+
oauthClientId,
154+
oauthClientSecret: oauthClientSecretEncrypted,
126155
headers: body.headers || {},
127156
timeout: body.timeout || 30000,
128157
retries: body.retries || 3,
129158
enabled: body.enabled !== false,
130-
connectionStatus: 'connected',
131-
lastConnected: new Date(),
159+
connectionStatus: resolvedAuthType === 'oauth' ? 'disconnected' : 'connected',
160+
lastConnected: resolvedAuthType === 'oauth' ? null : new Date(),
132161
updatedAt: new Date(),
133162
deletedAt: null,
134163
})
@@ -140,7 +169,10 @@ export const POST = withRouteHandler(
140169
`[${requestId}] Successfully updated MCP server: ${body.name} (ID: ${serverId})`
141170
)
142171

143-
return createMcpSuccessResponse({ serverId, updated: true }, 200)
172+
return createMcpSuccessResponse(
173+
{ serverId, updated: true, authType: resolvedAuthType },
174+
200
175+
)
144176
}
145177

146178
await db
@@ -153,12 +185,15 @@ export const POST = withRouteHandler(
153185
description: body.description,
154186
transport: body.transport,
155187
url: body.url,
188+
authType: resolvedAuthType,
189+
oauthClientId,
190+
oauthClientSecret: oauthClientSecretEncrypted,
156191
headers: body.headers || {},
157192
timeout: body.timeout || 30000,
158193
retries: body.retries || 3,
159194
enabled: body.enabled !== false,
160-
connectionStatus: 'connected',
161-
lastConnected: new Date(),
195+
connectionStatus: resolvedAuthType === 'oauth' ? 'disconnected' : 'connected',
196+
lastConnected: resolvedAuthType === 'oauth' ? null : new Date(),
162197
createdAt: new Date(),
163198
updatedAt: new Date(),
164199
})
@@ -217,7 +252,7 @@ export const POST = withRouteHandler(
217252
request,
218253
})
219254

220-
return createMcpSuccessResponse({ serverId }, 201)
255+
return createMcpSuccessResponse({ serverId, authType: resolvedAuthType }, 201)
221256
} catch (error) {
222257
logger.error(`[${requestId}] Error registering MCP server:`, error)
223258
return createMcpErrorResponse(toError(error), 'Failed to register MCP server', 500)

0 commit comments

Comments
 (0)