|
10 | 10 |
|
11 | 11 | import java.lang.reflect.Field; |
12 | 12 | import java.util.concurrent.CompletableFuture; |
| 13 | +import java.util.concurrent.ConcurrentHashMap; |
13 | 14 | import java.util.concurrent.Executors; |
14 | 15 | import java.util.concurrent.atomic.AtomicInteger; |
15 | 16 | import org.junit.jupiter.api.BeforeEach; |
@@ -69,25 +70,41 @@ void setUp() { |
69 | 70 | when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet()); |
70 | 71 | when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); |
71 | 72 |
|
72 | | - // Simulate the real backend: when a SUCCEED checkpoint is sent for the parallel op, |
73 | | - // make getOperationAndUpdateReplayState return a SUCCEEDED operation so waitForOperationCompletion() can find |
74 | | - // it. |
| 73 | + // Capture registered operations so we can drive onCheckpointComplete callbacks. |
| 74 | + var registeredOps = new ConcurrentHashMap<String, BaseDurableOperation<?>>(); |
| 75 | + doAnswer(inv -> { |
| 76 | + BaseDurableOperation<?> op = inv.getArgument(0); |
| 77 | + registeredOps.put(op.getOperationId(), op); |
| 78 | + return null; |
| 79 | + }) |
| 80 | + .when(executionManager) |
| 81 | + .registerOperation(any()); |
| 82 | + |
| 83 | + // Simulate the real backend for all sendOperationUpdate calls: |
| 84 | + // - For SUCCEED on the parallel op: update the stub and fire onCheckpointComplete to unblock join(). |
| 85 | + // - For everything else (START, child checkpoints): just return a completed future. |
75 | 86 | var succeededParallelOp = Operation.builder() |
76 | 87 | .id(OPERATION_ID) |
77 | 88 | .name("test-parallel") |
78 | 89 | .type(OperationType.CONTEXT) |
79 | 90 | .subType(OperationSubType.PARALLEL.getValue()) |
80 | 91 | .status(OperationStatus.SUCCEEDED) |
81 | 92 | .build(); |
82 | | - when(executionManager.sendOperationUpdate(argThat(u -> u != null |
83 | | - && u.id() != null |
84 | | - && u.id().equals(OPERATION_ID) |
85 | | - && u.action() == OperationAction.SUCCEED))) |
86 | | - .thenAnswer(inv -> { |
87 | | - when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)) |
88 | | - .thenReturn(succeededParallelOp); |
| 93 | + doAnswer(inv -> { |
| 94 | + var update = (software.amazon.awssdk.services.lambda.model.OperationUpdate) inv.getArgument(0); |
| 95 | + |
| 96 | + if (OPERATION_ID.equals(update.id()) && update.action() == OperationAction.SUCCEED) { |
| 97 | + when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)) |
| 98 | + .thenReturn(succeededParallelOp); |
| 99 | + var op = registeredOps.get(OPERATION_ID); |
| 100 | + if (op != null) { |
| 101 | + op.onCheckpointComplete(succeededParallelOp); |
| 102 | + } |
| 103 | + } |
89 | 104 | return CompletableFuture.completedFuture(null); |
90 | | - }); |
| 105 | + }) |
| 106 | + .when(executionManager) |
| 107 | + .sendOperationUpdate(any()); |
91 | 108 | } |
92 | 109 |
|
93 | 110 | private ParallelOperation<Void> createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) { |
|
0 commit comments