Skip to content

Commit 48015b9

Browse files
FullyTypedAstraea Quinn S
authored andcommitted
parity(Termination Paths): Match exit paths with reference
- Introduce `ExecutionError` and `InvocationError` for non-retriable and retriable errors respectfully. - align callback waiting behaviour according to spec: - terminated without resuklt will raise an error - terminated and successful will return appropriate value - anything else will suspend execution
1 parent 1171521 commit 48015b9

20 files changed

+274
-87
lines changed

src/aws_durable_execution_sdk_python/context.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
WaitForConditionConfig,
1717
)
1818
from aws_durable_execution_sdk_python.exceptions import (
19-
FatalError,
19+
CallbackError,
2020
SuspendExecution,
2121
ValidationError,
2222
)
@@ -125,11 +125,17 @@ def result(self) -> T | None:
125125
checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result(
126126
self.operation_id
127127
)
128-
if checkpointed_result.is_started():
129-
msg: str = "Calback result not received yet. Suspending execution while waiting for result."
130-
raise SuspendExecution(msg)
131128

132-
if checkpointed_result.is_failed() or checkpointed_result.is_timed_out():
129+
if not checkpointed_result.is_existent():
130+
msg = "Callback operation must exist"
131+
raise CallbackError(msg)
132+
133+
if (
134+
checkpointed_result.is_failed()
135+
or checkpointed_result.is_cancelled()
136+
or checkpointed_result.is_timed_out()
137+
or checkpointed_result.is_stopped()
138+
):
133139
checkpointed_result.raise_callable_error()
134140

135141
if checkpointed_result.is_succeeded():
@@ -143,8 +149,10 @@ def result(self) -> T | None:
143149
durable_execution_arn=self.state.durable_execution_arn,
144150
)
145151

146-
msg = "Callback must be started before you can await the result."
147-
raise FatalError(msg)
152+
# operation exists; it has not terminated (successfully or otherwise)
153+
# therefore we should wait
154+
msg = "Callback result not received yet. Suspending execution while waiting for result."
155+
raise SuspendExecution(msg)
148156

149157

150158
class DurableContext(DurableContextProtocol):

src/aws_durable_execution_sdk_python/exceptions.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,90 @@
77

88
from dataclasses import dataclass
99
from datetime import UTC, datetime, timedelta
10+
from enum import Enum
11+
12+
13+
class TerminationReason(Enum):
14+
"""Reasons why a durable execution terminated."""
15+
16+
UNHANDLED_ERROR = "UNHANDLED_ERROR"
17+
INVOCATION_ERROR = "INVOCATION_ERROR"
18+
EXECUTION_ERROR = "EXECUTION_ERROR"
19+
CHECKPOINT_FAILED = "CHECKPOINT_FAILED"
20+
NON_DETERMINISTIC_EXECUTION = "NON_DETERMINISTIC_EXECUTION"
21+
STEP_INTERRUPTED = "STEP_INTERRUPTED"
22+
CALLBACK_ERROR = "CALLBACK_ERROR"
23+
SERIALIZATION_ERROR = "SERIALIZATION_ERROR"
1024

1125

1226
class DurableExecutionsError(Exception):
1327
"""Base class for Durable Executions exceptions"""
1428

1529

16-
class FatalError(DurableExecutionsError):
17-
"""Unrecoverable error. Will not retry."""
30+
class UnrecoverableError(DurableExecutionsError):
31+
"""Base class for errors that terminate execution."""
32+
33+
def __init__(self, message: str, termination_reason: TerminationReason):
34+
super().__init__(message)
35+
self.termination_reason = termination_reason
36+
37+
38+
class ExecutionError(UnrecoverableError):
39+
"""Error that returns FAILED status without retry."""
40+
41+
def __init__(
42+
self,
43+
message: str,
44+
termination_reason: TerminationReason = TerminationReason.EXECUTION_ERROR,
45+
):
46+
super().__init__(message, termination_reason)
47+
48+
49+
class InvocationError(UnrecoverableError):
50+
"""Error that should cause Lambda retry by throwing from handler."""
51+
52+
def __init__(
53+
self,
54+
message: str,
55+
termination_reason: TerminationReason = TerminationReason.INVOCATION_ERROR,
56+
):
57+
super().__init__(message, termination_reason)
58+
59+
60+
class CallbackError(ExecutionError):
61+
"""Error in callback handling."""
62+
63+
def __init__(self, message: str, callback_id: str | None = None):
64+
super().__init__(message, TerminationReason.CALLBACK_ERROR)
65+
self.callback_id = callback_id
66+
67+
68+
class CheckpointFailedError(InvocationError):
69+
"""Error when checkpoint operation fails."""
70+
71+
def __init__(self, message: str, step_id: str | None = None):
72+
super().__init__(message, TerminationReason.CHECKPOINT_FAILED)
73+
self.step_id = step_id
74+
75+
76+
class NonDeterministicExecutionError(ExecutionError):
77+
"""Error when execution is non-deterministic."""
1878

