Skip to content

Commit 07a0b71

Browse files
committed
fix: preserve step retry attempt
• Fix checkpoint processor to preserve attempt count and timestamp • Add DurableChildContextTestRunner
1 parent 548e936 commit 07a0b71

File tree

6 files changed

+237
-53
lines changed

6 files changed

+237
-53
lines changed

src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,27 @@ def _create_context_details(self, update: OperationUpdate) -> ContextDetails | N
7070
else None
7171
)
7272

73-
def _create_step_details(self, update: OperationUpdate) -> StepDetails | None:
73+
def _create_step_details(
74+
self, update: OperationUpdate, current_operation: Operation | None = None
75+
) -> StepDetails | None:
7476
"""Create StepDetails from OperationUpdate."""
75-
return (
76-
StepDetails(result=update.payload, error=update.error)
77-
if update.operation_type == OperationType.STEP
78-
else None
79-
)
77+
attempt: int = 0
78+
next_attempt_timestamp: str | None = None
79+
80+
if update.operation_type is OperationType.STEP:
81+
if current_operation and current_operation.step_details:
82+
attempt = current_operation.step_details.attempt
83+
next_attempt_timestamp = (
84+
current_operation.step_details.next_attempt_timestamp
85+
)
86+
return StepDetails(
87+
attempt=attempt,
88+
next_attempt_timestamp=next_attempt_timestamp,
89+
result=update.payload,
90+
error=update.error,
91+
)
92+
93+
return None
8094

