-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathChildContextOperation.java
More file actions
291 lines (261 loc) · 13.2 KB
/
ChildContextOperation.java
File metadata and controls
291 lines (261 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.operation;
import static software.amazon.lambda.durable.execution.ExecutionManager.isTerminalStatus;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import software.amazon.awssdk.services.lambda.model.ContextOptions;
import software.amazon.awssdk.services.lambda.model.ErrorObject;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationAction;
import software.amazon.awssdk.services.lambda.model.OperationStatus;
import software.amazon.awssdk.services.lambda.model.OperationType;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.DurableContext;
import software.amazon.lambda.durable.TypeToken;
import software.amazon.lambda.durable.context.DurableContextImpl;
import software.amazon.lambda.durable.exception.CallbackFailedException;
import software.amazon.lambda.durable.exception.CallbackSubmitterException;
import software.amazon.lambda.durable.exception.CallbackTimeoutException;
import software.amazon.lambda.durable.exception.ChildContextFailedException;
import software.amazon.lambda.durable.exception.DurableOperationException;
import software.amazon.lambda.durable.exception.StepFailedException;
import software.amazon.lambda.durable.exception.StepInterruptedException;
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
import software.amazon.lambda.durable.execution.SuspendExecutionException;
import software.amazon.lambda.durable.model.OperationIdentifier;
import software.amazon.lambda.durable.serde.SerDes;
import software.amazon.lambda.durable.util.ExceptionHelper;
/**
* Manages the lifecycle of a child execution context.
*
* <p>A child context runs a user function in a separate thread with its own operation counter and checkpoint log.
* Operations within the child context use the child's context ID as their parentId.
*
* <p>When created as part of a {@link ConcurrencyOperation} (e.g., parallel execution), the child notifies its parent
* on completion via {@code onItemComplete()} BEFORE closing its own child context. It also skips checkpointing if the
* parent operation has already succeeded.
*/
public class ChildContextOperation<T> extends BaseDurableOperation<T> {
private static final int LARGE_RESULT_THRESHOLD = 256 * 1024;
private final Function<DurableContext, T> function;
private final ExecutorService userExecutor;
private final ConcurrencyOperation<?> parentOperation;
private boolean replayChildContext;
private T reconstructedResult;
public ChildContextOperation(
OperationIdentifier operationIdentifier,
Function<DurableContext, T> function,
TypeToken<T> resultTypeToken,
SerDes resultSerDes,
DurableContextImpl durableContext) {
this(operationIdentifier, function, resultTypeToken, resultSerDes, durableContext, null);
}
public ChildContextOperation(
OperationIdentifier operationIdentifier,
Function<DurableContext, T> function,
TypeToken<T> resultTypeToken,
SerDes resultSerDes,
DurableContextImpl durableContext,
ConcurrencyOperation<?> parentOperation) {
super(operationIdentifier, resultTypeToken, resultSerDes, durableContext);
this.function = function;
this.userExecutor = getContext().getDurableConfig().getExecutorService();
this.parentOperation = parentOperation;
}
/** Starts the operation. */
@Override
protected void start() {
// First execution: fire-and-forget START checkpoint, then run
sendOperationUpdateAsync(OperationUpdate.builder().action(OperationAction.START));
executeChildContext();
}
/** Replays the operation. */
@Override
protected void replay(Operation existing) {
switch (existing.status()) {
case SUCCEEDED -> {
if (existing.contextDetails() != null
&& Boolean.TRUE.equals(existing.contextDetails().replayChildren())) {
// Large result: re-execute child context to reconstruct result
replayChildContext = true;
executeChildContext();
} else {
markAlreadyCompleted();
}
}
case FAILED -> markAlreadyCompleted();
case STARTED -> executeChildContext();
default ->
terminateExecutionWithIllegalDurableOperationException(
"Unexpected child context status: " + existing.status());
}
}
@Override
protected void markAlreadyCompleted() {
super.markAlreadyCompleted();
if (parentOperation != null) {
parentOperation.onItemComplete(this);
}
}
private void executeChildContext() {
// The operationId is already globally unique (prefixed by parent context path via
// DurableContext.nextOperationId), so we use it directly as the contextId.
// E.g., first level child context "hash(1)",
// second level child context "hash(hash(1)-2)",
// third level child context "hash(hash(hash(1)-2)-1)".
var contextId = getOperationId();
// Thread registration is intentionally split across two threads:
// 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the
// parent can deregister and trigger suspension (race prevention).
// 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside
// the child context know which context they belong to.
// registerActiveThread is idempotent (no-op if already registered).
registerActiveThread(contextId);
Runnable userHandler = () -> {
// use a try-with-resources to
// - add thread id/type to thread local when the step starts
// - clear logger properties when the step finishes
//
// When this child is part of a ConcurrencyOperation (parentOperation != null),
// we notify the parent BEFORE closing the child context. This ensures the parent
// can trigger the next queued branch while the current child context is still valid.
try (var childContext = getContext().createChildContext(contextId, getName())) {
try {
T result = function.apply(childContext);
handleChildContextSuccess(result);
} catch (Throwable e) {
handleChildContextFailure(e);
} finally {
if (parentOperation != null) {
parentOperation.onItemComplete(this);
}
}
}
};
// Execute user provided child context code in user-configured executor
CompletableFuture.runAsync(userHandler, userExecutor);
}
private void handleChildContextSuccess(T result) {
if (replayChildContext) {
// Replaying a SUCCEEDED child with replayChildren=true — skip checkpointing.
// Mark the completableFuture completed so get() doesn't block waiting for a checkpoint response.
this.reconstructedResult = result;
markAlreadyCompleted();
} else {
checkpointSuccess(result);
}
}
private void checkpointSuccess(T result) {
// Skip checkpointing if parent ConcurrencyOperation has already completed —
// prevents race conditions where a child finishes after the parent has already completed.
if (parentOperation != null && parentOperation.isOperationCompleted()) {
this.reconstructedResult = result;
markAlreadyCompleted();
return;
}
var serialized = serializeResult(result);
if (serialized == null || serialized.getBytes(StandardCharsets.UTF_8).length < LARGE_RESULT_THRESHOLD) {
sendOperationUpdate(
OperationUpdate.builder().action(OperationAction.SUCCEED).payload(serialized));
} else {
// Large result: checkpoint with empty payload + ReplayChildren flag.
// Store the result so get() can return it directly without deserializing the empty payload.
this.reconstructedResult = result;
sendOperationUpdate(OperationUpdate.builder()
.action(OperationAction.SUCCEED)
.payload("")
.contextOptions(
ContextOptions.builder().replayChildren(true).build()));
}
}
private void handleChildContextFailure(Throwable exception) {
exception = ExceptionHelper.unwrapCompletableFuture(exception);
if (exception instanceof SuspendExecutionException suspendExecutionException) {
// Rethrow Error immediately — do not checkpoint
throw suspendExecutionException;
}
if (exception instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) {
// terminate the execution and throw the exception if it's not recoverable
terminateExecution(unrecoverableDurableExecutionException);
}
// Skip checkpointing if parent ConcurrencyOperation has already completed —
// prevents race conditions where a child finishes after the parent has already succeeded.
if (parentOperation != null && parentOperation.isOperationCompleted()) {
markAlreadyCompleted();
return;
}
final ErrorObject errorObject;
if (exception instanceof DurableOperationException opEx) {
errorObject = opEx.getErrorObject();
} else {
errorObject = serializeException(exception);
}
sendOperationUpdate(
OperationUpdate.builder().action(OperationAction.FAIL).error(errorObject));
}
@Override
public T get() {
var op = waitForOperationCompletion();
if (op.status() == OperationStatus.SUCCEEDED) {
if (reconstructedResult != null) {
return reconstructedResult;
}
var contextDetails = op.contextDetails();
var result = (contextDetails != null) ? contextDetails.result() : null;
return deserializeResult(result);
} else {
var contextDetails = op.contextDetails();
var errorObject = (contextDetails != null) ? contextDetails.error() : null;
// Attempt to reconstruct and throw the original exception
Throwable original = deserializeException(errorObject);
if (original != null) {
ExceptionHelper.sneakyThrow(original);
}
// throw a general failed exception if a user exception is not reconstructed
return switch (getSubType()) {
case WAIT_FOR_CALLBACK -> handleWaitForCallbackFailure(op);
case MAP -> throw new ChildContextFailedException(op);
case MAP_ITERATION -> throw new ChildContextFailedException(op);
case PARALLEL -> throw new ChildContextFailedException(op);
case PARALLEL_BRANCH -> throw new ChildContextFailedException(op);
case RUN_IN_CHILD_CONTEXT -> throw new ChildContextFailedException(op);
case WAIT_FOR_CONDITION -> throw new ChildContextFailedException(op);
};
}
}
private T handleWaitForCallbackFailure(Operation op) {
var childrenOps = getChildOperations(op.id());
var callbackOp = childrenOps.stream()
.filter(o -> o.type() == OperationType.CALLBACK)
.findFirst()
.orElse(null);
var submitterOp = childrenOps.stream()
.filter(o -> o.type() == OperationType.STEP)
.findFirst()
.orElse(null);
if (callbackOp != null) {
// if callback failed
if (isTerminalStatus(callbackOp.status())) {
switch (callbackOp.status()) {
case FAILED -> throw new CallbackFailedException(callbackOp);
case TIMED_OUT -> throw new CallbackTimeoutException(callbackOp);
}
}
// if submitter failed
if (submitterOp != null
&& isTerminalStatus(submitterOp.status())
&& submitterOp.status() != OperationStatus.SUCCEEDED) {
var stepError = submitterOp.stepDetails().error();
if (StepInterruptedException.isStepInterruptedException(stepError)) {
throw new CallbackSubmitterException(callbackOp, new StepInterruptedException(submitterOp));
} else {
throw new CallbackSubmitterException(callbackOp, new StepFailedException(submitterOp));
}
}
}
throw new IllegalStateException("Unknown waitForCallback status");
}
}