From df7742380c79835e24ce85f517176f77467785ed Mon Sep 17 00:00:00 2001 From: yaythomas Date: Thu, 25 Sep 2025 16:44:12 -0700 Subject: [PATCH] fix: preserve step retry attempt & InvokeOptions update - Fix checkpoint processor to preserve attempt count and timestamp - Update InvokeOptions with latest svc signature - Add DurableChildContextTestRunner --- .../checkpoint/processors/base.py | 36 +++-- .../observer.py | 11 +- .../runner.py | 75 +++++++--- tests/checkpoint/processors/base_test.py | 43 ++++-- tests/e2e/basic_success_path_test.py | 8 +- tests/runner_test.py | 133 ++++++++++++++++-- 6 files changed, 241 insertions(+), 65 deletions(-) diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py index 1444b91..5749c94 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py @@ -70,13 +70,27 @@ def _create_context_details(self, update: OperationUpdate) -> ContextDetails | N else None ) - def _create_step_details(self, update: OperationUpdate) -> StepDetails | None: + def _create_step_details( + self, update: OperationUpdate, current_operation: Operation | None = None + ) -> StepDetails | None: """Create StepDetails from OperationUpdate.""" - return ( - StepDetails(result=update.payload, error=update.error) - if update.operation_type == OperationType.STEP - else None - ) + attempt: int = 0 + next_attempt_timestamp: str | None = None + + if update.operation_type is OperationType.STEP: + if current_operation and current_operation.step_details: + attempt = current_operation.step_details.attempt + next_attempt_timestamp = ( + current_operation.step_details.next_attempt_timestamp + ) + return StepDetails( + attempt=attempt, + next_attempt_timestamp=next_attempt_timestamp, + result=update.payload, + error=update.error, + ) + + return None def _create_callback_details( self, update: OperationUpdate @@ -93,12 +107,10 @@ def _create_callback_details( def _create_invoke_details(self, update: OperationUpdate) -> InvokeDetails | None: """Create InvokeDetails from OperationUpdate.""" if update.operation_type == OperationType.INVOKE and update.invoke_options: - qualifier = ( - update.invoke_options.function_qualifier - or update.invoke_options.function_name - ) + # Create a basic ARN using the function name + # In a real implementation, this would need more context about the execution # TODO: To confirm how or if this works - arn = f"arn:aws:lambda:us-west-2:123456789012:durable-execution:{update.invoke_options.function_name}:{update.invoke_options.durable_execution_name}:{qualifier}" + arn = f"arn:aws:lambda:us-west-2:123456789012:durable-execution:{update.invoke_options.function_name}:execution-name" return InvokeDetails( durable_execution_arn=arn, result=update.payload, error=update.error ) @@ -134,7 +146,7 @@ def _translate_update_to_operation( execution_details = self._create_execution_details(update) context_details = self._create_context_details(update) - step_details = self._create_step_details(update) + step_details = self._create_step_details(update, current_operation) callback_details = self._create_callback_details(update) invoke_details = self._create_invoke_details(update) wait_details = self._create_wait_details(update, current_operation) diff --git a/src/aws_durable_execution_sdk_python_testing/observer.py b/src/aws_durable_execution_sdk_python_testing/observer.py index e8c6dbc..be49416 100644 --- a/src/aws_durable_execution_sdk_python_testing/observer.py +++ b/src/aws_durable_execution_sdk_python_testing/observer.py @@ -1,10 +1,15 @@ """Checkpoint processors can notify the Execution of notable event state changes. Observer pattern.""" +from __future__ import annotations + import threading from abc import ABC, abstractmethod -from collections.abc import Callable +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable -from aws_durable_execution_sdk_python.lambda_service import ErrorObject + from aws_durable_execution_sdk_python.lambda_service import ErrorObject class ExecutionObserver(ABC): @@ -34,7 +39,7 @@ def on_step_retry_scheduled( class ExecutionNotifier: """Notifies observers about execution events. Thread-safe.""" - def __init__(self): + def __init__(self) -> None: self._observers: list[ExecutionObserver] = [] self._lock = threading.RLock() diff --git a/src/aws_durable_execution_sdk_python_testing/runner.py b/src/aws_durable_execution_sdk_python_testing/runner.py index 0648848..cc53ade 100644 --- a/src/aws_durable_execution_sdk_python_testing/runner.py +++ b/src/aws_durable_execution_sdk_python_testing/runner.py @@ -1,8 +1,13 @@ from __future__ import annotations +import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Protocol, TypeVar, cast +from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Protocol, TypeVar, cast +from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_handler, +) from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, OperationStatus, @@ -31,6 +36,7 @@ import datetime from collections.abc import Callable, MutableMapping + from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.execution import InvocationStatus from aws_durable_execution_sdk_python_testing.execution import Execution @@ -49,6 +55,7 @@ class Operation: T = TypeVar("T", bound=Operation) +P = ParamSpec("P") class OperationFactory(Protocol): @@ -90,7 +97,7 @@ def from_svc_operation( @dataclass(frozen=True) class ContextOperation(Operation): child_operations: list[Operation] - result: str | None = None + result: Any = None error: ErrorObject | None = None @staticmethod @@ -119,9 +126,11 @@ def from_svc_operation( start_timestamp=operation.start_timestamp, end_timestamp=operation.end_timestamp, child_operations=child_operations, - result=operation.context_details.result - if operation.context_details - else None, + result=( + json.loads(operation.context_details.result) + if operation.context_details and operation.context_details.result + else None + ), error=operation.context_details.error if operation.context_details else None, @@ -157,8 +166,7 @@ def get_execution(self, name: str) -> ExecutionOperation: class StepOperation(ContextOperation): attempt: int = 0 next_attempt_timestamp: str | None = None - # TODO: deserialize? - result: str | None = None + result: Any = None error: ErrorObject | None = None @staticmethod @@ -193,7 +201,11 @@ def from_svc_operation( if operation.step_details else None ), - result=operation.step_details.result if operation.step_details else None, + result=( + json.loads(operation.step_details.result) + if operation.step_details and operation.step_details.result + else None + ), error=operation.step_details.error if operation.step_details else None, ) @@ -230,7 +242,7 @@ def from_svc_operation( @dataclass(frozen=True) class CallbackOperation(ContextOperation): callback_id: str | None = None - result: str | None = None + result: Any = None error: ErrorObject | None = None @staticmethod @@ -264,9 +276,11 @@ def from_svc_operation( if operation.callback_details else None ), - result=operation.callback_details.result - if operation.callback_details - else None, + result=( + json.loads(operation.callback_details.result) + if operation.callback_details and operation.callback_details.result + else None + ), error=operation.callback_details.error if operation.callback_details else None, @@ -276,7 +290,7 @@ def from_svc_operation( @dataclass(frozen=True) class InvokeOperation(Operation): durable_execution_arn: str | None = None - result: str | None = None + result: Any = None error: ErrorObject | None = None @staticmethod @@ -301,9 +315,11 @@ def from_svc_operation( if operation.invoke_details else None ), - result=operation.invoke_details.result - if operation.invoke_details - else None, + result=( + json.loads(operation.invoke_details.result) + if operation.invoke_details and operation.invoke_details.result + else None + ), error=operation.invoke_details.error if operation.invoke_details else None, ) @@ -334,7 +350,7 @@ def create_operation( class DurableFunctionTestResult: status: InvocationStatus operations: list[Operation] - result: str | None = None + result: Any = None error: ErrorObject | None = None @classmethod @@ -352,10 +368,14 @@ def create(cls, execution: Execution) -> DurableFunctionTestResult: msg: str = "Execution result must exist to create test result." raise DurableFunctionsTestError(msg) + deserialized_result = ( + json.loads(execution.result.result) if execution.result.result else None + ) + return cls( status=execution.result.status, operations=operations, - result=execution.result.result, + result=deserialized_result, error=execution.result.error, ) @@ -413,7 +433,7 @@ def close(self): def run( self, - input: str, # noqa: A002 + input: str | None = None, # noqa: A002 timeout: int = 900, function_name: str = "test-function", execution_name: str = "execution-name", @@ -451,4 +471,19 @@ def run( execution: Execution = self._store.load(output.execution_arn) return DurableFunctionTestResult.create(execution=execution) - # return execution + +class DurableChildContextTestRunner(DurableFunctionTestRunner): + """Test a durable block, annotated with @durable_with_child_context, in isolation.""" + + def __init__( + self, + context_function: Callable[Concatenate[DurableContext, P], Any], + *args, + **kwargs, + ): + # wrap the durable context around a durable handler as a convenience to run directly + @durable_handler + def handler(event: Any, context: DurableContext): # noqa: ARG001 + return context_function(*args, **kwargs)(context) + + super().__init__(handler) diff --git a/tests/checkpoint/processors/base_test.py b/tests/checkpoint/processors/base_test.py index fa3ac0a..d2ff4e6 100644 --- a/tests/checkpoint/processors/base_test.py +++ b/tests/checkpoint/processors/base_test.py @@ -66,9 +66,9 @@ def create_context_details(self, update): """Public method to access _create_context_details for testing.""" return self._create_context_details(update) - def create_step_details(self, update): + def create_step_details(self, update, current_operation): """Public method to access _create_step_details for testing.""" - return self._create_step_details(update) + return self._create_step_details(update, current_operation) def create_callback_details(self, update): """Public method to access _create_callback_details for testing.""" @@ -187,7 +187,11 @@ def test_create_step_details(): error=error, ) - result = processor.create_step_details(update) + current_op = Mock() + current_op.step_details = Mock() + current_op.step_details.attempt = Mock() + + result = processor.create_step_details(update, current_op) assert isinstance(result, StepDetails) assert result.result == "test-payload" @@ -203,11 +207,34 @@ def test_create_step_details_non_step_type(): payload="test-payload", ) - result = processor.create_step_details(update) + current_op = Mock() + current_op.step_details = Mock() + current_op.step_details.attempt = Mock() + + result = processor.create_step_details(update, current_op) assert result is None +def test_create_step_details_without_current_operation(): + processor = MockProcessor() + error = ErrorObject.from_message("test error") + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + payload="test-payload", + error=error, + ) + + result = processor.create_step_details(update, None) + + assert isinstance(result, StepDetails) + assert result.result == "test-payload" + assert result.error == error + assert result.attempt == 0 + + def test_create_callback_details(): processor = MockProcessor() error = ErrorObject.from_message("test error") @@ -244,11 +271,7 @@ def test_create_callback_details_non_callback_type(): def test_create_invoke_details(): processor = MockProcessor() error = ErrorObject.from_message("test error") - invoke_options = InvokeOptions( - function_name="test-function", - function_qualifier="test-qualifier", - durable_execution_name="test-execution", - ) + invoke_options = InvokeOptions(function_name="test-function") update = OperationUpdate( operation_id="test-id", operation_type=OperationType.INVOKE, @@ -262,8 +285,6 @@ def test_create_invoke_details(): assert isinstance(result, InvokeDetails) assert "test-function" in result.durable_execution_arn - assert "test-execution" in result.durable_execution_arn - assert "test-qualifier" in result.durable_execution_arn assert result.result == "test-payload" assert result.error == error diff --git a/tests/e2e/basic_success_path_test.py b/tests/e2e/basic_success_path_test.py index faee614..d84686b 100644 --- a/tests/e2e/basic_success_path_test.py +++ b/tests/e2e/basic_success_path_test.py @@ -68,16 +68,16 @@ def function_under_test(event: Any, context: DurableContext) -> list[str]: result: DurableFunctionTestResult = runner.run(input="input str", timeout=10) assert result.status is InvocationStatus.SUCCEEDED - assert result.result == '["1 2", "3 4 4 3", "5 6"]' + assert result.result == ["1 2", "3 4 4 3", "5 6"] one_result: StepOperation = result.get_step("one") - assert one_result.result == '"1 2"' + assert one_result.result == "1 2" two_result: ContextOperation = result.get_context("two") - assert two_result.result == '"3 4 4 3"' + assert two_result.result == "3 4 4 3" three_result: StepOperation = result.get_step("three") - assert three_result.result == '"5 6"' + assert three_result.result == "5 6" # currently has the optimization where it's not saving child checkpoints after parent done # prob should unpick that for test diff --git a/tests/runner_test.py b/tests/runner_test.py index 9fdcef4..723ffa4 100644 --- a/tests/runner_test.py +++ b/tests/runner_test.py @@ -1,6 +1,7 @@ """Unit tests for runner module.""" import datetime +import json from unittest.mock import Mock, patch import pytest @@ -29,6 +30,7 @@ OPERATION_FACTORIES, CallbackOperation, ContextOperation, + DurableChildContextTestRunner, DurableFunctionTestResult, DurableFunctionTestRunner, ExecutionOperation, @@ -94,7 +96,7 @@ def test_execution_operation_wrong_type(): def test_context_operation_from_svc_operation(): """Test ContextOperation creation from service operation.""" - context_details = ContextDetails(result="test-result", error=None) + context_details = ContextDetails(result=json.dumps("test-result"), error=None) svc_op = SvcOperation( operation_id="ctx-id", operation_type=OperationType.CONTEXT, @@ -116,7 +118,7 @@ def test_context_operation_with_children(): operation_id="parent-id", operation_type=OperationType.CONTEXT, status=OperationStatus.SUCCEEDED, - context_details=ContextDetails(result="parent-result"), + context_details=ContextDetails(result=json.dumps("parent-result")), ) child_op = SvcOperation( @@ -125,7 +127,7 @@ def test_context_operation_with_children(): status=OperationStatus.SUCCEEDED, parent_id="parent-id", name="child-step", - step_details=StepDetails(result="child-result"), + step_details=StepDetails(result=json.dumps("child-result")), ) all_ops = [parent_op, child_op] @@ -301,7 +303,7 @@ def test_context_operation_get_execution(): def test_step_operation_from_svc_operation(): """Test StepOperation creation from service operation.""" - step_details = StepDetails(attempt=2, result="step-result", error=None) + step_details = StepDetails(attempt=2, result=json.dumps("step-result"), error=None) svc_op = SvcOperation( operation_id="step-id", operation_type=OperationType.STEP, @@ -365,7 +367,9 @@ def test_wait_operation_wrong_type(): def test_callback_operation_from_svc_operation(): """Test CallbackOperation creation from service operation.""" - callback_details = CallbackDetails(callback_id="cb-123", result="callback-result") + callback_details = CallbackDetails( + callback_id="cb-123", result=json.dumps("callback-result") + ) svc_op = SvcOperation( operation_id="callback-id", operation_type=OperationType.CALLBACK, @@ -399,7 +403,7 @@ def test_invoke_operation_from_svc_operation(): """Test InvokeOperation creation from service operation.""" invoke_details = InvokeDetails( durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test", - result="invoke-result", + result=json.dumps("invoke-result"), ) svc_op = SvcOperation( operation_id="invoke-id", @@ -453,7 +457,7 @@ def test_create_operation_step(): operation_id="step-id", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, - step_details=StepDetails(result="test-result"), + step_details=StepDetails(result=json.dumps("test-result")), ) operation = create_operation(svc_op) @@ -490,14 +494,14 @@ def test_durable_function_test_result_create(): step_op.operation_id = "step-id" step_op.status = OperationStatus.SUCCEEDED step_op.name = "test-step" - step_op.step_details = StepDetails(result="step-result") + step_op.step_details = StepDetails(result=json.dumps("step-result")) execution.operations = [exec_op, step_op] # Mock execution result execution.result = Mock() execution.result.status = InvocationStatus.SUCCEEDED - execution.result.result = "test-result" + execution.result.result = json.dumps("test-result") execution.result.error = None result = DurableFunctionTestResult.create(execution) @@ -732,7 +736,7 @@ def test_durable_function_test_runner_run(): mock_execution.operations = [] mock_execution.result = Mock() mock_execution.result.status = InvocationStatus.SUCCEEDED - mock_execution.result.result = "test-result" + mock_execution.result.result = json.dumps("test-result") mock_execution.result.error = None mock_store.load.return_value = mock_execution @@ -781,7 +785,7 @@ def test_durable_function_test_runner_run_with_custom_params(): mock_execution.operations = [] mock_execution.result = Mock() mock_execution.result.status = InvocationStatus.SUCCEEDED - mock_execution.result.result = "test-result" + mock_execution.result.result = json.dumps("test-result") mock_execution.result.error = None mock_store.load.return_value = mock_execution @@ -854,7 +858,7 @@ def test_context_operation_with_child_operations_none(): operation_id="ctx-id", operation_type=OperationType.CONTEXT, status=OperationStatus.SUCCEEDED, - context_details=ContextDetails(result="test-result"), + context_details=ContextDetails(result=json.dumps("test-result")), ) ctx_op = ContextOperation.from_svc_operation(svc_op, None) @@ -882,7 +886,7 @@ def test_step_operation_with_child_operations_none(): operation_id="step-id", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, - step_details=StepDetails(result="step-result"), + step_details=StepDetails(result=json.dumps("step-result")), ) step_op = StepOperation.from_svc_operation(svc_op, None) @@ -906,14 +910,113 @@ def test_durable_function_test_result_create_with_parent_operations(): root_op.operation_id = "root-id" root_op.status = OperationStatus.SUCCEEDED root_op.name = "root-step" - root_op.step_details = StepDetails(result="root-result") + root_op.step_details = StepDetails(result=json.dumps("root-result")) execution.operations = [child_op, root_op] execution.result = Mock() execution.result.status = InvocationStatus.SUCCEEDED - execution.result.result = "test-result" + execution.result.result = json.dumps("test-result") execution.result.error = None result = DurableFunctionTestResult.create(execution) assert len(result.operations) == 1 # Only root operation included + + +@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore") +@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient") +@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker") +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +@patch("aws_durable_execution_sdk_python_testing.runner.durable_handler") +def test_durable_context_test_runner_init( + mock_durable_handler, + mock_executor, + mock_invoker, + mock_client, + mock_processor, + mock_store, + mock_scheduler, +): + """Test DurableContextTestRunner initialization.""" + handler = Mock() + decorated_handler = Mock() + mock_durable_handler.return_value = decorated_handler + + DurableChildContextTestRunner(handler) # type: ignore + + # Verify all components are initialized + mock_scheduler.assert_called_once() + mock_scheduler.return_value.start.assert_called_once() + mock_store.assert_called_once() + mock_processor.assert_called_once() + mock_client.assert_called_once() + mock_invoker.assert_called_once_with(decorated_handler, mock_client.return_value) + mock_executor.assert_called_once() + + # Verify observer pattern setup + mock_processor.return_value.add_execution_observer.assert_called_once_with( + mock_executor.return_value + ) + + # Verify durable_handler was called (with internal lambda function) + mock_durable_handler.assert_called_once() + + # Verify the lambda function calls our handler + durable_handler_func = mock_durable_handler.call_args.args[0] + assert callable(durable_handler_func) + + # verify handler is called when durable function is invoked + durable_handler_func(Mock(), Mock()) + handler.assert_called_once() + + +@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore") +@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient") +@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker") +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +@patch("aws_durable_execution_sdk_python_testing.runner.durable_handler") +def test_durable_child_context_test_runner_init_with_args( + mock_durable_handler, + mock_executor, + mock_invoker, + mock_client, + mock_processor, + mock_store, + mock_scheduler, +): + """Test DurableChildContextTestRunner initialization with additional args.""" + handler = Mock() + decorated_handler = Mock() + mock_durable_handler.return_value = decorated_handler + + str_input = "a random string input" + num_input = 10 + DurableChildContextTestRunner(handler, str_input, num=num_input) # type: ignore + + # Verify all components are initialized + mock_scheduler.assert_called_once() + mock_scheduler.return_value.start.assert_called_once() + mock_store.assert_called_once() + mock_processor.assert_called_once() + mock_client.assert_called_once() + mock_invoker.assert_called_once_with(decorated_handler, mock_client.return_value) + mock_executor.assert_called_once() + + # Verify observer pattern setup + mock_processor.return_value.add_execution_observer.assert_called_once_with( + mock_executor.return_value + ) + + # Verify durable_handler was called (with internal lambda function) + mock_durable_handler.assert_called_once() + # Verify the lambda function calls our handler + durable_handler_func = mock_durable_handler.call_args.args[0] + assert callable(durable_handler_func) + + # verify that handler is called with expected args when durable function is invoked + durable_handler_func(Mock(), Mock()) + handler.assert_called_once_with(str_input, num=num_input)