diff --git a/apps/docs/content/docs/en/enterprise/data-drains.mdx b/apps/docs/content/docs/en/enterprise/data-drains.mdx index c1e5ced5876..620d6a13027 100644 --- a/apps/docs/content/docs/en/enterprise/data-drains.mdx +++ b/apps/docs/content/docs/en/enterprise/data-drains.mdx @@ -15,6 +15,10 @@ Drains are independent of [Data Retention](/enterprise/data-retention) but desig Go to **Settings → Enterprise → Data Drains** in your workspace, then click **New drain**. +![Data Drains settings page showing two configured drains — one exporting workflow logs to Amazon S3 daily, another exporting Copilot chats to an HTTPS webhook hourly](/static/enterprise/data-drains-list.png) + +![New data drain dialog with fields for name, source, cadence, destination, and S3 credentials](/static/enterprise/data-drains-new.png) + Each drain has four pieces: 1. A **source** — the category of data to export diff --git a/apps/docs/public/static/enterprise/data-drains-list.png b/apps/docs/public/static/enterprise/data-drains-list.png new file mode 100644 index 00000000000..b18af7cade5 Binary files /dev/null and b/apps/docs/public/static/enterprise/data-drains-list.png differ diff --git a/apps/docs/public/static/enterprise/data-drains-new.png b/apps/docs/public/static/enterprise/data-drains-new.png new file mode 100644 index 00000000000..4d85d0fc682 Binary files /dev/null and b/apps/docs/public/static/enterprise/data-drains-new.png differ diff --git a/apps/sim/app/api/mcp/oauth/callback/route.ts b/apps/sim/app/api/mcp/oauth/callback/route.ts new file mode 100644 index 00000000000..8621b200481 --- /dev/null +++ b/apps/sim/app/api/mcp/oauth/callback/route.ts @@ -0,0 +1,123 @@ +import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js' +import { db } from '@sim/db' +import { mcpServers } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { and, eq, isNull } from 'drizzle-orm' +import type { NextRequest } from 'next/server' +import { NextResponse } from 'next/server' +import { getSession } from '@/lib/auth' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { + clearState, + clearVerifier, + loadOauthRowByState, + loadPreregisteredClient, + SimMcpOauthProvider, +} from '@/lib/mcp/oauth' +import { mcpService } from '@/lib/mcp/service' + +const logger = createLogger('McpOauthCallbackAPI') + +export const dynamic = 'force-dynamic' + +function escapeHtml(value: string): string { + return value + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, ''') +} + +function htmlClose(message: string, ok: boolean, serverId?: string): NextResponse { + const safeMessage = escapeHtml(message) + const title = ok ? 'Connected' : 'Connection failed' + const serverIdLiteral = serverId + ? JSON.stringify(serverId).replace(//g, '\\u003e') + : 'undefined' + const body = `${title}

${safeMessage}

` + return new NextResponse(body, { + headers: { 'Content-Type': 'text/html; charset=utf-8' }, + }) +} + +export const GET = withRouteHandler(async (request: NextRequest) => { + const url = new URL(request.url) + const state = url.searchParams.get('state') + const code = url.searchParams.get('code') + const errorParam = url.searchParams.get('error') + + if (errorParam) { + logger.warn(`MCP OAuth callback received error: ${errorParam}`) + return htmlClose(`Authorization failed: ${errorParam}`, false) + } + if (!state || !code) { + return htmlClose('Missing state or code in callback URL.', false) + } + + let serverId: string | undefined + try { + const session = await getSession() + if (!session?.user?.id) { + return htmlClose('You must be signed in to complete authorization.', false) + } + + const row = await loadOauthRowByState(state) + if (!row) { + return htmlClose('Invalid or expired authorization state.', false) + } + serverId = row.mcpServerId + + if (session.user.id !== row.userId) { + return htmlClose( + 'You must be signed in as the same user that initiated the flow.', + false, + serverId + ) + } + + const [server] = await db + .select({ id: mcpServers.id, url: mcpServers.url, workspaceId: mcpServers.workspaceId }) + .from(mcpServers) + .where(and(eq(mcpServers.id, row.mcpServerId), isNull(mcpServers.deletedAt))) + .limit(1) + if (!server || !server.url) { + return htmlClose('Server no longer exists.', false, serverId) + } + + // Burn state before token exchange so a replayed callback cannot reuse it. + await clearState(row.id) + + const preregistered = await loadPreregisteredClient(server.id) + const provider = new SimMcpOauthProvider({ row, preregistered }) + let result: Awaited> + try { + result = await mcpAuth(provider, { + serverUrl: server.url, + authorizationCode: code, + }) + } finally { + await clearVerifier(row.id) + } + + if (result !== 'AUTHORIZED') { + return htmlClose('Authorization did not complete.', false, server.id) + } + + try { + await mcpService.clearCache(server.workspaceId) + await mcpService.discoverServerTools(session.user.id, server.id, server.workspaceId) + } catch (e) { + logger.warn('Post-auth tools refresh failed', toError(e).message) + } + + return htmlClose('Connected. You can close this window.', true, server.id) + } catch (error) { + logger.error('MCP OAuth callback failed', error) + return htmlClose('Authorization failed. Please try again.', false, serverId) + } +}) diff --git a/apps/sim/app/api/mcp/oauth/start/route.test.ts b/apps/sim/app/api/mcp/oauth/start/route.test.ts new file mode 100644 index 00000000000..ec88a06831e --- /dev/null +++ b/apps/sim/app/api/mcp/oauth/start/route.test.ts @@ -0,0 +1,152 @@ +/** + * @vitest-environment node + */ +import { + dbChainMock, + dbChainMockFns, + hybridAuthMock, + hybridAuthMockFns, + permissionsMock, + permissionsMockFns, + resetDbChainMock, + schemaMock, +} from '@sim/testing' +import { NextRequest } from 'next/server' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockMcpAuth, + mockGetOrCreateOauthRow, + mockLoadPreregisteredClient, + mockSetOauthRowUser, + MockMcpOauthRedirectRequired, +} = vi.hoisted(() => ({ + mockMcpAuth: vi.fn(), + mockGetOrCreateOauthRow: vi.fn(), + mockLoadPreregisteredClient: vi.fn(), + mockSetOauthRowUser: vi.fn(), + MockMcpOauthRedirectRequired: class MockMcpOauthRedirectRequired extends Error { + constructor(public readonly authorizationUrl: string) { + super('redirect required') + } + }, +})) + +vi.mock('@sim/db', () => dbChainMock) +vi.mock('@sim/db/schema', () => schemaMock) +vi.mock('drizzle-orm', () => ({ + and: vi.fn(), + eq: vi.fn(), + isNull: vi.fn(), +})) +vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({ + auth: mockMcpAuth, +})) +vi.mock('@/lib/auth/hybrid', () => hybridAuthMock) +vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock) +vi.mock('@/lib/mcp/oauth', () => ({ + getOrCreateOauthRow: mockGetOrCreateOauthRow, + loadPreregisteredClient: mockLoadPreregisteredClient, + McpOauthRedirectRequired: MockMcpOauthRedirectRequired, + setOauthRowUser: mockSetOauthRowUser, + SimMcpOauthProvider: vi.fn().mockImplementation((value) => value), +})) + +import { GET } from './route' + +describe('MCP OAuth start route', () => { + beforeEach(() => { + vi.clearAllMocks() + resetDbChainMock() + hybridAuthMockFns.mockCheckSessionOrInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-2', + userName: 'User Two', + userEmail: 'user2@example.com', + authType: 'session', + }) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValue('write') + dbChainMockFns.limit.mockResolvedValue([ + { + id: 'server-1', + name: 'Exa', + url: 'https://mcp.exa.ai/mcp', + workspaceId: 'workspace-1', + authType: 'oauth', + deletedAt: null, + }, + ]) + mockGetOrCreateOauthRow.mockResolvedValue({ + id: 'oauth-row-1', + mcpServerId: 'server-1', + userId: 'user-1', + workspaceId: 'workspace-1', + clientInformation: null, + tokens: null, + codeVerifier: null, + state: null, + updatedAt: new Date(), + }) + mockLoadPreregisteredClient.mockResolvedValue(undefined) + mockMcpAuth.mockRejectedValue(new MockMcpOauthRedirectRequired('https://mcp.exa.ai/authorize')) + }) + + it('requires workspace write permission via MCP auth middleware', async () => { + const request = new NextRequest( + 'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1' + ) + + await GET(request) + + expect(permissionsMockFns.mockGetUserEntityPermissions).toHaveBeenCalledWith( + 'user-2', + 'workspace', + 'workspace-1' + ) + }) + + it('uses a workspace-scoped OAuth row and stamps the latest authorizing user', async () => { + const request = new NextRequest( + 'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1' + ) + + const response = await GET(request) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body).toEqual({ + status: 'redirect', + authorizationUrl: 'https://mcp.exa.ai/authorize', + }) + expect(mockGetOrCreateOauthRow).toHaveBeenCalledWith({ + mcpServerId: 'server-1', + userId: 'user-2', + workspaceId: 'workspace-1', + }) + expect(mockSetOauthRowUser).toHaveBeenCalledWith('oauth-row-1', 'user-2') + }) + + it('rejects a second user starting OAuth while another authorization is active', async () => { + mockGetOrCreateOauthRow.mockResolvedValueOnce({ + id: 'oauth-row-1', + mcpServerId: 'server-1', + userId: 'user-1', + workspaceId: 'workspace-1', + clientInformation: null, + tokens: null, + codeVerifier: null, + state: 'hashed-active-state', + updatedAt: new Date(), + }) + const request = new NextRequest( + 'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1' + ) + + const response = await GET(request) + const body = await response.json() + + expect(response.status).toBe(409) + expect(body.error).toBe('OAuth authorization already in progress for this server') + expect(mockMcpAuth).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/app/api/mcp/oauth/start/route.ts b/apps/sim/app/api/mcp/oauth/start/route.ts new file mode 100644 index 00000000000..b57617388c6 --- /dev/null +++ b/apps/sim/app/api/mcp/oauth/start/route.ts @@ -0,0 +1,109 @@ +import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js' +import { db } from '@sim/db' +import { mcpServers } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { and, eq, isNull } from 'drizzle-orm' +import type { NextRequest } from 'next/server' +import { NextResponse } from 'next/server' +import { startMcpOauthQuerySchema } from '@/lib/api/contracts/mcp' +import { validationErrorResponse } from '@/lib/api/server' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { withMcpAuth } from '@/lib/mcp/middleware' +import { + getOrCreateOauthRow, + loadPreregisteredClient, + McpOauthRedirectRequired, + SimMcpOauthProvider, + setOauthRowUser, +} from '@/lib/mcp/oauth' +import { createMcpErrorResponse } from '@/lib/mcp/utils' + +const logger = createLogger('McpOauthStartAPI') +const OAUTH_START_TTL_MS = 10 * 60 * 1000 + +export const dynamic = 'force-dynamic' + +export const GET = withRouteHandler( + withMcpAuth('write')(async (request: NextRequest, { userId, workspaceId, requestId }) => { + try { + const queryResult = startMcpOauthQuerySchema.safeParse( + Object.fromEntries(new URL(request.url).searchParams) + ) + if (!queryResult.success) { + return validationErrorResponse(queryResult.error) + } + const { serverId } = queryResult.data + + const [server] = await db + .select() + .from(mcpServers) + .where( + and( + eq(mcpServers.id, serverId), + eq(mcpServers.workspaceId, workspaceId), + isNull(mcpServers.deletedAt) + ) + ) + .limit(1) + + if (!server) { + return createMcpErrorResponse(new Error('Server not found'), 'Server not found', 404) + } + if (server.authType !== 'oauth') { + return createMcpErrorResponse( + new Error(`Server authType is "${server.authType}", not oauth`), + 'Server is not configured for OAuth', + 400 + ) + } + if (!server.url) { + return createMcpErrorResponse(new Error('Server has no URL'), 'Missing server URL', 400) + } + + const row = await getOrCreateOauthRow({ + mcpServerId: server.id, + userId, + workspaceId, + }) + const hasActiveFlow = !!row.state && row.updatedAt.getTime() > Date.now() - OAUTH_START_TTL_MS + if (hasActiveFlow && row.userId && row.userId !== userId) { + return createMcpErrorResponse( + new Error('OAuth authorization already in progress'), + 'OAuth authorization already in progress for this server', + 409 + ) + } + if (row.userId !== userId) { + await setOauthRowUser(row.id, userId) + row.userId = userId + } + const preregistered = await loadPreregisteredClient(server.id) + const provider = new SimMcpOauthProvider({ row, preregistered }) + + try { + const result = await mcpAuth(provider, { serverUrl: server.url }) + if (result === 'AUTHORIZED') { + return NextResponse.json({ status: 'already_authorized' }) + } + return createMcpErrorResponse( + new Error('Provider did not capture redirect URL'), + 'Failed to start OAuth flow', + 500 + ) + } catch (e) { + if (e instanceof McpOauthRedirectRequired) { + logger.info(`[${requestId}] OAuth redirect for server ${serverId}`) + return NextResponse.json({ + status: 'redirect', + authorizationUrl: e.authorizationUrl, + }) + } + throw e + } + } catch (error) { + logger.error(`[${requestId}] Error starting MCP OAuth flow:`, error) + return createMcpErrorResponse(toError(error), 'Failed to start OAuth flow', 500) + } + }) +) diff --git a/apps/sim/app/api/mcp/servers/[id]/route.ts b/apps/sim/app/api/mcp/servers/[id]/route.ts index b2b3b35f5b9..0e28c4a99a0 100644 --- a/apps/sim/app/api/mcp/servers/[id]/route.ts +++ b/apps/sim/app/api/mcp/servers/[id]/route.ts @@ -1,11 +1,12 @@ import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit' import { db } from '@sim/db' -import { mcpServers } from '@sim/db/schema' +import { mcpServerOauth, mcpServers } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { updateMcpServerBodySchema } from '@/lib/api/contracts/mcp' +import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { McpDnsResolutionError, @@ -52,8 +53,16 @@ export const PATCH = withRouteHandler( } ) - // Remove workspaceId from body to prevent it from being updated - const { workspaceId: _, ...updateData } = body + const { workspaceId: _, oauthClientSecret, ...updateData } = body + const finalUpdateData: Record = { ...updateData } + if (oauthClientSecret !== undefined) { + finalUpdateData.oauthClientSecret = oauthClientSecret + ? (await encryptSecret(oauthClientSecret)).encrypted + : null + } + if (updateData.oauthClientId !== undefined) { + finalUpdateData.oauthClientId = updateData.oauthClientId || null + } if (updateData.url) { try { @@ -78,9 +87,13 @@ export const PATCH = withRouteHandler( } } - // Get the current server to check if URL is changing const [currentServer] = await db - .select({ url: mcpServers.url }) + .select({ + url: mcpServers.url, + authType: mcpServers.authType, + oauthClientId: mcpServers.oauthClientId, + oauthClientSecret: mcpServers.oauthClientSecret, + }) .from(mcpServers) .where( and( @@ -91,20 +104,60 @@ export const PATCH = withRouteHandler( ) .limit(1) - const [updatedServer] = await db - .update(mcpServers) - .set({ - ...updateData, - updatedAt: new Date(), - }) - .where( - and( - eq(mcpServers.id, serverId), - eq(mcpServers.workspaceId, workspaceId), - isNull(mcpServers.deletedAt) + // Adding OAuth client credentials to a non-OAuth server promotes it + // to OAuth so the connect-with-OAuth UI becomes reachable. + if ( + body.oauthClientId && + currentServer && + currentServer.authType !== 'oauth' && + finalUpdateData.authType === undefined + ) { + finalUpdateData.authType = 'oauth' + } + + const urlChanged = body.url !== undefined && currentServer?.url !== body.url + const clientIdChanged = + body.oauthClientId !== undefined && + (body.oauthClientId || null) !== (currentServer?.oauthClientId ?? null) + let clientSecretChanged = false + if (oauthClientSecret !== undefined) { + if (!oauthClientSecret) { + clientSecretChanged = currentServer?.oauthClientSecret != null + } else if (!currentServer?.oauthClientSecret) { + clientSecretChanged = true + } else { + const currentPlaintext = (await decryptSecret(currentServer.oauthClientSecret)) + .decrypted + clientSecretChanged = currentPlaintext !== oauthClientSecret + } + } + const oauthCredsChanged = clientIdChanged || clientSecretChanged + const shouldClearOauth = urlChanged || oauthCredsChanged + + const updatedServer = await db.transaction(async (tx) => { + const [updated] = await tx + .update(mcpServers) + .set({ + ...finalUpdateData, + updatedAt: new Date(), + }) + .where( + and( + eq(mcpServers.id, serverId), + eq(mcpServers.workspaceId, workspaceId), + isNull(mcpServers.deletedAt) + ) ) - ) - .returning() + .returning() + + if (!updated) return null + + if (shouldClearOauth) { + await tx.delete(mcpServerOauth).where(eq(mcpServerOauth.mcpServerId, serverId)) + } + + return updated + }) if (!updatedServer) { return createMcpErrorResponse( @@ -114,8 +167,15 @@ export const PATCH = withRouteHandler( ) } + if (shouldClearOauth) { + logger.info( + `[${requestId}] Cleared OAuth credentials for server ${serverId} due to ${urlChanged ? 'URL' : 'OAuth credential'} change` + ) + } + const shouldClearCache = - (body.url !== undefined && currentServer?.url !== body.url) || + urlChanged || + oauthCredsChanged || body.enabled !== undefined || body.headers !== undefined || body.timeout !== undefined || @@ -149,7 +209,10 @@ export const PATCH = withRouteHandler( request, }) - return createMcpSuccessResponse({ server: updatedServer }) + const { oauthClientSecret: _secret, ...rest } = updatedServer + return createMcpSuccessResponse({ + server: { ...rest, hasOauthClientSecret: !!_secret }, + }) } catch (error) { logger.error(`[${requestId}] Error updating MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to update MCP server', 500) diff --git a/apps/sim/app/api/mcp/servers/route.ts b/apps/sim/app/api/mcp/servers/route.ts index d2666431506..0eb68b0553c 100644 --- a/apps/sim/app/api/mcp/servers/route.ts +++ b/apps/sim/app/api/mcp/servers/route.ts @@ -1,6 +1,6 @@ import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit' import { db } from '@sim/db' -import { mcpServers } from '@sim/db/schema' +import { mcpServerOauth, mcpServers } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { generateId } from '@sim/utils/id' @@ -8,6 +8,7 @@ import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { createMcpServerBodySchema, deleteMcpServerByQuerySchema } from '@/lib/api/contracts/mcp' import { validationErrorResponse } from '@/lib/api/server' +import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { McpDnsResolutionError, @@ -17,6 +18,7 @@ import { validateMcpServerSsrf, } from '@/lib/mcp/domain-check' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { detectMcpAuthType } from '@/lib/mcp/oauth' import { mcpService } from '@/lib/mcp/service' import { createMcpErrorResponse, @@ -37,11 +39,16 @@ export const GET = withRouteHandler( try { logger.info(`[${requestId}] Listing MCP servers for workspace ${workspaceId}`) - const servers = await db + const rows = await db .select() .from(mcpServers) .where(and(eq(mcpServers.workspaceId, workspaceId), isNull(mcpServers.deletedAt))) + const servers = rows.map(({ oauthClientSecret: _secret, ...rest }) => ({ + ...rest, + hasOauthClientSecret: !!_secret, + })) + logger.info( `[${requestId}] Listed ${servers.length} MCP servers for workspace ${workspaceId}` ) @@ -105,34 +112,120 @@ export const POST = withRouteHandler( const serverId = body.url ? generateMcpServerId(workspaceId, body.url) : generateId() + const oauthClientSecretProvided = body.oauthClientSecret !== undefined + const oauthClientSecretEncrypted = body.oauthClientSecret + ? (await encryptSecret(body.oauthClientSecret)).encrypted + : null + const oauthClientIdProvided = body.oauthClientId !== undefined + const oauthClientId = body.oauthClientId || null + const [existingServer] = await db - .select({ id: mcpServers.id, deletedAt: mcpServers.deletedAt }) + .select({ + id: mcpServers.id, + deletedAt: mcpServers.deletedAt, + url: mcpServers.url, + authType: mcpServers.authType, + oauthClientId: mcpServers.oauthClientId, + oauthClientSecret: mcpServers.oauthClientSecret, + }) .from(mcpServers) .where(and(eq(mcpServers.id, serverId), eq(mcpServers.workspaceId, workspaceId))) .limit(1) + const urlChanged = existingServer ? existingServer.url !== body.url : true + const hasHeaders = body.headers && Object.keys(body.headers).length > 0 + + let resolvedAuthType: 'none' | 'headers' | 'oauth' = body.authType ?? 'headers' + if (!body.authType) { + if (existingServer && !urlChanged) { + // Preserve existing authType on edits that don't change the URL — re-probing + // can flip a working OAuth+DCR server to 'headers' on a transient 401/timeout. + resolvedAuthType = (existingServer.authType ?? 'headers') as + | 'none' + | 'headers' + | 'oauth' + } else if (body.url && !hasHeaders) { + try { + resolvedAuthType = await detectMcpAuthType(body.url) + logger.info(`[${requestId}] Probed ${body.url}: authType=${resolvedAuthType}`) + } catch (e) { + logger.warn(`[${requestId}] Probe failed for ${body.url}, defaulting to headers`, e) + resolvedAuthType = 'headers' + } + } + } + + // User-supplied client credentials imply OAuth; pin authType regardless of probe. + if (body.oauthClientId) resolvedAuthType = 'oauth' + if (existingServer) { logger.info( `[${requestId}] Server with ID ${serverId} already exists, updating instead of creating` ) - await db - .update(mcpServers) - .set({ + const clientIdChanged = + oauthClientIdProvided && + (oauthClientId || null) !== (existingServer.oauthClientId ?? null) + let clientSecretChanged = false + if (oauthClientSecretProvided) { + if (!body.oauthClientSecret) { + clientSecretChanged = existingServer.oauthClientSecret != null + } else if (!existingServer.oauthClientSecret) { + clientSecretChanged = true + } else { + const currentPlaintext = (await decryptSecret(existingServer.oauthClientSecret)) + .decrypted + clientSecretChanged = currentPlaintext !== body.oauthClientSecret + } + } + const oauthCredsChanged = clientIdChanged || clientSecretChanged + + const isRevival = existingServer.deletedAt !== null + const shouldClearOauth = urlChanged || oauthCredsChanged || isRevival + + await db.transaction(async (tx) => { + if (shouldClearOauth) { + await tx.delete(mcpServerOauth).where(eq(mcpServerOauth.mcpServerId, serverId)) + } + const updateValues: Record = { name: body.name, description: body.description, transport: body.transport, url: body.url, + authType: resolvedAuthType, headers: body.headers || {}, timeout: body.timeout || 30000, retries: body.retries || 3, enabled: body.enabled !== false, - connectionStatus: 'connected', - lastConnected: new Date(), updatedAt: new Date(), deletedAt: null, - }) - .where(eq(mcpServers.id, serverId)) + } + if (resolvedAuthType === 'oauth') { + if (shouldClearOauth) { + updateValues.connectionStatus = 'disconnected' + updateValues.lastConnected = null + } + } else { + updateValues.connectionStatus = 'connected' + updateValues.lastConnected = new Date() + } + if (oauthClientIdProvided) updateValues.oauthClientId = oauthClientId + if (oauthClientSecretProvided) { + updateValues.oauthClientSecret = oauthClientSecretEncrypted + } + await tx.update(mcpServers).set(updateValues).where(eq(mcpServers.id, serverId)) + }) + + if (shouldClearOauth) { + const reason = isRevival + ? 'server revival' + : urlChanged + ? 'URL change' + : 'OAuth credential change' + logger.info( + `[${requestId}] Cleared OAuth credentials for server ${serverId} due to ${reason}` + ) + } await mcpService.clearCache(workspaceId) @@ -140,7 +233,10 @@ export const POST = withRouteHandler( `[${requestId}] Successfully updated MCP server: ${body.name} (ID: ${serverId})` ) - return createMcpSuccessResponse({ serverId, updated: true }, 200) + return createMcpSuccessResponse( + { serverId, updated: true, authType: resolvedAuthType }, + 200 + ) } await db @@ -153,12 +249,15 @@ export const POST = withRouteHandler( description: body.description, transport: body.transport, url: body.url, + authType: resolvedAuthType, + oauthClientId, + oauthClientSecret: oauthClientSecretEncrypted, headers: body.headers || {}, timeout: body.timeout || 30000, retries: body.retries || 3, enabled: body.enabled !== false, - connectionStatus: 'connected', - lastConnected: new Date(), + connectionStatus: resolvedAuthType === 'oauth' ? 'disconnected' : 'connected', + lastConnected: resolvedAuthType === 'oauth' ? null : new Date(), createdAt: new Date(), updatedAt: new Date(), }) @@ -178,9 +277,7 @@ export const POST = withRouteHandler( transport: body.transport, workspaceId, }) - } catch (_e) { - // Silently fail - } + } catch (_e) {} const sourceParam = body.source as string | undefined const source = @@ -217,7 +314,7 @@ export const POST = withRouteHandler( request, }) - return createMcpSuccessResponse({ serverId }, 201) + return createMcpSuccessResponse({ serverId, authType: resolvedAuthType }, 201) } catch (error) { logger.error(`[${requestId}] Error registering MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to register MCP server', 500) diff --git a/apps/sim/app/api/mcp/tools/discover/route.ts b/apps/sim/app/api/mcp/tools/discover/route.ts index e94f2f56328..b125fa7ff2b 100644 --- a/apps/sim/app/api/mcp/tools/discover/route.ts +++ b/apps/sim/app/api/mcp/tools/discover/route.ts @@ -1,3 +1,4 @@ +import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { createLogger } from '@sim/logger' import type { NextRequest } from 'next/server' import { mcpToolDiscoveryQuerySchema, refreshMcpToolsBodySchema } from '@/lib/api/contracts/mcp' @@ -5,7 +6,7 @@ import { validationErrorResponse } from '@/lib/api/server' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' import { mcpService } from '@/lib/mcp/service' -import type { McpToolDiscoveryResponse } from '@/lib/mcp/types' +import { McpOauthAuthorizationRequiredError, type McpToolDiscoveryResponse } from '@/lib/mcp/types' import { categorizeError, createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils' const logger = createLogger('McpToolDiscoveryAPI') @@ -46,6 +47,12 @@ export const GET = withRouteHandler( ) return createMcpSuccessResponse(responseData) } catch (error) { + if ( + error instanceof McpOauthAuthorizationRequiredError || + error instanceof UnauthorizedError + ) { + return createMcpErrorResponse(error, 'OAuth re-authorization required', 401) + } logger.error(`[${requestId}] Error discovering MCP tools:`, error) const { message, status } = categorizeError(error) return createMcpErrorResponse(new Error(message), 'Failed to discover MCP tools', status) @@ -100,6 +107,12 @@ export const POST = withRouteHandler( }, }) } catch (error) { + if ( + error instanceof McpOauthAuthorizationRequiredError || + error instanceof UnauthorizedError + ) { + return createMcpErrorResponse(error, 'OAuth re-authorization required', 401) + } logger.error(`[${requestId}] Error refreshing tool discovery:`, error) const { message, status } = categorizeError(error) return createMcpErrorResponse(new Error(message), 'Failed to refresh tool discovery', status) diff --git a/apps/sim/app/api/mcp/tools/execute/route.ts b/apps/sim/app/api/mcp/tools/execute/route.ts index d9458deceab..8599a5fcadf 100644 --- a/apps/sim/app/api/mcp/tools/execute/route.ts +++ b/apps/sim/app/api/mcp/tools/execute/route.ts @@ -1,5 +1,7 @@ +import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { createLogger } from '@sim/logger' import type { NextRequest } from 'next/server' +import { NextResponse } from 'next/server' import { mcpToolExecutionBodySchema } from '@/lib/api/contracts/mcp' import { getHighestPrioritySubscription } from '@/lib/billing/core/plan' import { getExecutionTimeout } from '@/lib/core/execution-limits' @@ -7,8 +9,14 @@ import type { SubscriptionPlan } from '@/lib/core/rate-limiter/types' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { SIM_VIA_HEADER } from '@/lib/execution/call-chain' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { McpOauthRedirectRequired } from '@/lib/mcp/oauth' import { mcpService } from '@/lib/mcp/service' -import type { McpTool, McpToolCall, McpToolResult } from '@/lib/mcp/types' +import { + McpOauthAuthorizationRequiredError, + type McpTool, + type McpToolCall, + type McpToolResult, +} from '@/lib/mcp/types' import { categorizeError, createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils' import { assertPermissionsAllowed, @@ -43,6 +51,7 @@ function hasType(prop: unknown): prop is SchemaProperty { */ export const POST = withRouteHandler( withMcpAuth('read')(async (request: NextRequest, { userId, workspaceId, requestId }) => { + let serverId: string | undefined try { const rawBody = getParsedBody(request) ?? (await request.json()) const parsedBody = mcpToolExecutionBodySchema.safeParse(rawBody) @@ -63,7 +72,8 @@ export const POST = withRouteHandler( userId: userId, }) - const { serverId, toolName, arguments: rawArgs } = body + const { toolName, arguments: rawArgs } = body + serverId = body.serverId const args = rawArgs || {} try { @@ -101,7 +111,8 @@ export const POST = withRouteHandler( if (tool.inputSchema?.properties) { for (const [paramName, paramSchema] of Object.entries(tool.inputSchema.properties)) { - const schema = paramSchema as any + const schema = hasType(paramSchema) ? paramSchema : null + if (!schema) continue const value = args[paramName] if (value === undefined || value === null) { @@ -185,12 +196,18 @@ export const POST = withRouteHandler( extraHeaders[SIM_VIA_HEADER] = simViaHeader } + let timeoutHandle: ReturnType | undefined const result = await Promise.race([ mcpService.executeTool(userId, serverId, toolCall, workspaceId, extraHeaders), - new Promise((_, reject) => - setTimeout(() => reject(new Error('Tool execution timeout')), executionTimeout) - ), - ]) + new Promise((_, reject) => { + timeoutHandle = setTimeout( + () => reject(new Error('Tool execution timeout')), + executionTimeout + ) + }), + ]).finally(() => { + if (timeoutHandle !== undefined) clearTimeout(timeoutHandle) + }) const transformedResult = transformToolResult(result) @@ -218,6 +235,27 @@ export const POST = withRouteHandler( return createMcpSuccessResponse(transformedResult) } catch (error) { + if ( + error instanceof McpOauthAuthorizationRequiredError || + error instanceof McpOauthRedirectRequired || + error instanceof UnauthorizedError + ) { + const errorServerId = + error instanceof McpOauthAuthorizationRequiredError ? error.serverId : serverId + logger.warn(`[${requestId}] OAuth re-authorization required for MCP tool execution`, { + serverId: errorServerId, + }) + return NextResponse.json( + { + success: false, + error: 'OAuth re-authorization required', + code: 'reauth_required', + serverId: errorServerId, + }, + { status: 401 } + ) + } + logger.error(`[${requestId}] Error executing MCP tool:`, error) const { message, status } = categorizeError(error) diff --git a/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/route.ts b/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/route.ts index b0b291b0807..98f671dcb3a 100644 --- a/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/route.ts +++ b/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/route.ts @@ -12,7 +12,11 @@ import { } from '@/lib/api/contracts/data-drains' import { parseRequest, validationErrorResponse } from '@/lib/api/server' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { authorizeDrainAccess, loadDrain } from '@/lib/data-drains/access' +import { + authorizeDrainAccess, + loadDrain, + validateDrainEnvironmentWorkspace, +} from '@/lib/data-drains/access' import { getDestination } from '@/lib/data-drains/destinations/registry' import { encryptCredentials } from '@/lib/data-drains/encryption' import { serializeDrain } from '@/lib/data-drains/serializers' @@ -75,10 +79,19 @@ export const PUT = withRouteHandler(async (request: NextRequest, context: RouteC return NextResponse.json({ error: 'source cannot be changed after creation' }, { status: 400 }) } + const envWorkspaceError = await validateDrainEnvironmentWorkspace( + organizationId, + body.environmentWorkspaceId + ) + if (envWorkspaceError) return envWorkspaceError + const updates: Partial = { updatedAt: new Date() } if (body.name !== undefined) updates.name = body.name if (body.scheduleCadence !== undefined) updates.scheduleCadence = body.scheduleCadence if (body.enabled !== undefined) updates.enabled = body.enabled + if (body.environmentWorkspaceId !== undefined) { + updates.environmentWorkspaceId = body.environmentWorkspaceId + } if (body.destinationType !== undefined && body.destinationType !== drain.destinationType) { return NextResponse.json( @@ -141,6 +154,7 @@ export const PUT = withRouteHandler(async (request: NextRequest, context: RouteC source: body.source, scheduleCadence: body.scheduleCadence, enabled: body.enabled, + environmentWorkspaceId: body.environmentWorkspaceId, destinationConfigChanged: body.destinationConfig !== undefined, destinationCredentialsChanged: body.destinationCredentials !== undefined, }, diff --git a/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/test/route.ts b/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/test/route.ts index 5550ff9eb4c..808faf92cfe 100644 --- a/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/test/route.ts +++ b/apps/sim/app/api/organizations/[id]/data-drains/[drainId]/test/route.ts @@ -8,6 +8,7 @@ import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { authorizeDrainAccess, loadDrain } from '@/lib/data-drains/access' import { getDestination } from '@/lib/data-drains/destinations/registry' import { decryptCredentials } from '@/lib/data-drains/encryption' +import { resolveDataDrainDestinationEnvVars } from '@/lib/data-drains/resolve-config' const logger = createLogger('DataDrainTestAPI') @@ -36,10 +37,16 @@ export const POST = withRouteHandler(async (request: NextRequest, context: Route ) } - const config = destination.configSchema.parse(drain.destinationConfig) - const credentials = destination.credentialsSchema.parse( - await decryptCredentials(drain.destinationCredentials) + const unresolvedConfig = destination.configSchema.parse(drain.destinationConfig) + const unresolvedCredentials = await decryptCredentials(drain.destinationCredentials) + const resolved = await resolveDataDrainDestinationEnvVars( + unresolvedConfig, + unresolvedCredentials, + drain.createdBy, + drain.environmentWorkspaceId ?? undefined ) + const config = destination.configSchema.parse(resolved.config) + const credentials = destination.credentialsSchema.parse(resolved.credentials) const controller = new AbortController() const timeout = setTimeout(() => controller.abort(), TEST_TIMEOUT_MS) diff --git a/apps/sim/app/api/organizations/[id]/data-drains/route.ts b/apps/sim/app/api/organizations/[id]/data-drains/route.ts index d78655ae28b..763a578f9d3 100644 --- a/apps/sim/app/api/organizations/[id]/data-drains/route.ts +++ b/apps/sim/app/api/organizations/[id]/data-drains/route.ts @@ -9,7 +9,7 @@ import { type NextRequest, NextResponse } from 'next/server' import { createDataDrainContract, listDataDrainsContract } from '@/lib/api/contracts/data-drains' import { parseRequest, validationErrorResponse } from '@/lib/api/server' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { authorizeDrainAccess } from '@/lib/data-drains/access' +import { authorizeDrainAccess, validateDrainEnvironmentWorkspace } from '@/lib/data-drains/access' import { getDestination } from '@/lib/data-drains/destinations/registry' import { encryptCredentials } from '@/lib/data-drains/encryption' import { serializeDrain } from '@/lib/data-drains/serializers' @@ -45,6 +45,12 @@ export const POST = withRouteHandler(async (request: NextRequest, context: Route const body = parsed.data.body + const envWorkspaceError = await validateDrainEnvironmentWorkspace( + organizationId, + body.environmentWorkspaceId + ) + if (envWorkspaceError) return envWorkspaceError + if (!body.destinationCredentials) { return NextResponse.json( { error: 'destinationCredentials is required when creating a drain' }, @@ -84,6 +90,7 @@ export const POST = withRouteHandler(async (request: NextRequest, context: Route destinationType: body.destinationType, destinationConfig: configResult.data as Record, destinationCredentials: encryptedCredentials, + environmentWorkspaceId: body.environmentWorkspaceId, scheduleCadence: body.scheduleCadence, enabled: body.enabled ?? true, cursor: null, diff --git a/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/form-field/form-field.tsx b/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/form-field/form-field.tsx index 04beeb1484a..cad5381d1d2 100644 --- a/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/form-field/form-field.tsx +++ b/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/form-field/form-field.tsx @@ -9,7 +9,7 @@ interface FormFieldProps { export function FormField({ label, children, optional }: FormFieldProps) { return (
-