Skip to content

Commit 14cbc22

Browse files
committed
optional assume role
1 parent e9fc7af commit 14cbc22

File tree

4 files changed

+501
-8
lines changed

4 files changed

+501
-8
lines changed

apps/webapp/app/v3/getDeploymentImageRef.server.ts

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,91 @@ import {
77
RepositoryNotFoundException,
88
GetAuthorizationTokenCommand,
99
} from "@aws-sdk/client-ecr";
10+
import { STSClient, AssumeRoleCommand } from "@aws-sdk/client-sts";
1011
import { tryCatch } from "@trigger.dev/core";
1112
import { logger } from "~/services/logger.server";
1213

14+
// Optional configuration for cross-account access
15+
export type CrossAccountConfig = {
16+
assumeRole: boolean;
17+
roleName: string;
18+
};
19+
20+
const DEFAULT_CROSS_ACCOUNT_CONFIG: CrossAccountConfig = {
21+
assumeRole: false,
22+
roleName: "OrganizationAccountAccessRole",
23+
};
24+
25+
async function getAssumedRoleCredentials(
26+
region: string,
27+
accountId: string,
28+
config: CrossAccountConfig
29+
): Promise<{
30+
accessKeyId: string;
31+
secretAccessKey: string;
32+
sessionToken: string;
33+
}> {
34+
const sts = new STSClient({ region });
35+
const roleArn = `arn:aws:iam::${accountId}:role/${config.roleName}`;
36+
37+
// Generate a unique session name using timestamp and random string
38+
// This helps with debugging but doesn't affect concurrent sessions
39+
const timestamp = Date.now();
40+
const randomSuffix = Math.random().toString(36).substring(2, 8);
41+
const sessionName = `TriggerWebappECRAccess_${timestamp}_${randomSuffix}`;
42+
43+
try {
44+
const response = await sts.send(
45+
new AssumeRoleCommand({
46+
RoleArn: roleArn,
47+
RoleSessionName: sessionName,
48+
// Sessions automatically expire after 1 hour
49+
// AWS allows 5000 concurrent sessions by default
50+
DurationSeconds: 3600,
51+
})
52+
);
53+
54+
if (!response.Credentials) {
55+
throw new Error("STS: No credentials returned from assumed role");
56+
}
57+
58+
if (
59+
!response.Credentials.AccessKeyId ||
60+
!response.Credentials.SecretAccessKey ||
61+
!response.Credentials.SessionToken
62+
) {
63+
throw new Error("STS: Invalid credentials returned from assumed role");
64+
}
65+
66+
return {
67+
accessKeyId: response.Credentials.AccessKeyId,
68+
secretAccessKey: response.Credentials.SecretAccessKey,
69+
sessionToken: response.Credentials.SessionToken,
70+
};
71+
} catch (error) {
72+
logger.error("Failed to assume role", { roleArn, sessionName, error });
73+
throw error;
74+
}
75+
}
76+
77+
async function createEcrClient(
78+
region: string,
79+
registryId?: string,
80+
crossAccountConfig: CrossAccountConfig = DEFAULT_CROSS_ACCOUNT_CONFIG
81+
) {
82+
// If no registryId or role assumption is disabled, use default credentials
83+
if (!registryId || !crossAccountConfig.assumeRole) {
84+
return new ECRClient({ region });
85+
}
86+
87+
// Get credentials for cross-account access
88+
const credentials = await getAssumedRoleCredentials(region, registryId, crossAccountConfig);
89+
return new ECRClient({
90+
region,
91+
credentials,
92+
});
93+
}
94+
1395
export async function getDeploymentImageRef({
1496
host,
1597
namespace,
@@ -18,6 +100,7 @@ export async function getDeploymentImageRef({
18100
environmentSlug,
19101
registryId,
20102
registryTags,
103+
crossAccountConfig,
21104
}: {
22105
host: string;
23106
namespace: string;
@@ -26,6 +109,7 @@ export async function getDeploymentImageRef({
26109
environmentSlug: string;
27110
registryId?: string;
28111
registryTags?: string;
112+
crossAccountConfig?: CrossAccountConfig;
29113
}): Promise<{
30114
imageRef: string;
31115
isEcr: boolean;
@@ -41,7 +125,13 @@ export async function getDeploymentImageRef({
41125
}
42126

43127
const [ecrRepoError] = await tryCatch(
44-
ensureEcrRepositoryExists({ repositoryName, registryHost: host, registryId, registryTags })
128+
ensureEcrRepositoryExists({
129+
repositoryName,
130+
registryHost: host,
131+
registryId,
132+
registryTags,
133+
crossAccountConfig,
134+
})
45135
);
46136

47137
if (ecrRepoError) {
@@ -75,13 +165,15 @@ async function createEcrRepository({
75165
region,
76166
registryId,
77167
registryTags,
168+
crossAccountConfig,
78169
}: {
79170
repositoryName: string;
80171
region: string;
81172
registryId?: string;
82173
registryTags?: string;
174+
crossAccountConfig?: CrossAccountConfig;
83175
}): Promise<Repository> {
84-
const ecr = new ECRClient({ region });
176+
const ecr = await createEcrClient(region, registryId, crossAccountConfig);
85177

86178
const result = await ecr.send(
87179
new CreateRepositoryCommand({
@@ -107,12 +199,14 @@ async function getEcrRepository({
107199
repositoryName,
108200
region,
109201
registryId,
202+
crossAccountConfig,
110203
}: {
111204
repositoryName: string;
112205
region: string;
113206
registryId?: string;
207+
crossAccountConfig?: CrossAccountConfig;
114208
}): Promise<Repository | undefined> {
115-
const ecr = new ECRClient({ region });
209+
const ecr = await createEcrClient(region, registryId, crossAccountConfig);
116210

117211
try {
118212
const result = await ecr.send(
@@ -153,11 +247,13 @@ async function ensureEcrRepositoryExists({
153247
registryHost,
154248
registryId,
155249
registryTags,
250+
crossAccountConfig,
156251
}: {
157252
repositoryName: string;
158253
registryHost: string;
159254
registryId?: string;
160255
registryTags?: string;
256+
crossAccountConfig?: CrossAccountConfig;
161257
}): Promise<Repository> {
162258
const region = getEcrRegion(registryHost);
163259

@@ -166,7 +262,7 @@ async function ensureEcrRepositoryExists({
166262
}
167263

168264
const [getRepoError, existingRepo] = await tryCatch(
169-
getEcrRepository({ repositoryName, region, registryId })
265+
getEcrRepository({ repositoryName, region, registryId, crossAccountConfig })
170266
);
171267

172268
if (getRepoError) {
@@ -180,7 +276,7 @@ async function ensureEcrRepositoryExists({
180276
}
181277

182278
const [createRepoError, newRepo] = await tryCatch(
183-
createEcrRepository({ repositoryName, region, registryId, registryTags })
279+
createEcrRepository({ repositoryName, region, registryId, registryTags, crossAccountConfig })
184280
);
185281

186282
if (createRepoError) {
@@ -201,17 +297,19 @@ async function ensureEcrRepositoryExists({
201297
export async function getEcrAuthToken({
202298
registryHost,
203299
registryId,
300+
crossAccountConfig,
204301
}: {
205302
registryHost: string;
206303
registryId?: string;
304+
crossAccountConfig?: CrossAccountConfig;
207305
}): Promise<{ username: string; password: string }> {
208306
const region = getEcrRegion(registryHost);
209307
if (!region) {
210308
logger.error("Invalid ECR registry host", { registryHost });
211309
throw new Error("Invalid ECR registry host");
212310
}
213311

214-
const ecr = new ECRClient({ region });
312+
const ecr = await createEcrClient(region, registryId, crossAccountConfig);
215313
const response = await ecr.send(
216314
new GetAuthorizationTokenCommand({
217315
registryIds: registryId ? [registryId] : undefined,

apps/webapp/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"@ariakit/react-core": "^0.4.6",
3636
"@aws-sdk/client-ecr": "^3.839.0",
3737
"@aws-sdk/client-sqs": "^3.445.0",
38+
"@aws-sdk/client-sts": "^3.840.0",
3839
"@codemirror/autocomplete": "^6.3.1",
3940
"@codemirror/commands": "^6.1.2",
4041
"@codemirror/lang-javascript": "^6.1.1",

apps/webapp/test/getDeploymentImageRef.test.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ describe.skipIf(process.env.RUN_REGISTRY_TESTS !== "1")("getDeploymentImageRef",
1414
const registryId = process.env.DEPLOY_REGISTRY_ID;
1515
const registryTags = "test=test,test2=test2";
1616

17+
const assumeRole = process.env.ASSUME_ROLE === "1";
18+
19+
const crossAccountConfig = {
20+
assumeRole,
21+
roleName: "OrganizationAccountAccessRole",
22+
};
23+
1724
// Clean up test repository after tests
1825
afterAll(async () => {
1926
if (!registryId) {
@@ -65,6 +72,7 @@ describe.skipIf(process.env.RUN_REGISTRY_TESTS !== "1")("getDeploymentImageRef",
6572
environmentSlug: "test",
6673
registryId,
6774
registryTags,
75+
crossAccountConfig,
6876
});
6977

7078
expect(imageRef.imageRef).toBe(
@@ -83,6 +91,7 @@ describe.skipIf(process.env.RUN_REGISTRY_TESTS !== "1")("getDeploymentImageRef",
8391
environmentSlug: "prod",
8492
registryId,
8593
registryTags,
94+
crossAccountConfig,
8695
});
8796

8897
expect(imageRef.imageRef).toBe(

0 commit comments

Comments
 (0)