Skip to content

Commit 21acb10

Browse files
authored
feat(cloud-agent): user-authored-prs (#1453)
1 parent d092804 commit 21acb10

File tree

12 files changed

+268
-6
lines changed

12 files changed

+268
-6
lines changed

apps/code/src/main/services/git/schemas.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ export const ghStatusOutput = z.object({
223223

224224
export type GhStatusOutput = z.infer<typeof ghStatusOutput>;
225225

226+
export const ghAuthTokenOutput = z.object({
227+
success: z.boolean(),
228+
token: z.string().nullable(),
229+
error: z.string().nullable(),
230+
});
231+
232+
export type GhAuthTokenOutput = z.infer<typeof ghAuthTokenOutput>;
233+
226234
// Pull request status
227235
export const prStatusInput = directoryPathInput;
228236
export const prStatusOutput = z.object({

apps/code/src/main/services/git/service.test.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,61 @@ describe("GitService.getPrChangedFiles", () => {
127127
).rejects.toThrow("Failed to fetch PR files");
128128
});
129129
});
130+
131+
describe("GitService.getGhAuthToken", () => {
132+
let service: GitService;
133+
134+
beforeEach(() => {
135+
vi.clearAllMocks();
136+
service = new GitService({} as LlmGatewayService);
137+
});
138+
139+
it("returns the authenticated GitHub CLI token", async () => {
140+
mockExecGh.mockResolvedValue({
141+
exitCode: 0,
142+
stdout: "ghu_test_token\n",
143+
stderr: "",
144+
});
145+
146+
const result = await service.getGhAuthToken();
147+
148+
expect(mockExecGh).toHaveBeenCalledWith(["auth", "token"]);
149+
expect(result).toEqual({
150+
success: true,
151+
token: "ghu_test_token",
152+
error: null,
153+
});
154+
});
155+
156+
it("returns the gh error when auth token lookup fails", async () => {
157+
mockExecGh.mockResolvedValue({
158+
exitCode: 1,
159+
stdout: "",
160+
stderr: "authentication required",
161+
});
162+
163+
const result = await service.getGhAuthToken();
164+
165+
expect(result).toEqual({
166+
success: false,
167+
token: null,
168+
error: "authentication required",
169+
});
170+
});
171+
172+
it("returns error when stdout is empty", async () => {
173+
mockExecGh.mockResolvedValue({
174+
exitCode: 0,
175+
stdout: "",
176+
stderr: "",
177+
});
178+
179+
const result = await service.getGhAuthToken();
180+
181+
expect(result).toEqual({
182+
success: false,
183+
token: null,
184+
error: "GitHub auth token is empty",
185+
});
186+
});
187+
});

apps/code/src/main/services/git/service.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import type {
4646
DiscardFileChangesOutput,
4747
GetCommitConventionsOutput,
4848
GetPrTemplateOutput,
49+
GhAuthTokenOutput,
4950
GhStatusOutput,
5051
GitCommitInfo,
5152
GitFileStatus,
@@ -706,6 +707,33 @@ export class GitService extends TypedEventEmitter<GitServiceEvents> {
706707
};
707708
}
708709

710+
public async getGhAuthToken(): Promise<GhAuthTokenOutput> {
711+
const result = await execGh(["auth", "token"]);
712+
if (result.exitCode !== 0) {
713+
return {
714+
success: false,
715+
token: null,
716+
error:
717+
result.stderr || result.error || "Failed to read GitHub auth token",
718+
};
719+
}
720+
721+
const token = result.stdout.trim();
722+
if (!token) {
723+
return {
724+
success: false,
725+
token: null,
726+
error: "GitHub auth token is empty",
727+
};
728+
}
729+
730+
return {
731+
success: true,
732+
token,
733+
error: null,
734+
};
735+
}
736+
709737
public async getPrStatus(directoryPath: string): Promise<PrStatusOutput> {
710738
const base: PrStatusOutput = {
711739
hasRemote: false,

apps/code/src/main/trpc/routers/git.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import {
4444
getPrChangedFilesOutput,
4545
getPrTemplateInput,
4646
getPrTemplateOutput,
47+
ghAuthTokenOutput,
4748
ghStatusOutput,
4849
gitStateSnapshotSchema,
4950
openPrInput,
@@ -264,6 +265,10 @@ export const gitRouter = router({
264265
.output(ghStatusOutput)
265266
.query(() => getService().getGhStatus()),
266267

268+
getGhAuthToken: publicProcedure
269+
.output(ghAuthTokenOutput)
270+
.query(() => getService().getGhAuthToken()),
271+
267272
getPrStatus: publicProcedure
268273
.input(prStatusInput)
269274
.output(prStatusOutput)

apps/code/src/renderer/api/posthogClient.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import type {
1818
Task,
1919
TaskRun,
2020
} from "@shared/types";
21+
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
2122
import type { StoredLogEntry } from "@shared/types/session-events";
2223
import { logger } from "@utils/logger";
2324
import { buildApiFetcher } from "./fetcher";
@@ -732,6 +733,10 @@ export class PostHogAPIClient {
732733
resumeFromRunId?: string;
733734
pendingUserMessage?: string;
734735
sandboxEnvironmentId?: string;
736+
prAuthorshipMode?: PrAuthorshipMode;
737+
runSource?: CloudRunSource;
738+
signalReportId?: string;
739+
githubUserToken?: string;
735740
},
736741
): Promise<Task> {
737742
const teamId = await this.getTeamId();
@@ -748,6 +753,18 @@ export class PostHogAPIClient {
748753
if (options?.sandboxEnvironmentId) {
749754
body.sandbox_environment_id = options.sandboxEnvironmentId;
750755
}
756+
if (options?.prAuthorshipMode) {
757+
body.pr_authorship_mode = options.prAuthorshipMode;
758+
}
759+
if (options?.runSource) {
760+
body.run_source = options.runSource;
761+
}
762+
if (options?.signalReportId) {
763+
body.signal_report_id = options.signalReportId;
764+
}
765+
if (options?.githubUserToken) {
766+
body.github_user_token = options.githubUserToken;
767+
}
751768

752769
const data = await this.api.post(
753770
`/api/projects/{project_id}/tasks/{id}/run/`,

apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ export const useInboxCloudTaskStore = create<InboxCloudTaskStore>()(
6161
workspaceMode: "cloud",
6262
githubIntegrationId: params.githubIntegrationId,
6363
repository: selectedRepo,
64+
cloudPrAuthorshipMode: "user",
65+
cloudRunSource: "signal_report",
66+
signalReportId: params.reportId,
6467
});
6568

6669
if (result.success) {

apps/code/src/renderer/features/sessions/service/service.ts

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import { isNotification, POSTHOG_NOTIFICATIONS } from "@posthog/agent";
3535
import { DEFAULT_GATEWAY_MODEL } from "@posthog/agent/gateway-models";
3636
import { getIsOnline } from "@renderer/stores/connectivityStore";
3737
import { trpcClient } from "@renderer/trpc/client";
38+
import { getGhUserTokenOrThrow } from "@renderer/utils/github";
3839
import { toast } from "@renderer/utils/toast";
3940
import { getCloudUrlFromRegion } from "@shared/constants/oauth";
4041
import {
@@ -45,6 +46,7 @@ import {
4546
type Task,
4647
} from "@shared/types";
4748
import { ANALYTICS_EVENTS } from "@shared/types/analytics";
49+
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
4850
import type { AcpMessage, StoredLogEntry } from "@shared/types/session-events";
4951
import { isJsonRpcRequest } from "@shared/types/session-events";
5052
import { buildPermissionToolMetadata, track } from "@utils/analytics";
@@ -1342,6 +1344,35 @@ export class SessionService {
13421344

13431345
const { blocks, promptText } = await this.prepareCloudPrompt(prompt);
13441346

1347+
const [previousRun, task] = await Promise.all([
1348+
client.getTaskRun(session.taskId, session.taskRunId),
1349+
client.getTask(session.taskId),
1350+
]);
1351+
const hasGitHubRepo = !!task.repository && !!task.github_integration;
1352+
const previousState = previousRun.state as Record<string, unknown>;
1353+
const previousOutput = (previousRun.output ?? {}) as Record<
1354+
string,
1355+
unknown
1356+
>;
1357+
// Prefer the actual working branch the agent last pushed to (synced by
1358+
// agent-server after each turn), then the run-level branch field, then
1359+
// the original base branch from state. This preserves unmerged work when
1360+
// the snapshot has expired and the sandbox is rebuilt from scratch.
1361+
const previousBaseBranch =
1362+
(typeof previousOutput.head_branch === "string"
1363+
? previousOutput.head_branch
1364+
: null) ??
1365+
previousRun.branch ??
1366+
(typeof previousState.pr_base_branch === "string"
1367+
? previousState.pr_base_branch
1368+
: null) ??
1369+
session.cloudBranch;
1370+
const prAuthorshipMode = this.getCloudPrAuthorshipMode(previousState);
1371+
const githubUserToken =
1372+
prAuthorshipMode === "user" && hasGitHubRepo
1373+
? await getGhUserTokenOrThrow()
1374+
: undefined;
1375+
13451376
log.info("Creating resume run for terminal cloud task", {
13461377
taskId: session.taskId,
13471378
previousRunId: session.taskRunId,
@@ -1353,10 +1384,17 @@ export class SessionService {
13531384
// The agent will load conversation history and restore the sandbox snapshot.
13541385
const updatedTask = await client.runTaskInCloud(
13551386
session.taskId,
1356-
session.cloudBranch,
1387+
previousBaseBranch,
13571388
{
13581389
resumeFromRunId: session.taskRunId,
13591390
pendingUserMessage: serializeCloudPrompt(blocks),
1391+
prAuthorshipMode,
1392+
runSource: this.getCloudRunSource(previousState),
1393+
signalReportId:
1394+
typeof previousState.signal_report_id === "string"
1395+
? previousState.signal_report_id
1396+
: undefined,
1397+
githubUserToken,
13601398
},
13611399
);
13621400
const newRun = updatedTask.latest_run;
@@ -2081,6 +2119,20 @@ export class SessionService {
20812119
}
20822120
}
20832121

2122+
private getCloudPrAuthorshipMode(
2123+
state: Record<string, unknown>,
2124+
): PrAuthorshipMode {
2125+
const explicitMode = state.pr_authorship_mode;
2126+
if (explicitMode === "user" || explicitMode === "bot") {
2127+
return explicitMode;
2128+
}
2129+
return state.run_source === "signal_report" ? "bot" : "user";
2130+
}
2131+
2132+
private getCloudRunSource(state: Record<string, unknown>): CloudRunSource {
2133+
return state.run_source === "signal_report" ? "signal_report" : "manual";
2134+
}
2135+
20842136
/**
20852137
* Filter out session/prompt events that should be skipped during resume.
20862138
* When resuming a cloud run, the initial session/prompt from the new run's

apps/code/src/renderer/sagas/task/task-creation.test.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ describe("TaskCreationSaga", () => {
153153
{
154154
pendingUserMessage: "Ship the fix",
155155
sandboxEnvironmentId: undefined,
156+
prAuthorshipMode: "bot",
157+
runSource: "manual",
158+
signalReportId: undefined,
159+
githubUserToken: undefined,
156160
},
157161
);
158162
expect(sendRunCommandMock).not.toHaveBeenCalled();
@@ -212,12 +216,14 @@ describe("TaskCreationSaga", () => {
212216
expect(runTaskInCloudMock).toHaveBeenCalledWith(
213217
"task-123",
214218
"release/remembered-branch",
215-
{
219+
expect.objectContaining({
216220
pendingUserMessage: expect.stringContaining(
217221
"__twig_cloud_prompt_v1__:",
218222
),
219223
sandboxEnvironmentId: undefined,
220-
},
224+
prAuthorshipMode: "bot",
225+
runSource: "manual",
226+
}),
221227
);
222228
expect(sendRunCommandMock).not.toHaveBeenCalled();
223229
expect(runTaskInCloudMock.mock.invocationCallOrder[0]).toBeLessThan(

apps/code/src/renderer/sagas/task/task-creation.ts

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import { trpcClient } from "@renderer/trpc";
2121
import { generateTitleAndSummary } from "@renderer/utils/generateTitle";
2222
import { getTaskRepository } from "@renderer/utils/repository";
2323
import type { ExecutionMode, Task } from "@shared/types";
24+
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
25+
import { getGhUserTokenOrThrow } from "@utils/github";
2426
import { logger } from "@utils/logger";
2527
import { queryClient } from "@utils/queryClient";
2628

@@ -78,6 +80,8 @@ export interface TaskCreationInput {
7880
reasoningLevel?: string;
7981
environmentId?: string;
8082
sandboxEnvironmentId?: string;
83+
cloudPrAuthorshipMode?: PrAuthorshipMode;
84+
cloudRunSource?: CloudRunSource;
8185
signalReportId?: string;
8286
}
8387

@@ -275,13 +279,27 @@ export class TaskCreationSaga extends Saga<
275279
if (shouldStartCloudRun) {
276280
task = await this.step({
277281
name: "cloud_run",
278-
execute: () =>
279-
this.deps.posthogClient.runTaskInCloud(task.id, branch, {
282+
execute: async () => {
283+
const hasGitHubRepo = !!task.repository && !!task.github_integration;
284+
const prAuthorshipMode =
285+
input.cloudPrAuthorshipMode ?? (hasGitHubRepo ? "user" : "bot");
286+
let githubUserToken: string | undefined;
287+
288+
if (prAuthorshipMode === "user" && hasGitHubRepo) {
289+
githubUserToken = await getGhUserTokenOrThrow();
290+
}
291+
292+
return this.deps.posthogClient.runTaskInCloud(task.id, branch, {
280293
pendingUserMessage: initialCloudPrompt
281294
? serializeCloudPrompt(initialCloudPrompt)
282295
: undefined,
283296
sandboxEnvironmentId: input.sandboxEnvironmentId,
284-
}),
297+
prAuthorshipMode,
298+
runSource: input.cloudRunSource ?? "manual",
299+
signalReportId: input.signalReportId,
300+
githubUserToken,
301+
});
302+
},
285303
rollback: async () => {
286304
log.info("Rolling back: cloud run (no-op)", { taskId: task.id });
287305
},
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import { trpcClient } from "@renderer/trpc";
2+
3+
export async function getGhUserTokenOrThrow(): Promise<string> {
4+
const tokenResult = await trpcClient.git.getGhAuthToken.query();
5+
if (!tokenResult.success || !tokenResult.token) {
6+
throw new Error(
7+
tokenResult.error ||
8+
"Authenticate GitHub CLI with `gh auth login` before starting a cloud task.",
9+
);
10+
}
11+
return tokenResult.token;
12+
}

0 commit comments

Comments
 (0)