Skip to content

Commit 0414e69

Browse files
committed
Make IdempotencyConfig thread-safe to avoid Lamdba context leaks in config singleton.
1 parent 6623d89 commit 0414e69

3 files changed

Lines changed: 141 additions & 4 deletions

File tree

powertools-idempotency/powertools-idempotency-core/src/main/java/software/amazon/lambda/powertools/idempotency/IdempotencyConfig.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public class IdempotencyConfig {
3434
private final boolean throwOnNoIdempotencyKey;
3535
private final String hashFunction;
3636
private final BiFunction<Object, DataRecord, Object> responseHook;
37-
private Context lambdaContext;
37+
private final InheritableThreadLocal<Context> lambdaContext = new InheritableThreadLocal<>();
3838

3939
private IdempotencyConfig(String eventKeyJMESPath, String payloadValidationJMESPath,
4040
boolean throwOnNoIdempotencyKey, boolean useLocalCache, int localCacheMaxItems,
@@ -87,11 +87,11 @@ public String getHashFunction() {
8787
}
8888

8989
public Context getLambdaContext() {
90-
return lambdaContext;
90+
return lambdaContext.get();
9191
}
9292

9393
public void setLambdaContext(Context lambdaContext) {
94-
this.lambdaContext = lambdaContext;
94+
this.lambdaContext.set(lambdaContext);
9595
}
9696

9797
public BiFunction<Object, DataRecord, Object> getResponseHook() {

powertools-idempotency/powertools-idempotency-core/src/test/java/software/amazon/lambda/powertools/idempotency/PowertoolsIdempotencyTest.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,72 @@ void firstCall_withExplicitIdempotencyKey_shouldPutInStore() {
165165
verify(store).saveSuccess(any(), resultCaptor.capture(), any());
166166
assertThat(resultCaptor.getValue()).isEqualTo(basket);
167167
}
168+
169+
@Test
170+
void concurrentInvocations_shouldNotLeakContext() throws Exception {
171+
Idempotency.config()
172+
.withPersistenceStore(store)
173+
.configure();
174+
175+
PowertoolsIdempotencyMultiArgFunction function = new PowertoolsIdempotencyMultiArgFunction();
176+
177+
// GIVEN
178+
int threadCount = 10;
179+
Thread[] threads = new Thread[threadCount];
180+
Context[] capturedContexts = new Context[threadCount];
181+
int[] capturedRemainingTimes = new int[threadCount];
182+
boolean[] success = new boolean[threadCount];
183+
184+
// WHEN - Multiple threads call handleRequest with different contexts
185+
for (int i = 0; i < threadCount; i++) {
186+
final int threadIndex = i;
187+
final int expectedTime = (i + 1) * 2000; // 2000, 4000, 6000, ..., 20000
188+
189+
final Context threadContext = new TestLambdaContext() {
190+
@Override
191+
public int getRemainingTimeInMillis() {
192+
return expectedTime;
193+
}
194+
};
195+
196+
threads[i] = new Thread(() -> {
197+
try {
198+
Product p = new Product(threadIndex, "product" + threadIndex, 10);
199+
function.handleRequest(p, threadContext);
200+
201+
// Capture the context that was actually stored in ThreadLocal by this thread
202+
Context captured = Idempotency.getInstance().getConfig().getLambdaContext();
203+
capturedContexts[threadIndex] = captured;
204+
capturedRemainingTimes[threadIndex] = captured != null ? captured.getRemainingTimeInMillis() : -1;
205+
success[threadIndex] = true;
206+
} catch (Exception e) {
207+
success[threadIndex] = false;
208+
}
209+
});
210+
}
211+
212+
// Start all threads
213+
for (Thread thread : threads) {
214+
thread.start();
215+
}
216+
217+
// Wait for all threads to complete
218+
for (Thread thread : threads) {
219+
thread.join();
220+
}
221+
222+
// THEN - All threads should complete successfully
223+
for (boolean result : success) {
224+
assertThat(result).isTrue();
225+
}
226+
227+
// THEN - Each thread should have captured its own context (no leakage)
228+
for (int i = 0; i < threadCount; i++) {
229+
int expectedTime = (i + 1) * 2000;
230+
assertThat(capturedRemainingTimes[i])
231+
.as("Thread %d should have remaining time %d", i, expectedTime)
232+
.isEqualTo(expectedTime);
233+
assertThat(capturedContexts[i]).as("Thread %d should have non-null context", i).isNotNull();
234+
}
235+
}
168236
}

powertools-idempotency/powertools-idempotency-core/src/test/java/software/amazon/lambda/powertools/idempotency/internal/IdempotencyAspectTest.java

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ void secondCall_notExpired_shouldNotGetFromStoreIfPresentOnIdempotencyException(
194194
"Test message",
195195
new RuntimeException("Test Cause"),
196196
dr))
197-
.when(store).saveInProgress(any(), any(), any());
197+
.when(store).saveInProgress(any(), any(), any());
198198

199199
// WHEN
200200
IdempotencyEnabledFunction function = new IdempotencyEnabledFunction();
@@ -538,4 +538,73 @@ void idempotencyOnSubMethodVoid_shouldThrowException() {
538538
IdempotencyConfigurationException.class);
539539
}
540540

