Skip to content

Commit ae1b192

Browse files
authored
refactor: move auth to main (#1360)
## Problem We have a lot of transient auth issues which are hard to debug. The auth store is a giant 1000 line monolith. It handles auth, oauth, onboarding, project switching etc. We're passing the auth token by value instead of reference throughout the entire repository. Improving this setup could fix some of the issues, or at least makes them easier to debug. <!-- Who is this for and what problem does it solve? --> <!-- Closes #ISSUE_ID --> ## Changes This is the first PR in the auth refactor, and mostly scaffolds stuff: - Move the auth logic to a service in main. - Added a table for storing auth sessions - Changed boot order so that auth starts before window creation - Renderer auth store no longer persists auth/handles refresh - Auth store is now a temporary wrapper around the new main service - `PostHogAPIClient` now asks main for a valid token instead of storing it locally We will replace all usages of the auth store TRPC queries later on. This PR breaks several things: - Token handling in main - Onboarding persistence These are added back in later PRs in the stack, to keep this PR somewhat readable.
1 parent 56aba1c commit ae1b192

22 files changed

Lines changed: 1433 additions & 1140 deletions
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CREATE TABLE `auth_sessions` (
2+
`id` integer PRIMARY KEY NOT NULL CHECK (`id` = 1),
3+
`refresh_token_encrypted` text NOT NULL,
4+
`cloud_region` text NOT NULL,
5+
`selected_project_id` integer,
6+
`scope_version` integer NOT NULL,
7+
`created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
8+
`updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL
9+
);

apps/code/src/main/db/migrations/meta/_journal.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
"when": 1773335630838,
2323
"tag": "0002_massive_bishop",
2424
"breakpoints": true
25+
},
26+
{
27+
"idx": 3,
28+
"version": "6",
29+
"when": 1774890000000,
30+
"tag": "0003_fair_whiplash",
31+
"breakpoints": true
2532
}
2633
]
2734
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import type {
2+
AuthSession,
3+
IAuthSessionRepository,
4+
PersistAuthSessionInput,
5+
} from "./auth-session-repository";
6+
7+
export interface MockAuthSessionRepository extends IAuthSessionRepository {
8+
_session: AuthSession | null;
9+
}
10+
11+
export function createMockAuthSessionRepository(): MockAuthSessionRepository {
12+
let session: AuthSession | null = null;
13+
14+
const clone = (value: AuthSession | null): AuthSession | null =>
15+
value ? { ...value } : null;
16+
17+
return {
18+
get _session() {
19+
return clone(session);
20+
},
21+
set _session(value) {
22+
session = clone(value);
23+
},
24+
getCurrent: () => clone(session),
25+
saveCurrent: (input: PersistAuthSessionInput) => {
26+
const timestamp = new Date().toISOString();
27+
session = {
28+
id: 1,
29+
refreshTokenEncrypted: input.refreshTokenEncrypted,
30+
cloudRegion: input.cloudRegion,
31+
selectedProjectId: input.selectedProjectId,
32+
scopeVersion: input.scopeVersion,
33+
createdAt: session?.createdAt ?? timestamp,
34+
updatedAt: timestamp,
35+
};
36+
return { ...session };
37+
},
38+
clearCurrent: () => {
39+
session = null;
40+
},
41+
};
42+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import type { CloudRegion } from "@shared/types/oauth";
2+
import { eq } from "drizzle-orm";
3+
import { inject, injectable } from "inversify";
4+
import { MAIN_TOKENS } from "../../di/tokens";
5+
import { authSessions } from "../schema";
6+
import type { DatabaseService } from "../service";
7+
8+
export type AuthSession = typeof authSessions.$inferSelect;
9+
export type NewAuthSession = typeof authSessions.$inferInsert;
10+
11+
export interface PersistAuthSessionInput {
12+
refreshTokenEncrypted: string;
13+
cloudRegion: CloudRegion;
14+
selectedProjectId: number | null;
15+
scopeVersion: number;
16+
}
17+
18+
export interface IAuthSessionRepository {
19+
getCurrent(): AuthSession | null;
20+
saveCurrent(input: PersistAuthSessionInput): AuthSession;
21+
clearCurrent(): void;
22+
}
23+
24+
const CURRENT_AUTH_SESSION_ID = 1;
25+
const byId = eq(authSessions.id, CURRENT_AUTH_SESSION_ID);
26+
const now = () => new Date().toISOString();
27+
28+
@injectable()
29+
export class AuthSessionRepository implements IAuthSessionRepository {
30+
constructor(
31+
@inject(MAIN_TOKENS.DatabaseService)
32+
private readonly databaseService: DatabaseService,
33+
) {}
34+
35+
private get db() {
36+
return this.databaseService.db;
37+
}
38+
39+
getCurrent(): AuthSession | null {
40+
return (
41+
this.db.select().from(authSessions).where(byId).limit(1).get() ?? null
42+
);
43+
}
44+
45+
saveCurrent(input: PersistAuthSessionInput): AuthSession {
46+
const timestamp = now();
47+
const existing = this.getCurrent();
48+
49+
const row: NewAuthSession = {
50+
id: CURRENT_AUTH_SESSION_ID,
51+
refreshTokenEncrypted: input.refreshTokenEncrypted,
52+
cloudRegion: input.cloudRegion,
53+
selectedProjectId: input.selectedProjectId,
54+
scopeVersion: input.scopeVersion,
55+
createdAt: existing?.createdAt ?? timestamp,
56+
updatedAt: timestamp,
57+
};
58+
59+
if (existing) {
60+
this.db.update(authSessions).set(row).where(byId).run();
61+
} else {
62+
this.db.insert(authSessions).values(row).run();
63+
}
64+
65+
const saved = this.getCurrent();
66+
if (!saved) {
67+
throw new Error("Failed to persist current auth session");
68+
}
69+
return saved;
70+
}
71+
72+
clearCurrent(): void {
73+
this.db.delete(authSessions).where(byId).run();
74+
}
75+
}

apps/code/src/main/db/schema.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { sql } from "drizzle-orm";
2-
import { index, sqliteTable, text } from "drizzle-orm/sqlite-core";
2+
import { index, integer, sqliteTable, text } from "drizzle-orm/sqlite-core";
33

44
const id = () =>
55
text()
@@ -76,3 +76,13 @@ export const suspensions = sqliteTable("suspensions", {
7676
createdAt: createdAt(),
7777
updatedAt: updatedAt(),
7878
});
79+
80+
export const authSessions = sqliteTable("auth_sessions", {
81+
id: integer().primaryKey(),
82+
refreshTokenEncrypted: text().notNull(),
83+
cloudRegion: text({ enum: ["us", "eu", "dev"] }).notNull(),
84+
selectedProjectId: integer(),
85+
scopeVersion: integer().notNull(),
86+
createdAt: createdAt(),
87+
updatedAt: updatedAt(),
88+
});

apps/code/src/main/di/container.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import "reflect-metadata";
22

33
import { Container } from "inversify";
44
import { ArchiveRepository } from "../db/repositories/archive-repository";
5+
import { AuthSessionRepository } from "../db/repositories/auth-session-repository";
56
import { RepositoryRepository } from "../db/repositories/repository-repository";
67
import { SuspensionRepositoryImpl } from "../db/repositories/suspension-repository";
78
import { WorkspaceRepository } from "../db/repositories/workspace-repository";
@@ -10,6 +11,7 @@ import { DatabaseService } from "../db/service";
1011
import { AgentService } from "../services/agent/service";
1112
import { AppLifecycleService } from "../services/app-lifecycle/service";
1213
import { ArchiveService } from "../services/archive/service";
14+
import { AuthService } from "../services/auth/service";
1315
import { AuthProxyService } from "../services/auth-proxy/service";
1416
import { CloudTaskService } from "../services/cloud-task/service";
1517
import { ConnectivityService } from "../services/connectivity/service";
@@ -49,12 +51,14 @@ export const container = new Container({
4951
});
5052

5153
container.bind(MAIN_TOKENS.DatabaseService).to(DatabaseService);
54+
container.bind(MAIN_TOKENS.AuthSessionRepository).to(AuthSessionRepository);
5255
container.bind(MAIN_TOKENS.RepositoryRepository).to(RepositoryRepository);
5356
container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository);
5457
container.bind(MAIN_TOKENS.WorktreeRepository).to(WorktreeRepository);
5558
container.bind(MAIN_TOKENS.ArchiveRepository).to(ArchiveRepository);
5659
container.bind(MAIN_TOKENS.SuspensionRepository).to(SuspensionRepositoryImpl);
5760
container.bind(MAIN_TOKENS.AgentService).to(AgentService);
61+
container.bind(MAIN_TOKENS.AuthService).to(AuthService);
5862
container.bind(MAIN_TOKENS.AuthProxyService).to(AuthProxyService);
5963
container.bind(MAIN_TOKENS.ArchiveService).to(ArchiveService);
6064
container.bind(MAIN_TOKENS.SuspensionService).to(SuspensionService);

