Skip to content
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { CloudRegion } from "@shared/types/oauth";
import type { CloudRegion } from "@shared/types/regions";
import { eq } from "drizzle-orm";
import { inject, injectable } from "inversify";
import { MAIN_TOKENS } from "../../di/tokens";
Expand Down
8 changes: 3 additions & 5 deletions apps/code/src/main/services/auth/service.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import {
getCloudUrlFromRegion,
OAUTH_SCOPE_VERSION,
} from "@shared/constants/oauth";
import type { CloudRegion } from "@shared/types/oauth";
import { OAUTH_SCOPE_VERSION } from "@shared/constants/oauth";
import type { CloudRegion } from "@shared/types/regions";
import { type BackoffOptions, sleepWithBackoff } from "@shared/utils/backoff";
import { getCloudUrlFromRegion } from "@shared/utils/urls";
import { powerMonitor } from "electron";
import { inject, injectable, postConstruct, preDestroy } from "inversify";
import type { IAuthPreferenceRepository } from "../../db/repositories/auth-preference-repository";
Expand Down
2 changes: 1 addition & 1 deletion apps/code/src/main/services/github-integration/service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getCloudUrlFromRegion } from "@shared/constants/oauth";
import { getCloudUrlFromRegion } from "@shared/utils/urls";
import { shell } from "electron";
import { injectable } from "inversify";
import { logger } from "../../utils/logger";
Expand Down
2 changes: 1 addition & 1 deletion apps/code/src/main/services/linear-integration/service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getCloudUrlFromRegion } from "@shared/constants/oauth.js";
import { getCloudUrlFromRegion } from "@shared/utils/urls.js";
import { shell } from "electron";
import { injectable } from "inversify";
import { logger } from "../../utils/logger.js";
Expand Down
2 changes: 1 addition & 1 deletion apps/code/src/main/services/oauth/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import * as crypto from "node:crypto";
import * as http from "node:http";
import type { Socket } from "node:net";
import {
getCloudUrlFromRegion,
getOauthClientIdFromRegion,
OAUTH_SCOPES,
} from "@shared/constants/oauth";
import { getCloudUrlFromRegion } from "@shared/utils/urls";
import { shell } from "electron";
import { inject, injectable } from "inversify";
import { MAIN_TOKENS } from "../../di/tokens";
Expand Down
184 changes: 151 additions & 33 deletions apps/code/src/renderer/api/posthogClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,29 @@ import type {
Task,
TaskRun,
} from "@shared/types";
import type { SeatData } from "@shared/types/seat";
import { SEAT_PRODUCT_KEY } from "@shared/types/seat";
import type { StoredLogEntry } from "@shared/types/session-events";
import { logger } from "@utils/logger";
import { buildApiFetcher } from "./fetcher";
import { createApiClient, type Schemas } from "./generated";

export class SeatSubscriptionRequiredError extends Error {
redirectUrl: string;
constructor(redirectUrl: string) {
super("Billing subscription required");
this.name = "SeatSubscriptionRequiredError";
this.redirectUrl = redirectUrl;
}
}

export class SeatPaymentFailedError extends Error {
constructor(message?: string) {
super(message ?? "Payment failed");
this.name = "SeatPaymentFailedError";
}
}

const log = logger.scope("posthog-client");

export type McpRecommendedServer = Schemas.RecommendedServer;
Expand Down Expand Up @@ -848,39 +866,6 @@ export class PostHogAPIClient {
return await response.json();
}

/**
* Get billing information for a specific organization.
*/
async getOrgBilling(orgId: string): Promise<{
has_active_subscription: boolean;
customer_id: string | null;
}> {
const url = new URL(
`${this.api.baseUrl}/api/organizations/${orgId}/billing/`,
);
const response = await this.api.fetcher.fetch({
method: "get",
url,
path: `/api/organizations/${orgId}/billing/`,
});

if (!response.ok) {
throw new Error(
`Failed to fetch organization billing: ${response.statusText}`,
);
}

const data = await response.json();
return {
has_active_subscription:
typeof data.has_active_subscription === "boolean"
? data.has_active_subscription
: false,
customer_id:
typeof data.customer_id === "string" ? data.customer_id : null,
};
}

