Skip to content

Commit 84e9bdf

Browse files
wangyb-AAlex Wang
andauthored
feat: [Parallel] Do not checkpoint parallel during replay (#233)
* Do not checkpoint when replay * Update checkpoint condition * Remove runJoin(), add mock backend to fix tests --------- Co-authored-by: Alex Wang <wangyb@amazon.com>
1 parent 8840167 commit 84e9bdf

4 files changed

Lines changed: 146 additions & 27 deletions

File tree

examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ public Output handleRequest(Input input, DurableContext context) {
6262

6363
var deliveries = futures.stream().map(DurableFuture::get).toList();
6464
logger.info("All {} notifications delivered", deliveries.size());
65+
// Test replay
66+
context.wait("wait for finalization", Duration.ofSeconds(5));
6567
return new Output(deliveries);
6668
}
6769
}

sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import software.amazon.lambda.durable.TypeToken;
1313
import software.amazon.lambda.durable.context.DurableContextImpl;
1414
import software.amazon.lambda.durable.exception.ConcurrencyExecutionException;
15+
import software.amazon.lambda.durable.execution.ExecutionManager;
1516
import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus;
1617
import software.amazon.lambda.durable.model.OperationIdentifier;
1718
import software.amazon.lambda.durable.model.OperationSubType;
@@ -42,6 +43,8 @@
4243
*/
4344
public class ParallelOperation<T> extends ConcurrencyOperation<T> {
4445

46+
private boolean skipCheckpoint = false;
47+
4548
public ParallelOperation(
4649
OperationIdentifier operationIdentifier,
4750
TypeToken<T> resultTypeToken,
@@ -79,6 +82,10 @@ protected <R> ChildContextOperation<R> createItem(
7982

8083
@Override
8184
protected void handleSuccess() {
85+
if (skipCheckpoint) {
86+
// Do not send checkpoint during replay
87+
return;
88+
}
8289
sendOperationUpdate(OperationUpdate.builder()
8390
.action(OperationAction.SUCCEED)
8491
.subType(getSubType().getValue())
@@ -99,8 +106,9 @@ protected void start() {
99106

100107
@Override
101108
protected void replay(Operation existing) {
102-
// Always replay child branches for parallel
103-
start();
109+
// No-op: child branches handle their own replay via ChildContextOperation.replay().
110+
// Set replaying=true so handleSuccess() skips re-checkpointing the already-completed parallel context.
111+
skipCheckpoint = ExecutionManager.isTerminalStatus(existing.status());
104112
}
105113

106114
@Override

sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
package software.amazon.lambda.durable.operation;
44

55
import static org.junit.jupiter.api.Assertions.*;
6+
import static org.mockito.ArgumentMatchers.any;
67
import static org.mockito.ArgumentMatchers.anyString;
78
import static org.mockito.Mockito.*;
89

910
import java.lang.reflect.Field;
11+
import java.util.concurrent.CompletableFuture;
1012
import java.util.concurrent.Executors;
1113
import java.util.concurrent.atomic.AtomicBoolean;
1214
import java.util.concurrent.atomic.AtomicInteger;
@@ -70,6 +72,18 @@ void setUp() {
7072
when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet());
7173
// All child operations are NOT in replay
7274
when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null);
75+
// Simulate the real backend: the parent concurrency operation is available in storage after completion
76+
// so that waitForOperationCompletion() can find it. TestConcurrencyOperation.handleSuccess/Failure are no-ops
77+
// (no checkpoint sent), so we stub this unconditionally for OPERATION_ID.
78+
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
79+
.thenReturn(Operation.builder()
80+
.id(OPERATION_ID)
81+
.name("test-concurrency")
82+
.type(OperationType.CONTEXT)
83+
.subType(OperationSubType.PARALLEL.getValue())
84+
.status(OperationStatus.SUCCEEDED)
85+
.build());
86+
when(executionManager.sendOperationUpdate(any())).thenReturn(CompletableFuture.completedFuture(null));
7387
}
7488

7589
private TestConcurrencyOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount)
@@ -138,7 +152,7 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception {
138152
TypeToken.get(String.class),
139153
SER_DES);
140154

141-
runJoin(op);
155+
op.exposedJoin();
142156

143157
assertTrue(op.isSuccessHandled());
144158
assertFalse(op.isFailureHandled());
@@ -171,7 +185,7 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception {
171185
TypeToken.get(String.class),
172186
SER_DES);
173187

174-
runJoin(op);
188+
op.exposedJoin();
175189

176190
assertTrue(op.isSuccessHandled());
177191
assertEquals(1, op.getSucceededCount());
@@ -191,14 +205,6 @@ void addItem_usesRootChildContextAsParent() throws Exception {
191205
assertSame(childContext, op.getLastParentContext());
192206
}
193207

194-
// ===== Helpers =====
195-
196-
private void runJoin(TestConcurrencyOperation op) throws InterruptedException {
197-
var t = new Thread(op::exposedJoin);
198-
t.start();
199-
t.join(2000);
200-
}
201-
202208
// ===== Test subclass =====
203209

204210
static class TestConcurrencyOperation extends ConcurrencyOperation<Void> {

sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.mockito.Mockito.*;
1010

1111
import java.lang.reflect.Field;
12+
import java.util.concurrent.CompletableFuture;
1213
import java.util.concurrent.Executors;
1314
import java.util.concurrent.atomic.AtomicInteger;
1415
import org.junit.jupiter.api.BeforeEach;
@@ -67,6 +68,26 @@ void setUp() {
6768
mockIdGenerator = mock(OperationIdGenerator.class);
6869
when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet());
6970
when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null);
71+
72+
// Simulate the real backend: when a SUCCEED checkpoint is sent for the parallel op,
73+
// make getOperationAndUpdateReplayState return a SUCCEEDED operation so waitForOperationCompletion() can find
74+
// it.
75+
var succeededParallelOp = Operation.builder()
76+
.id(OPERATION_ID)
77+
.name("test-parallel")
78+
.type(OperationType.CONTEXT)
79+
.subType(OperationSubType.PARALLEL.getValue())
80+
.status(OperationStatus.SUCCEEDED)
81+
.build();
82+
when(executionManager.sendOperationUpdate(argThat(u -> u != null
83+
&& u.id() != null
84+
&& u.id().equals(OPERATION_ID)
85+
&& u.action() == OperationAction.SUCCEED)))
86+
.thenAnswer(inv -> {
87+
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
88+
.thenReturn(succeededParallelOp);
89+
return CompletableFuture.completedFuture(null);
90+
});
7091
}
7192

7293
private ParallelOperation<Void> createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) {
@@ -153,7 +174,7 @@ void handleSuccess_sendsSucceedCheckpoint() throws Exception {
153174
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);
154175
op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES);
155176

156-
runJoin(op);
177+
op.get();
157178

158179
verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
159180
}
@@ -179,7 +200,7 @@ void minSuccessful_joinCompletesWhenThresholdMet() throws Exception {
179200
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);
180201