541+
@Test
542+
void concurrentInvocations_shouldNotLeakContext() throws Exception {
543+
Idempotency.config()
544+
.withPersistenceStore(store)
545+
.configure();
546+
547+
// Use IdempotencyInternalFunction which calls registerLambdaContext
548+
IdempotencyInternalFunction function = new IdempotencyInternalFunction(true);
549+
550+
// GIVEN
551+
int threadCount = 10;
552+
Thread[] threads = new Thread[threadCount];
553+
Context[] capturedContexts = new Context[threadCount];
554+
int[] capturedRemainingTimes = new int[threadCount];
555+
boolean[] success = new boolean[threadCount];
556+
557+
// WHEN - Multiple threads call handleRequest with different contexts
558+
for (int i = 0; i < threadCount; i++) {
559+
final int threadIndex = i;
560+
final int expectedTime = (i + 1) * 1000; // 1000, 2000, 3000, ..., 10000
561+
562+
final Context threadContext = new TestLambdaContext() {
563+
@Override
564+
public int getRemainingTimeInMillis() {
565+
return expectedTime;
566+
}
567+
};
568+
569+
threads[i] = new Thread(() -> {
570+
try {
571+
Product p = new Product(threadIndex, "product" + threadIndex, 10);
572+
function.handleRequest(p, threadContext);
573+
574+
// Capture the context that was actually stored in ThreadLocal by this thread
575+
Context captured = Idempotency.getInstance().getConfig().getLambdaContext();
576+
capturedContexts[threadIndex] = captured;
577+
capturedRemainingTimes[threadIndex] = captured != null ? captured.getRemainingTimeInMillis() : -1;
578+
success[threadIndex] = true;
579+
} catch (Exception e) {
580+
success[threadIndex] = false;
581+
}
582+
});
583+
}
584+
585+
// Start all threads
586+
for (Thread thread : threads) {
587+
thread.start();
588+
}
589+
590+
// Wait for all threads to complete
591+
for (Thread thread : threads) {
592+
thread.join();
593+
}
594+
595+
// THEN - All threads should complete successfully
596+
for (boolean result : success) {
597+
assertThat(result).isTrue();
598+
}
599+
600+
// THEN - Each thread should have captured its own context (no leakage)
601+
for (int i = 0; i < threadCount; i++) {
602+
int expectedTime = (i + 1) * 1000;
603+
assertThat(capturedRemainingTimes[i])
604+
.as("Thread %d should have remaining time %d", i, expectedTime)
605+
.isEqualTo(expectedTime);
606+
assertThat(capturedContexts[i]).as("Thread %d should have non-null context", i).isNotNull();
607+
}
608+
}
609+
541610
}

0 commit comments

Comments
 (0)