Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions src/app/(app)/integrations/discord/link/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import type { NextRequest } from 'next/server';
import { NextResponse } from 'next/server';
import * as z from 'zod';
import { APP_URL } from '@/lib/constants';
import { getUserFromAuth } from '@/lib/user.server';
import { createOAuthState, type OAuthStateContext } from '@/lib/integrations/oauth-state';
import { getDiscordUserLinkOAuthUrl, getInstallation } from '@/lib/integrations/discord-service';
import { isOrganizationMember } from '@/lib/organizations/organizations';
import type { Owner } from '@/lib/integrations/core/types';

const DISCORD_SNOWFLAKE_REGEX = /^\d+$/;

const LinkRequestSchema = z
.discriminatedUnion('ownerType', [
z.object({
ownerType: z.literal('org'),
ownerId: z.uuid(),
guildId: z.string().regex(DISCORD_SNOWFLAKE_REGEX).optional(),
channelId: z.string().regex(DISCORD_SNOWFLAKE_REGEX).optional(),
messageId: z.string().regex(DISCORD_SNOWFLAKE_REGEX).optional(),
}),
z.object({
ownerType: z.literal('user'),
ownerId: z.string().min(1),
guildId: z.string().regex(DISCORD_SNOWFLAKE_REGEX).optional(),
channelId: z.string().regex(DISCORD_SNOWFLAKE_REGEX).optional(),
messageId: z.string().regex(DISCORD_SNOWFLAKE_REGEX).optional(),
}),
])
.superRefine((value, ctx) => {
const presentCount = [value.guildId, value.channelId, value.messageId].filter(Boolean).length;
if (presentCount > 0 && presentCount < 3) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: 'guildId, channelId, and messageId must be provided together',
});
}
});

function buildIntegrationPath(owner: Owner, queryParam?: string): string {
const basePath =
owner.type === 'org'
? `/organizations/${owner.id}/integrations/discord`
: '/integrations/discord';

return queryParam ? `${basePath}?${queryParam}` : basePath;
}

function buildSignInPath(callbackPath: string): string {
return `/users/sign_in?callbackPath=${encodeURIComponent(callbackPath)}`;
}

export async function GET(request: NextRequest) {
const parsed = LinkRequestSchema.safeParse({
ownerType: request.nextUrl.searchParams.get('ownerType'),
ownerId: request.nextUrl.searchParams.get('ownerId'),
guildId: request.nextUrl.searchParams.get('guildId') ?? undefined,
channelId: request.nextUrl.searchParams.get('channelId') ?? undefined,
messageId: request.nextUrl.searchParams.get('messageId') ?? undefined,
});

if (!parsed.success) {
return NextResponse.redirect(new URL('/integrations/discord?error=invalid_link', APP_URL));
}

const owner: Owner =
parsed.data.ownerType === 'org'
? { type: 'org', id: parsed.data.ownerId }
: { type: 'user', id: parsed.data.ownerId };

const callbackPath = `${request.nextUrl.pathname}${request.nextUrl.search}`;
const authResult = await getUserFromAuth({ adminOnly: false });
if (!authResult.user) {
return NextResponse.redirect(new URL(buildSignInPath(callbackPath), APP_URL));
}

if (owner.type === 'org' && !authResult.user.is_admin) {
const isMember = await isOrganizationMember(owner.id, authResult.user.id);
if (!isMember) {
return NextResponse.redirect(
new URL(buildIntegrationPath(owner, 'error=unauthorized'), APP_URL)
);
}
}

if (owner.type === 'user' && authResult.user.id !== owner.id) {
return NextResponse.redirect(
new URL(buildIntegrationPath(owner, 'error=unauthorized'), APP_URL)
);
}

const installation = await getInstallation(owner);
if (!installation) {
return NextResponse.redirect(
new URL(buildIntegrationPath(owner, 'error=installation_missing'), APP_URL)
);
}

const replayContext: OAuthStateContext | undefined =
parsed.data.guildId && parsed.data.channelId && parsed.data.messageId
? {
discordReplayGuildId: parsed.data.guildId,
discordReplayChannelId: parsed.data.channelId,
discordReplayMessageId: parsed.data.messageId,
}
: undefined;

if (
replayContext?.discordReplayGuildId &&
installation.platform_installation_id !== replayContext.discordReplayGuildId
) {
return NextResponse.redirect(
new URL(buildIntegrationPath(owner, 'error=invalid_link'), APP_URL)
);
}

const statePrefix = owner.type === 'org' ? `org_${owner.id}` : `user_${owner.id}`;
const state = createOAuthState(statePrefix, authResult.user.id, replayContext);
return NextResponse.redirect(getDiscordUserLinkOAuthUrl(state));
}
197 changes: 190 additions & 7 deletions src/app/api/integrations/discord/callback/route.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,154 @@
import type { NextRequest } from 'next/server';
import { NextResponse } from 'next/server';
import { after, NextResponse } from 'next/server';
import { getUserFromAuth } from '@/lib/user.server';
import { ensureOrganizationAccess } from '@/routers/organizations/utils';
import type { Owner } from '@/lib/integrations/core/types';
import { captureException, captureMessage } from '@sentry/nextjs';
import { exchangeDiscordCode, upsertDiscordInstallation } from '@/lib/integrations/discord-service';
import {
exchangeDiscordCode,
getDiscordBotUserId,
getDiscordChannelMessage,
getDiscordOAuthUserId,
linkDiscordRequesterToOwner,
postDiscordMessage,
upsertDiscordInstallation,
} from '@/lib/integrations/discord-service';
import { verifyOAuthState } from '@/lib/integrations/oauth-state';
import { APP_URL } from '@/lib/constants';
import { processDiscordBotMessage } from '@/lib/discord-bot';
import { getDevUserSuffix } from '@/lib/slack-bot/dev-user-info';
import {
isDiscordBotMessage,
replaceDiscordUserMentionsWithNames,
stripDiscordBotMention,
truncateForDiscord,
} from '@/lib/discord-bot/discord-utils';
import { z } from 'zod';