async getSignalReports(
params?: SignalReportsQueryParams,
): Promise<SignalReportsResponse> {
Expand Down Expand Up @@ -1135,6 +1120,139 @@ export class PostHogAPIClient {
}
}

async getMySeat(): Promise<SeatData | null> {
try {
const url = new URL(`${this.api.baseUrl}/api/seats/me/`);
url.searchParams.set("product_key", SEAT_PRODUCT_KEY);
const response = await this.api.fetcher.fetch({
method: "get",
url,
path: "/api/seats/me/",
});
return (await response.json()) as SeatData;
} catch (error) {
if (this.isFetcherStatusError(error, 404)) {
return null;
}
throw error;
}
}

async createSeat(planKey: string): Promise<SeatData> {
try {
const url = new URL(`${this.api.baseUrl}/api/seats/`);
const response = await this.api.fetcher.fetch({
method: "post",
url,
path: "/api/seats/",
overrides: {
body: JSON.stringify({
product_key: SEAT_PRODUCT_KEY,
plan_key: planKey,
}),
},
});
return (await response.json()) as SeatData;
} catch (error) {
this.throwSeatError(error);
}
}

async upgradeSeat(planKey: string): Promise<SeatData> {
try {
const url = new URL(`${this.api.baseUrl}/api/seats/me/`);
const response = await this.api.fetcher.fetch({
method: "patch",
url,
path: "/api/seats/me/",
overrides: {
body: JSON.stringify({
product_key: SEAT_PRODUCT_KEY,
plan_key: planKey,
}),
},
});
return (await response.json()) as SeatData;
} catch (error) {
this.throwSeatError(error);
}
}

async cancelSeat(): Promise<void> {
try {
const url = new URL(`${this.api.baseUrl}/api/seats/me/`);
url.searchParams.set("product_key", SEAT_PRODUCT_KEY);
await this.api.fetcher.fetch({
method: "delete",
url,
path: "/api/seats/me/",
});
} catch (error) {
if (this.isFetcherStatusError(error, 204)) {
return;
}
this.throwSeatError(error);
}
}

async reactivateSeat(): Promise<SeatData> {
try {
const url = new URL(`${this.api.baseUrl}/api/seats/me/reactivate/`);
const response = await this.api.fetcher.fetch({
method: "post",
url,
path: "/api/seats/me/reactivate/",
overrides: {
body: JSON.stringify({ product_key: SEAT_PRODUCT_KEY }),
},
});
return (await response.json()) as SeatData;
} catch (error) {
this.throwSeatError(error);
}
}

private isFetcherStatusError(error: unknown, status: number): boolean {
return error instanceof Error && error.message.includes(`[${status}]`);
}

private parseFetcherError(error: unknown): {
status: number;
body: Record<string, unknown>;
} | null {
if (!(error instanceof Error)) return null;
const match = error.message.match(/\[(\d+)\]\s*(.*)/);
if (!match) return null;
try {
return {
status: Number.parseInt(match[1], 10),
body: JSON.parse(match[2]) as Record<string, unknown>,
};
} catch {
return { status: Number.parseInt(match[1], 10), body: {} };
}
}

private throwSeatError(error: unknown): never {
const parsed = this.parseFetcherError(error);

if (parsed) {
if (
parsed.status === 400 &&
typeof parsed.body.redirect_url === "string"
) {
throw new SeatSubscriptionRequiredError(parsed.body.redirect_url);
}
if (parsed.status === 402) {
const message =
typeof parsed.body.error === "string" ? parsed.body.error : undefined;
throw new SeatPaymentFailedError(message);
}
}

throw error;
}

