From 72ad963cdf70fa20e3edf97a5b1b0a2714723452 Mon Sep 17 00:00:00 2001 From: Alessandro Pogliaghi Date: Thu, 2 Apr 2026 16:21:27 +0100 Subject: [PATCH] feat(cloud-agent): user-authored-prs --- apps/code/src/main/services/git/schemas.ts | 8 +++ .../src/main/services/git/service.test.ts | 58 +++++++++++++++++++ apps/code/src/main/services/git/service.ts | 28 +++++++++ apps/code/src/main/trpc/routers/git.ts | 5 ++ apps/code/src/renderer/api/posthogClient.ts | 19 ++++++ .../inbox/stores/inboxCloudTaskStore.ts | 3 + .../features/sessions/service/service.ts | 57 +++++++++++++++++- .../src/renderer/sagas/task/task-creation.ts | 26 ++++++++- apps/code/src/renderer/utils/github.ts | 12 ++++ apps/code/src/shared/types/cloud.ts | 2 + packages/agent/src/server/agent-server.ts | 53 +++++++++++++++++ 11 files changed, 267 insertions(+), 4 deletions(-) create mode 100644 apps/code/src/renderer/utils/github.ts create mode 100644 apps/code/src/shared/types/cloud.ts diff --git a/apps/code/src/main/services/git/schemas.ts b/apps/code/src/main/services/git/schemas.ts index ddf9bba2c..5748b43bb 100644 --- a/apps/code/src/main/services/git/schemas.ts +++ b/apps/code/src/main/services/git/schemas.ts @@ -213,6 +213,14 @@ export const ghStatusOutput = z.object({ export type GhStatusOutput = z.infer; +export const ghAuthTokenOutput = z.object({ + success: z.boolean(), + token: z.string().nullable(), + error: z.string().nullable(), +}); + +export type GhAuthTokenOutput = z.infer; + // Pull request status export const prStatusInput = directoryPathInput; export const prStatusOutput = z.object({ diff --git a/apps/code/src/main/services/git/service.test.ts b/apps/code/src/main/services/git/service.test.ts index 7c25a08e3..2e57a5bbc 100644 --- a/apps/code/src/main/services/git/service.test.ts +++ b/apps/code/src/main/services/git/service.test.ts @@ -127,3 +127,61 @@ describe("GitService.getPrChangedFiles", () => { ).rejects.toThrow("Failed to fetch PR files"); }); }); + +describe("GitService.getGhAuthToken", () => { + let service: GitService; + + beforeEach(() => { + vi.clearAllMocks(); + service = new GitService({} as LlmGatewayService); + }); + + it("returns the authenticated GitHub CLI token", async () => { + mockExecGh.mockResolvedValue({ + exitCode: 0, + stdout: "ghu_test_token\n", + stderr: "", + }); + + const result = await service.getGhAuthToken(); + + expect(mockExecGh).toHaveBeenCalledWith(["auth", "token"]); + expect(result).toEqual({ + success: true, + token: "ghu_test_token", + error: null, + }); + }); + + it("returns the gh error when auth token lookup fails", async () => { + mockExecGh.mockResolvedValue({ + exitCode: 1, + stdout: "", + stderr: "authentication required", + }); + + const result = await service.getGhAuthToken(); + + expect(result).toEqual({ + success: false, + token: null, + error: "authentication required", + }); + }); + + it("returns error when stdout is empty", async () => { + mockExecGh.mockResolvedValue({ + exitCode: 0, + stdout: "", + stderr: "", + }); + + const result = await service.getGhAuthToken(); + + expect(result).toEqual({ + success: false, + token: null, + error: "GitHub auth token is empty", + }); + }); +}); diff --git a/apps/code/src/main/services/git/service.ts b/apps/code/src/main/services/git/service.ts index 7483c6443..42c580a12 100644 --- a/apps/code/src/main/services/git/service.ts +++ b/apps/code/src/main/services/git/service.ts @@ -44,6 +44,7 @@ import type { DiscardFileChangesOutput, GetCommitConventionsOutput, GetPrTemplateOutput, + GhAuthTokenOutput, GhStatusOutput, GitCommitInfo, GitFileStatus, @@ -646,6 +647,33 @@ export class GitService extends TypedEventEmitter { }; } + public async getGhAuthToken(): Promise { + const result = await execGh(["auth", "token"]); + if (result.exitCode !== 0) { + return { + success: false, + token: null, + error: + result.stderr || result.error || "Failed to read GitHub auth token", + }; + } + + const token = result.stdout.trim(); + if (!token) { + return { + success: false, + token: null, + error: "GitHub auth token is empty", + }; + } + + return { + success: true, + token, + error: null, + }; + } + public async getPrStatus(directoryPath: string): Promise { const base: PrStatusOutput = { hasRemote: false, diff --git a/apps/code/src/main/trpc/routers/git.ts b/apps/code/src/main/trpc/routers/git.ts index dfd3eab07..c44338023 100644 --- a/apps/code/src/main/trpc/routers/git.ts +++ b/apps/code/src/main/trpc/routers/git.ts @@ -44,6 +44,7 @@ import { getPrChangedFilesOutput, getPrTemplateInput, getPrTemplateOutput, + ghAuthTokenOutput, ghStatusOutput, openPrInput, openPrOutput, @@ -234,6 +235,10 @@ export const gitRouter = router({ .output(ghStatusOutput) .query(() => getService().getGhStatus()), + getGhAuthToken: publicProcedure + .output(ghAuthTokenOutput) + .query(() => getService().getGhAuthToken()), + getPrStatus: publicProcedure .input(prStatusInput) .output(prStatusOutput) diff --git a/apps/code/src/renderer/api/posthogClient.ts b/apps/code/src/renderer/api/posthogClient.ts index 040796173..d003e782f 100644 --- a/apps/code/src/renderer/api/posthogClient.ts +++ b/apps/code/src/renderer/api/posthogClient.ts @@ -10,6 +10,7 @@ import type { Task, TaskRun, } from "@shared/types"; +import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud"; import type { StoredLogEntry } from "@shared/types/session-events"; import { logger } from "@utils/logger"; import { buildApiFetcher } from "./fetcher"; @@ -559,6 +560,12 @@ export class PostHogAPIClient { branch?: string | null, resumeOptions?: { resumeFromRunId: string; pendingUserMessage: string }, sandboxEnvironmentId?: string, + runOptions?: { + prAuthorshipMode?: PrAuthorshipMode; + runSource?: CloudRunSource; + signalReportId?: string; + githubUserToken?: string; + }, ): Promise { const teamId = await this.getTeamId(); const body: Record = { mode: "interactive" }; @@ -572,6 +579,18 @@ export class PostHogAPIClient { if (sandboxEnvironmentId) { body.sandbox_environment_id = sandboxEnvironmentId; } + if (runOptions?.prAuthorshipMode) { + body.pr_authorship_mode = runOptions.prAuthorshipMode; + } + if (runOptions?.runSource) { + body.run_source = runOptions.runSource; + } + if (runOptions?.signalReportId) { + body.signal_report_id = runOptions.signalReportId; + } + if (runOptions?.githubUserToken) { + body.github_user_token = runOptions.githubUserToken; + } const data = await this.api.post( `/api/projects/{project_id}/tasks/{id}/run/`, diff --git a/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts b/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts index b584f93c1..2a931737a 100644 --- a/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts +++ b/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts @@ -61,6 +61,9 @@ export const useInboxCloudTaskStore = create()( workspaceMode: "cloud", githubIntegrationId: params.githubIntegrationId, repository: selectedRepo, + cloudPrAuthorshipMode: "bot", + cloudRunSource: "signal_report", + signalReportId: params.reportId, }); if (result.success) { diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index c8843f34f..a39015c26 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -29,6 +29,7 @@ import { taskViewedApi } from "@features/sidebar/hooks/useTaskViewed"; import { DEFAULT_GATEWAY_MODEL } from "@posthog/agent/gateway-models"; import { getIsOnline } from "@renderer/stores/connectivityStore"; import { trpcClient } from "@renderer/trpc/client"; +import { getGhUserTokenOrThrow } from "@renderer/utils/github"; import { toast } from "@renderer/utils/toast"; import { getCloudUrlFromRegion } from "@shared/constants/oauth"; import { @@ -39,6 +40,7 @@ import { type Task, } from "@shared/types"; import { ANALYTICS_EVENTS } from "@shared/types/analytics"; +import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud"; import type { AcpMessage, StoredLogEntry } from "@shared/types/session-events"; import { isJsonRpcRequest } from "@shared/types/session-events"; import { buildPermissionToolMetadata, track } from "@utils/analytics"; @@ -1364,6 +1366,35 @@ export class SessionService { throw new Error("Authentication required for cloud commands"); } + const [previousRun, task] = await Promise.all([ + client.getTaskRun(session.taskId, session.taskRunId), + client.getTask(session.taskId), + ]); + const hasGitHubRepo = !!task.repository && !!task.github_integration; + const previousState = previousRun.state as Record; + const previousOutput = (previousRun.output ?? {}) as Record< + string, + unknown + >; + // Prefer the actual working branch the agent last pushed to (synced by + // agent-server after each turn), then the run-level branch field, then + // the original base branch from state. This preserves unmerged work when + // the snapshot has expired and the sandbox is rebuilt from scratch. + const previousBaseBranch = + (typeof previousOutput.head_branch === "string" + ? previousOutput.head_branch + : null) ?? + previousRun.branch ?? + (typeof previousState.pr_base_branch === "string" + ? previousState.pr_base_branch + : null) ?? + session.cloudBranch; + const prAuthorshipMode = this.getCloudPrAuthorshipMode(previousState); + const githubUserToken = + prAuthorshipMode === "user" && hasGitHubRepo + ? await getGhUserTokenOrThrow() + : undefined; + log.info("Creating resume run for terminal cloud task", { taskId: session.taskId, previousRunId: session.taskRunId, @@ -1375,11 +1406,21 @@ export class SessionService { // The agent will load conversation history and restore the sandbox snapshot. const updatedTask = await client.runTaskInCloud( session.taskId, - session.cloudBranch, + previousBaseBranch, { resumeFromRunId: session.taskRunId, pendingUserMessage: promptText, }, + undefined, + { + prAuthorshipMode, + runSource: this.getCloudRunSource(previousState), + signalReportId: + typeof previousState.signal_report_id === "string" + ? previousState.signal_report_id + : undefined, + githubUserToken, + }, ); const newRun = updatedTask.latest_run; if (!newRun?.id) { @@ -2102,6 +2143,20 @@ export class SessionService { } } + private getCloudPrAuthorshipMode( + state: Record, + ): PrAuthorshipMode { + const explicitMode = state.pr_authorship_mode; + if (explicitMode === "user" || explicitMode === "bot") { + return explicitMode; + } + return state.run_source === "signal_report" ? "bot" : "user"; + } + + private getCloudRunSource(state: Record): CloudRunSource { + return state.run_source === "signal_report" ? "signal_report" : "manual"; + } + /** * Filter out session/prompt events that should be skipped during resume. * When resuming a cloud run, the initial session/prompt from the new run's diff --git a/apps/code/src/renderer/sagas/task/task-creation.ts b/apps/code/src/renderer/sagas/task/task-creation.ts index bc6371f78..1b50825fc 100644 --- a/apps/code/src/renderer/sagas/task/task-creation.ts +++ b/apps/code/src/renderer/sagas/task/task-creation.ts @@ -17,6 +17,8 @@ import { trpcClient } from "@renderer/trpc"; import { generateTitle } from "@renderer/utils/generateTitle"; import { getTaskRepository } from "@renderer/utils/repository"; import type { ExecutionMode, Task } from "@shared/types"; +import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud"; +import { getGhUserTokenOrThrow } from "@utils/github"; import { logger } from "@utils/logger"; import { queryClient } from "@utils/queryClient"; @@ -72,6 +74,8 @@ export interface TaskCreationInput { reasoningLevel?: string; environmentId?: string; sandboxEnvironmentId?: string; + cloudPrAuthorshipMode?: PrAuthorshipMode; + cloudRunSource?: CloudRunSource; signalReportId?: string; } @@ -256,13 +260,29 @@ export class TaskCreationSaga extends Saga< if (workspaceMode === "cloud" && !task.latest_run) { await this.step({ name: "cloud_run", - execute: () => - this.deps.posthogClient.runTaskInCloud( + execute: async () => { + const hasGitHubRepo = !!task.repository && !!task.github_integration; + const prAuthorshipMode = + input.cloudPrAuthorshipMode ?? (hasGitHubRepo ? "user" : "bot"); + let githubUserToken: string | undefined; + + if (prAuthorshipMode === "user" && hasGitHubRepo) { + githubUserToken = await getGhUserTokenOrThrow(); + } + + return this.deps.posthogClient.runTaskInCloud( task.id, branch, undefined, input.sandboxEnvironmentId, - ), + { + prAuthorshipMode, + runSource: input.cloudRunSource ?? "manual", + signalReportId: input.signalReportId, + githubUserToken, + }, + ); + }, rollback: async () => { log.info("Rolling back: cloud run (no-op)", { taskId: task.id }); }, diff --git a/apps/code/src/renderer/utils/github.ts b/apps/code/src/renderer/utils/github.ts new file mode 100644 index 000000000..721cf619d --- /dev/null +++ b/apps/code/src/renderer/utils/github.ts @@ -0,0 +1,12 @@ +import { trpcClient } from "@renderer/trpc"; + +export async function getGhUserTokenOrThrow(): Promise { + const tokenResult = await trpcClient.git.getGhAuthToken.query(); + if (!tokenResult.success || !tokenResult.token) { + throw new Error( + tokenResult.error || + "Authenticate GitHub CLI with `gh auth login` before starting a cloud task.", + ); + } + return tokenResult.token; +} diff --git a/apps/code/src/shared/types/cloud.ts b/apps/code/src/shared/types/cloud.ts new file mode 100644 index 000000000..d3601a180 --- /dev/null +++ b/apps/code/src/shared/types/cloud.ts @@ -0,0 +1,2 @@ +export type PrAuthorshipMode = "user" | "bot"; +export type CloudRunSource = "manual" | "signal_report"; diff --git a/packages/agent/src/server/agent-server.ts b/packages/agent/src/server/agent-server.ts index bafd1fdf1..7931470d0 100644 --- a/packages/agent/src/server/agent-server.ts +++ b/packages/agent/src/server/agent-server.ts @@ -4,6 +4,7 @@ import { PROTOCOL_VERSION, } from "@agentclientprotocol/sdk"; import { type ServerType, serve } from "@hono/node-server"; +import { getCurrentBranch } from "@posthog/git/queries"; import { Hono } from "hono"; import packageJson from "../../package.json" with { type: "json" }; import { POSTHOG_NOTIFICATIONS } from "../acp-extensions"; @@ -161,6 +162,7 @@ export class AgentServer { private posthogAPI: PostHogAPIClient; private questionRelayedToSlack = false; private detectedPrUrl: string | null = null; + private lastReportedBranch: string | null = null; private resumeState: ResumeState | null = null; // Guards against concurrent session initialization. autoInitializeSession() and // the GET /events SSE handler can both call initializeSession() — the SSE connection @@ -515,6 +517,10 @@ export class AgentServer { stopReason: result.stopReason, }); + if (result.stopReason === "end_turn") { + void this.syncCloudBranchMetadata(this.session.payload); + } + this.broadcastTurnComplete(result.stopReason); if (result.stopReason === "end_turn") { @@ -861,6 +867,10 @@ export class AgentServer { stopReason: result.stopReason, }); + if (result.stopReason === "end_turn") { + void this.syncCloudBranchMetadata(payload); + } + this.broadcastTurnComplete(result.stopReason); if (result.stopReason === "end_turn") { @@ -935,6 +945,10 @@ export class AgentServer { stopReason: result.stopReason, }); + if (result.stopReason === "end_turn") { + void this.syncCloudBranchMetadata(payload); + } + this.broadcastTurnComplete(result.stopReason); if (result.stopReason === "end_turn") { @@ -1117,6 +1131,44 @@ Important: `; } + private async getCurrentGitBranch(): Promise { + if (!this.config.repositoryPath) { + return null; + } + + try { + return await getCurrentBranch(this.config.repositoryPath); + } catch (error) { + this.logger.warn("Failed to determine current git branch", { + repositoryPath: this.config.repositoryPath, + error, + }); + return null; + } + } + + private async syncCloudBranchMetadata(payload: JwtPayload): Promise { + const branchName = await this.getCurrentGitBranch(); + if (!branchName || branchName === this.lastReportedBranch) { + return; + } + + try { + await this.posthogAPI.updateTaskRun(payload.task_id, payload.run_id, { + branch: branchName, + output: { head_branch: branchName }, + }); + this.lastReportedBranch = branchName; + } catch (error) { + this.logger.warn("Failed to attach current branch to task run", { + taskId: payload.task_id, + runId: payload.run_id, + branchName, + error, + }); + } + } + private async signalTaskComplete( payload: JwtPayload, stopReason: string, @@ -1505,6 +1557,7 @@ Important: } this.pendingEvents = []; + this.lastReportedBranch = null; this.session = null; }