Skip to content

Commit f499a35

Browse files
committed
add a lock around updatedOperations to prevent race conditions
1 parent be020a5 commit f499a35

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

sdk-testing/src/main/java/software/amazon/lambda/durable/testing/local/LocalMemoryExecutionClient.java

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
import java.time.Instant;
66
import java.util.Collections;
7+
import java.util.HashMap;
78
import java.util.LinkedHashMap;
89
import java.util.List;
910
import java.util.Map;
1011
import java.util.UUID;
11-
import java.util.concurrent.ConcurrentHashMap;
1212
import java.util.concurrent.CopyOnWriteArrayList;
1313
import java.util.concurrent.atomic.AtomicBoolean;
1414
import software.amazon.awssdk.services.lambda.model.*;
@@ -28,23 +28,27 @@ public class LocalMemoryExecutionClient implements DurableExecutionClient {
2828
private final Map<String, Operation> existingOperations = Collections.synchronizedMap(new LinkedHashMap<>());
2929
private final EventProcessor eventProcessor = new EventProcessor();
3030
private final List<OperationUpdate> operationUpdates = new CopyOnWriteArrayList<>();
31-
private final Map<String, Operation> updatedOperations = new ConcurrentHashMap<>();
31+
private final Map<String, Operation> updatedOperations = new HashMap<>();
3232

3333
@Override
3434
public CheckpointDurableExecutionResponse checkpoint(String arn, String token, List<OperationUpdate> updates) {
3535
operationUpdates.addAll(updates);
3636
updates.forEach(this::applyUpdate);
3737

3838
var newToken = UUID.randomUUID().toString();
39-
var response = CheckpointDurableExecutionResponse.builder()
40-
.checkpointToken(newToken)
41-
.newExecutionState(CheckpointUpdatedExecutionState.builder()
42-
.operations(updatedOperations.values())
43-
.build())
44-
.build();
4539

46-
// updatedOperations was copied into response, so clearing it is safe here
47-
updatedOperations.clear();
40+
CheckpointDurableExecutionResponse response;
41+
synchronized (updatedOperations) {
42+
response = CheckpointDurableExecutionResponse.builder()
43+
.checkpointToken(newToken)
44+
.newExecutionState(CheckpointUpdatedExecutionState.builder()
45+
.operations(updatedOperations.values())
46+
.build())
47+
.build();
48+
49+
// updatedOperations was copied into response, so clearing it is safe here
50+
updatedOperations.clear();
51+
}
4852
return response;
4953
}
5054

@@ -65,14 +69,14 @@ public List<OperationUpdate> getOperationUpdates() {
6569
* @return true if any operations were advanced, false otherwise
6670
*/
6771
public boolean advanceTime() {
68-
var replaced = new AtomicBoolean(false);
69-
existingOperations.replaceAll((key, op) -> {
72+
var hasOperationsAdvanced = new AtomicBoolean(false);
73+
// forEach is safe as we're not adding or removing keys here
74+
existingOperations.forEach((key, op) -> {
7075
// advance pending retries
7176
if (op.status() == OperationStatus.PENDING) {
72-
replaced.set(true);
77+
hasOperationsAdvanced.set(true);
7378
var readyOp = op.toBuilder().status(OperationStatus.READY).build();
74-
updatedOperations.put(op.id(), readyOp);
75-
return readyOp;
79+
updateOperation(readyOp);
7680
}
7781

7882
// advance waits
@@ -87,13 +91,11 @@ public boolean advanceTime() {
8791
.action(OperationAction.SUCCEED)
8892
.build();
8993
eventProcessor.processUpdate(update, succeededOp);
90-
replaced.set(true);
91-
updatedOperations.put(op.id(), succeededOp);
92-
return succeededOp;
94+
hasOperationsAdvanced.set(true);
95+
updateOperation(succeededOp);
9396
}
94-
return op;
9597
});
96-
return replaced.get();
98+
return hasOperationsAdvanced.get();
9799
}
98100

99101
/** Completes a chained invoke operation with the given result, simulating a child Lambda response. */
@@ -172,7 +174,9 @@ public void simulateFireAndForgetCheckpointLoss(String stepName) {
172174
throw new IllegalStateException("Operation not found: " + stepName);
173175
}
174176
existingOperations.remove(op.id());
175-
updatedOperations.remove(op.id());
177+
synchronized (updatedOperations) {
178+
updatedOperations.remove(op.id());
179+
}
176180
}
177181

178182
private void applyUpdate(OperationUpdate update) {
@@ -319,6 +323,8 @@ private OperationStatus deriveStatus(OperationAction action) {
319323

320324
private void updateOperation(Operation op) {
321325
existingOperations.put(op.id(), op);
322-
updatedOperations.put(op.id(), op);
326+
synchronized (updatedOperations) {
327+
updatedOperations.put(op.id(), op);
328+
}
323329
}
324330
}

0 commit comments

Comments
 (0)