79+
def __init__(self, message: str, step_id: str | None = None):
80+
super().__init__(message, TerminationReason.NON_DETERMINISTIC_EXECUTION)
81+
self.step_id = step_id
1982

20-
class CheckpointError(FatalError):
83+
84+
class CheckpointError(CheckpointFailedError):
2185
"""Failure to checkpoint. Will terminate the lambda."""
2286

87+
def __init__(self, message: str):
88+
super().__init__(message)
89+
90+
@classmethod
91+
def from_exception(cls, exception: Exception) -> CheckpointError:
92+
return cls(message=str(exception))
93+
2394

2495
class ValidationError(DurableExecutionsError):
2596
"""Incorrect arguments to a Durable Function operation."""
@@ -50,9 +121,13 @@ def __init__(
50121
self.stack_trace = stack_trace
51122

52123

53-
class StepInterruptedError(UserlandError):
124+
class StepInterruptedError(InvocationError):
54125
"""Raised when a step is interrupted before it checkpointed at the end."""
55126

127+
def __init__(self, message: str, step_id: str | None = None):
128+
super().__init__(message, TerminationReason.STEP_INTERRUPTED)
129+
self.step_id = step_id
130+
56131

57132
class SuspendExecution(BaseException):
58133
"""Raise this exception to suspend the current execution by returning PENDING to DAR.

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from aws_durable_execution_sdk_python.exceptions import (
1313
CheckpointError,
1414
DurableExecutionsError,
15-
FatalError,
15+
ExecutionError,
16+
InvocationError,
1617
SuspendExecution,
1718
)
1819
from aws_durable_execution_sdk_python.lambda_service import (
@@ -291,10 +292,16 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
291292
logger.exception("Failed to checkpoint")
292293
# Throw the error to terminate the lambda
293294
raise
294-
except FatalError as e:
295-
logger.exception("Fatal error")
295+
296+
except InvocationError:
297+
logger.exception("Invocation error. Must terminate.")
298+
# Throw the error to trigger Lambda retry
299+
raise
300+
except ExecutionError as e:
301+
logger.exception("Execution error. Must terminate without retry.")
296302
return DurableExecutionInvocationOutput(
297-
status=InvocationStatus.PENDING, error=ErrorObject.from_exception(e)
303+
status=InvocationStatus.FAILED,
304+
error=ErrorObject.from_exception(e),
298305
).to_dict()
299306
except Exception as e:
300307
# all user-space errors go here

src/aws_durable_execution_sdk_python/lambda_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ def checkpoint(
983983
return CheckpointOutput.from_dict(result)
984984
except Exception as e:
985985
logger.exception("Failed to checkpoint.")
986-
raise CheckpointError(e) from e
986+
raise CheckpointError.from_exception(e) from e
987987

988988
def get_execution_state(
989989
self,

src/aws_durable_execution_sdk_python/operation/callback.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING, Any
66

77
from aws_durable_execution_sdk_python.config import StepConfig
8-
from aws_durable_execution_sdk_python.exceptions import FatalError
8+
from aws_durable_execution_sdk_python.exceptions import CallbackError
99
from aws_durable_execution_sdk_python.lambda_service import (
1010
CallbackOptions,
1111
OperationUpdate,
@@ -58,8 +58,8 @@ def create_callback_handler(
5858
not checkpointed_result.operation
5959
or not checkpointed_result.operation.callback_details
6060
):
61-
msg = "Missing callback details"
62-
raise FatalError(msg)
61+
msg = f"Missing callback details for operation: {operation_identifier.operation_id}"
62+
raise CallbackError(msg)
6363

6464
return checkpointed_result.operation.callback_details.callback_id
6565

@@ -74,8 +74,8 @@ def create_callback_handler(
7474
)
7575

7676
if not result.operation or not result.operation.callback_details:
77-
msg = "Missing callback details"
78-
raise FatalError(msg)
77+
msg = f"Missing callback details for operation: {operation_identifier.operation_id}"
78+
raise CallbackError(msg)
7979

8080
return result.operation.callback_details.callback_id
8181

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from typing import TYPE_CHECKING, TypeVar
77

88
from aws_durable_execution_sdk_python.config import ChildConfig
9-
from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution
9+
from aws_durable_execution_sdk_python.exceptions import (
10+
InvocationError,
11+
SuspendExecution,
12+
)
1013
from aws_durable_execution_sdk_python.lambda_service import (
1114
ContextOptions,
1215
ErrorObject,
@@ -138,7 +141,11 @@ def child_handler(
138141
)
139142
state.create_checkpoint(operation_update=fail_operation)
140143

141-
# TODO: rethink FatalError
142-
if isinstance(e, FatalError):
144+
# InvocationError and its derivatives can be retried
145+
# When we encounter an invocation error (in all of its forms), we bubble that
146+
# error upwards (with the checkpoint in place) such that we reach the
147+
# execution handler at the very top, which will then induce a retry from the
148+
# dataplane.
149+
if isinstance(e, InvocationError):
143150
raise
144151
raise error_object.to_callable_runtime_error() from e

src/aws_durable_execution_sdk_python/operation/invoke.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from typing import TYPE_CHECKING, TypeVar
77

88
from aws_durable_execution_sdk_python.config import InvokeConfig
9-
from aws_durable_execution_sdk_python.exceptions import (
10-
FatalError,
11-
)
9+
from aws_durable_execution_sdk_python.exceptions import ExecutionError
1210
from aws_durable_execution_sdk_python.lambda_service import (
1311
ChainedInvokeOptions,
1412
OperationUpdate,
@@ -107,5 +105,6 @@ def invoke_handler(
107105
)
108106
suspend_with_optional_timeout(msg, config.timeout_seconds)
109107
# This line should never be reached since suspend_with_optional_timeout always raises
108+
# if it is ever reached, we will crash in a non-retryable manner via ExecutionError
110109
msg = "suspend_with_optional_timeout should have raised an exception, but did not."
111-
raise FatalError(msg) from None
110+
raise ExecutionError(msg) from None

src/aws_durable_execution_sdk_python/operation/step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
StepSemantics,
1212
)
1313
from aws_durable_execution_sdk_python.exceptions import (
14-
FatalError,
14+
ExecutionError,
1515
StepInterruptedError,
1616
)
1717
from aws_durable_execution_sdk_python.lambda_service import (
@@ -151,14 +151,14 @@ def step_handler(
151151
)
152152
return raw_result # noqa: TRY300
153153
except Exception as e:
154-
if isinstance(e, FatalError):
154+
if isinstance(e, ExecutionError):
155155
# no retry on fatal - e.g checkpoint exception
156156
logger.debug(
157157
"💥 Fatal error for id: %s, name: %s",
158158
operation_identifier.operation_id,
159159
operation_identifier.name,
160160
)
161-
# this bubbles up to execution.durable_handler, where it will exit with PENDING. TODO: confirm if still correct
161+
# this bubbles up to execution.durable_handler, where it will exit with FAILED
162162
raise
163163

164164
logger.exception(
@@ -168,8 +168,10 @@ def step_handler(
168168
)
169169

170170
retry_handler(e, state, operation_identifier, config, checkpointed_result)
171+
# if we've failed to raise an exception from the retry_handler, then we are in a
172+
# weird state, and should crash terminate the execution
171173
msg = "retry handler should have raised an exception, but did not."
172-
raise FatalError(msg) from None
174+
raise ExecutionError(msg) from None
173175

174176

175177
# TODO: I don't much like this func, needs refactor. Messy grab-bag of args, refine.

src/aws_durable_execution_sdk_python/operation/wait_for_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING, TypeVar
77

88
from aws_durable_execution_sdk_python.exceptions import (
9-
FatalError,
9+
ExecutionError,
1010
)
1111
from aws_durable_execution_sdk_python.lambda_service import (
1212
ErrorObject,
@@ -203,4 +203,4 @@ def wait_for_condition_handler(
203203
raise
204204

205205
msg: str = "wait_for_condition should never reach this point"
206-
raise FatalError(msg)
206+
raise ExecutionError(msg)

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from aws_durable_execution_sdk_python.exceptions import (
3636
DurableExecutionsError,
37-
FatalError,
37+
ExecutionError,
3838
SerDesError,
3939
)
4040

@@ -440,9 +440,12 @@ def serialize(
440440
try:
441441
return active_serdes.serialize(value, serdes_context)
442442
except Exception as e:
443-
logger.exception("⚠️ Serialization failed for id: %s", operation_id)
444-
msg = f"Serialization failed for id: {operation_id}, error: {e}"
445-
raise FatalError(msg) from e
443+
logger.exception(
444+
"⚠️ Serialization failed for id: %s",
445+
operation_id,
446+
)
447+
msg = f"Serialization failed for id: {operation_id}, error: {e}."
448+
raise ExecutionError(msg) from e
446449

447450

448451
def deserialize(
@@ -469,4 +472,4 @@ def deserialize(
469472
except Exception as e:
470473
logger.exception("⚠️ Deserialization failed for id: %s", operation_id)
471474
msg = f"Deserialization failed for id: {operation_id}"
472-
raise FatalError(msg) from e
475+
raise ExecutionError(msg) from e

0 commit comments

Comments
 (0)