/**
* Check if a feature flag is enabled for the current project.
* Returns true if the flag exists and is active, false otherwise.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import { Callout, Flex, Spinner, Text, Theme } from "@radix-ui/themes";
import codeLogo from "@renderer/assets/images/code.svg";
import logomark from "@renderer/assets/images/logomark.svg";
import { trpcClient } from "@renderer/trpc/client";
import { REGION_LABELS } from "@shared/constants/oauth";
import type { CloudRegion } from "@shared/types/oauth";
import type { CloudRegion } from "@shared/types/regions";
import { REGION_LABELS } from "@shared/types/regions";
import { RegionSelect } from "./RegionSelect";

export const getErrorMessage = (error: unknown) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Flex, Select, Text } from "@radix-ui/themes";
import { IS_DEV } from "@shared/constants/environment";
import type { CloudRegion } from "@shared/types/oauth";
import type { CloudRegion } from "@shared/types/regions";
import { useState } from "react";

interface RegionSelectProps {
Expand Down
2 changes: 1 addition & 1 deletion apps/code/src/renderer/features/auth/hooks/authClient.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { PostHogAPIClient } from "@renderer/api/posthogClient";
import { trpcClient } from "@renderer/trpc/client";
import { getCloudUrlFromRegion } from "@shared/constants/oauth";
import { getCloudUrlFromRegion } from "@shared/utils/urls";
import { useMemo } from "react";
import {
type AuthState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { useOnboardingStore } from "@features/onboarding/stores/onboardingStore"
import { resetSessionService } from "@features/sessions/service/service";
import { trpcClient } from "@renderer/trpc/client";
import { ANALYTICS_EVENTS } from "@shared/types/analytics";
import type { CloudRegion } from "@shared/types/oauth";
import type { CloudRegion } from "@shared/types/regions";
import { useNavigationStore } from "@stores/navigationStore";
import { useMutation } from "@tanstack/react-query";
import { track } from "@utils/analytics";
Expand Down
14 changes: 14 additions & 0 deletions apps/code/src/renderer/features/auth/stores/authStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ vi.mock("@renderer/api/posthogClient", () => ({
this.getCurrentUser = mockGetCurrentUser;
this.setTeamId = vi.fn();
}),
SeatSubscriptionRequiredError: class SeatSubscriptionRequiredError extends Error {
redirectUrl: string;
constructor(redirectUrl: string) {
super("Billing subscription required");
this.name = "SeatSubscriptionRequiredError";
this.redirectUrl = redirectUrl;
}
},
SeatPaymentFailedError: class SeatPaymentFailedError extends Error {
constructor(message?: string) {
super(message ?? "Payment failed");
this.name = "SeatPaymentFailedError";
}
},
}));

vi.mock("@utils/analytics", () => ({
Expand Down
11 changes: 9 additions & 2 deletions apps/code/src/renderer/features/auth/stores/authStore.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import { useSeatStore } from "@features/billing/stores/seatStore";
import { useSettingsDialogStore } from "@features/settings/stores/settingsDialogStore";
import { PostHogAPIClient } from "@renderer/api/posthogClient";
import { trpcClient } from "@renderer/trpc/client";
import { getCloudUrlFromRegion } from "@shared/constants/oauth";
import { ANALYTICS_EVENTS } from "@shared/types/analytics";
import type { CloudRegion } from "@shared/types/oauth";
import type { CloudRegion } from "@shared/types/regions";
import { getCloudUrlFromRegion } from "@shared/utils/urls";
import { useNavigationStore } from "@stores/navigationStore";
import { identifyUser, resetUser, track } from "@utils/analytics";
import { logger } from "@utils/logger";
Expand Down Expand Up @@ -281,6 +283,9 @@ export const useAuthStore = create<AuthStoreState>((set, get) => ({

completeOnboarding: () => {
set({ hasCompletedOnboarding: true });
if (!useSeatStore.getState().seat) {
useSeatStore.getState().provisionFreeSeat();
}
},

selectPlan: (plan: "free" | "pro") => {
Expand All @@ -294,6 +299,8 @@ export const useAuthStore = create<AuthStoreState>((set, get) => ({
logout: async () => {
track(ANALYTICS_EVENTS.USER_LOGGED_OUT);
sessionResetCallback?.();
useSeatStore.getState().reset();
useSettingsDialogStore.getState().close();
clearAuthenticatedRendererState({ clearAllQueries: true });
await trpcClient.auth.logout.mutate();
useNavigationStore.getState().navigateToTaskInput();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { CloudRegion } from "@shared/types/oauth";
import type { CloudRegion } from "@shared/types/regions";
import { create } from "zustand";

interface AuthUiStateStoreState {
Expand Down
Loading
Loading