Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions apps/code/src/main/services/git/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ export const ghStatusOutput = z.object({

export type GhStatusOutput = z.infer<typeof ghStatusOutput>;

export const ghAuthTokenOutput = z.object({
success: z.boolean(),
token: z.string().nullable(),
error: z.string().nullable(),
});

export type GhAuthTokenOutput = z.infer<typeof ghAuthTokenOutput>;

// Pull request status
export const prStatusInput = directoryPathInput;
export const prStatusOutput = z.object({
Expand Down
58 changes: 58 additions & 0 deletions apps/code/src/main/services/git/service.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,61 @@ describe("GitService.getPrChangedFiles", () => {
).rejects.toThrow("Failed to fetch PR files");
});
});

describe("GitService.getGhAuthToken", () => {
let service: GitService;

beforeEach(() => {
vi.clearAllMocks();
service = new GitService({} as LlmGatewayService);
});

it("returns the authenticated GitHub CLI token", async () => {
mockExecGh.mockResolvedValue({
exitCode: 0,
stdout: "ghu_test_token\n",
stderr: "",
});

const result = await service.getGhAuthToken();

expect(mockExecGh).toHaveBeenCalledWith(["auth", "token"]);
expect(result).toEqual({
success: true,
token: "ghu_test_token",
error: null,
});
});

it("returns the gh error when auth token lookup fails", async () => {
mockExecGh.mockResolvedValue({
exitCode: 1,
stdout: "",
stderr: "authentication required",
});

const result = await service.getGhAuthToken();

expect(result).toEqual({
success: false,
token: null,
error: "authentication required",
});
});

it("returns error when stdout is empty", async () => {
mockExecGh.mockResolvedValue({
exitCode: 0,
stdout: "",
stderr: "",
});

const result = await service.getGhAuthToken();

expect(result).toEqual({
success: false,
token: null,
error: "GitHub auth token is empty",
});
});
});
28 changes: 28 additions & 0 deletions apps/code/src/main/services/git/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import type {
DiscardFileChangesOutput,
GetCommitConventionsOutput,
GetPrTemplateOutput,
GhAuthTokenOutput,
GhStatusOutput,
GitCommitInfo,
GitFileStatus,
Expand Down Expand Up @@ -646,6 +647,33 @@ export class GitService extends TypedEventEmitter<GitServiceEvents> {
};
}

public async getGhAuthToken(): Promise<GhAuthTokenOutput> {
const result = await execGh(["auth", "token"]);
if (result.exitCode !== 0) {
return {
success: false,
token: null,
error:
result.stderr || result.error || "Failed to read GitHub auth token",
};
}

const token = result.stdout.trim();
if (!token) {
return {
success: false,
token: null,
error: "GitHub auth token is empty",
};
}

return {
success: true,
token,
error: null,
};
}

public async getPrStatus(directoryPath: string): Promise<PrStatusOutput> {
const base: PrStatusOutput = {
hasRemote: false,
Expand Down
5 changes: 5 additions & 0 deletions apps/code/src/main/trpc/routers/git.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
getPrChangedFilesOutput,
getPrTemplateInput,
getPrTemplateOutput,
ghAuthTokenOutput,
ghStatusOutput,
openPrInput,
openPrOutput,
Expand Down Expand Up @@ -234,6 +235,10 @@ export const gitRouter = router({
.output(ghStatusOutput)
.query(() => getService().getGhStatus()),

getGhAuthToken: publicProcedure
.output(ghAuthTokenOutput)
.query(() => getService().getGhAuthToken()),

getPrStatus: publicProcedure
.input(prStatusInput)
.output(prStatusOutput)
Expand Down
19 changes: 19 additions & 0 deletions apps/code/src/renderer/api/posthogClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import type {
Task,
TaskRun,
} from "@shared/types";
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
import type { StoredLogEntry } from "@shared/types/session-events";
import { logger } from "@utils/logger";
import { buildApiFetcher } from "./fetcher";
Expand Down Expand Up @@ -559,6 +560,12 @@ export class PostHogAPIClient {
branch?: string | null,
resumeOptions?: { resumeFromRunId: string; pendingUserMessage: string },
sandboxEnvironmentId?: string,
runOptions?: {
prAuthorshipMode?: PrAuthorshipMode;
runSource?: CloudRunSource;
signalReportId?: string;
githubUserToken?: string;
},
): Promise<Task> {
const teamId = await this.getTeamId();
const body: Record<string, unknown> = { mode: "interactive" };
Expand All @@ -572,6 +579,18 @@ export class PostHogAPIClient {
if (sandboxEnvironmentId) {
body.sandbox_environment_id = sandboxEnvironmentId;
}
if (runOptions?.prAuthorshipMode) {
body.pr_authorship_mode = runOptions.prAuthorshipMode;
}
if (runOptions?.runSource) {
body.run_source = runOptions.runSource;
}
if (runOptions?.signalReportId) {
body.signal_report_id = runOptions.signalReportId;
}
if (runOptions?.githubUserToken) {
body.github_user_token = runOptions.githubUserToken;
}

const data = await this.api.post(
`/api/projects/{project_id}/tasks/{id}/run/`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ export const useInboxCloudTaskStore = create<InboxCloudTaskStore>()(
workspaceMode: "cloud",
githubIntegrationId: params.githubIntegrationId,
repository: selectedRepo,
cloudPrAuthorshipMode: "bot",
cloudRunSource: "signal_report",
signalReportId: params.reportId,
});

if (result.success) {
Expand Down
57 changes: 56 additions & 1 deletion apps/code/src/renderer/features/sessions/service/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { taskViewedApi } from "@features/sidebar/hooks/useTaskViewed";
import { DEFAULT_GATEWAY_MODEL } from "@posthog/agent/gateway-models";
import { getIsOnline } from "@renderer/stores/connectivityStore";
import { trpcClient } from "@renderer/trpc/client";
import { getGhUserTokenOrThrow } from "@renderer/utils/github";
import { toast } from "@renderer/utils/toast";
import { getCloudUrlFromRegion } from "@shared/constants/oauth";
import {
Expand All @@ -39,6 +40,7 @@ import {
type Task,
} from "@shared/types";
import { ANALYTICS_EVENTS } from "@shared/types/analytics";
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
import type { AcpMessage, StoredLogEntry } from "@shared/types/session-events";
import { isJsonRpcRequest } from "@shared/types/session-events";
import { buildPermissionToolMetadata, track } from "@utils/analytics";
Expand Down Expand Up @@ -1364,6 +1366,35 @@ export class SessionService {
throw new Error("Authentication required for cloud commands");
}

const [previousRun, task] = await Promise.all([
client.getTaskRun(session.taskId, session.taskRunId),
client.getTask(session.taskId),
]);
const hasGitHubRepo = !!task.repository && !!task.github_integration;
const previousState = previousRun.state as Record<string, unknown>;
const previousOutput = (previousRun.output ?? {}) as Record<
string,
unknown
>;
// Prefer the actual working branch the agent last pushed to (synced by
// agent-server after each turn), then the run-level branch field, then
// the original base branch from state. This preserves unmerged work when
// the snapshot has expired and the sandbox is rebuilt from scratch.
const previousBaseBranch =
(typeof previousOutput.head_branch === "string"
? previousOutput.head_branch
: null) ??
previousRun.branch ??
(typeof previousState.pr_base_branch === "string"
? previousState.pr_base_branch
: null) ??
session.cloudBranch;
const prAuthorshipMode = this.getCloudPrAuthorshipMode(previousState);
const githubUserToken =
prAuthorshipMode === "user" && hasGitHubRepo
? await getGhUserTokenOrThrow()
: undefined;

log.info("Creating resume run for terminal cloud task", {
taskId: session.taskId,
previousRunId: session.taskRunId,
Expand All @@ -1375,11 +1406,21 @@ export class SessionService {
// The agent will load conversation history and restore the sandbox snapshot.
const updatedTask = await client.runTaskInCloud(
session.taskId,
session.cloudBranch,
previousBaseBranch,
{
resumeFromRunId: session.taskRunId,
pendingUserMessage: promptText,
},
undefined,
{
prAuthorshipMode,
runSource: this.getCloudRunSource(previousState),
signalReportId:
typeof previousState.signal_report_id === "string"
? previousState.signal_report_id
: undefined,
githubUserToken,
},
);
const newRun = updatedTask.latest_run;
if (!newRun?.id) {
Expand Down Expand Up @@ -2102,6 +2143,20 @@ export class SessionService {
}
}

private getCloudPrAuthorshipMode(
state: Record<string, unknown>,
): PrAuthorshipMode {
const explicitMode = state.pr_authorship_mode;
if (explicitMode === "user" || explicitMode === "bot") {
return explicitMode;
}
return state.run_source === "signal_report" ? "bot" : "user";
}

private getCloudRunSource(state: Record<string, unknown>): CloudRunSource {
return state.run_source === "signal_report" ? "signal_report" : "manual";
}

/**
* Filter out session/prompt events that should be skipped during resume.
* When resuming a cloud run, the initial session/prompt from the new run's
Expand Down
26 changes: 23 additions & 3 deletions apps/code/src/renderer/sagas/task/task-creation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import { trpcClient } from "@renderer/trpc";
import { generateTitle } from "@renderer/utils/generateTitle";
import { getTaskRepository } from "@renderer/utils/repository";
import type { ExecutionMode, Task } from "@shared/types";
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
import { getGhUserTokenOrThrow } from "@utils/github";
import { logger } from "@utils/logger";
import { queryClient } from "@utils/queryClient";

Expand Down Expand Up @@ -72,6 +74,8 @@ export interface TaskCreationInput {
reasoningLevel?: string;
environmentId?: string;
sandboxEnvironmentId?: string;
cloudPrAuthorshipMode?: PrAuthorshipMode;
cloudRunSource?: CloudRunSource;
signalReportId?: string;
}

Expand Down Expand Up @@ -256,13 +260,29 @@ export class TaskCreationSaga extends Saga<
if (workspaceMode === "cloud" && !task.latest_run) {
await this.step({
name: "cloud_run",
execute: () =>
this.deps.posthogClient.runTaskInCloud(
execute: async () => {
const hasGitHubRepo = !!task.repository && !!task.github_integration;
const prAuthorshipMode =
input.cloudPrAuthorshipMode ?? (hasGitHubRepo ? "user" : "bot");
let githubUserToken: string | undefined;

if (prAuthorshipMode === "user" && hasGitHubRepo) {
githubUserToken = await getGhUserTokenOrThrow();
}

return this.deps.posthogClient.runTaskInCloud(
task.id,
branch,
undefined,
input.sandboxEnvironmentId,
),
{
prAuthorshipMode,
runSource: input.cloudRunSource ?? "manual",
signalReportId: input.signalReportId,
githubUserToken,
},
);
},
rollback: async () => {
log.info("Rolling back: cloud run (no-op)", { taskId: task.id });
},
Expand Down
12 changes: 12 additions & 0 deletions apps/code/src/renderer/utils/github.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { trpcClient } from "@renderer/trpc";

export async function getGhUserTokenOrThrow(): Promise<string> {
const tokenResult = await trpcClient.git.getGhAuthToken.query();
if (!tokenResult.success || !tokenResult.token) {
throw new Error(
tokenResult.error ||
"Authenticate GitHub CLI with `gh auth login` before starting a cloud task.",
);
}
return tokenResult.token;
}
2 changes: 2 additions & 0 deletions apps/code/src/shared/types/cloud.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export type PrAuthorshipMode = "user" | "bot";
export type CloudRunSource = "manual" | "signal_report";
Loading
Loading