const DISCORD_SNOWFLAKE_REGEX = /^\d+$/;

const DiscordReplayContextSchema = z.object({
discordReplayGuildId: z.string().regex(DISCORD_SNOWFLAKE_REGEX),
discordReplayChannelId: z.string().regex(DISCORD_SNOWFLAKE_REGEX),
discordReplayMessageId: z.string().regex(DISCORD_SNOWFLAKE_REGEX),
});

type DiscordReplayContext = {
guildId: string;
channelId: string;
messageId: string;
};

function getDiscordReplayContext(
value: Record<string, string> | undefined
): DiscordReplayContext | null {
if (!value) {
return null;
}

const parsed = DiscordReplayContextSchema.safeParse(value);
if (!parsed.success) {
return null;
}

return {
guildId: parsed.data.discordReplayGuildId,
channelId: parsed.data.discordReplayChannelId,
messageId: parsed.data.discordReplayMessageId,
};
}

async function replayLinkedDiscordMessage(
replayContext: DiscordReplayContext,
linkedDiscordUserId: string
): Promise<void> {
const messageResult = await getDiscordChannelMessage(
replayContext.channelId,
replayContext.messageId
);
if (!messageResult.ok) {
captureMessage('Discord replay failed to fetch original message', {
level: 'warning',
tags: { endpoint: 'discord/callback', source: 'discord_replay' },
extra: { replayContext, error: messageResult.error },
});
return;
}

const message = messageResult.message;
if (message.author.id !== linkedDiscordUserId) {
captureMessage('Discord replay skipped due to author mismatch', {
level: 'warning',
tags: { endpoint: 'discord/callback', source: 'discord_replay' },
extra: {
replayContext,
expectedDiscordUserId: linkedDiscordUserId,
messageAuthorId: message.author.id,
},
});
return;
}

if (isDiscordBotMessage({ author: { bot: message.author.bot } })) {
return;
}

const botUserResult = await getDiscordBotUserId();
if (!botUserResult.ok) {
captureMessage('Discord replay failed to resolve bot user', {
level: 'warning',
tags: { endpoint: 'discord/callback', source: 'discord_replay' },
extra: { replayContext, error: botUserResult.error },
});
return;
}

const botUserId = botUserResult.userId;
const mentionsBot = message.mentions.some(mention => mention.id === botUserId);
if (!mentionsBot) {
captureMessage('Discord replay skipped because message no longer mentions bot', {
level: 'info',
tags: { endpoint: 'discord/callback', source: 'discord_replay' },
extra: { replayContext, botUserId },
});
return;
}

const cleanedText = stripDiscordBotMention(message.content, botUserId);
if (!cleanedText) {
return;
}

const resolvedText = await replaceDiscordUserMentionsWithNames(
cleanedText,
replayContext.guildId
);
const result = await processDiscordBotMessage(resolvedText, replayContext.guildId, {
channelId: replayContext.channelId,
guildId: replayContext.guildId,
userId: linkedDiscordUserId,
messageId: replayContext.messageId,
});

const responseText = truncateForDiscord(result.response + getDevUserSuffix());
const postResult = await postDiscordMessage(replayContext.channelId, responseText, {
messageReference: { message_id: replayContext.messageId },
linkButton: result.linkDiscordAccountUrl
? {
label: 'Link My Discord Account',
url: result.linkDiscordAccountUrl,
}
: undefined,
});

if (!postResult.ok) {
captureMessage('Discord replay failed to post response', {
level: 'warning',
tags: { endpoint: 'discord/callback', source: 'discord_replay' },
extra: { replayContext, error: postResult.error },
});
}
}