181202
// Should not throw
182-
assertDoesNotThrow(() -> runJoin(op));
203+
op.get();
183204
assertEquals(1, op.getSucceededCount());
184205
}
185206

@@ -199,6 +220,100 @@ void contextHierarchy_branchesUseParallelContextAsParent() throws Exception {
199220
assertNotNull(childOp);
200221
}
201222

223+
// ===== Replay =====
224+
225+
@Test
226+
void replay_doesNotSendStartCheckpoint() throws Exception {
227+
// Simulate the parallel operation already existing in the service (STARTED status)
228+
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
229+
.thenReturn(Operation.builder()
230+
.id(OPERATION_ID)
231+
.name("test-parallel")
232+
.type(OperationType.CONTEXT)
233+
.subType(OperationSubType.PARALLEL.getValue())
234+
.status(OperationStatus.STARTED)
235+
.build());
236+
// Both branches already succeeded
237+
when(executionManager.getOperationAndUpdateReplayState("child-1"))
238+
.thenReturn(Operation.builder()
239+
.id("child-1")
240+
.name("branch-1")
241+
.type(OperationType.CONTEXT)
242+
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
243+
.status(OperationStatus.SUCCEEDED)
244+
.contextDetails(
245+
ContextDetails.builder().result("\"r1\"").build())
246+
.build());
247+
when(executionManager.getOperationAndUpdateReplayState("child-2"))
248+
.thenReturn(Operation.builder()
249+
.id("child-2")
250+
.name("branch-2")
251+
.type(OperationType.CONTEXT)
252+
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
253+
.status(OperationStatus.SUCCEEDED)
254+
.contextDetails(
255+
ContextDetails.builder().result("\"r2\"").build())
256+
.build());
257+
258+
var op = createOperation(-1, -1, 0);
259+
setOperationIdGenerator(op, mockIdGenerator);
260+
op.execute();
261+
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);
262+
op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES);
263+
264+
op.get();
265+
266+
verify(executionManager, never())
267+
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.START));
268+
verify(executionManager, times(1))
269+
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
270+
}
271+
272+
@Test
273+
void replay_doesNotSendSucceedCheckpointWhenParallelAlreadySucceeded() throws Exception {
274+
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
275+
.thenReturn(Operation.builder()
276+
.id(OPERATION_ID)
277+
.name("test-parallel")
278+
.type(OperationType.CONTEXT)
279+
.subType(OperationSubType.PARALLEL.getValue())
280+
.status(OperationStatus.SUCCEEDED)
281+
.build());
282+
when(executionManager.getOperationAndUpdateReplayState("child-1"))
283+
.thenReturn(Operation.builder()
284+
.id("child-1")
285+
.name("branch-1")
286+
.type(OperationType.CONTEXT)
287+
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
288+
.status(OperationStatus.SUCCEEDED)
289+
.contextDetails(
290+
ContextDetails.builder().result("\"r1\"").build())
291+
.build());
292+
when(executionManager.getOperationAndUpdateReplayState("child-2"))
293+
.thenReturn(Operation.builder()
294+
.id("child-2")
295+
.name("branch-2")
296+
.type(OperationType.CONTEXT)
297+
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
298+
.status(OperationStatus.SUCCEEDED)
299+
.contextDetails(
300+
ContextDetails.builder().result("\"r2\"").build())
301+
.build());
302+
303+
var op = createOperation(-1, -1, 0);
304+
setOperationIdGenerator(op, mockIdGenerator);
305+
op.execute();
306+
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);
307+
op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES);
308+
309+
op.get();
310+
311+
verify(executionManager, never())
312+
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.START));
313+
verify(executionManager, never())
314+
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
315+
}
316+
202317
// ===== handleFailure still sends SUCCEED =====
203318

204319
@Test
@@ -224,22 +339,10 @@ void handleFailure_sendsSucceedCheckpointEvenWhenFailureToleranceExceeded() thro
224339
TypeToken.get(String.class),
225340
SER_DES);
226341

227-
runJoin(op);
342+
op.get();
228343

229344
verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
230345
verify(executionManager, never())
231346
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL));
232347
}
233-
234-
// ===== Helpers =====
235-
236-
private void runJoin(ParallelOperation<?> op) throws InterruptedException {
237-
var t = new Thread(op::get);
238-
t.start();
239-
t.join(2000);
240-
if (t.isAlive()) {
241-
t.interrupt();
242-
fail("join() did not complete within 2 seconds");
243-
}
244-
}
245348
}

0 commit comments

Comments
 (0)