apps/code/src/main/di/tokens.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export const MAIN_TOKENS = Object.freeze({
1010

1111
// Database
1212
DatabaseService: Symbol.for("Main.DatabaseService"),
13+
AuthSessionRepository: Symbol.for("Main.AuthSessionRepository"),
1314
RepositoryRepository: Symbol.for("Main.RepositoryRepository"),
1415
WorkspaceRepository: Symbol.for("Main.WorkspaceRepository"),
1516
WorktreeRepository: Symbol.for("Main.WorktreeRepository"),
@@ -18,6 +19,7 @@ export const MAIN_TOKENS = Object.freeze({
1819

1920
// Services
2021
AgentService: Symbol.for("Main.AgentService"),
22+
AuthService: Symbol.for("Main.AuthService"),
2123
AuthProxyService: Symbol.for("Main.AuthProxyService"),
2224
ArchiveService: Symbol.for("Main.ArchiveService"),
2325
SuspensionService: Symbol.for("Main.SuspensionService"),

apps/code/src/main/index.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { container } from "./di/container";
1111
import { MAIN_TOKENS } from "./di/tokens";
1212
import { registerMcpSandboxProtocol } from "./protocols/mcp-sandbox";
1313
import type { AppLifecycleService } from "./services/app-lifecycle/service";
14+
import type { AuthService } from "./services/auth/service";
1415
import type { ExternalAppsService } from "./services/external-apps/service";
1516
import type { NotificationService } from "./services/notification/service";
1617
import type { OAuthService } from "./services/oauth/service";
@@ -35,15 +36,18 @@ if (!gotTheLock) {
3536
process.exit(0);
3637
}
3738

38-
function initializeServices(): void {
39+
async function initializeServices(): Promise<void> {
3940
container.get<DatabaseService>(MAIN_TOKENS.DatabaseService);
4041
container.get<OAuthService>(MAIN_TOKENS.OAuthService);
42+
const authService = container.get<AuthService>(MAIN_TOKENS.AuthService);
4143
container.get<NotificationService>(MAIN_TOKENS.NotificationService);
4244
container.get<UpdatesService>(MAIN_TOKENS.UpdatesService);
4345
container.get<TaskLinkService>(MAIN_TOKENS.TaskLinkService);
4446
container.get<ExternalAppsService>(MAIN_TOKENS.ExternalAppsService);
4547
container.get<PosthogPluginService>(MAIN_TOKENS.PosthogPluginService);
4648

49+
await authService.initialize();
50+
4751
// Initialize workspace branch watcher for live branch rename detection
4852
const workspaceService = container.get<WorkspaceService>(
4953
MAIN_TOKENS.WorkspaceService,
@@ -69,7 +73,7 @@ registerDeepLinkHandlers();
6973
// Initialize PostHog analytics
7074
initializePostHog();
7175

72-
app.whenReady().then(() => {
76+
app.whenReady().then(async () => {
7377
const commit = __BUILD_COMMIT__ ?? "dev";
7478
const buildDate = __BUILD_DATE__ ?? "dev";
7579
log.info(
@@ -87,8 +91,9 @@ app.whenReady().then(() => {
8791
ensureClaudeConfigDir();
8892
registerMcpSandboxProtocol();
8993
createWindow();
90-
initializeServices();
94+
await initializeServices();
9195
initializeDeepLinks();
96+
await initializeServices();
9297
powerMonitor.on("suspend", () => {
9398
log.info("System entering sleep");
9499
});
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { z } from "zod";
2+
import { cloudRegion, type oAuthTokenResponse } from "../oauth/schemas";
3+
4+
export const authStatusSchema = z.enum(["anonymous", "authenticated"]);
5+
export type AuthStatus = z.infer<typeof authStatusSchema>;
6+
7+
export const authStateSchema = z.object({
8+
status: authStatusSchema,
9+
bootstrapComplete: z.boolean(),
10+
cloudRegion: cloudRegion.nullable(),
11+
projectId: z.number().nullable(),
12+
availableProjectIds: z.array(z.number()),
13+
availableOrgIds: z.array(z.string()),
14+
hasCodeAccess: z.boolean().nullable(),
15+
needsScopeReauth: z.boolean(),
16+
});
17+
export type AuthState = z.infer<typeof authStateSchema>;
18+
19+
export const loginInput = z.object({
20+
region: cloudRegion,
21+
});
22+
export type LoginInput = z.infer<typeof loginInput>;
23+
24+
export const loginOutput = z.object({
25+
state: authStateSchema,
26+
});
27+
export type LoginOutput = z.infer<typeof loginOutput>;
28+
29+
export const redeemInviteCodeInput = z.object({
30+
code: z.string().min(1),
31+
});
32+
33+
export const selectProjectInput = z.object({
34+
projectId: z.number(),
35+
});
36+
37+
export const validAccessTokenOutput = z.object({
38+
accessToken: z.string(),
39+
apiHost: z.string(),
40+
});
41+
export type ValidAccessTokenOutput = z.infer<typeof validAccessTokenOutput>;
42+
43+
export const AuthServiceEvent = {
44+
StateChanged: "state-changed",
45+
} as const;
46+
47+
export interface AuthServiceEvents {
48+
[AuthServiceEvent.StateChanged]: AuthState;
49+
}
50+
51+
export type AuthTokenResponse = z.infer<typeof oAuthTokenResponse>;

0 commit comments

Comments
 (0)