const buildDiscordRedirectPath = (state: string | null, queryParam: string): string => {
// Try to extract the owner from a signed state for best-effort redirects on error paths.
Expand Down Expand Up @@ -89,6 +231,8 @@ export async function GET(request: NextRequest) {
return NextResponse.redirect(new URL('/integrations?error=unauthorized', APP_URL));
}

const replayContext = getDiscordReplayContext(verified.context);

// 5. Parse owner from verified state payload
let owner: Owner;
const ownerStr = verified.owner;
Expand Down Expand Up @@ -121,14 +265,53 @@ export async function GET(request: NextRequest) {
// 7. Exchange code for access token
const oauthData = await exchangeDiscordCode(code);

// 8. Store installation in database
await upsertDiscordInstallation(owner, oauthData);
// 8. Resolve the Discord requester identity and persist authorization mapping
const discordUserId = await getDiscordOAuthUserId(oauthData.access_token);
const authorizedRequester = {
kiloUserId: user.id,
discordUserId,
};

const isInstallFlow = Boolean(oauthData.guild?.id);
if (isInstallFlow) {
await upsertDiscordInstallation(owner, oauthData, authorizedRequester);
} else {
const linked = await linkDiscordRequesterToOwner(owner, authorizedRequester);
if (!linked) {
captureMessage('Discord user link callback without an existing installation', {
level: 'warning',
tags: { endpoint: 'discord/callback', source: 'discord_oauth' },
extra: { owner, userId: user.id },
});

return NextResponse.redirect(
new URL(buildDiscordRedirectPath(state, 'error=installation_missing'), APP_URL)
);
}

if (replayContext && replayContext.guildId === linked.platform_installation_id) {
after(async () => {
await replayLinkedDiscordMessage(replayContext, discordUserId);
});
} else if (replayContext) {
captureMessage('Discord replay context guild mismatch; replay skipped', {
level: 'warning',
tags: { endpoint: 'discord/callback', source: 'discord_replay' },
extra: {
replayContext,
linkedInstallationGuildId: linked.platform_installation_id,
owner,
},
});
}
}

// 9. Redirect to success page
const successPath =
owner.type === 'org'
const successPath = isInstallFlow
? owner.type === 'org'
? `/organizations/${owner.id}/integrations/discord?success=installed`
: `/integrations/discord?success=installed`;
: '/integrations/discord?success=installed'
: '/integrations/discord/link/success';

return NextResponse.redirect(new URL(successPath, APP_URL));
} catch (error) {
Expand Down
6 changes: 6 additions & 0 deletions src/app/discord/webhook/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ async function processGatewayMessage(event: ForwardedGatewayEvent) {
const responseText = truncateForDiscord(responseWithDevInfo);
const postResult = await postDiscordMessage(channelId, responseText, {
messageReference: { message_id: messageId },
linkButton: result.linkDiscordAccountUrl
? {
label: 'Link My Discord Account',
url: result.linkDiscordAccountUrl,
}
: undefined,
});

console.log(
Expand Down
22 changes: 22 additions & 0 deletions src/app/integrations/discord/link/success/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { KiloCardLayout } from '@/components/KiloCardLayout';
import { CheckCircle2 } from 'lucide-react';
import { getUserFromAuthOrRedirect } from '@/lib/user.server';

export default async function DiscordLinkSuccessPage() {
await getUserFromAuthOrRedirect('/users/sign_in?callbackPath=/integrations/discord/link/success');

return (
<KiloCardLayout
className="max-w-xl"
contentClassName="flex flex-col items-center gap-6 py-12 text-center"
>
<CheckCircle2 className="h-20 w-20 text-green-600" />
<div className="space-y-2">
<h1 className="text-3xl font-semibold tracking-tight">Discord account linked</h1>
<p className="text-muted-foreground text-lg">
Your account is now linked to Kilo. You can close this tab and return to Discord.
</p>
</div>
</KiloCardLayout>
);
}
10 changes: 10 additions & 0 deletions src/components/integrations/DiscordIntegrationDetails.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,17 @@ export function DiscordIntegrationDetails({
</div>

{/* Actions */}
<Alert>
<AlertDescription>
Each team member who wants to use Kilo in Discord must link their own Discord
account.
</AlertDescription>
</Alert>

<div className="flex flex-wrap gap-3">
<Button variant="outline" onClick={handleInstall} disabled={!oauthUrlData?.url}>
Link My Discord Account
</Button>
<Button
variant="outline"
onClick={handleTestConnection}
Expand Down
Loading