8195
def _create_callback_details(
8296
self, update: OperationUpdate
@@ -134,7 +148,7 @@ def _translate_update_to_operation(
134148

135149
execution_details = self._create_execution_details(update)
136150
context_details = self._create_context_details(update)
137-
step_details = self._create_step_details(update)
151+
step_details = self._create_step_details(update, current_operation)
138152
callback_details = self._create_callback_details(update)
139153
invoke_details = self._create_invoke_details(update)
140154
wait_details = self._create_wait_details(update, current_operation)

src/aws_durable_execution_sdk_python_testing/observer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""Checkpoint processors can notify the Execution of notable event state changes. Observer pattern."""
22

3+
from __future__ import annotations
4+
35
import threading
46
from abc import ABC, abstractmethod
5-
from collections.abc import Callable
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Callable
611

7-
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
12+
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
813

914

1015
class ExecutionObserver(ABC):
@@ -34,7 +39,7 @@ def on_step_retry_scheduled(
3439
class ExecutionNotifier:
3540
"""Notifies observers about execution events. Thread-safe."""
3641

37-
def __init__(self):
42+
def __init__(self) -> None:
3843
self._observers: list[ExecutionObserver] = []
3944
self._lock = threading.RLock()
4045

src/aws_durable_execution_sdk_python_testing/runner.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from __future__ import annotations
22

3+
import json
34
from dataclasses import dataclass, field
4-
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
5+
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Protocol, TypeVar, cast
56

7+
from aws_durable_execution_sdk_python.execution import (
8+
InvocationStatus,
9+
durable_handler,
10+
)
611
from aws_durable_execution_sdk_python.lambda_service import (
712
ErrorObject,
813
OperationStatus,
@@ -31,6 +36,7 @@
3136
import datetime
3237
from collections.abc import Callable, MutableMapping
3338

39+
from aws_durable_execution_sdk_python.context import DurableContext
3440
from aws_durable_execution_sdk_python.execution import InvocationStatus
3541

3642
from aws_durable_execution_sdk_python_testing.execution import Execution
@@ -49,6 +55,7 @@ class Operation:
4955

5056

5157
T = TypeVar("T", bound=Operation)
58+
P = ParamSpec("P")
5259

5360

5461
class OperationFactory(Protocol):
@@ -90,7 +97,7 @@ def from_svc_operation(
9097
@dataclass(frozen=True)
9198
class ContextOperation(Operation):
9299
child_operations: list[Operation]
93-
result: str | None = None
100+
result: Any = None
94101
error: ErrorObject | None = None
95102

96103
@staticmethod
@@ -119,9 +126,11 @@ def from_svc_operation(
119126
start_timestamp=operation.start_timestamp,
120127
end_timestamp=operation.end_timestamp,
121128
child_operations=child_operations,
122-
result=operation.context_details.result
123-
if operation.context_details
124-
else None,
129+
result=(
130+
json.loads(operation.context_details.result)
131+
if operation.context_details and operation.context_details.result
132+
else None
133+
),
125134
error=operation.context_details.error
126135
if operation.context_details
127136
else None,
@@ -157,8 +166,7 @@ def get_execution(self, name: str) -> ExecutionOperation:
157166
class StepOperation(ContextOperation):
158167
attempt: int = 0
159168
next_attempt_timestamp: str | None = None
160-
# TODO: deserialize?
161-
result: str | None = None
169+
result: Any = None
162170
error: ErrorObject | None = None
163171

164172
@staticmethod
@@ -193,7 +201,11 @@ def from_svc_operation(
193201
if operation.step_details
194202
else None
195203
),
196-
result=operation.step_details.result if operation.step_details else None,
204+
result=(
205+
json.loads(operation.step_details.result)
206+
if operation.step_details and operation.step_details.result
207+
else None
208+
),
197209
error=operation.step_details.error if operation.step_details else None,
198210
)
199211

@@ -230,7 +242,7 @@ def from_svc_operation(
230242
@dataclass(frozen=True)
231243
class CallbackOperation(ContextOperation):
232244
callback_id: str | None = None
233-
result: str | None = None
245+
result: Any = None
234246
error: ErrorObject | None = None
235247

236248
@staticmethod
@@ -264,9 +276,11 @@ def from_svc_operation(
264276
if operation.callback_details
265277
else None
266278
),
267-
result=operation.callback_details.result
268-
if operation.callback_details
269-
else None,
279+
result=(
280+
json.loads(operation.callback_details.result)
281+
if operation.callback_details and operation.callback_details.result
282+
else None
283+
),
270284
error=operation.callback_details.error
271285
if operation.callback_details
272286
else None,
@@ -276,7 +290,7 @@ def from_svc_operation(
276290
@dataclass(frozen=True)
277291
class InvokeOperation(Operation):
278292
durable_execution_arn: str | None = None
279-
result: str | None = None
293+
result: Any = None
280294
error: ErrorObject | None = None
281295

282296
@staticmethod
@@ -301,9 +315,11 @@ def from_svc_operation(
301315
if operation.invoke_details
302316
else None
303317
),
304-
result=operation.invoke_details.result
305-
if operation.invoke_details
306-
else None,
318+
result=(
319+
json.loads(operation.invoke_details.result)
320+
if operation.invoke_details and operation.invoke_details.result
321+
else None
322+
),
307323
error=operation.invoke_details.error if operation.invoke_details else None,
308324
)
309325

@@ -334,7 +350,7 @@ def create_operation(
334350
class DurableFunctionTestResult:
335351
status: InvocationStatus
336352
operations: list[Operation]
337-
result: str | None = None
353+
result: Any = None
338354
error: ErrorObject | None = None
339355

340356
@classmethod
@@ -352,10 +368,14 @@ def create(cls, execution: Execution) -> DurableFunctionTestResult:
352368
msg: str = "Execution result must exist to create test result."
353369
raise DurableFunctionsTestError(msg)
354370

371+
deserialized_result = (
372+
json.loads(execution.result.result) if execution.result.result else None
373+
)
374+
355375
return cls(
356376
status=execution.result.status,
357377
operations=operations,
358-
result=execution.result.result,
378+
result=deserialized_result,
359379
error=execution.result.error,
360380
)
361381

@@ -413,7 +433,7 @@ def close(self):
413433

414434
def run(
415435
self,
416-
input: str, # noqa: A002
436+
input: str | None = None, # noqa: A002
417437
timeout: int = 900,
418438
function_name: str = "test-function",
419439
execution_name: str = "execution-name",
@@ -451,4 +471,19 @@ def run(
451471
execution: Execution = self._store.load(output.execution_arn)
452472
return DurableFunctionTestResult.create(execution=execution)
453473

454-
# return execution
474+
475+
class DurableChildContextTestRunner(DurableFunctionTestRunner):
476+
"""Test a durable block, annotated with @durable_with_child_context, in isolation."""
477+
478+
def __init__(
479+
self,
480+
context_function: Callable[Concatenate[DurableContext, P], Any],
481+
*args,
482+
**kwargs,
483+
):
484+
# wrap the durable context around a durable handler as a convenience to run directly
485+
@durable_handler
486+
def handler(event: Any, context: DurableContext): # noqa: ARG001
487+
return context_function(*args, **kwargs)(context)
488+
489+
super().__init__(handler)

tests/checkpoint/processors/base_test.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def create_context_details(self, update):
6666
"""Public method to access _create_context_details for testing."""
6767
return self._create_context_details(update)
6868

69-
def create_step_details(self, update):
69+
def create_step_details(self, update, current_operation):
7070
"""Public method to access _create_step_details for testing."""
71-
return self._create_step_details(update)
71+
return self._create_step_details(update, current_operation)
7272

7373
def create_callback_details(self, update):
7474
"""Public method to access _create_callback_details for testing."""
@@ -187,7 +187,11 @@ def test_create_step_details():
187187
error=error,
188188
)
189189

190-
result = processor.create_step_details(update)
190+
current_op = Mock()
191+
current_op.step_details = Mock()
192+
current_op.step_details.attempt = Mock()
193+
194+
result = processor.create_step_details(update, current_op)
191195

192196
assert isinstance(result, StepDetails)
193197
assert result.result == "test-payload"
@@ -203,11 +207,34 @@ def test_create_step_details_non_step_type():
203207
payload="test-payload",
204208
)
205209

206-
result = processor.create_step_details(update)
210+
current_op = Mock()
211+
current_op.step_details = Mock()
212+
current_op.step_details.attempt = Mock()
213+
214+
result = processor.create_step_details(update, current_op)
207215

208216
assert result is None
209217

210218

219+
def test_create_step_details_without_current_operation():
220+
processor = MockProcessor()
221+
error = ErrorObject.from_message("test error")
222+
update = OperationUpdate(
223+
operation_id="test-id",
224+
operation_type=OperationType.STEP,
225+
action=OperationAction.START,
226+
payload="test-payload",
227+
error=error,
228+
)
229+
230+
result = processor.create_step_details(update, None)
231+
232+
assert isinstance(result, StepDetails)
233+
assert result.result == "test-payload"
234+
assert result.error == error
235+
assert result.attempt == 0
236+
237+
211238
def test_create_callback_details():
212239
processor = MockProcessor()
213240
error = ErrorObject.from_message("test error")

tests/e2e/basic_success_path_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,16 @@ def function_under_test(event: Any, context: DurableContext) -> list[str]:
6868
result: DurableFunctionTestResult = runner.run(input="input str", timeout=10)
6969

7070
assert result.status is InvocationStatus.SUCCEEDED
71-
assert result.result == '["1 2", "3 4 4 3", "5 6"]'
71+
assert result.result == ["1 2", "3 4 4 3", "5 6"]
7272

7373
one_result: StepOperation = result.get_step("one")
74-
assert one_result.result == '"1 2"'
74+
assert one_result.result == "1 2"
7575

7676
two_result: ContextOperation = result.get_context("two")
77-
assert two_result.result == '"3 4 4 3"'
77+
assert two_result.result == "3 4 4 3"
7878

7979
three_result: StepOperation = result.get_step("three")
80-
assert three_result.result == '"5 6"'
80+
assert three_result.result == "5 6"
8181

8282
# currently has the optimization where it's not saving child checkpoints after parent done
8383
# prob should unpick that for test

0 commit comments

Comments
 (0)