diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 8728965..8efaed0 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -24,17 +24,17 @@ from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.logger import Logger, LogInfo from aws_durable_execution_sdk_python.operation.callback import ( - create_callback_handler, + CallbackOperationExecutor, wait_for_callback_handler, ) from aws_durable_execution_sdk_python.operation.child import child_handler -from aws_durable_execution_sdk_python.operation.invoke import invoke_handler +from aws_durable_execution_sdk_python.operation.invoke import InvokeOperationExecutor from aws_durable_execution_sdk_python.operation.map import map_handler from aws_durable_execution_sdk_python.operation.parallel import parallel_handler -from aws_durable_execution_sdk_python.operation.step import step_handler -from aws_durable_execution_sdk_python.operation.wait import wait_handler +from aws_durable_execution_sdk_python.operation.step import StepOperationExecutor +from aws_durable_execution_sdk_python.operation.wait import WaitOperationExecutor from aws_durable_execution_sdk_python.operation.wait_for_condition import ( - wait_for_condition_handler, + WaitForConditionOperationExecutor, ) from aws_durable_execution_sdk_python.serdes import ( PassThroughSerDes, @@ -323,13 +323,14 @@ def create_callback( if not config: config = CallbackConfig() operation_id: str = self._create_step_id() - callback_id: str = create_callback_handler( + executor: CallbackOperationExecutor = CallbackOperationExecutor( state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, parent_id=self._parent_id, name=name ), config=config, ) + callback_id: str = executor.process() result: Callback = Callback( callback_id=callback_id, operation_id=operation_id, @@ -357,8 +358,10 @@ def invoke( Returns: The result of the invoked function """ + if not config: + config = InvokeConfig[P, R]() operation_id = self._create_step_id() - result: R = invoke_handler( + executor: InvokeOperationExecutor[R] = InvokeOperationExecutor( function_name=function_name, payload=payload, state=self.state, @@ -369,6 +372,7 @@ def invoke( ), config=config, ) + result: R = executor.process() self.state.track_replay(operation_id=operation_id) return result @@ -505,8 +509,10 @@ def step( ) -> T: step_name = self._resolve_step_name(name, func) logger.debug("Step name: %s", step_name) + if not config: + config = StepConfig() operation_id = self._create_step_id() - result: T = step_handler( + executor: StepOperationExecutor[T] = StepOperationExecutor( func=func, config=config, state=self.state, @@ -517,6 +523,7 @@ def step( ), context_logger=self.logger, ) + result: T = executor.process() self.state.track_replay(operation_id=operation_id) return result @@ -532,8 +539,9 @@ def wait(self, duration: Duration, name: str | None = None) -> None: msg = "duration must be at least 1 second" raise ValidationError(msg) operation_id = self._create_step_id() - wait_handler( - seconds=seconds, + wait_seconds = duration.seconds + executor: WaitOperationExecutor = WaitOperationExecutor( + seconds=wait_seconds, state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, @@ -541,6 +549,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None: name=name, ), ) + executor.process() self.state.track_replay(operation_id=operation_id) def wait_for_callback( @@ -584,17 +593,20 @@ def wait_for_condition( raise ValidationError(msg) operation_id = self._create_step_id() - result: T = wait_for_condition_handler( - check=check, - config=config, - state=self.state, - operation_identifier=OperationIdentifier( - operation_id=operation_id, - parent_id=self._parent_id, - name=name, - ), - context_logger=self.logger, + executor: WaitForConditionOperationExecutor[T] = ( + WaitForConditionOperationExecutor( + check=check, + config=config, + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=operation_id, + parent_id=self._parent_id, + name=name, + ), + context_logger=self.logger, + ) ) + result: T = executor.process() self.state.track_replay(operation_id=operation_id) return result diff --git a/src/aws_durable_execution_sdk_python/operation/base.py b/src/aws_durable_execution_sdk_python/operation/base.py new file mode 100644 index 0000000..5836cda --- /dev/null +++ b/src/aws_durable_execution_sdk_python/operation/base.py @@ -0,0 +1,187 @@ +"""Base classes for operation executors with checkpoint response handling.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, TypeVar + +from aws_durable_execution_sdk_python.exceptions import InvalidStateError + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python.state import CheckpointedResult + +T = TypeVar("T") + + +@dataclass(frozen=True) +class CheckResult(Generic[T]): + """Result of checking operation checkpoint status. + + Encapsulates the outcome of checking an operation's status and determines + the next action in the operation execution flow. + + IMPORTANT: Do not construct directly. Use factory methods: + - create_is_ready_to_execute(checkpoint) - operation ready to execute + - create_started() - checkpoint created, check status again + - create_completed(result) - terminal result available + + Attributes: + is_ready_to_execute: True if the operation is ready to execute its logic + has_checkpointed_result: True if a terminal result is already available + checkpointed_result: Checkpoint data for execute() + deserialized_result: Final result when operation completed + """ + + is_ready_to_execute: bool + has_checkpointed_result: bool + checkpointed_result: CheckpointedResult | None = None + deserialized_result: T | None = None + + @classmethod + def create_is_ready_to_execute( + cls, checkpoint: CheckpointedResult + ) -> CheckResult[T]: + """Create a CheckResult indicating the operation is ready to execute. + + Args: + checkpoint: The checkpoint data to pass to execute() + + Returns: + CheckResult with is_ready_to_execute=True + """ + return cls( + is_ready_to_execute=True, + has_checkpointed_result=False, + checkpointed_result=checkpoint, + ) + + @classmethod + def create_started(cls) -> CheckResult[T]: + """Create a CheckResult signaling that a checkpoint was created. + + Signals that process() should verify checkpoint status again to detect + if the operation completed already during checkpoint creation. + + Returns: + CheckResult indicating process() should check status again + """ + return cls(is_ready_to_execute=False, has_checkpointed_result=False) + + @classmethod + def create_completed(cls, result: T) -> CheckResult[T]: + """Create a CheckResult with a terminal result already deserialized. + + Args: + result: The final deserialized result + + Returns: + CheckResult with has_checkpointed_result=True and deserialized_result set + """ + return cls( + is_ready_to_execute=False, + has_checkpointed_result=True, + deserialized_result=result, + ) + + +class OperationExecutor(ABC, Generic[T]): + """Base class for durable operations with checkpoint response handling. + + Provides a framework for implementing operations that check status after + creating START checkpoints to handle synchronous completion, avoiding + unnecessary execution or suspension. + + The common pattern: + 1. Check operation status + 2. Create START checkpoint if needed + 3. Check status again (detects synchronous completion) + 4. Execute operation logic when ready + + Subclasses must implement: + - check_result_status(): Check status, create checkpoint if needed, return next action + - execute(): Execute the operation logic with checkpoint data + """ + + @abstractmethod + def check_result_status(self) -> CheckResult[T]: + """Check operation status and create START checkpoint if needed. + + Called twice by process() when creating synchronous checkpoints: once before + and once after, to detect if the operation completed immediately. + + This method should: + 1. Get the current checkpoint result + 2. Check for terminal statuses (SUCCEEDED, FAILED, etc.) and handle them + 3. Check for pending statuses and suspend if needed + 4. Create a START checkpoint if the operation hasn't started + 5. Return a CheckResult indicating the next action + + Returns: + CheckResult indicating whether to: + - Return a terminal result (has_checkpointed_result=True) + - Execute operation logic (is_ready_to_execute=True) + - Check status again (neither flag set - checkpoint was just created) + + Raises: + Operation-specific exceptions for terminal failure states + SuspendExecution for pending states + """ + ... # pragma: no cover + + @abstractmethod + def execute(self, checkpointed_result: CheckpointedResult) -> T: + """Execute operation logic with checkpoint data. + + This method is called when the operation is ready to execute its core logic. + It receives the checkpoint data that was returned by check_result_status(). + + Args: + checkpointed_result: The checkpoint data containing operation state + + Returns: + The result of executing the operation + + Raises: + May raise operation-specific errors during execution + """ + ... # pragma: no cover + + def process(self) -> T: + """Process operation with checkpoint response handling. + + Orchestrates the double-check pattern: + 1. Check status (handles replay and existing checkpoints) + 2. If checkpoint was just created, check status again (detects synchronous completion) + 3. Return terminal result if available + 4. Execute operation logic if ready + 5. Raise error for invalid states + + Returns: + The final result of the operation + + Raises: + InvalidStateError: If the check result is in an invalid state + May raise operation-specific errors from check_result_status() or execute() + """ + # Check 1: Entry (handles replay and existing checkpoints) + result = self.check_result_status() + + # If checkpoint was created, verify checkpoint response for immediate status change + if not result.is_ready_to_execute and not result.has_checkpointed_result: + result = self.check_result_status() + + # Return terminal result if available (can be None for operations that return None) + if result.has_checkpointed_result: + return result.deserialized_result # type: ignore[return-value] + + # Execute operation logic + if result.is_ready_to_execute: + if result.checkpointed_result is None: + msg = "CheckResult is marked ready to execute but checkpointed result is not set." + raise InvalidStateError(msg) + return self.execute(result.checkpointed_result) + + # Invalid state - neither terminal nor ready to execute + msg = "Invalid CheckResult state: neither terminal nor ready to execute" + raise InvalidStateError(msg) diff --git a/src/aws_durable_execution_sdk_python/operation/callback.py b/src/aws_durable_execution_sdk_python/operation/callback.py index e7bc064..67c51eb 100644 --- a/src/aws_durable_execution_sdk_python/operation/callback.py +++ b/src/aws_durable_execution_sdk_python/operation/callback.py @@ -10,6 +10,10 @@ CallbackOptions, OperationUpdate, ) +from aws_durable_execution_sdk_python.operation.base import ( + CheckResult, + OperationExecutor, +) from aws_durable_execution_sdk_python.types import WaitForCallbackContext if TYPE_CHECKING: @@ -31,62 +35,117 @@ ) -def create_callback_handler( - state: ExecutionState, - operation_identifier: OperationIdentifier, - config: CallbackConfig | None = None, -) -> str: - """Create the callback checkpoint and return the callback id.""" - callback_options: CallbackOptions = ( - CallbackOptions( - timeout_seconds=config.timeout_seconds, - heartbeat_timeout_seconds=config.heartbeat_timeout_seconds, - ) - if config - else CallbackOptions() - ) +class CallbackOperationExecutor(OperationExecutor[str]): + """Executor for callback operations. - checkpointed_result: CheckpointedResult = state.get_checkpoint_result( - operation_identifier.operation_id - ) - if checkpointed_result.is_failed(): - # have to throw the exact same error on replay as the checkpointed failure - checkpointed_result.raise_callable_error() - - if ( - checkpointed_result.is_started() - or checkpointed_result.is_succeeded() - or checkpointed_result.is_timed_out() + Checks operation status after creating START checkpoints to handle operations + that complete synchronously, avoiding unnecessary execution or suspension. + + Unlike other operations, callbacks NEVER execute logic - they only create + checkpoints and return callback IDs. + + CRITICAL: Errors are deferred to Callback.result() for deterministic replay. + create_callback() always returns the callback_id, even for FAILED callbacks. + """ + + def __init__( + self, + state: ExecutionState, + operation_identifier: OperationIdentifier, + config: CallbackConfig | None, ): - # callback id should already exist + """Initialize the callback operation executor. + + Args: + state: The execution state + operation_identifier: The operation identifier + config: The callback configuration (optional) + """ + self.state = state + self.operation_identifier = operation_identifier + self.config = config + + def check_result_status(self) -> CheckResult[str]: + """Check operation status and create START checkpoint if needed. + + Called twice by process() when creating synchronous checkpoints: once before + and once after, to detect if the operation completed immediately. + + CRITICAL: This method does NOT raise on FAILED status. Errors are deferred + to Callback.result() to ensure deterministic replay. Code between + create_callback() and callback.result() must always execute. + + Returns: + CheckResult.create_is_ready_to_execute() for any existing status (including FAILED) + or CheckResult.create_started() after creating checkpoint + + Raises: + CallbackError: If callback_details are missing from checkpoint + """ + checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result( + self.operation_identifier.operation_id + ) + + # CRITICAL: Do NOT raise on FAILED - defer error to Callback.result() + # If checkpoint exists (any status including FAILED), return ready to execute + # The execute() method will extract the callback_id + if checkpointed_result.is_existent(): + if ( + not checkpointed_result.operation + or not checkpointed_result.operation.callback_details + ): + msg = f"Missing callback details for operation: {self.operation_identifier.operation_id}" + raise CallbackError(msg) + + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + # Create START checkpoint + callback_options: CallbackOptions = ( + CallbackOptions( + timeout_seconds=self.config.timeout_seconds, + heartbeat_timeout_seconds=self.config.heartbeat_timeout_seconds, + ) + if self.config + else CallbackOptions() + ) + + create_callback_operation: OperationUpdate = OperationUpdate.create_callback( + identifier=self.operation_identifier, + callback_options=callback_options, + ) + + # Checkpoint callback START with blocking (is_sync=True, default). + # Must wait for the API to generate and return the callback ID before proceeding. + # The callback ID is needed immediately by the caller to pass to external systems. + self.state.create_checkpoint(operation_update=create_callback_operation) + + # Signal to process() to check status again for immediate response + return CheckResult.create_started() + + def execute(self, checkpointed_result: CheckpointedResult) -> str: + """Execute callback operation by extracting the callback_id. + + Callbacks don't execute logic - they just extract and return the callback_id + from the checkpoint data. + + Args: + checkpointed_result: The checkpoint data containing callback_details + + Returns: + The callback_id from the checkpoint + + Raises: + CallbackError: If callback_details are missing (should never happen) + """ if ( not checkpointed_result.operation or not checkpointed_result.operation.callback_details ): - msg = f"Missing callback details for operation: {operation_identifier.operation_id}" + msg = f"Missing callback details for operation: {self.operation_identifier.operation_id}" raise CallbackError(msg) return checkpointed_result.operation.callback_details.callback_id - create_callback_operation = OperationUpdate.create_callback( - identifier=operation_identifier, - callback_options=callback_options, - ) - # Checkpoint callback START with blocking (is_sync=True, default). - # Must wait for the API to generate and return the callback ID before proceeding. - # The callback ID is needed immediately by the caller to pass to external systems. - state.create_checkpoint(operation_update=create_callback_operation) - - result: CheckpointedResult = state.get_checkpoint_result( - operation_identifier.operation_id - ) - - if not result.operation or not result.operation.callback_details: - msg = f"Missing callback details for operation: {operation_identifier.operation_id}" - raise CallbackError(msg) - - return result.operation.callback_details.callback_id - def wait_for_callback_handler( context: DurableContext, diff --git a/src/aws_durable_execution_sdk_python/operation/child.py b/src/aws_durable_execution_sdk_python/operation/child.py index 07d0a08..04819d4 100644 --- a/src/aws_durable_execution_sdk_python/operation/child.py +++ b/src/aws_durable_execution_sdk_python/operation/child.py @@ -16,13 +16,20 @@ OperationSubType, OperationUpdate, ) +from aws_durable_execution_sdk_python.operation.base import ( + CheckResult, + OperationExecutor, +) from aws_durable_execution_sdk_python.serdes import deserialize, serialize if TYPE_CHECKING: from collections.abc import Callable from aws_durable_execution_sdk_python.identifier import OperationIdentifier - from aws_durable_execution_sdk_python.state import ExecutionState + from aws_durable_execution_sdk_python.state import ( + CheckpointedResult, + ExecutionState, + ) logger = logging.getLogger(__name__) @@ -32,131 +39,239 @@ CHECKPOINT_SIZE_LIMIT = 256 * 1024 -def child_handler( - func: Callable[[], T], - state: ExecutionState, - operation_identifier: OperationIdentifier, - config: ChildConfig | None, -) -> T: - logger.debug( - "▶️ Executing child context for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) +class ChildOperationExecutor(OperationExecutor[T]): + """Executor for child context operations. - if not config: - config = ChildConfig() + Checks operation status after creating START checkpoints to handle operations + that complete synchronously, avoiding unnecessary execution or suspension. + + Handles large payload scenarios with ReplayChildren mode. + """ - checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id) - if ( - checkpointed_result.is_succeeded() - and not checkpointed_result.is_replay_children() + def __init__( + self, + func: Callable[[], T], + state: ExecutionState, + operation_identifier: OperationIdentifier, + config: ChildConfig, ): - logger.debug( - "Child context already completed, skipping execution for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - if checkpointed_result.result is None: - return None # type: ignore - return deserialize( - serdes=config.serdes, - data=checkpointed_result.result, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, - ) - if checkpointed_result.is_failed(): - checkpointed_result.raise_callable_error() - sub_type = config.sub_type or OperationSubType.RUN_IN_CHILD_CONTEXT - - if not checkpointed_result.is_existent(): - start_operation = OperationUpdate.create_context_start( - identifier=operation_identifier, - sub_type=sub_type, + """Initialize the child operation executor. + + Args: + func: The child context function to execute + state: The execution state + operation_identifier: The operation identifier + config: The child configuration + """ + self.func = func + self.state = state + self.operation_identifier = operation_identifier + self.config = config + self.sub_type = config.sub_type or OperationSubType.RUN_IN_CHILD_CONTEXT + + def check_result_status(self) -> CheckResult[T]: + """Check operation status and create START checkpoint if needed. + + Called twice by process() when creating synchronous checkpoints: once before + and once after, to detect if the operation completed immediately. + + Returns: + CheckResult indicating the next action to take + + Raises: + CallableRuntimeError: For FAILED operations + """ + checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result( + self.operation_identifier.operation_id ) - # Checkpoint child context START with non-blocking (is_sync=False). - # This is a fire-and-forget operation for performance - we don't need to wait for - # persistence before executing the child context. The START checkpoint is purely - # for observability and tracking the operation hierarchy. - state.create_checkpoint(operation_update=start_operation, is_sync=False) - - try: - raw_result: T = func() - if checkpointed_result.is_replay_children(): + + # Terminal success without replay_children - deserialize and return + if ( + checkpointed_result.is_succeeded() + and not checkpointed_result.is_replay_children() + ): logger.debug( - "ReplayChildren mode: Executed child context again on replay due to large payload. Exiting child context without creating another checkpoint. id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, + "Child context already completed, skipping execution for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, ) - return raw_result - serialized_result: str = serialize( - serdes=config.serdes, - value=raw_result, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, - ) - # Summary Generator Logic: - # When the serialized result exceeds 256KB, we use ReplayChildren mode to avoid - # checkpointing large payloads. Instead, we checkpoint a compact summary and mark - # the operation for replay. This matches the TypeScript implementation behavior. - # - # See TypeScript reference: - # - aws-durable-execution-sdk-js/src/handlers/run-in-child-context-handler/run-in-child-context-handler.ts (lines ~200-220) - # - # The summary generator creates a JSON summary with metadata (type, counts, status) - # instead of the full BatchResult. During replay, the child context is re-executed - # to reconstruct the full result rather than deserializing from the checkpoint. - replay_children: bool = False - if len(serialized_result) > CHECKPOINT_SIZE_LIMIT: - logger.debug( - "Large payload detected, using ReplayChildren mode: id: %s, name: %s, payload_size: %d, limit: %d", - operation_identifier.operation_id, - operation_identifier.name, - len(serialized_result), - CHECKPOINT_SIZE_LIMIT, + if checkpointed_result.result is None: + return CheckResult.create_completed(None) # type: ignore + + result: T = deserialize( + serdes=self.config.serdes, + data=checkpointed_result.result, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, ) - replay_children = True - # Use summary generator if provided, otherwise use empty string (matches TypeScript) - serialized_result = ( - config.summary_generator(raw_result) if config.summary_generator else "" + return CheckResult.create_completed(result) + + # Terminal success with replay_children - re-execute + if ( + checkpointed_result.is_succeeded() + and checkpointed_result.is_replay_children() + ): + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + # Terminal failure + if checkpointed_result.is_failed(): + checkpointed_result.raise_callable_error() + + # Create START checkpoint if not exists + if not checkpointed_result.is_existent(): + start_operation: OperationUpdate = OperationUpdate.create_context_start( + identifier=self.operation_identifier, + sub_type=self.sub_type, + ) + # Checkpoint child context START with non-blocking (is_sync=False). + # This is a fire-and-forget operation for performance - we don't need to wait for + # persistence before executing the child context. The START checkpoint is purely + # for observability and tracking the operation hierarchy. + self.state.create_checkpoint( + operation_update=start_operation, is_sync=False ) - success_operation = OperationUpdate.create_context_succeed( - identifier=operation_identifier, - payload=serialized_result, - sub_type=sub_type, - context_options=ContextOptions(replay_children=replay_children), - ) - # Checkpoint child context SUCCEED with blocking (is_sync=True, default). - # Must ensure the child context result is persisted before returning to the parent. - # This guarantees the result is durable and child operations won't be re-executed on replay - # (unless replay_children=True for large payloads). - state.create_checkpoint(operation_update=success_operation) + # Ready to execute (checkpoint exists or was just created) + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + def execute(self, checkpointed_result: CheckpointedResult) -> T: + """Execute child context function with error handling and large payload support. + Args: + checkpointed_result: The checkpoint data containing operation state + + Returns: + The result of executing the child context function + + Raises: + SuspendExecution: Re-raised without checkpointing + InvocationError: Re-raised after checkpointing FAIL + CallableRuntimeError: Raised for other exceptions after checkpointing FAIL + """ logger.debug( - "✅ Successfully completed child context for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, + "▶️ Executing child context for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, ) - return raw_result # noqa: TRY300 - except SuspendExecution: - # Don't checkpoint SuspendExecution - let it bubble up - raise - except Exception as e: - error_object = ErrorObject.from_exception(e) - fail_operation = OperationUpdate.create_context_fail( - identifier=operation_identifier, error=error_object, sub_type=sub_type - ) - # Checkpoint child context FAIL with blocking (is_sync=True, default). - # Must ensure the failure state is persisted before raising the exception. - # This guarantees the error is durable and child operations won't be re-executed on replay. - state.create_checkpoint(operation_update=fail_operation) - - # InvocationError and its derivatives can be retried - # When we encounter an invocation error (in all of its forms), we bubble that - # error upwards (with the checkpoint in place) such that we reach the - # execution handler at the very top, which will then induce a retry from the - # dataplane. - if isinstance(e, InvocationError): + + try: + raw_result: T = self.func() + + # If in replay_children mode, return without checkpointing + if checkpointed_result.is_replay_children(): + logger.debug( + "ReplayChildren mode: Executed child context again on replay due to large payload. Exiting child context without creating another checkpoint. id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + return raw_result + + # Serialize result + serialized_result: str = serialize( + serdes=self.config.serdes, + value=raw_result, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, + ) + + # Check payload size and use ReplayChildren mode if needed + # Summary Generator Logic: + # When the serialized result exceeds 256KB, we use ReplayChildren mode to avoid + # checkpointing large payloads. Instead, we checkpoint a compact summary and mark + # the operation for replay. This matches the TypeScript implementation behavior. + # + # See TypeScript reference: + # - aws-durable-execution-sdk-js/src/handlers/run-in-child-context-handler/run-in-child-context-handler.ts (lines ~200-220) + # + # The summary generator creates a JSON summary with metadata (type, counts, status) + # instead of the full BatchResult. During replay, the child context is re-executed + # to reconstruct the full result rather than deserializing from the checkpoint. + replay_children: bool = False + if len(serialized_result) > CHECKPOINT_SIZE_LIMIT: + logger.debug( + "Large payload detected, using ReplayChildren mode: id: %s, name: %s, payload_size: %d, limit: %d", + self.operation_identifier.operation_id, + self.operation_identifier.name, + len(serialized_result), + CHECKPOINT_SIZE_LIMIT, + ) + replay_children = True + # Use summary generator if provided, otherwise use empty string (matches TypeScript) + serialized_result = ( + self.config.summary_generator(raw_result) + if self.config.summary_generator + else "" + ) + + # Checkpoint SUCCEED + success_operation: OperationUpdate = OperationUpdate.create_context_succeed( + identifier=self.operation_identifier, + payload=serialized_result, + sub_type=self.sub_type, + context_options=ContextOptions(replay_children=replay_children), + ) + # Checkpoint child context SUCCEED with blocking (is_sync=True, default). + # Must ensure the child context result is persisted before returning to the parent. + # This guarantees the result is durable and child operations won't be re-executed on replay + # (unless replay_children=True for large payloads). + self.state.create_checkpoint(operation_update=success_operation) + + logger.debug( + "✅ Successfully completed child context for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + return raw_result # noqa: TRY300 + except SuspendExecution: + # Don't checkpoint SuspendExecution - let it bubble up raise - raise error_object.to_callable_runtime_error() from e + except Exception as e: + error_object = ErrorObject.from_exception(e) + fail_operation: OperationUpdate = OperationUpdate.create_context_fail( + identifier=self.operation_identifier, + error=error_object, + sub_type=self.sub_type, + ) + # Checkpoint child context FAIL with blocking (is_sync=True, default). + # Must ensure the failure state is persisted before raising the exception. + # This guarantees the error is durable and child operations won't be re-executed on replay. + self.state.create_checkpoint(operation_update=fail_operation) + + # InvocationError and its derivatives can be retried + # When we encounter an invocation error (in all of its forms), we bubble that + # error upwards (with the checkpoint in place) such that we reach the + # execution handler at the very top, which will then induce a retry from the + # dataplane. + if isinstance(e, InvocationError): + raise + raise error_object.to_callable_runtime_error() from e + + +def child_handler( + func: Callable[[], T], + state: ExecutionState, + operation_identifier: OperationIdentifier, + config: ChildConfig | None, +) -> T: + """Public API for child context operations - maintains existing signature. + + This function creates a ChildOperationExecutor and delegates to its process() method, + maintaining backward compatibility with existing code that calls child_handler. + + Args: + func: The child context function to execute + state: The execution state + operation_identifier: The operation identifier + config: The child configuration (optional) + + Returns: + The result of executing the child context + + Raises: + May raise operation-specific errors during execution + """ + if not config: + config = ChildConfig() + + executor = ChildOperationExecutor(func, state, operation_identifier, config) + return executor.process() diff --git a/src/aws_durable_execution_sdk_python/operation/invoke.py b/src/aws_durable_execution_sdk_python/operation/invoke.py index 4b1eb99..9288c98 100644 --- a/src/aws_durable_execution_sdk_python/operation/invoke.py +++ b/src/aws_durable_execution_sdk_python/operation/invoke.py @@ -5,12 +5,17 @@ import logging from typing import TYPE_CHECKING, TypeVar -from aws_durable_execution_sdk_python.config import InvokeConfig from aws_durable_execution_sdk_python.exceptions import ExecutionError from aws_durable_execution_sdk_python.lambda_service import ( ChainedInvokeOptions, OperationUpdate, ) + +# Import base classes for operation executor pattern +from aws_durable_execution_sdk_python.operation.base import ( + CheckResult, + OperationExecutor, +) from aws_durable_execution_sdk_python.serdes import ( DEFAULT_JSON_SERDES, deserialize, @@ -19,8 +24,12 @@ from aws_durable_execution_sdk_python.suspend import suspend_with_optional_resume_delay if TYPE_CHECKING: + from aws_durable_execution_sdk_python.config import InvokeConfig from aws_durable_execution_sdk_python.identifier import OperationIdentifier - from aws_durable_execution_sdk_python.state import ExecutionState + from aws_durable_execution_sdk_python.state import ( + CheckpointedResult, + ExecutionState, + ) P = TypeVar("P") # Payload type R = TypeVar("R") # Result type @@ -28,92 +37,136 @@ logger = logging.getLogger(__name__) -def invoke_handler( - function_name: str, - payload: P, - state: ExecutionState, - operation_identifier: OperationIdentifier, - config: InvokeConfig[P, R] | None, -) -> R: - """Invoke another Durable Function.""" - logger.debug( - "🔗 Invoke %s (%s)", - operation_identifier.name or function_name, - operation_identifier.operation_id, - ) +class InvokeOperationExecutor(OperationExecutor[R]): + """Executor for invoke operations. + + Checks operation status after creating START checkpoints to handle operations + that complete synchronously, avoiding unnecessary execution or suspension. + + The invoke operation never actually "executes" in the traditional sense - + it always suspends to wait for the async invocation to complete. + """ - if not config: - config = InvokeConfig[P, R]() - tenant_id = config.tenant_id + def __init__( + self, + function_name: str, + payload: P, + state: ExecutionState, + operation_identifier: OperationIdentifier, + config: InvokeConfig[P, R], + ): + """Initialize the invoke operation executor. + + Args: + function_name: Name of the function to invoke + payload: The payload to pass to the invoked function + state: The execution state + operation_identifier: The operation identifier + config: Configuration for the invoke operation + """ + self.function_name = function_name + self.payload = payload + self.state = state + self.operation_identifier = operation_identifier + self.payload = payload + self.config = config + + def check_result_status(self) -> CheckResult[R]: + """Check operation status and create START checkpoint if needed. + + Called twice by process() when creating synchronous checkpoints: once before + and once after, to detect if the operation completed immediately. + + Returns: + CheckResult indicating the next action to take + + Raises: + CallableRuntimeError: For FAILED, TIMED_OUT, or STOPPED operations + SuspendExecution: For STARTED operations waiting for completion + """ + checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result( + self.operation_identifier.operation_id + ) + + # Terminal success - deserialize and return + if checkpointed_result.is_succeeded(): + if checkpointed_result.result is None: + return CheckResult.create_completed(None) # type: ignore - # Check if we have existing step data - checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id) + result: R = deserialize( + serdes=self.config.serdes_result or DEFAULT_JSON_SERDES, + data=checkpointed_result.result, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, + ) + return CheckResult.create_completed(result) - if checkpointed_result.is_succeeded(): - # Return persisted result - no need to check for errors in successful operations + # Terminal failures if ( - checkpointed_result.operation - and checkpointed_result.operation.chained_invoke_details - and checkpointed_result.operation.chained_invoke_details.result + checkpointed_result.is_failed() + or checkpointed_result.is_timed_out() + or checkpointed_result.is_stopped() ): - return deserialize( - serdes=config.serdes_result or DEFAULT_JSON_SERDES, - data=checkpointed_result.operation.chained_invoke_details.result, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, + checkpointed_result.raise_callable_error() + + # Still running - ready to suspend + if checkpointed_result.is_started(): + logger.debug( + "⏳ Invoke %s still in progress, will suspend", + self.operation_identifier.name or self.function_name, + ) + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + # Create START checkpoint if not exists + if not checkpointed_result.is_existent(): + serialized_payload: str = serialize( + serdes=self.config.serdes_payload or DEFAULT_JSON_SERDES, + value=self.payload, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, ) - return None # type: ignore + start_operation: OperationUpdate = OperationUpdate.create_invoke_start( + identifier=self.operation_identifier, + payload=serialized_payload, + chained_invoke_options=ChainedInvokeOptions( + function_name=self.function_name, + tenant_id=self.config.tenant_id, + ), + ) + # Checkpoint invoke START with blocking (is_sync=True). + # Must ensure the chained invocation is recorded before suspending execution. + self.state.create_checkpoint(operation_update=start_operation, is_sync=True) - if ( - checkpointed_result.is_failed() - or checkpointed_result.is_timed_out() - or checkpointed_result.is_stopped() - ): - # Operation failed, throw the exact same error on replay as the checkpointed failure - checkpointed_result.raise_callable_error() - - if checkpointed_result.is_started(): - # Operation is still running, suspend until completion - logger.debug( - "⏳ Invoke %s still in progress, suspending", - operation_identifier.name or function_name, - ) - msg = f"Invoke {operation_identifier.operation_id} still in progress" - suspend_with_optional_resume_delay(msg, config.timeout_seconds) - - serialized_payload: str = serialize( - serdes=config.serdes_payload or DEFAULT_JSON_SERDES, - value=payload, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, - ) + logger.debug( + "🚀 Invoke %s started, will check for immediate response", + self.operation_identifier.name or self.function_name, + ) - # the backend will do the invoke once it gets this checkpoint - start_operation: OperationUpdate = OperationUpdate.create_invoke_start( - identifier=operation_identifier, - payload=serialized_payload, - chained_invoke_options=ChainedInvokeOptions( - function_name=function_name, - tenant_id=tenant_id, - ), - ) + # Signal to process() that checkpoint was created - to recheck status for permissions errs etc. + # before proceeding. + return CheckResult.create_started() - # Checkpoint invoke START with blocking (is_sync=True, default). - # Must ensure the chained invocation is recorded before suspending execution. - # This guarantees the invoke operation is durable and will be tracked by the backend. - state.create_checkpoint(operation_update=start_operation) + # Ready to suspend (checkpoint exists but not in a terminal or started state) + return CheckResult.create_is_ready_to_execute(checkpointed_result) - logger.debug( - "🚀 Invoke %s started, suspending for async execution", - operation_identifier.name or function_name, - ) + def execute(self, _checkpointed_result: CheckpointedResult) -> R: + """Execute invoke operation by suspending to wait for async completion. - # Suspend so invoke executes asynchronously without consuming cpu here - msg = ( - f"Invoke {operation_identifier.operation_id} started, suspending for completion" - ) - suspend_with_optional_resume_delay(msg, config.timeout_seconds) - # This line should never be reached since suspend_with_optional_resume_delay always raises - # if it is ever reached, we will crash in a non-retryable manner via ExecutionError - msg = "suspend_with_optional_resume_delay should have raised an exception, but did not." - raise ExecutionError(msg) from None + The invoke operation doesn't execute synchronously - it suspends and + the backend executes the invoked function asynchronously. + + Args: + checkpointed_result: The checkpoint data (unused, but required by interface) + + Returns: + Never returns - always suspends + + Raises: + Always suspends via suspend_with_optional_resume_delay + ExecutionError: If suspend doesn't raise (should never happen) + """ + msg: str = f"Invoke {self.operation_identifier.operation_id} started, suspending for completion" + suspend_with_optional_resume_delay(msg, self.config.timeout_seconds) + # This line should never be reached since suspend_with_optional_resume_delay always raises + error_msg: str = "suspend_with_optional_resume_delay should have raised an exception, but did not." + raise ExecutionError(error_msg) from None diff --git a/src/aws_durable_execution_sdk_python/operation/step.py b/src/aws_durable_execution_sdk_python/operation/step.py index fd0badb..eb49c9b 100644 --- a/src/aws_durable_execution_sdk_python/operation/step.py +++ b/src/aws_durable_execution_sdk_python/operation/step.py @@ -11,6 +11,7 @@ ) from aws_durable_execution_sdk_python.exceptions import ( ExecutionError, + InvalidStateError, StepInterruptedError, ) from aws_durable_execution_sdk_python.lambda_service import ( @@ -18,6 +19,10 @@ OperationUpdate, ) from aws_durable_execution_sdk_python.logger import Logger, LogInfo +from aws_durable_execution_sdk_python.operation.base import ( + CheckResult, + OperationExecutor, +) from aws_durable_execution_sdk_python.retries import RetryDecision, RetryPresets from aws_durable_execution_sdk_python.serdes import deserialize, serialize from aws_durable_execution_sdk_python.suspend import ( @@ -40,230 +45,314 @@ T = TypeVar("T") -def step_handler( - func: Callable[[StepContext], T], - state: ExecutionState, - operation_identifier: OperationIdentifier, - config: StepConfig | None, - context_logger: Logger, -) -> T: - logger.debug( - "▶️ Executing step for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - - if not config: - config = StepConfig() - - checkpointed_result: CheckpointedResult = state.get_checkpoint_result( - operation_identifier.operation_id - ) - if checkpointed_result.is_succeeded(): - logger.debug( - "Step already completed, skipping execution for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - if checkpointed_result.result is None: - return None # type: ignore - - return deserialize( - serdes=config.serdes, - data=checkpointed_result.result, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, - ) +class StepOperationExecutor(OperationExecutor[T]): + """Executor for step operations. - if checkpointed_result.is_failed(): - # have to throw the exact same error on replay as the checkpointed failure - checkpointed_result.raise_callable_error() - - if checkpointed_result.is_pending(): - scheduled_timestamp = checkpointed_result.get_next_attempt_timestamp() - # normally, we'd ensure that a suspension here would be for > 0 seconds; - # however, this is coming from a checkpoint, and we can trust that it is a correct target timestamp. - suspend_with_optional_resume_timestamp( - msg=f"Retry scheduled for {operation_identifier.name or operation_identifier.operation_id} will retry at timestamp {scheduled_timestamp}", - datetime_timestamp=scheduled_timestamp, - ) + Checks operation status after creating START checkpoints to handle operations + that complete synchronously, avoiding unnecessary execution or suspension. + """ - if ( - checkpointed_result.is_started() - and config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY + def __init__( + self, + func: Callable[[StepContext], T], + config: StepConfig, + state: ExecutionState, + operation_identifier: OperationIdentifier, + context_logger: Logger, ): - # step was previously interrupted - msg = f"Step operation_id={operation_identifier.operation_id} name={operation_identifier.name} was previously interrupted" - retry_handler( - StepInterruptedError(msg), - state, - operation_identifier, - config, - checkpointed_result, + """Initialize the step operation executor. + + Args: + func: The step function to execute + config: The step configuration + state: The execution state + operation_identifier: The operation identifier + context_logger: The logger for the step context + """ + self.func = func + self.config = config + self.state = state + self.operation_identifier = operation_identifier + self.context_logger = context_logger + self._checkpoint_created = False # Track if we created the checkpoint + + def check_result_status(self) -> CheckResult[T]: + """Check operation status and create START checkpoint if needed. + + Called twice by process() when creating synchronous checkpoints: once before + and once after, to detect if the operation completed immediately. + + Returns: + CheckResult indicating the next action to take + + Raises: + CallableRuntimeError: For FAILED operations + StepInterruptedError: For interrupted AT_MOST_ONCE operations + SuspendExecution: For PENDING operations waiting for retry + """ + checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result( + self.operation_identifier.operation_id ) - checkpointed_result.raise_callable_error() + # Terminal success - deserialize and return + if checkpointed_result.is_succeeded(): + logger.debug( + "Step already completed, skipping execution for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + if checkpointed_result.result is None: + return CheckResult.create_completed(None) # type: ignore + + result: T = deserialize( + serdes=self.config.serdes, + data=checkpointed_result.result, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, + ) + return CheckResult.create_completed(result) + + # Terminal failure + if checkpointed_result.is_failed(): + # Have to throw the exact same error on replay as the checkpointed failure + checkpointed_result.raise_callable_error() + + # Pending retry + if checkpointed_result.is_pending(): + scheduled_timestamp = checkpointed_result.get_next_attempt_timestamp() + # Normally, we'd ensure that a suspension here would be for > 0 seconds; + # however, this is coming from a checkpoint, and we can trust that it is a correct target timestamp. + suspend_with_optional_resume_timestamp( + msg=f"Retry scheduled for {self.operation_identifier.name or self.operation_identifier.operation_id} will retry at timestamp {scheduled_timestamp}", + datetime_timestamp=scheduled_timestamp, + ) - if not ( - checkpointed_result.is_started() - and config.step_semantics is StepSemantics.AT_LEAST_ONCE_PER_RETRY - ): - # Do not checkpoint start for started & AT_LEAST_ONCE execution - # Checkpoint start for the other - start_operation: OperationUpdate = OperationUpdate.create_step_start( - identifier=operation_identifier, - ) - # Checkpoint START operation with appropriate synchronization: - # - AtMostOncePerRetry: Use blocking checkpoint (is_sync=True) to prevent duplicate execution. - # The step must not execute until the START checkpoint is persisted, ensuring exactly-once semantics. - # - AtLeastOncePerRetry: Use non-blocking checkpoint (is_sync=False) for performance optimization. - # The step can execute immediately without waiting for checkpoint persistence, allowing at-least-once semantics. - is_sync: bool = config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY - state.create_checkpoint(operation_update=start_operation, is_sync=is_sync) - - attempt: int = 0 - if checkpointed_result.operation and checkpointed_result.operation.step_details: - attempt = checkpointed_result.operation.step_details.attempt - - step_context = StepContext( - logger=context_logger.with_log_info( - LogInfo.from_operation_identifier( - execution_state=state, - op_id=operation_identifier, - attempt=attempt, + # Handle interrupted AT_MOST_ONCE (replay scenario only) + # This check only applies on REPLAY when a new Lambda invocation starts after interruption. + # A STARTED checkpoint with AT_MOST_ONCE on entry means the previous invocation + # was interrupted and it should NOT re-execute. + # + # This check is skipped on fresh executions because: + # - First call (fresh): checkpoint doesn't exist → is_started() returns False → skip this check + # - After creating sync checkpoint and refreshing: if status is STARTED, we return + # ready_to_execute directly, so process() never calls check_result_status() again + if ( + checkpointed_result.is_started() + and self.config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY + ): + # Step was previously interrupted in a prior invocation - handle retry + msg: str = f"Step operation_id={self.operation_identifier.operation_id} name={self.operation_identifier.name} was previously interrupted" + self.retry_handler(StepInterruptedError(msg), checkpointed_result) + checkpointed_result.raise_callable_error() + + # Ready to execute if STARTED + AT_LEAST_ONCE + if ( + checkpointed_result.is_started() + and self.config.step_semantics is StepSemantics.AT_LEAST_ONCE_PER_RETRY + ): + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + # Create START checkpoint if not exists + if not checkpointed_result.is_existent(): + start_operation: OperationUpdate = OperationUpdate.create_step_start( + identifier=self.operation_identifier, + ) + # Checkpoint START operation with appropriate synchronization: + # - AtMostOncePerRetry: Use blocking checkpoint (is_sync=True) to prevent duplicate execution. + # The step must not execute until the START checkpoint is persisted, ensuring exactly-once semantics. + # - AtLeastOncePerRetry: Use non-blocking checkpoint (is_sync=False) for performance optimization. + # The step can execute immediately without waiting for checkpoint persistence, allowing at-least-once semantics. + is_sync: bool = ( + self.config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY + ) + self.state.create_checkpoint( + operation_update=start_operation, is_sync=is_sync ) - ) - ) - try: - # this is the actual code provided by the caller to execute durably inside the step - raw_result: T = func(step_context) - serialized_result: str = serialize( - serdes=config.serdes, - value=raw_result, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, - ) - success_operation: OperationUpdate = OperationUpdate.create_step_succeed( - identifier=operation_identifier, - payload=serialized_result, + # After creating sync checkpoint, check the status + if is_sync: + # Refresh checkpoint result to check for immediate response + refreshed_result: CheckpointedResult = self.state.get_checkpoint_result( + self.operation_identifier.operation_id + ) + + # START checkpoint only returns STARTED status + # Any errors would be thrown as runtime exceptions during checkpoint creation + if not refreshed_result.is_started(): + # This should never happen - defensive check + error_msg: str = f"Unexpected status after START checkpoint: {refreshed_result.status}" + raise InvalidStateError(error_msg) + + # If we reach here, status must be STARTED - ready to execute + return CheckResult.create_is_ready_to_execute(refreshed_result) + + # Ready to execute + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + def execute(self, checkpointed_result: CheckpointedResult) -> T: + """Execute step function with error handling and retry logic. + + Args: + checkpointed_result: The checkpoint data containing operation state + + Returns: + The result of executing the step function + + Raises: + ExecutionError: For fatal errors that should not be retried + May raise other exceptions that will be handled by retry_handler + """ + attempt: int = 0 + if checkpointed_result.operation and checkpointed_result.operation.step_details: + attempt = checkpointed_result.operation.step_details.attempt + + step_context: StepContext = StepContext( + logger=self.context_logger.with_log_info( + LogInfo.from_operation_identifier( + execution_state=self.state, + op_id=self.operation_identifier, + attempt=attempt, + ) + ) ) - # Checkpoint SUCCEED operation with blocking (is_sync=True, default). - # Must ensure the success state is persisted before returning the result to the caller. - # This guarantees the step result is durable and won't be lost if Lambda terminates. - state.create_checkpoint(operation_update=success_operation) - - logger.debug( - "✅ Successfully completed step for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - return raw_result # noqa: TRY300 - except Exception as e: - if isinstance(e, ExecutionError): - # no retry on fatal - e.g checkpoint exception - logger.debug( - "💥 Fatal error for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, + try: + # This is the actual code provided by the caller to execute durably inside the step + raw_result: T = self.func(step_context) + serialized_result: str = serialize( + serdes=self.config.serdes, + value=raw_result, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, ) - # this bubbles up to execution.durable_execution, where it will exit with FAILED - raise - - logger.exception( - "❌ failed step for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - retry_handler(e, state, operation_identifier, config, checkpointed_result) - # if we've failed to raise an exception from the retry_handler, then we are in a - # weird state, and should crash terminate the execution - msg = "retry handler should have raised an exception, but did not." - raise ExecutionError(msg) from None + success_operation: OperationUpdate = OperationUpdate.create_step_succeed( + identifier=self.operation_identifier, + payload=serialized_result, + ) + # Checkpoint SUCCEED operation with blocking (is_sync=True, default). + # Must ensure the success state is persisted before returning the result to the caller. + # This guarantees the step result is durable and won't be lost if Lambda terminates. + self.state.create_checkpoint(operation_update=success_operation) -# TODO: I don't much like this func, needs refactor. Messy grab-bag of args, refine. -def retry_handler( - error: Exception, - state: ExecutionState, - operation_identifier: OperationIdentifier, - config: StepConfig, - checkpointed_result: CheckpointedResult, -): - """Checkpoint and suspend for replay if retry required, otherwise raise error.""" - error_object = ErrorObject.from_exception(error) + logger.debug( + "✅ Successfully completed step for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + return raw_result # noqa: TRY300 + except Exception as e: + if isinstance(e, ExecutionError): + # No retry on fatal - e.g checkpoint exception + logger.debug( + "💥 Fatal error for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + # This bubbles up to execution.durable_execution, where it will exit with FAILED + raise + + logger.exception( + "❌ failed step for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) - retry_strategy = config.retry_strategy or RetryPresets.default() + self.retry_handler(e, checkpointed_result) + # If we've failed to raise an exception from the retry_handler, then we are in a + # weird state, and should crash terminate the execution + msg = "retry handler should have raised an exception, but did not." + raise ExecutionError(msg) from None - retry_attempt: int = ( - checkpointed_result.operation.step_details.attempt - if ( - checkpointed_result.operation and checkpointed_result.operation.step_details - ) - else 0 - ) - retry_decision: RetryDecision = retry_strategy(error, retry_attempt + 1) - - if retry_decision.should_retry: - logger.debug( - "Retrying step for id: %s, name: %s, attempt: %s", - operation_identifier.operation_id, - operation_identifier.name, - retry_attempt + 1, + def retry_handler( + self, + error: Exception, + checkpointed_result: CheckpointedResult, + ): + """Checkpoint and suspend for replay if retry required, otherwise raise error. + + Args: + error: The exception that occurred during step execution + checkpointed_result: The checkpoint data containing operation state + + Raises: + SuspendExecution: If retry is scheduled + StepInterruptedError: If the error is a StepInterruptedError + CallableRuntimeError: If retry is exhausted or error is not retryable + """ + error_object = ErrorObject.from_exception(error) + + retry_strategy = self.config.retry_strategy or RetryPresets.default() + + retry_attempt: int = ( + checkpointed_result.operation.step_details.attempt + if ( + checkpointed_result.operation + and checkpointed_result.operation.step_details + ) + else 0 ) + retry_decision: RetryDecision = retry_strategy(error, retry_attempt + 1) - # because we are issuing a retry and create an OperationUpdate - # we enforce a minimum delay second of 1, to match model behaviour. - # we localize enforcement and keep it outside suspension methods as: - # a) those are used throughout the codebase, e.g. in wait(..) <- enforcement is done in context - # b) they shouldn't know model specific details <- enforcement is done above - # and c) this "issue" arises from retry-decision and we shouldn't push it down - delay_seconds = retry_decision.delay_seconds - if delay_seconds < 1: - logger.warning( - ( - "Retry delay_seconds step for id: %s, name: %s," - "attempt: %s is %d < 1. Setting to minimum of 1 seconds." - ), - operation_identifier.operation_id, - operation_identifier.name, + if retry_decision.should_retry: + logger.debug( + "Retrying step for id: %s, name: %s, attempt: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, retry_attempt + 1, - delay_seconds, ) - delay_seconds = 1 - retry_operation: OperationUpdate = OperationUpdate.create_step_retry( - identifier=operation_identifier, - error=error_object, - next_attempt_delay_seconds=delay_seconds, - ) + # because we are issuing a retry and create an OperationUpdate + # we enforce a minimum delay second of 1, to match model behaviour. + # we localize enforcement and keep it outside suspension methods as: + # a) those are used throughout the codebase, e.g. in wait(..) <- enforcement is done in context + # b) they shouldn't know model specific details <- enforcement is done above + # and c) this "issue" arises from retry-decision and we shouldn't push it down + delay_seconds = retry_decision.delay_seconds + if delay_seconds < 1: + logger.warning( + ( + "Retry delay_seconds step for id: %s, name: %s," + "attempt: %s is %d < 1. Setting to minimum of 1 seconds." + ), + self.operation_identifier.operation_id, + self.operation_identifier.name, + retry_attempt + 1, + delay_seconds, + ) + delay_seconds = 1 + + retry_operation: OperationUpdate = OperationUpdate.create_step_retry( + identifier=self.operation_identifier, + error=error_object, + next_attempt_delay_seconds=delay_seconds, + ) - # Checkpoint RETRY operation with blocking (is_sync=True, default). - # Must ensure retry state is persisted before suspending execution. - # This guarantees the retry attempt count and next attempt timestamp are durable. - state.create_checkpoint(operation_update=retry_operation) - - suspend_with_optional_resume_delay( - msg=( - f"Retry scheduled for {operation_identifier.operation_id}" - f"in {retry_decision.delay_seconds} seconds" - ), - delay_seconds=delay_seconds, - ) + # Checkpoint RETRY operation with blocking (is_sync=True, default). + # Must ensure retry state is persisted before suspending execution. + # This guarantees the retry attempt count and next attempt timestamp are durable. + self.state.create_checkpoint(operation_update=retry_operation) - # no retry - fail_operation: OperationUpdate = OperationUpdate.create_step_fail( - identifier=operation_identifier, error=error_object - ) + suspend_with_optional_resume_delay( + msg=( + f"Retry scheduled for {self.operation_identifier.operation_id}" + f"in {retry_decision.delay_seconds} seconds" + ), + delay_seconds=delay_seconds, + ) + + # no retry + fail_operation: OperationUpdate = OperationUpdate.create_step_fail( + identifier=self.operation_identifier, error=error_object + ) - # Checkpoint FAIL operation with blocking (is_sync=True, default). - # Must ensure the failure state is persisted before raising the exception. - # This guarantees the error is durable and the step won't be retried on replay. - state.create_checkpoint(operation_update=fail_operation) + # Checkpoint FAIL operation with blocking (is_sync=True, default). + # Must ensure the failure state is persisted before raising the exception. + # This guarantees the error is durable and the step won't be retried on replay. + self.state.create_checkpoint(operation_update=fail_operation) - if isinstance(error, StepInterruptedError): - raise error + if isinstance(error, StepInterruptedError): + raise error - raise error_object.to_callable_runtime_error() + raise error_object.to_callable_runtime_error() diff --git a/src/aws_durable_execution_sdk_python/operation/wait.py b/src/aws_durable_execution_sdk_python/operation/wait.py index 90d0880..fc16e66 100644 --- a/src/aws_durable_execution_sdk_python/operation/wait.py +++ b/src/aws_durable_execution_sdk_python/operation/wait.py @@ -6,6 +6,10 @@ from typing import TYPE_CHECKING from aws_durable_execution_sdk_python.lambda_service import OperationUpdate, WaitOptions +from aws_durable_execution_sdk_python.operation.base import ( + CheckResult, + OperationExecutor, +) from aws_durable_execution_sdk_python.suspend import suspend_with_optional_resume_delay if TYPE_CHECKING: @@ -18,36 +22,90 @@ logger = logging.getLogger(__name__) -def wait_handler( - seconds: int, state: ExecutionState, operation_identifier: OperationIdentifier -) -> None: - logger.debug( - "Wait requested for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) +class WaitOperationExecutor(OperationExecutor[None]): + """Executor for wait operations. - checkpointed_result: CheckpointedResult = state.get_checkpoint_result( - operation_identifier.operation_id - ) + Checks operation status after creating START checkpoints to handle operations + that complete synchronously, avoiding unnecessary execution or suspension. + """ - if checkpointed_result.is_succeeded(): - logger.debug( - "Wait already completed, skipping wait for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - return + def __init__( + self, + seconds: int, + state: ExecutionState, + operation_identifier: OperationIdentifier, + ): + """Initialize the wait operation executor. + + Args: + seconds: Number of seconds to wait + state: The execution state + operation_identifier: The operation identifier + """ + self.seconds = seconds + self.state = state + self.operation_identifier = operation_identifier - if not checkpointed_result.is_existent(): - operation = OperationUpdate.create_wait_start( - identifier=operation_identifier, - wait_options=WaitOptions(wait_seconds=seconds), + def check_result_status(self) -> CheckResult[None]: + """Check operation status and create START checkpoint if needed. + + Called twice by process() when creating synchronous checkpoints: once before + and once after, to detect if the operation completed immediately. + + Returns: + CheckResult indicating the next action to take + + Raises: + SuspendExecution: When wait timer has not completed + """ + checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result( + self.operation_identifier.operation_id ) - # Checkpoint wait START with blocking (is_sync=True, default). - # Must ensure the wait operation and scheduled timestamp are persisted before suspending. - # This guarantees the wait will resume at the correct time on the next invocation. - state.create_checkpoint(operation_update=operation) - msg = f"Wait for {seconds} seconds" - suspend_with_optional_resume_delay(msg, seconds) # throws suspend + # Terminal success - wait completed + if checkpointed_result.is_succeeded(): + logger.debug( + "Wait already completed, skipping wait for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + return CheckResult.create_completed(None) + + # Create START checkpoint if not exists + if not checkpointed_result.is_existent(): + operation: OperationUpdate = OperationUpdate.create_wait_start( + identifier=self.operation_identifier, + wait_options=WaitOptions(wait_seconds=self.seconds), + ) + # Checkpoint wait START with blocking (is_sync=True, default). + # Must ensure the wait operation and scheduled timestamp are persisted before suspending. + # This guarantees the wait will resume at the correct time on the next invocation. + self.state.create_checkpoint(operation_update=operation, is_sync=True) + + logger.debug( + "Wait checkpoint created for id: %s, name: %s, will check for immediate response", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + + # Signal to process() that checkpoint was created - which will re-run this check_result_status + # check from the top + return CheckResult.create_started() + + # Ready to suspend (checkpoint exists) + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + def execute(self, _checkpointed_result: CheckpointedResult) -> None: + """Execute wait by suspending. + + Wait operations 'execute' by suspending execution until the timer completes. + This method never returns normally - it always suspends. + + Args: + _checkpointed_result: The checkpoint data (unused for wait) + + Raises: + SuspendExecution: Always suspends to wait for timer completion + """ + msg: str = f"Wait for {self.seconds} seconds" + suspend_with_optional_resume_delay(msg, self.seconds) # throws suspend diff --git a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py index 6ec8e69..d1c2b4f 100644 --- a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py +++ b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py @@ -13,6 +13,10 @@ OperationUpdate, ) from aws_durable_execution_sdk_python.logger import LogInfo +from aws_durable_execution_sdk_python.operation.base import ( + CheckResult, + OperationExecutor, +) from aws_durable_execution_sdk_python.serdes import deserialize, serialize from aws_durable_execution_sdk_python.suspend import ( suspend_with_optional_resume_delay, @@ -40,196 +44,239 @@ logger = logging.getLogger(__name__) -def wait_for_condition_handler( - check: Callable[[T, WaitForConditionCheckContext], T], - config: WaitForConditionConfig[T], - state: ExecutionState, - operation_identifier: OperationIdentifier, - context_logger: Logger, -) -> T: - """Handle wait_for_condition operation. +class WaitForConditionOperationExecutor(OperationExecutor[T]): + """Executor for wait_for_condition operations. - wait_for_condition creates a STEP checkpoint. + Checks operation status after creating START checkpoints to handle operations + that complete synchronously, avoiding unnecessary execution or suspension. """ - logger.debug( - "▶️ Executing wait_for_condition for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - checkpointed_result: CheckpointedResult = state.get_checkpoint_result( - operation_identifier.operation_id - ) - - # Check if already completed - if checkpointed_result.is_succeeded(): - logger.debug( - "wait_for_condition already completed for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - if checkpointed_result.result is None: - return None # type: ignore - return deserialize( - serdes=config.serdes, - data=checkpointed_result.result, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, + def __init__( + self, + check: Callable[[T, WaitForConditionCheckContext], T], + config: WaitForConditionConfig[T], + state: ExecutionState, + operation_identifier: OperationIdentifier, + context_logger: Logger, + ): + """Initialize the wait_for_condition executor. + + Args: + check: The check function to evaluate the condition + config: Configuration for the wait_for_condition operation + state: The execution state + operation_identifier: The operation identifier + context_logger: Logger for the operation context + """ + self.check = check + self.config = config + self.state = state + self.operation_identifier = operation_identifier + self.context_logger = context_logger + + def check_result_status(self) -> CheckResult[T]: + """Check operation status and create START checkpoint if needed. + + Called twice by process() when creating synchronous checkpoints: once before + and once after, to detect if the operation completed immediately. + + Returns: + CheckResult indicating the next action to take + + Raises: + CallableRuntimeError: For FAILED operations + SuspendExecution: For PENDING operations waiting for retry + """ + checkpointed_result = self.state.get_checkpoint_result( + self.operation_identifier.operation_id ) - if checkpointed_result.is_failed(): - checkpointed_result.raise_callable_error() + # Check if already completed + if checkpointed_result.is_succeeded(): + logger.debug( + "wait_for_condition already completed for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + if checkpointed_result.result is None: + return CheckResult.create_completed(None) # type: ignore + result = deserialize( + serdes=self.config.serdes, + data=checkpointed_result.result, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, + ) + return CheckResult.create_completed(result) + + # Terminal failure + if checkpointed_result.is_failed(): + checkpointed_result.raise_callable_error() + + # Pending retry + if checkpointed_result.is_pending(): + scheduled_timestamp = checkpointed_result.get_next_attempt_timestamp() + suspend_with_optional_resume_timestamp( + msg=f"wait_for_condition {self.operation_identifier.name or self.operation_identifier.operation_id} will retry at timestamp {scheduled_timestamp}", + datetime_timestamp=scheduled_timestamp, + ) - if checkpointed_result.is_pending(): - scheduled_timestamp = checkpointed_result.get_next_attempt_timestamp() - suspend_with_optional_resume_timestamp( - msg=f"wait_for_condition {operation_identifier.name or operation_identifier.operation_id} will retry at timestamp {scheduled_timestamp}", - datetime_timestamp=scheduled_timestamp, - ) + # Create START checkpoint if not started + if not checkpointed_result.is_started(): + start_operation = OperationUpdate.create_wait_for_condition_start( + identifier=self.operation_identifier, + ) + # Checkpoint wait_for_condition START with non-blocking (is_sync=False). + # This is purely for observability - we don't need to wait for persistence before + # executing the check function. The START checkpoint just records that polling began. + self.state.create_checkpoint( + operation_update=start_operation, is_sync=False + ) + # For async checkpoint, no immediate response possible + # Proceed directly to execute with current checkpoint data + + # Ready to execute check function + return CheckResult.create_is_ready_to_execute(checkpointed_result) + + def execute(self, checkpointed_result: CheckpointedResult) -> T: + """Execute check function and handle decision. + + Args: + checkpointed_result: The checkpoint data - attempt: int = 1 - if checkpointed_result.is_started_or_ready(): - # This is a retry - get state from previous checkpoint - if checkpointed_result.result: + Returns: + The final state when condition is met + + Raises: + Suspends if condition not met + Raises error if check function fails + """ + # Determine current state from checkpoint + if checkpointed_result.is_started_or_ready() and checkpointed_result.result: try: current_state = deserialize( - serdes=config.serdes, + serdes=self.config.serdes, data=checkpointed_result.result, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, ) except Exception: - # default to initial state if there's an error getting checkpointed state + # Default to initial state if there's an error getting checkpointed state logger.exception( "⚠️ wait_for_condition failed to deserialize state for id: %s, name: %s. Using initial state.", - operation_identifier.operation_id, - operation_identifier.name, + self.operation_identifier.operation_id, + self.operation_identifier.name, ) - current_state = config.initial_state + current_state = self.config.initial_state else: - current_state = config.initial_state + current_state = self.config.initial_state - # at this point operation has to exist. Nonetheless, just in case somehow it's not there. + # Get attempt number + attempt: int = 1 if checkpointed_result.operation and checkpointed_result.operation.step_details: attempt = checkpointed_result.operation.step_details.attempt - else: - # First execution - current_state = config.initial_state - - # Checkpoint START for observability. - if not checkpointed_result.is_started(): - start_operation: OperationUpdate = ( - OperationUpdate.create_wait_for_condition_start( - identifier=operation_identifier, - ) - ) - # Checkpoint wait_for_condition START with non-blocking (is_sync=False). - # This is purely for observability - we don't need to wait for persistence before - # executing the check function. The START checkpoint just records that polling began. - state.create_checkpoint(operation_update=start_operation, is_sync=False) - - try: - # Execute the check function with the injected logger - check_context = WaitForConditionCheckContext( - logger=context_logger.with_log_info( - LogInfo.from_operation_identifier( - execution_state=state, - op_id=operation_identifier, - attempt=attempt, + + try: + # Execute the check function with the injected logger + check_context = WaitForConditionCheckContext( + logger=self.context_logger.with_log_info( + LogInfo.from_operation_identifier( + execution_state=self.state, + op_id=self.operation_identifier, + attempt=attempt, + ) ) ) - ) - new_state = check(current_state, check_context) + new_state = self.check(current_state, check_context) - # Check if condition is met with the wait strategy - decision: WaitForConditionDecision = config.wait_strategy(new_state, attempt) - - serialized_state = serialize( - serdes=config.serdes, - value=new_state, - operation_id=operation_identifier.operation_id, - durable_execution_arn=state.durable_execution_arn, - ) - - logger.debug( - "wait_for_condition check completed: %s, name: %s, attempt: %s", - operation_identifier.operation_id, - operation_identifier.name, - attempt, - ) + # Check if condition is met with the wait strategy + decision: WaitForConditionDecision = self.config.wait_strategy( + new_state, attempt + ) - if not decision.should_continue: - # Condition is met - complete successfully - success_operation = OperationUpdate.create_wait_for_condition_succeed( - identifier=operation_identifier, - payload=serialized_state, + serialized_state = serialize( + serdes=self.config.serdes, + value=new_state, + operation_id=self.operation_identifier.operation_id, + durable_execution_arn=self.state.durable_execution_arn, ) - # Checkpoint SUCCEED operation with blocking (is_sync=True, default). - # Must ensure the final state is persisted before returning to the caller. - # This guarantees the condition result is durable and won't be re-evaluated on replay. - state.create_checkpoint(operation_update=success_operation) logger.debug( - "✅ wait_for_condition completed for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) - return new_state - - # Condition not met - schedule retry - # we enforce a minimum delay second of 1, to match model behaviour. - # we localize enforcement and keep it outside suspension methods as: - # a) those are used throughout the codebase, e.g. in wait(..) <- enforcement is done in context - # b) they shouldn't know model specific details <- enforcement is done above - # and c) this "issue" arises from retry-decision and shouldn't be chased deeper. - delay_seconds = decision.delay_seconds - if delay_seconds is not None and delay_seconds < 1: - logger.warning( - ( - "WaitDecision delay_seconds step for id: %s, name: %s," - "is %d < 1. Setting to minimum of 1 seconds." - ), - operation_identifier.operation_id, - operation_identifier.name, - delay_seconds, + "wait_for_condition check completed: %s, name: %s, attempt: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + attempt, ) - delay_seconds = 1 - retry_operation = OperationUpdate.create_wait_for_condition_retry( - identifier=operation_identifier, - payload=serialized_state, - next_attempt_delay_seconds=delay_seconds, - ) + if not decision.should_continue: + # Condition is met - complete successfully + success_operation = OperationUpdate.create_wait_for_condition_succeed( + identifier=self.operation_identifier, + payload=serialized_state, + ) + # Checkpoint SUCCEED operation with blocking (is_sync=True, default). + # Must ensure the final state is persisted before returning to the caller. + # This guarantees the condition result is durable and won't be re-evaluated on replay. + self.state.create_checkpoint(operation_update=success_operation) + + logger.debug( + "✅ wait_for_condition completed for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) + return new_state + + # Condition not met - schedule retry + # We enforce a minimum delay second of 1, to match model behaviour. + delay_seconds = decision.delay_seconds + if delay_seconds is not None and delay_seconds < 1: + logger.warning( + ( + "WaitDecision delay_seconds step for id: %s, name: %s," + "is %d < 1. Setting to minimum of 1 seconds." + ), + self.operation_identifier.operation_id, + self.operation_identifier.name, + delay_seconds, + ) + delay_seconds = 1 - # Checkpoint RETRY operation with blocking (is_sync=True, default). - # Must ensure the current state and next attempt timestamp are persisted before suspending. - # This guarantees the polling state is durable and will resume correctly on the next invocation. - state.create_checkpoint(operation_update=retry_operation) + retry_operation = OperationUpdate.create_wait_for_condition_retry( + identifier=self.operation_identifier, + payload=serialized_state, + next_attempt_delay_seconds=delay_seconds, + ) - suspend_with_optional_resume_delay( - msg=f"wait_for_condition {operation_identifier.name or operation_identifier.operation_id} will retry in {decision.delay_seconds} seconds", - delay_seconds=decision.delay_seconds, - ) + # Checkpoint RETRY operation with blocking (is_sync=True, default). + # Must ensure the current state and next attempt timestamp are persisted before suspending. + # This guarantees the polling state is durable and will resume correctly on the next invocation. + self.state.create_checkpoint(operation_update=retry_operation) - except Exception as e: - # Mark as failed - waitForCondition doesn't have its own retry logic for errors - # If the check function throws, it's considered a failure - logger.exception( - "❌ wait_for_condition failed for id: %s, name: %s", - operation_identifier.operation_id, - operation_identifier.name, - ) + suspend_with_optional_resume_delay( + msg=f"wait_for_condition {self.operation_identifier.name or self.operation_identifier.operation_id} will retry in {decision.delay_seconds} seconds", + delay_seconds=decision.delay_seconds, + ) + + except Exception as e: + # Mark as failed - waitForCondition doesn't have its own retry logic for errors + # If the check function throws, it's considered a failure + logger.exception( + "❌ wait_for_condition failed for id: %s, name: %s", + self.operation_identifier.operation_id, + self.operation_identifier.name, + ) - fail_operation = OperationUpdate.create_wait_for_condition_fail( - identifier=operation_identifier, - error=ErrorObject.from_exception(e), + fail_operation = OperationUpdate.create_wait_for_condition_fail( + identifier=self.operation_identifier, + error=ErrorObject.from_exception(e), + ) + # Checkpoint FAIL operation with blocking (is_sync=True, default). + # Must ensure the failure state is persisted before raising the exception. + # This guarantees the error is durable and the condition won't be re-evaluated on replay. + self.state.create_checkpoint(operation_update=fail_operation) + raise + + msg: str = ( + "wait_for_condition should never reach this point" # pragma: no cover ) - # Checkpoint FAIL operation with blocking (is_sync=True, default). - # Must ensure the failure state is persisted before raising the exception. - # This guarantees the error is durable and the condition won't be re-evaluated on replay. - state.create_checkpoint(operation_update=fail_operation) - raise - - msg: str = "wait_for_condition should never reach this point" # pragma: no cover - raise ExecutionError(msg) # pragma: no cover + raise ExecutionError(msg) # pragma: no cover diff --git a/tests/context_test.py b/tests/context_test.py index 3168683..4e43347 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -3,7 +3,7 @@ import json import random from itertools import islice -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, MagicMock, Mock, patch import pytest @@ -238,10 +238,13 @@ def test_callback_result_timed_out(): # region create_callback -@patch("aws_durable_execution_sdk_python.context.create_callback_handler") -def test_create_callback_basic(mock_handler): +@patch("aws_durable_execution_sdk_python.context.CallbackOperationExecutor") +def test_create_callback_basic(mock_executor_class): """Test create_callback with basic parameters.""" - mock_handler.return_value = "callback123" + mock_executor = MagicMock() + mock_executor.process.return_value = "callback123" + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -258,17 +261,21 @@ def test_create_callback_basic(mock_handler): assert callback.operation_id == expected_operation_id assert callback.state is mock_state - mock_handler.assert_called_once_with( + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_operation_id, None, None), config=CallbackConfig(), ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.create_callback_handler") -def test_create_callback_with_name_and_config(mock_handler): +@patch("aws_durable_execution_sdk_python.context.CallbackOperationExecutor") +def test_create_callback_with_name_and_config(mock_executor_class): """Test create_callback with name and config.""" - mock_handler.return_value = "callback456" + mock_executor = MagicMock() + mock_executor.process.return_value = "callback456" + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -286,18 +293,23 @@ def test_create_callback_with_name_and_config(mock_handler): assert callback.callback_id == "callback456" assert callback.operation_id == expected_operation_id - mock_handler.assert_called_once_with( + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_operation_id, None, None), config=config, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.create_callback_handler") -def test_create_callback_with_parent_id(mock_handler): +@patch("aws_durable_execution_sdk_python.context.CallbackOperationExecutor") +def test_create_callback_with_parent_id(mock_executor_class): """Test create_callback with parent_id.""" - mock_handler.return_value = "callback789" + mock_executor = MagicMock() + + mock_executor.process.return_value = "callback789" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -313,17 +325,21 @@ def test_create_callback_with_parent_id(mock_handler): assert callback.operation_id == expected_operation_id - mock_handler.assert_called_once_with( + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_operation_id, "parent123"), config=CallbackConfig(), ) -@patch("aws_durable_execution_sdk_python.context.create_callback_handler") -def test_create_callback_increments_counter(mock_handler): +@patch("aws_durable_execution_sdk_python.context.CallbackOperationExecutor") +def test_create_callback_increments_counter(mock_executor_class): """Test create_callback increments step counter.""" - mock_handler.return_value = "callback_test" + mock_executor = MagicMock() + + mock_executor.process.return_value = "callback_test" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -350,10 +366,14 @@ def test_create_callback_increments_counter(mock_handler): # region step -@patch("aws_durable_execution_sdk_python.context.step_handler") -def test_step_basic(mock_handler): +@patch("aws_durable_execution_sdk_python.context.StepOperationExecutor") +def test_step_basic(mock_executor_class): """Test step with basic parameters.""" - mock_handler.return_value = "step_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "step_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -370,19 +390,24 @@ def test_step_basic(mock_handler): result = context.step(mock_callable) assert result == "step_result" - mock_handler.assert_called_once_with( - func=mock_callable, - config=None, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_operation_id, None, None), + config=ANY, # StepConfig() is created in context.step() + func=mock_callable, context_logger=ANY, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.step_handler") -def test_step_with_name_and_config(mock_handler): +@patch("aws_durable_execution_sdk_python.context.StepOperationExecutor") +def test_step_with_name_and_config(mock_executor_class): """Test step with name and config.""" - mock_handler.return_value = "configured_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "configured_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -404,19 +429,24 @@ def test_step_with_name_and_config(mock_handler): expected_id = next(seq) # 6th assert result == "configured_result" - mock_handler.assert_called_once_with( - func=mock_callable, - config=config, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, None, None), + config=config, + func=mock_callable, context_logger=ANY, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.step_handler") -def test_step_with_parent_id(mock_handler): +@patch("aws_durable_execution_sdk_python.context.StepOperationExecutor") +def test_step_with_parent_id(mock_executor_class): """Test step with parent_id.""" - mock_handler.return_value = "parent_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "parent_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -436,19 +466,24 @@ def test_step_with_parent_id(mock_handler): [next(seq) for _ in range(2)] # Skip first 2 expected_id = next(seq) # 3rd - mock_handler.assert_called_once_with( - func=mock_callable, - config=None, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, "parent123"), + config=ANY, + func=mock_callable, context_logger=ANY, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.step_handler") -def test_step_increments_counter(mock_handler): +@patch("aws_durable_execution_sdk_python.context.StepOperationExecutor") +def test_step_increments_counter(mock_executor_class): """Test step increments step counter.""" - mock_handler.return_value = "result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -471,18 +506,22 @@ def test_step_increments_counter(mock_handler): expected_id2 = next(seq) # 12th assert context._step_counter.get_current() == 12 # noqa: SLF001 - assert mock_handler.call_args_list[0][1][ + assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) - assert mock_handler.call_args_list[1][1][ + assert mock_executor_class.call_args_list[1][1][ "operation_identifier" ] == OperationIdentifier(expected_id2, None, None) -@patch("aws_durable_execution_sdk_python.context.step_handler") -def test_step_with_original_name(mock_handler): +@patch("aws_durable_execution_sdk_python.context.StepOperationExecutor") +def test_step_with_original_name(mock_executor_class): """Test step with callable that has _original_name attribute.""" - mock_handler.return_value = "named_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "named_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -498,23 +537,28 @@ def test_step_with_original_name(mock_handler): seq = operation_id_sequence() expected_id = next(seq) # 1st - mock_handler.assert_called_once_with( - func=mock_callable, - config=None, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, None, "override_name"), + config=ANY, + func=mock_callable, context_logger=ANY, ) + mock_executor.process.assert_called_once() # endregion step # region invoke -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_basic(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_basic(mock_executor_class): """Test invoke with basic parameters.""" - mock_handler.return_value = "invoke_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "invoke_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -528,19 +572,24 @@ def test_invoke_basic(mock_handler): assert result == "invoke_result" - mock_handler.assert_called_once_with( - function_name="test_function", - payload="test_payload", + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_operation_id, None, None), - config=None, + function_name="test_function", + payload="test_payload", + config=ANY, # InvokeConfig() is created in context.invoke() ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_with_name_and_config(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_with_name_and_config(mock_executor_class): """Test invoke with name and config.""" - mock_handler.return_value = "configured_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "configured_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -560,19 +609,24 @@ def test_invoke_with_name_and_config(mock_handler): expected_id = next(seq) # 6th assert result == "configured_result" - mock_handler.assert_called_once_with( - function_name="test_function", - payload={"key": "value"}, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, None, "named_invoke"), + function_name="test_function", + payload={"key": "value"}, config=config, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_with_parent_id(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_with_parent_id(mock_executor_class): """Test invoke with parent_id.""" - mock_handler.return_value = "parent_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "parent_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -587,19 +641,24 @@ def test_invoke_with_parent_id(mock_handler): [next(seq) for _ in range(2)] expected_id = next(seq) - mock_handler.assert_called_once_with( - function_name="test_function", - payload=None, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, "parent123", None), - config=None, + function_name="test_function", + payload=None, + config=ANY, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_increments_counter(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_increments_counter(mock_executor_class): """Test invoke increments step counter.""" - mock_handler.return_value = "result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -617,18 +676,22 @@ def test_invoke_increments_counter(mock_handler): expected_id2 = next(seq) assert context._step_counter.get_current() == 12 # noqa: SLF001 - assert mock_handler.call_args_list[0][1][ + assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) - assert mock_handler.call_args_list[1][1][ + assert mock_executor_class.call_args_list[1][1][ "operation_identifier" ] == OperationIdentifier(expected_id2, None, None) -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_with_none_payload(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_with_none_payload(mock_executor_class): """Test invoke with None payload.""" - mock_handler.return_value = None + mock_executor = MagicMock() + + mock_executor.process.return_value = None + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -643,19 +706,24 @@ def test_invoke_with_none_payload(mock_handler): assert result is None - mock_handler.assert_called_once_with( - function_name="test_function", - payload=None, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, None, None), - config=None, + function_name="test_function", + payload=None, + config=ANY, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_with_custom_serdes(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_with_custom_serdes(mock_executor_class): """Test invoke with custom serialization config.""" - mock_handler.return_value = {"transformed": "data"} + mock_executor = MagicMock() + + mock_executor.process.return_value = {"transformed": "data"} + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -682,24 +750,29 @@ def test_invoke_with_custom_serdes(mock_handler): expected_id = next(seq) assert result == {"transformed": "data"} - mock_handler.assert_called_once_with( - function_name="test_function", - payload={"original": "data"}, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier( expected_id, None, "custom_serdes_invoke" ), + function_name="test_function", + payload={"original": "data"}, config=config, ) + mock_executor.process.assert_called_once() # endregion invoke # region wait -@patch("aws_durable_execution_sdk_python.context.wait_handler") -def test_wait_basic(mock_handler): +@patch("aws_durable_execution_sdk_python.context.WaitOperationExecutor") +def test_wait_basic(mock_executor_class): """Test wait with basic parameters.""" + mock_executor = MagicMock() + mock_executor.process.return_value = None + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -711,16 +784,21 @@ def test_wait_basic(mock_handler): context.wait(Duration.from_seconds(30)) - mock_handler.assert_called_once_with( - seconds=30, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_operation_id, None, None), + seconds=30, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.wait_handler") -def test_wait_with_name(mock_handler): +@patch("aws_durable_execution_sdk_python.context.WaitOperationExecutor") +def test_wait_with_name(mock_executor_class): """Test wait with name parameter.""" + mock_executor = MagicMock() + mock_executor.process.return_value = None + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -735,16 +813,21 @@ def test_wait_with_name(mock_handler): [next(seq) for _ in range(5)] expected_id = next(seq) - mock_handler.assert_called_once_with( - seconds=60, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, None, "test_wait"), + seconds=60, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.wait_handler") -def test_wait_with_parent_id(mock_handler): +@patch("aws_durable_execution_sdk_python.context.WaitOperationExecutor") +def test_wait_with_parent_id(mock_executor_class): """Test wait with parent_id.""" + mock_executor = MagicMock() + mock_executor.process.return_value = None + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -759,16 +842,21 @@ def test_wait_with_parent_id(mock_handler): [next(seq) for _ in range(2)] expected_id = next(seq) - mock_handler.assert_called_once_with( - seconds=45, + mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier(expected_id, "parent123"), + seconds=45, ) + mock_executor.process.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.wait_handler") -def test_wait_increments_counter(mock_handler): +@patch("aws_durable_execution_sdk_python.context.WaitOperationExecutor") +def test_wait_increments_counter(mock_executor_class): """Test wait increments step counter.""" + mock_executor = MagicMock() + mock_executor.process.return_value = None + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -786,17 +874,21 @@ def test_wait_increments_counter(mock_handler): expected_id2 = next(seq) assert context._step_counter.get_current() == 12 # noqa: SLF001 - assert mock_handler.call_args_list[0][1][ + assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) - assert mock_handler.call_args_list[1][1][ + assert mock_executor_class.call_args_list[1][1][ "operation_identifier" ] == OperationIdentifier(expected_id2, None, None) -@patch("aws_durable_execution_sdk_python.context.wait_handler") -def test_wait_returns_none(mock_handler): +@patch("aws_durable_execution_sdk_python.context.WaitOperationExecutor") +def test_wait_returns_none(mock_executor_class): """Test wait returns None.""" + mock_executor = MagicMock() + mock_executor.process.return_value = None + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -809,9 +901,13 @@ def test_wait_returns_none(mock_handler): assert result is None -@patch("aws_durable_execution_sdk_python.context.wait_handler") -def test_wait_with_time_less_than_one(mock_handler): +@patch("aws_durable_execution_sdk_python.context.WaitOperationExecutor") +def test_wait_with_time_less_than_one(mock_executor_class): """Test wait with time less than one.""" + mock_executor = MagicMock() + mock_executor.process.return_value = None + mock_executor_class.return_value = mock_executor + mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -889,9 +985,13 @@ def test_run_in_child_context_with_name_and_config(mock_handler): @patch("aws_durable_execution_sdk_python.context.child_handler") -def test_run_in_child_context_with_parent_id(mock_handler): +def test_run_in_child_context_with_parent_id(mock_executor_class): """Test run_in_child_context with parent_id.""" - mock_handler.return_value = "parent_child_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "parent_child_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -910,14 +1010,14 @@ def test_run_in_child_context_with_parent_id(mock_handler): [next(seq) for _ in range(1)] expected_id = next(seq) - call_args = mock_handler.call_args + call_args = mock_executor_class.call_args assert call_args[1]["operation_identifier"] == OperationIdentifier( expected_id, "parent456", None ) @patch("aws_durable_execution_sdk_python.context.child_handler") -def test_run_in_child_context_creates_child_context(mock_handler): +def test_run_in_child_context_creates_child_context(mock_executor_class): """Test run_in_child_context creates proper child context.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( @@ -935,7 +1035,7 @@ def capture_child_context(child_context): return "child_executed" mock_callable = Mock(side_effect=capture_child_context) - mock_handler.side_effect = lambda func, **kwargs: func() + mock_executor_class.side_effect = lambda func, **kwargs: func() context = DurableContext(state=mock_state) @@ -946,9 +1046,13 @@ def capture_child_context(child_context): @patch("aws_durable_execution_sdk_python.context.child_handler") -def test_run_in_child_context_increments_counter(mock_handler): +def test_run_in_child_context_increments_counter(mock_executor_class): """Test run_in_child_context increments step counter.""" - mock_handler.return_value = "result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -970,18 +1074,22 @@ def test_run_in_child_context_increments_counter(mock_handler): expected_id2 = next(seq) assert context._step_counter.get_current() == 7 # noqa: SLF001 - assert mock_handler.call_args_list[0][1][ + assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) - assert mock_handler.call_args_list[1][1][ + assert mock_executor_class.call_args_list[1][1][ "operation_identifier" ] == OperationIdentifier(expected_id2, None, None) @patch("aws_durable_execution_sdk_python.context.child_handler") -def test_run_in_child_context_resolves_name_from_callable(mock_handler): +def test_run_in_child_context_resolves_name_from_callable(mock_executor_class): """Test run_in_child_context resolves name from callable._original_name.""" - mock_handler.return_value = "named_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "named_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -993,7 +1101,7 @@ def test_run_in_child_context_resolves_name_from_callable(mock_handler): context.run_in_child_context(mock_callable) - call_args = mock_handler.call_args + call_args = mock_executor_class.call_args assert call_args[1]["operation_identifier"].name == "original_function_name" @@ -1002,9 +1110,13 @@ def test_run_in_child_context_resolves_name_from_callable(mock_handler): # region wait_for_callback @patch("aws_durable_execution_sdk_python.context.wait_for_callback_handler") -def test_wait_for_callback_basic(mock_handler): +def test_wait_for_callback_basic(mock_executor_class): """Test wait_for_callback with basic parameters.""" - mock_handler.return_value = "callback_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "callback_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -1029,9 +1141,13 @@ def test_wait_for_callback_basic(mock_handler): @patch("aws_durable_execution_sdk_python.context.wait_for_callback_handler") -def test_wait_for_callback_with_name_and_config(mock_handler): +def test_wait_for_callback_with_name_and_config(mock_executor_class): """Test wait_for_callback with name and config.""" - mock_handler.return_value = "configured_callback_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "configured_callback_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -1054,9 +1170,13 @@ def test_wait_for_callback_with_name_and_config(mock_handler): @patch("aws_durable_execution_sdk_python.context.wait_for_callback_handler") -def test_wait_for_callback_resolves_name_from_submitter(mock_handler): +def test_wait_for_callback_resolves_name_from_submitter(mock_executor_class): """Test wait_for_callback resolves name from submitter._original_name.""" - mock_handler.return_value = "named_callback_result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "named_callback_result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -1075,7 +1195,7 @@ def test_wait_for_callback_resolves_name_from_submitter(mock_handler): @patch("aws_durable_execution_sdk_python.context.wait_for_callback_handler") -def test_wait_for_callback_passes_child_context(mock_handler): +def test_wait_for_callback_passes_child_context(mock_executor_class): """Test wait_for_callback passes child context to handler.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( @@ -1088,7 +1208,7 @@ def capture_handler_call(context, submitter, name, config): assert submitter is mock_submitter return "handler_result" - mock_handler.side_effect = capture_handler_call + mock_executor_class.side_effect = capture_handler_call with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: @@ -1103,7 +1223,7 @@ def run_child_context(callable_func, name): result = context.wait_for_callback(mock_submitter) assert result == "handler_result" - mock_handler.assert_called_once() + mock_executor_class.assert_called_once() # endregion wait_for_callback @@ -1606,17 +1726,20 @@ def test_wait_strategy(state, attempt): wait_strategy=test_wait_strategy, initial_state="test" ) - # Mock the handler to track calls + # Mock the executor to track calls with patch( - "aws_durable_execution_sdk_python.context.wait_for_condition_handler" - ) as mock_handler: - mock_handler.return_value = "final_state" + "aws_durable_execution_sdk_python.context.WaitForConditionOperationExecutor" + ) as mock_executor_class: + mock_executor = MagicMock() + mock_executor.process.return_value = "final_state" + mock_executor_class.return_value = mock_executor # Call wait_for_condition method result = context.wait_for_condition(test_check, config) - # Verify wait_for_condition_handler was called (line 425) - mock_handler.assert_called_once() + # Verify executor was called + mock_executor_class.assert_called_once() + mock_executor.process.assert_called_once() assert result == "final_state" @@ -1683,10 +1806,14 @@ def test_operation_id_generation_unique(): assert ids[i] != ids[i + 1] -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_with_explicit_tenant_id(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_with_explicit_tenant_id(mock_executor_class): """Test invoke with explicit tenant_id in config.""" - mock_handler.return_value = "result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -1698,14 +1825,18 @@ def test_invoke_with_explicit_tenant_id(mock_handler): result = context.invoke("test_function", "payload", config=config) assert result == "result" - call_args = mock_handler.call_args[1] + call_args = mock_executor_class.call_args[1] assert call_args["config"].tenant_id == "explicit-tenant" -@patch("aws_durable_execution_sdk_python.context.invoke_handler") -def test_invoke_without_tenant_id_defaults_to_none(mock_handler): +@patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") +def test_invoke_without_tenant_id_defaults_to_none(mock_executor_class): """Test invoke without tenant_id defaults to None.""" - mock_handler.return_value = "result" + mock_executor = MagicMock() + + mock_executor.process.return_value = "result" + + mock_executor_class.return_value = mock_executor mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" @@ -1716,6 +1847,7 @@ def test_invoke_without_tenant_id_defaults_to_none(mock_handler): result = context.invoke("test_function", "payload") assert result == "result" - # Config should be None when not provided - call_args = mock_handler.call_args[1] - assert call_args["config"] is None + # Config is created as InvokeConfig() when not provided + call_args = mock_executor_class.call_args[1] + assert isinstance(call_args["config"], InvokeConfig) + assert call_args["config"].tenant_id is None diff --git a/tests/e2e/checkpoint_response_int_test.py b/tests/e2e/checkpoint_response_int_test.py new file mode 100644 index 0000000..c0fd0f5 --- /dev/null +++ b/tests/e2e/checkpoint_response_int_test.py @@ -0,0 +1,768 @@ +"""Integration tests for immediate checkpoint response handling. + +Tests end-to-end operation execution with the immediate response handling +that's implemented via the OperationExecutor base class pattern. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import Mock, patch + +import pytest + +from aws_durable_execution_sdk_python.config import ChildConfig, Duration +from aws_durable_execution_sdk_python.context import DurableContext, durable_step +from aws_durable_execution_sdk_python.exceptions import InvocationError +from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, +) +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackDetails, + CheckpointOutput, + CheckpointUpdatedExecutionState, + Operation, + OperationStatus, + OperationType, +) + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python.types import StepContext + + +def create_mock_checkpoint_with_operations(): + """Create a mock checkpoint function that properly tracks operations. + + Returns a tuple of (mock_checkpoint_function, checkpoint_calls_list). + The mock properly maintains an operations list that gets updated with each checkpoint. + """ + checkpoint_calls = [] + operations = [ + Operation( + operation_id="execution-1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + ) + ] + + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): + checkpoint_calls.append(updates) + + # Convert updates to Operation objects and add to operations list + for update in updates: + op = Operation( + operation_id=update.operation_id, + operation_type=update.operation_type, + status=OperationStatus.STARTED, + parent_id=update.parent_id, + ) + operations.append(op) + + return CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=operations.copy() + ), + ) + + return mock_checkpoint, checkpoint_calls + + +def test_end_to_end_step_operation_with_double_check(): + """Test end-to-end step operation execution with double-check pattern. + + Verifies that the OperationExecutor.process() method properly calls + check_result_status() twice when a checkpoint is created, enabling + immediate response handling. + """ + + @durable_step + def my_step(step_context: StepContext) -> str: + return "step_result" + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + result: str = context.step(my_step()) + return result + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + result = my_handler(event, lambda_context) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert result["Result"] == '"step_result"' + + # Verify checkpoints were created (START + SUCCEED) + all_operations = [op for batch in checkpoint_calls for op in batch] + assert len(all_operations) == 2 + + +def test_end_to_end_multiple_operations_execute_sequentially(): + """Test end-to-end execution with multiple operations. + + Verifies that multiple operations in a workflow execute correctly + with the immediate response handling pattern. + """ + + @durable_step + def step1(step_context: StepContext) -> str: + return "result1" + + @durable_step + def step2(step_context: StepContext) -> str: + return "result2" + + @durable_execution + def my_handler(event, context: DurableContext) -> list[str]: + return [context.step(step1()), context.step(step2())] + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + result = my_handler(event, lambda_context) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert result["Result"] == '["result1", "result2"]' + + # Verify all checkpoints were created (2 START + 2 SUCCEED) + all_operations = [op for batch in checkpoint_calls for op in batch] + assert len(all_operations) == 4 + + +def test_end_to_end_wait_operation_with_double_check(): + """Test end-to-end wait operation execution with double-check pattern. + + Verifies that wait operations properly use the double-check pattern + for immediate response handling. + """ + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + context.wait(Duration.from_seconds(5)) + return "completed" + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + # Wait will suspend, so we expect PENDING status + result = my_handler(event, lambda_context) + + assert result["Status"] == InvocationStatus.PENDING.value + + # Verify wait checkpoint was created + all_operations = [op for batch in checkpoint_calls for op in batch] + assert len(all_operations) >= 1 + + +def test_end_to_end_checkpoint_synchronization_with_operations_list(): + """Test that synchronous checkpoints properly update operations list. + + Verifies that when is_sync=True, the operations list is updated + before the second status check occurs. + """ + + @durable_step + def my_step(step_context: StepContext) -> str: + return "result" + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + return context.step(my_step()) + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + result = my_handler(event, lambda_context) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + # Verify operations list was properly maintained + all_operations = [op for batch in checkpoint_calls for op in batch] + assert len(all_operations) >= 2 # At least START and SUCCEED + + +def test_callback_deferred_error_handling_to_result(): + """Test callback deferred error handling pattern. + + Verifies that callback operations properly return callback_id through + the immediate response handling pattern, enabling deferred error handling. + """ + + @durable_step + def step_after_callback(step_context: StepContext) -> str: + return "code_executed_after_callback" + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + # Create callback + callback_id = context.create_callback("test_callback") + + # This code executes even if callback will eventually fail + # This is the deferred error handling pattern + result = context.step(step_after_callback()) + + return f"{callback_id}:{result}" + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + checkpoint_calls = [] + operations = [ + Operation( + operation_id="execution-1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + ) + ] + + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): + checkpoint_calls.append(updates) + + # Add operations with proper details + for update in updates: + if update.operation_type == OperationType.CALLBACK: + op = Operation( + operation_id=update.operation_id, + operation_type=update.operation_type, + status=OperationStatus.STARTED, + parent_id=update.parent_id, + callback_details=CallbackDetails( + callback_id=f"cb-{update.operation_id[:8]}" + ), + ) + else: + op = Operation( + operation_id=update.operation_id, + operation_type=update.operation_type, + status=OperationStatus.STARTED, + parent_id=update.parent_id, + ) + operations.append(op) + + return CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=operations.copy() + ), + ) + + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + result = my_handler(event, lambda_context) + + # Verify execution succeeded and code after callback executed + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "code_executed_after_callback" in result["Result"] + + +def test_end_to_end_invoke_operation_with_double_check(): + """Test end-to-end invoke operation execution with double-check pattern. + + Verifies that invoke operations properly use the double-check pattern + for immediate response handling. + """ + + @durable_execution + def my_handler(event, context: DurableContext): + context.invoke("my-function", {"data": "test"}) + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + # Invoke will suspend, so we expect PENDING status + result = my_handler(event, lambda_context) + + assert result["Status"] == InvocationStatus.PENDING.value + + # Verify invoke checkpoint was created + all_operations = [op for batch in checkpoint_calls for op in batch] + assert len(all_operations) >= 1 + + +def test_end_to_end_child_context_with_async_checkpoint(): + """Test end-to-end child context execution with async checkpoint. + + Verifies that child context operations use async checkpoint (is_sync=False) + and execute correctly without waiting for immediate response. + """ + + def child_function(ctx: DurableContext) -> str: + return "child_result" + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + result: str = context.run_in_child_context(child_function) + return result + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + result = my_handler(event, lambda_context) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert result["Result"] == '"child_result"' + + # Verify checkpoints were created (START + SUCCEED) + all_operations = [op for batch in checkpoint_calls for op in batch] + assert len(all_operations) == 2 + + +def test_end_to_end_child_context_replay_children_mode(): + """Test end-to-end child context with large payload and ReplayChildren mode. + + Verifies that child context with large result (>256KB) triggers replay_children mode, + uses summary generator if provided, and re-executes function on replay. + """ + execution_count = {"count": 0} + + def child_function_with_large_result(ctx: DurableContext) -> str: + execution_count["count"] += 1 + return "large" * 256 * 1024 + + def summary_generator(result: str) -> str: + return f"summary_of_{len(result)}_bytes" + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + context.run_in_child_context( + child_function_with_large_result, + config=ChildConfig(summary_generator=summary_generator), + ) + return f"executed_{execution_count['count']}_times" + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + checkpoint_calls = [] + operations = [ + Operation( + operation_id="execution-1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + ) + ] + + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): + checkpoint_calls.append(updates) + + for update in updates: + op = Operation( + operation_id=update.operation_id, + operation_type=update.operation_type, + status=OperationStatus.STARTED, + parent_id=update.parent_id, + ) + operations.append(op) + + return CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=operations.copy() + ), + ) + + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + result = my_handler(event, lambda_context) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # Function executed once during initial execution + assert execution_count["count"] == 1 + + # Verify replay_children was set in SUCCEED checkpoint + all_operations = [op for batch in checkpoint_calls for op in batch] + succeed_updates = [ + op + for op in all_operations + if hasattr(op, "action") and op.action.value == "SUCCEED" + ] + assert len(succeed_updates) == 1 + assert succeed_updates[0].context_options.replay_children is True + + +def test_end_to_end_child_context_error_handling(): + """Test end-to-end child context error handling. + + Verifies that child context that raises exception creates FAIL checkpoint + and error is wrapped as CallableRuntimeError. + """ + + def child_function_that_fails(ctx: DurableContext) -> str: + msg = "Child function error" + raise ValueError(msg) + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + result: str = context.run_in_child_context(child_function_that_fails) + return result + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + result = my_handler(event, lambda_context) + + # Verify execution failed + assert result["Status"] == InvocationStatus.FAILED.value + + # Verify FAIL checkpoint was created + all_operations = [op for batch in checkpoint_calls for op in batch] + fail_updates = [ + op + for op in all_operations + if hasattr(op, "action") and op.action.value == "FAIL" + ] + assert len(fail_updates) == 1 + + +def test_end_to_end_child_context_invocation_error_reraised(): + """Test end-to-end child context InvocationError re-raising. + + Verifies that child context that raises InvocationError creates FAIL checkpoint + and re-raises InvocationError (not wrapped) to enable retry at execution handler level. + """ + + def child_function_with_invocation_error(ctx: DurableContext) -> str: + msg = "Invocation failed in child" + raise InvocationError(msg) + + @durable_execution + def my_handler(event, context: DurableContext) -> str: + result: str = context.run_in_child_context(child_function_with_invocation_error) + return result + + with patch( + "aws_durable_execution_sdk_python.execution.LambdaClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.initialize_client.return_value = mock_client + + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() + mock_client.checkpoint = mock_checkpoint + + event = { + "DurableExecutionArn": "test-arn", + "CheckpointToken": "test-token", + "InitialExecutionState": { + "Operations": [ + { + "Id": "execution-1", + "Type": "EXECUTION", + "Status": "STARTED", + "ExecutionDetails": {"InputPayload": "{}"}, + } + ], + "NextMarker": "", + }, + "LocalRunner": True, + } + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request-id" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 + lambda_context.invoked_function_arn = "test-arn" + lambda_context.tenant_id = None + + # InvocationError should be re-raised (not wrapped) to trigger Lambda retry + with pytest.raises(InvocationError, match="Invocation failed in child"): + my_handler(event, lambda_context) + + # Verify FAIL checkpoint was created before re-raising + all_operations = [op for batch in checkpoint_calls for op in batch] + fail_updates = [ + op + for op in all_operations + if hasattr(op, "action") and op.action.value == "FAIL" + ] + assert len(fail_updates) == 1 diff --git a/tests/e2e/execution_int_test.py b/tests/e2e/execution_int_test.py index 286bfc9..5a884bf 100644 --- a/tests/e2e/execution_int_test.py +++ b/tests/e2e/execution_int_test.py @@ -34,6 +34,49 @@ from aws_durable_execution_sdk_python.types import StepContext +def create_mock_checkpoint_with_operations(): + """Create a mock checkpoint function that properly tracks operations. + + Returns a tuple of (mock_checkpoint_function, checkpoint_calls_list). + The mock properly maintains an operations list that gets updated with each checkpoint. + """ + checkpoint_calls = [] + operations = [ + Operation( + operation_id="execution-1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + ) + ] + + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): + checkpoint_calls.append(updates) + + # Convert updates to Operation objects and add to operations list + for update in updates: + op = Operation( + operation_id=update.operation_id, + operation_type=update.operation_type, + status=OperationStatus.STARTED, # New operations start as STARTED + parent_id=update.parent_id, + ) + operations.append(op) + + return CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=operations.copy() + ), + ) + + return mock_checkpoint, checkpoint_calls + + def test_step_different_ways_to_pass_args(): def step_plain(step_context: StepContext) -> str: return "from step plain" @@ -259,22 +302,8 @@ def my_handler(event, context): mock_client = Mock() mock_client_class.initialize_client.return_value = mock_client - # Mock the checkpoint method to track calls - checkpoint_calls = [] - - def mock_checkpoint( - durable_execution_arn, - checkpoint_token, - updates, - client_token="token", # noqa: S107 - ): - checkpoint_calls.append(updates) - - return CheckpointOutput( - checkpoint_token="new_token", # noqa: S106 - new_execution_state=CheckpointUpdatedExecutionState(), - ) - + # Use helper to create mock that properly tracks operations + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() mock_client.checkpoint = mock_checkpoint # Create test event @@ -428,22 +457,8 @@ def my_handler(event: Any, context: DurableContext): mock_client = Mock() mock_client_class.initialize_client.return_value = mock_client - # Mock the checkpoint method to track calls - checkpoint_calls = [] - - def mock_checkpoint( - durable_execution_arn, - checkpoint_token, - updates, - client_token="token", # noqa: S107 - ): - checkpoint_calls.append(updates) - - return CheckpointOutput( - checkpoint_token="new_token", # noqa: S106 - new_execution_state=CheckpointUpdatedExecutionState(), - ) - + # Use helper to create mock that properly tracks operations + mock_checkpoint, checkpoint_calls = create_mock_checkpoint_with_operations() mock_client.checkpoint = mock_checkpoint # Create test event diff --git a/tests/operation/base_test.py b/tests/operation/base_test.py new file mode 100644 index 0000000..4b20818 --- /dev/null +++ b/tests/operation/base_test.py @@ -0,0 +1,314 @@ +"""Unit tests for OperationExecutor base framework.""" + +from __future__ import annotations + +import pytest + +from aws_durable_execution_sdk_python.exceptions import InvalidStateError +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationStatus, + OperationType, +) +from aws_durable_execution_sdk_python.operation.base import ( + CheckResult, + OperationExecutor, +) +from aws_durable_execution_sdk_python.state import CheckpointedResult + +# Test fixtures and helpers + + +class ConcreteOperationExecutor(OperationExecutor[str]): + """Concrete implementation for testing the abstract base class.""" + + def __init__(self): + self.check_result_status_called = 0 + self.execute_called = 0 + self.check_result_to_return = None + self.execute_result_to_return = "executed_result" + + def check_result_status(self) -> CheckResult[str]: + """Mock implementation that returns configured result.""" + self.check_result_status_called += 1 + if self.check_result_to_return is None: + msg = "check_result_to_return not configured" + raise ValueError(msg) + return self.check_result_to_return + + def execute(self, checkpointed_result: CheckpointedResult) -> str: + """Mock implementation that returns configured result.""" + self.execute_called += 1 + return self.execute_result_to_return + + +def create_mock_checkpoint(status: OperationStatus) -> CheckpointedResult: + """Create a mock CheckpointedResult with the given status.""" + operation = Operation( + operation_id="test_op", + operation_type=OperationType.STEP, + status=status, + ) + return CheckpointedResult.create_from_operation(operation) + + +# Tests for CheckResult factory methods + + +def test_check_result_create_is_ready_to_execute(): + """Test CheckResult.create_is_ready_to_execute factory method.""" + checkpoint = create_mock_checkpoint(OperationStatus.STARTED) + + result = CheckResult.create_is_ready_to_execute(checkpoint) + + assert result.is_ready_to_execute is True + assert result.has_checkpointed_result is False + assert result.checkpointed_result is checkpoint + assert result.deserialized_result is None + + +def test_check_result_create_started(): + """Test CheckResult.create_started factory method.""" + result = CheckResult.create_started() + + assert result.is_ready_to_execute is False + assert result.has_checkpointed_result is False + assert result.checkpointed_result is None + assert result.deserialized_result is None + + +def test_check_result_create_completed(): + """Test CheckResult.create_completed factory method.""" + test_result = "test_completed_result" + + result = CheckResult.create_completed(test_result) + + assert result.is_ready_to_execute is False + assert result.has_checkpointed_result is True + assert result.checkpointed_result is None + assert result.deserialized_result == test_result + + +def test_check_result_create_completed_with_none(): + """Test CheckResult.create_completed with None result (valid for operations that return None).""" + result = CheckResult.create_completed(None) + + assert result.is_ready_to_execute is False + assert result.has_checkpointed_result is True + assert result.checkpointed_result is None + assert result.deserialized_result is None + + +# Tests for OperationExecutor.process() method + + +def test_process_with_terminal_result_on_first_check(): + """Test process() when check_result_status returns terminal result on first call.""" + executor = ConcreteOperationExecutor() + executor.check_result_to_return = CheckResult.create_completed("terminal_result") + + result = executor.process() + + assert result == "terminal_result" + assert executor.check_result_status_called == 1 + assert executor.execute_called == 0 + + +def test_process_with_ready_to_execute_on_first_check(): + """Test process() when check_result_status returns ready_to_execute on first call.""" + executor = ConcreteOperationExecutor() + checkpoint = create_mock_checkpoint(OperationStatus.STARTED) + executor.check_result_to_return = CheckResult.create_is_ready_to_execute(checkpoint) + executor.execute_result_to_return = "execution_result" + + result = executor.process() + + assert result == "execution_result" + assert executor.check_result_status_called == 1 + assert executor.execute_called == 1 + + +def test_process_with_checkpoint_created_then_terminal(): + """Test process() when checkpoint is created, then terminal result on second check.""" + executor = ConcreteOperationExecutor() + + # First call returns create_started (checkpoint was created) + # Second call returns terminal result (immediate response) + call_count = 0 + + def check_result_side_effect(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return CheckResult.create_started() + return CheckResult.create_completed("immediate_response") + + executor.check_result_status = check_result_side_effect + + result = executor.process() + + assert result == "immediate_response" + assert call_count == 2 + assert executor.execute_called == 0 + + +def test_process_with_checkpoint_created_then_ready_to_execute(): + """Test process() when checkpoint is created, then ready_to_execute on second check.""" + executor = ConcreteOperationExecutor() + checkpoint = create_mock_checkpoint(OperationStatus.STARTED) + + # First call returns create_started (checkpoint was created) + # Second call returns ready_to_execute (no immediate response, proceed to execute) + call_count = 0 + + def check_result_side_effect(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return CheckResult.create_started() + return CheckResult.create_is_ready_to_execute(checkpoint) + + executor.check_result_status = check_result_side_effect + executor.execute_result_to_return = "execution_result" + + result = executor.process() + + assert result == "execution_result" + assert call_count == 2 + assert executor.execute_called == 1 + + +def test_process_with_none_result_terminal(): + """Test process() with terminal result that is None (valid for operations returning None).""" + executor = ConcreteOperationExecutor() + executor.check_result_to_return = CheckResult.create_completed(None) + + result = executor.process() + + assert result is None + assert executor.check_result_status_called == 1 + assert executor.execute_called == 0 + + +def test_process_raises_invalid_state_when_checkpointed_result_missing(): + """Test process() raises InvalidStateError when ready_to_execute but checkpoint is None.""" + executor = ConcreteOperationExecutor() + # Create invalid state: ready_to_execute but no checkpoint + executor.check_result_to_return = CheckResult( + is_ready_to_execute=True, + has_checkpointed_result=False, + checkpointed_result=None, + ) + + with pytest.raises(InvalidStateError) as exc_info: + executor.process() + + assert "checkpointed result is not set" in str(exc_info.value) + + +def test_process_raises_invalid_state_when_neither_terminal_nor_ready(): + """Test process() raises InvalidStateError when result is neither terminal nor ready.""" + executor = ConcreteOperationExecutor() + # Create invalid state: neither terminal nor ready (both False) + executor.check_result_to_return = CheckResult( + is_ready_to_execute=False, + has_checkpointed_result=False, + ) + + # Mock to return same invalid state on both calls + call_count = 0 + + def check_result_side_effect(): + nonlocal call_count + call_count += 1 + return CheckResult( + is_ready_to_execute=False, + has_checkpointed_result=False, + ) + + executor.check_result_status = check_result_side_effect + + with pytest.raises(InvalidStateError) as exc_info: + executor.process() + + assert "neither terminal nor ready to execute" in str(exc_info.value) + assert call_count == 2 # Should call twice before raising + + +def test_process_double_check_pattern(): + """Test that process() implements the double-check pattern correctly. + + This verifies the core immediate response handling logic: + 1. Check status once (may find existing checkpoint or create new one) + 2. If checkpoint was just created, check again (catches immediate response) + 3. Only call execute() if ready after both checks + """ + executor = ConcreteOperationExecutor() + checkpoint = create_mock_checkpoint(OperationStatus.STARTED) + + check_calls = [] + + def track_check_calls(): + call_num = len(check_calls) + 1 + check_calls.append(call_num) + + if call_num == 1: + # First check: checkpoint doesn't exist, create it + return CheckResult.create_started() + # Second check: checkpoint exists, ready to execute + return CheckResult.create_is_ready_to_execute(checkpoint) + + executor.check_result_status = track_check_calls + executor.execute_result_to_return = "final_result" + + result = executor.process() + + # Verify the double-check pattern + assert len(check_calls) == 2, "Should check status exactly twice" + assert check_calls == [1, 2], "Checks should be in order" + assert executor.execute_called == 1, "Should execute once after both checks" + assert result == "final_result" + + +def test_process_single_check_when_terminal_immediately(): + """Test that process() only checks once when terminal result is found immediately.""" + executor = ConcreteOperationExecutor() + + check_calls = [] + + def track_check_calls(): + call_num = len(check_calls) + 1 + check_calls.append(call_num) + return CheckResult.create_completed("immediate_terminal") + + executor.check_result_status = track_check_calls + + result = executor.process() + + # Should only check once since terminal result was found + assert len(check_calls) == 1, "Should check status only once for immediate terminal" + assert executor.execute_called == 0, "Should not execute when terminal result found" + assert result == "immediate_terminal" + + +def test_process_single_check_when_ready_immediately(): + """Test that process() only checks once when ready_to_execute is found immediately.""" + executor = ConcreteOperationExecutor() + checkpoint = create_mock_checkpoint(OperationStatus.STARTED) + + check_calls = [] + + def track_check_calls(): + call_num = len(check_calls) + 1 + check_calls.append(call_num) + return CheckResult.create_is_ready_to_execute(checkpoint) + + executor.check_result_status = track_check_calls + executor.execute_result_to_return = "execution_result" + + result = executor.process() + + # Should only check once since ready_to_execute was found + assert len(check_calls) == 1, "Should check status only once when ready immediately" + assert executor.execute_called == 1, "Should execute once" + assert result == "execution_result" diff --git a/tests/operation/callback_test.py b/tests/operation/callback_test.py index 688704e..334e276 100644 --- a/tests/operation/callback_test.py +++ b/tests/operation/callback_test.py @@ -11,11 +11,13 @@ StepConfig, WaitForCallbackConfig, ) +from aws_durable_execution_sdk_python.context import Callback from aws_durable_execution_sdk_python.exceptions import CallbackError, ValidationError from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( CallbackDetails, CallbackOptions, + ErrorObject, Operation, OperationAction, OperationStatus, @@ -24,7 +26,7 @@ OperationUpdate, ) from aws_durable_execution_sdk_python.operation.callback import ( - create_callback_handler, + CallbackOperationExecutor, wait_for_callback_handler, ) from aws_durable_execution_sdk_python.retries import RetryDecision @@ -33,6 +35,17 @@ from aws_durable_execution_sdk_python.types import DurableContext, StepContext +# Test helper - maintains old handler signature for backward compatibility in tests +def create_callback_handler(state, operation_identifier, config=None): + """Test helper that wraps CallbackOperationExecutor with old handler signature.""" + executor = CallbackOperationExecutor( + state=state, + operation_identifier=operation_identifier, + config=config, + ) + return executor.process() + + # region create_callback_handler def test_create_callback_handler_new_operation_with_config(): """Test create_callback_handler creates new checkpoint when operation doesn't exist.""" @@ -142,23 +155,27 @@ def test_create_callback_handler_existing_started_operation(): def test_create_callback_handler_existing_failed_operation(): - """Test create_callback_handler raises error for failed operation.""" + """Test create_callback_handler returns callback_id for failed operation (deferred error).""" + # CRITICAL: create_callback_handler should NOT raise on FAILED + # Errors are deferred to Callback.result() for deterministic replay mock_state = Mock(spec=ExecutionState) - mock_result = Mock(spec=CheckpointedResult) - mock_result.is_failed.return_value = True - mock_result.is_started.return_value = False - msg = "Checkpointed error" - mock_result.raise_callable_error.side_effect = Exception(msg) + failed_op = Operation( + operation_id="callback4", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, + callback_details=CallbackDetails(callback_id="failed_cb4"), + ) + mock_result = CheckpointedResult.create_from_operation(failed_op) mock_state.get_checkpoint_result.return_value = mock_result - with pytest.raises(Exception, match="Checkpointed error"): - create_callback_handler( - state=mock_state, - operation_identifier=OperationIdentifier("callback4", None), - config=None, - ) + # Should return callback_id without raising + callback_id = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback4", None), + config=None, + ) - mock_result.raise_callable_error.assert_called_once() + assert callback_id == "failed_cb4" mock_state.create_checkpoint.assert_not_called() @@ -876,19 +893,25 @@ def test_callback_timeout_configuration(): def test_callback_error_propagation(): """Test error propagation through callback operations.""" + # CRITICAL: create_callback_handler should NOT raise on FAILED + # Errors are deferred to Callback.result() for deterministic replay mock_state = Mock(spec=ExecutionState) - mock_result = Mock(spec=CheckpointedResult) - mock_result.is_failed.return_value = True - msg = "Callback creation failed" - mock_result.raise_callable_error.side_effect = RuntimeError(msg) + failed_op = Operation( + operation_id="error_callback", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, + callback_details=CallbackDetails(callback_id="failed_cb"), + ) + mock_result = CheckpointedResult.create_from_operation(failed_op) mock_state.get_checkpoint_result.return_value = mock_result - with pytest.raises(RuntimeError, match="Callback creation failed"): - create_callback_handler( - state=mock_state, - operation_identifier=OperationIdentifier("error_callback", None), - config=None, - ) + # Should return callback_id without raising + callback_id = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("error_callback", None), + config=None, + ) + assert callback_id == "failed_cb" mock_context = Mock(spec=DurableContext) mock_context.create_callback.side_effect = ValueError("Context creation failed") @@ -1040,3 +1063,471 @@ def test_callback_operation_update_creation(mock_operation_update): # endregion wait_for_callback_handler + + +# region immediate response handling tests +def test_callback_immediate_response_get_checkpoint_result_called_twice(): + """Test that get_checkpoint_result is called twice when checkpoint is created.""" + mock_state = Mock(spec=ExecutionState) + + # First call: not found, second call: started (no immediate response) + not_found = CheckpointedResult.create_not_found() + callback_details = CallbackDetails(callback_id="cb_immediate_1") + started_op = Operation( + operation_id="callback_immediate_1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_immediate_1", None), + config=None, + ) + + # Verify callback_id was returned + assert result == "cb_immediate_1" + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_callback_immediate_response_create_checkpoint_with_is_sync_true(): + """Test that create_checkpoint is called with is_sync=True.""" + mock_state = Mock(spec=ExecutionState) + + # First call: not found, second call: started + not_found = CheckpointedResult.create_not_found() + callback_details = CallbackDetails(callback_id="cb_immediate_2") + started_op = Operation( + operation_id="callback_immediate_2", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_immediate_2", None), + config=None, + ) + + # Verify callback_id was returned + assert result == "cb_immediate_2" + # Verify create_checkpoint was called with is_sync=True (default) + mock_state.create_checkpoint.assert_called_once() + # is_sync=True is the default, so it won't be in kwargs if not explicitly passed + # We just verify the checkpoint was created + + +def test_callback_immediate_response_immediate_success(): + """Test immediate success: checkpoint returns SUCCEEDED on second check. + + When checkpoint returns SUCCEEDED on second check, operation returns callback_id + without raising. + """ + mock_state = Mock(spec=ExecutionState) + + # First call: not found, second call: succeeded (immediate response) + not_found = CheckpointedResult.create_not_found() + callback_details = CallbackDetails(callback_id="cb_immediate_success") + succeeded_op = Operation( + operation_id="callback_immediate_3", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + callback_details=callback_details, + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.side_effect = [not_found, succeeded] + + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_immediate_3", None), + config=None, + ) + + # Verify callback_id was returned without raising + assert result == "cb_immediate_success" + # Verify checkpoint was created + mock_state.create_checkpoint.assert_called_once() + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_callback_immediate_response_immediate_failure_deferred(): + """Test immediate failure deferred: checkpoint returns FAILED on second check. + + CRITICAL: When checkpoint returns FAILED on second check, create_callback() + returns callback_id (does NOT raise). Errors are deferred to Callback.result() + for deterministic replay. + """ + mock_state = Mock(spec=ExecutionState) + + # First call: not found, second call: failed (immediate response) + not_found = CheckpointedResult.create_not_found() + callback_details = CallbackDetails(callback_id="cb_immediate_failed") + failed_op = Operation( + operation_id="callback_immediate_4", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, + callback_details=callback_details, + ) + failed = CheckpointedResult.create_from_operation(failed_op) + mock_state.get_checkpoint_result.side_effect = [not_found, failed] + + # CRITICAL: Should return callback_id without raising + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_immediate_4", None), + config=None, + ) + + # Verify callback_id was returned (error deferred) + assert result == "cb_immediate_failed" + # Verify checkpoint was created + mock_state.create_checkpoint.assert_called_once() + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_callback_result_raises_error_for_failed_callbacks(): + """Test that Callback.result() raises error for FAILED callbacks (deferred error handling). + + This test verifies that errors are properly deferred to Callback.result() rather + than being raised during create_callback(). This ensures deterministic replay: + code between create_callback() and callback.result() always executes. + """ + + mock_state = Mock(spec=ExecutionState) + + # Create a FAILED callback operation + error = ErrorObject( + message="Callback failed", type="CallbackError", data=None, stack_trace=None + ) + callback_details = CallbackDetails( + callback_id="cb_failed_result", result=None, error=error + ) + failed_op = Operation( + operation_id="callback_failed_result", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, + callback_details=callback_details, + ) + failed_result = CheckpointedResult.create_from_operation(failed_op) + mock_state.get_checkpoint_result.return_value = failed_result + + # Create Callback instance + callback = Callback( + callback_id="cb_failed_result", + operation_id="callback_failed_result", + state=mock_state, + serdes=None, + ) + + # Verify that result() raises CallbackError + with pytest.raises(CallbackError, match="Callback failed"): + callback.result() + + +def test_callback_result_raises_error_for_timed_out_callbacks(): + """Test that Callback.result() raises error for TIMED_OUT callbacks.""" + + mock_state = Mock(spec=ExecutionState) + + # Create a TIMED_OUT callback operation + error = ErrorObject( + message="Callback timed out", + type="CallbackTimeoutError", + data=None, + stack_trace=None, + ) + callback_details = CallbackDetails( + callback_id="cb_timed_out_result", result=None, error=error + ) + timed_out_op = Operation( + operation_id="callback_timed_out_result", + operation_type=OperationType.CALLBACK, + status=OperationStatus.TIMED_OUT, + callback_details=callback_details, + ) + timed_out_result = CheckpointedResult.create_from_operation(timed_out_op) + mock_state.get_checkpoint_result.return_value = timed_out_result + + # Create Callback instance + callback = Callback( + callback_id="cb_timed_out_result", + operation_id="callback_timed_out_result", + state=mock_state, + serdes=None, + ) + + # Verify that result() raises CallbackError + with pytest.raises(CallbackError, match="Callback timed out"): + callback.result() + + +def test_callback_immediate_response_no_immediate_response(): + """Test no immediate response: checkpoint returns STARTED on second check. + + When checkpoint returns STARTED on second check, operation returns callback_id + normally (callbacks don't suspend). + """ + mock_state = Mock(spec=ExecutionState) + + # First call: not found, second call: started (no immediate response) + not_found = CheckpointedResult.create_not_found() + callback_details = CallbackDetails(callback_id="cb_immediate_started") + started_op = Operation( + operation_id="callback_immediate_5", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_immediate_5", None), + config=None, + ) + + # Verify callback_id was returned + assert result == "cb_immediate_started" + # Verify checkpoint was created + mock_state.create_checkpoint.assert_called_once() + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_callback_immediate_response_already_completed(): + """Test already completed: checkpoint exists on first check. + + When checkpoint is already SUCCEEDED on first check, no checkpoint is created + and callback_id is returned immediately. + """ + mock_state = Mock(spec=ExecutionState) + + # First call: already succeeded + callback_details = CallbackDetails(callback_id="cb_already_completed") + succeeded_op = Operation( + operation_id="callback_immediate_6", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + callback_details=callback_details, + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.return_value = succeeded + + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_immediate_6", None), + config=None, + ) + + # Verify callback_id was returned + assert result == "cb_already_completed" + # Verify no checkpoint was created (already exists) + mock_state.create_checkpoint.assert_not_called() + # Verify get_checkpoint_result was called only once + assert mock_state.get_checkpoint_result.call_count == 1 + + +def test_callback_immediate_response_already_failed(): + """Test already failed: checkpoint is already FAILED on first check. + + When checkpoint is already FAILED on first check, no checkpoint is created + and callback_id is returned (error deferred to Callback.result()). + """ + mock_state = Mock(spec=ExecutionState) + + # First call: already failed + callback_details = CallbackDetails(callback_id="cb_already_failed") + failed_op = Operation( + operation_id="callback_immediate_7", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, + callback_details=callback_details, + ) + failed = CheckpointedResult.create_from_operation(failed_op) + mock_state.get_checkpoint_result.return_value = failed + + # Should return callback_id without raising + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_immediate_7", None), + config=None, + ) + + # Verify callback_id was returned (error deferred) + assert result == "cb_already_failed" + # Verify no checkpoint was created (already exists) + mock_state.create_checkpoint.assert_not_called() + # Verify get_checkpoint_result was called only once + assert mock_state.get_checkpoint_result.call_count == 1 + + +def test_callback_deferred_error_handling_code_execution_between_create_and_result(): + """Test callback deferred error handling with code execution between create_callback() and callback.result(). + + This test verifies that code between create_callback() and callback.result() executes + even when the callback is FAILED. This ensures deterministic replay. + """ + + mock_state = Mock(spec=ExecutionState) + + # Setup: callback is already FAILED + error = ErrorObject( + message="Callback failed", type="CallbackError", data=None, stack_trace=None + ) + callback_details = CallbackDetails( + callback_id="cb_deferred_error", result=None, error=error + ) + failed_op = Operation( + operation_id="callback_deferred_error", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, + callback_details=callback_details, + ) + failed_result = CheckpointedResult.create_from_operation(failed_op) + mock_state.get_checkpoint_result.return_value = failed_result + + # Step 1: create_callback() returns callback_id without raising + callback_id = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_deferred_error", None), + config=None, + ) + assert callback_id == "cb_deferred_error" + + # Step 2: Code executes between create_callback() and callback.result() + execution_log = [ + "code_executed_after_create_callback", + f"callback_id: {callback_id}", + ] + + # Step 3: Callback.result() raises the error + callback = Callback( + callback_id=callback_id, + operation_id="callback_deferred_error", + state=mock_state, + serdes=None, + ) + + with pytest.raises(CallbackError, match="Callback failed"): + callback.result() + + # Verify code between create_callback() and callback.result() executed + assert execution_log == [ + "code_executed_after_create_callback", + "callback_id: cb_deferred_error", + ] + + +def test_callback_immediate_response_with_config(): + """Test immediate response with callback configuration.""" + mock_state = Mock(spec=ExecutionState) + + # First call: not found, second call: succeeded + not_found = CheckpointedResult.create_not_found() + callback_details = CallbackDetails(callback_id="cb_with_config") + succeeded_op = Operation( + operation_id="callback_with_config", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + callback_details=callback_details, + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.side_effect = [not_found, succeeded] + + config = CallbackConfig( + timeout=Duration.from_minutes(5), heartbeat_timeout=Duration.from_minutes(1) + ) + + result = create_callback_handler( + state=mock_state, + operation_identifier=OperationIdentifier("callback_with_config", None), + config=config, + ) + + # Verify callback_id was returned + assert result == "cb_with_config" + # Verify checkpoint was created with config + mock_state.create_checkpoint.assert_called_once() + call_args = mock_state.create_checkpoint.call_args[1] + operation_update = call_args["operation_update"] + assert operation_update.callback_options.timeout_seconds == 300 + assert operation_update.callback_options.heartbeat_timeout_seconds == 60 + + +# endregion immediate response handling tests + + +def test_callback_returns_id_when_second_check_returns_started(): + """Test when the second checkpoint check returns + STARTED (not terminal), the callback operation returns callback_id normally. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: checkpoint doesn't exist + # Second call: checkpoint returns STARTED (no immediate response) + mock_state.get_checkpoint_result.side_effect = [ + CheckpointedResult.create_not_found(), + CheckpointedResult.create_from_operation( + Operation( + operation_id="callback-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="cb-123"), + ) + ), + ] + + executor = CallbackOperationExecutor( + state=mock_state, + operation_identifier=OperationIdentifier("callback-1", None, "test_callback"), + config=CallbackConfig(), + ) + callback_id = executor.process() + + # Assert - behaves like "old way" + assert callback_id == "cb-123" + assert mock_state.get_checkpoint_result.call_count == 2 # Double-check happened + mock_state.create_checkpoint.assert_called_once() # START checkpoint created + + +def test_callback_returns_id_when_second_check_returns_started_duplicate(): + """Test when the second checkpoint check returns + STARTED (not terminal), the callback operation returns callback_id normally. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: checkpoint doesn't exist + # Second call: checkpoint returns STARTED (no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="callback-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="cb-123"), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + executor = CallbackOperationExecutor( + state=mock_state, + operation_identifier=OperationIdentifier("callback-1", None, "test_callback"), + config=CallbackConfig(), + ) + callback_id = executor.process() + + # Assert - behaves like "old way" + assert callback_id == "cb-123" + assert mock_state.get_checkpoint_result.call_count == 2 # Double-check happened + mock_state.create_checkpoint.assert_called_once() # START checkpoint created diff --git a/tests/operation/child_test.py b/tests/operation/child_test.py index e888ebb..ae1bb3a 100644 --- a/tests/operation/child_test.py +++ b/tests/operation/child_test.py @@ -1,5 +1,7 @@ """Unit tests for child handler.""" +from __future__ import annotations + import json from typing import cast from unittest.mock import Mock @@ -7,7 +9,10 @@ import pytest from aws_durable_execution_sdk_python.config import ChildConfig -from aws_durable_execution_sdk_python.exceptions import CallableRuntimeError +from aws_durable_execution_sdk_python.exceptions import ( + CallableRuntimeError, + InvocationError, +) from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, @@ -34,9 +39,15 @@ ], ) def test_child_handler_not_started( - config: ChildConfig, expected_sub_type: OperationSubType + config: ChildConfig | None, expected_sub_type: OperationSubType ): - """Test child_handler when operation not started.""" + """Test child_handler when operation not started. + + Verifies: + - get_checkpoint_result is called once (async checkpoint, no second check) + - create_checkpoint is called with is_sync=False for START + - Operation executes and creates SUCCEED checkpoint + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -44,7 +55,6 @@ def test_child_handler_not_started( mock_result.is_failed.return_value = False mock_result.is_started.return_value = False mock_result.is_replay_children.return_value = False - mock_result.is_replay_children.return_value = False mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="fresh_result") @@ -54,10 +64,15 @@ def test_child_handler_not_started( ) assert result == "fresh_result" + + # Verify get_checkpoint_result called once (async checkpoint, no second check) + assert mock_state.get_checkpoint_result.call_count == 1 + + # Verify create_checkpoint called twice (start and succeed) mock_state.create_checkpoint.assert_called() - assert mock_state.create_checkpoint.call_count == 2 # start and succeed + assert mock_state.create_checkpoint.call_count == 2 - # Verify start checkpoint + # Verify start checkpoint with is_sync=False start_call = mock_state.create_checkpoint.call_args_list[0] start_operation = start_call[1]["operation_update"] assert start_operation.operation_id == "op1" @@ -65,6 +80,8 @@ def test_child_handler_not_started( assert start_operation.operation_type is OperationType.CONTEXT assert start_operation.sub_type is expected_sub_type assert start_operation.action is OperationAction.START + # CRITICAL: Verify is_sync=False for START checkpoint (async, no immediate response) + assert start_call[1]["is_sync"] is False # Verify success checkpoint success_call = mock_state.create_checkpoint.call_args_list[1] @@ -80,7 +97,13 @@ def test_child_handler_not_started( def test_child_handler_already_succeeded(): - """Test child_handler when operation already succeeded.""" + """Test child_handler when operation already succeeded without replay_children. + + Verifies: + - Returns cached result without executing function + - No checkpoint created + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -95,8 +118,12 @@ def test_child_handler_already_succeeded(): ) assert result == "cached_result" + # Verify function not executed mock_callable.assert_not_called() + # Verify no checkpoint created mock_state.create_checkpoint.assert_not_called() + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 def test_child_handler_already_succeeded_none_result(): @@ -119,7 +146,13 @@ def test_child_handler_already_succeeded_none_result(): def test_child_handler_already_failed(): - """Test child_handler when operation already failed.""" + """Test child_handler when operation already failed. + + Verifies: + - Already failed: raises error without executing function + - No checkpoint created + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_result = Mock() mock_result.is_succeeded.return_value = False @@ -138,7 +171,10 @@ def test_child_handler_already_failed(): None, ) + # Verify function not executed mock_callable.assert_not_called() + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 @pytest.mark.parametrize( @@ -153,9 +189,15 @@ def test_child_handler_already_failed(): ], ) def test_child_handler_already_started( - config: ChildConfig, expected_sub_type: OperationSubType + config: ChildConfig | None, expected_sub_type: OperationSubType ): - """Test child_handler when operation already started.""" + """Test child_handler when operation already started. + + Verifies: + - Operation executes when already started + - Only SUCCEED checkpoint created (no START) + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -172,7 +214,11 @@ def test_child_handler_already_started( assert result == "started_result" - # Verify success checkpoint + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 + + # Verify only success checkpoint (no START since already started) + assert mock_state.create_checkpoint.call_count == 1 success_call = mock_state.create_checkpoint.call_args_list[0] success_operation = success_call[1]["operation_update"] assert success_operation.operation_id == "op5" @@ -197,9 +243,15 @@ def test_child_handler_already_started( ], ) def test_child_handler_callable_exception( - config: ChildConfig, expected_sub_type: OperationSubType + config: ChildConfig | None, expected_sub_type: OperationSubType ): - """Test child_handler when callable raises exception.""" + """Test child_handler when callable raises exception. + + Verifies: + - Error handling: checkpoints FAIL and raises wrapped error + - get_checkpoint_result called once + - create_checkpoint called with is_sync=False for START + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -218,10 +270,14 @@ def test_child_handler_callable_exception( config, ) + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 + + # Verify create_checkpoint called twice (start and fail) mock_state.create_checkpoint.assert_called() - assert mock_state.create_checkpoint.call_count == 2 # start and fail + assert mock_state.create_checkpoint.call_count == 2 - # Verify start checkpoint + # Verify start checkpoint with is_sync=False start_call = mock_state.create_checkpoint.call_args_list[0] start_operation = start_call[1]["operation_update"] assert start_operation.operation_id == "op6" @@ -229,6 +285,7 @@ def test_child_handler_callable_exception( assert start_operation.operation_type is OperationType.CONTEXT assert start_operation.sub_type is expected_sub_type assert start_operation.action is OperationAction.START + assert start_call[1]["is_sync"] is False # Verify fail checkpoint fail_call = mock_state.create_checkpoint.call_args_list[1] @@ -242,13 +299,19 @@ def test_child_handler_callable_exception( def test_child_handler_error_wrapped(): - """Test child_handler wraps regular errors as CallableRuntimeError.""" + """Test child_handler wraps regular errors as CallableRuntimeError. + + Verifies: + - Regular exceptions are wrapped as CallableRuntimeError + - FAIL checkpoint is created + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() mock_result.is_succeeded.return_value = False mock_result.is_failed.return_value = False mock_result.is_started.return_value = False + mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result test_error = RuntimeError("Test error") mock_callable = Mock(side_effect=test_error) @@ -261,6 +324,46 @@ def test_child_handler_error_wrapped(): None, ) + # Verify FAIL checkpoint was created + assert mock_state.create_checkpoint.call_count == 2 # start and fail + + +def test_child_handler_invocation_error_reraised(): + """Test child_handler re-raises InvocationError after checkpointing FAIL. + + Verifies: + - InvocationError: checkpoints FAIL and re-raises (for retry) + - FAIL checkpoint is created + - Original InvocationError is re-raised (not wrapped) + """ + + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + mock_result = Mock() + mock_result.is_succeeded.return_value = False + mock_result.is_failed.return_value = False + mock_result.is_started.return_value = False + mock_result.is_existent.return_value = False + mock_state.get_checkpoint_result.return_value = mock_result + test_error = InvocationError("Invocation failed") + mock_callable = Mock(side_effect=test_error) + + with pytest.raises(InvocationError, match="Invocation failed"): + child_handler( + mock_callable, + mock_state, + OperationIdentifier("op7b", None, "test_name"), + None, + ) + + # Verify FAIL checkpoint was created + assert mock_state.create_checkpoint.call_count == 2 # start and fail + + # Verify fail checkpoint + fail_call = mock_state.create_checkpoint.call_args_list[1] + fail_operation = fail_call[1]["operation_update"] + assert fail_operation.action is OperationAction.FAIL + def test_child_handler_with_config(): """Test child_handler with config parameter.""" @@ -270,6 +373,7 @@ def test_child_handler_with_config(): mock_result.is_succeeded.return_value = False mock_result.is_failed.return_value = False mock_result.is_started.return_value = False + mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="config_result") config = ChildConfig() @@ -280,6 +384,8 @@ def test_child_handler_with_config(): assert result == "config_result" mock_callable.assert_called_once() + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 def test_child_handler_default_serialization(): @@ -291,6 +397,7 @@ def test_child_handler_default_serialization(): mock_result.is_failed.return_value = False mock_result.is_started.return_value = False mock_result.is_replay_children.return_value = False + mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} mock_callable = Mock(return_value=complex_result) @@ -300,6 +407,8 @@ def test_child_handler_default_serialization(): ) assert result == complex_result + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 # Verify JSON serialization was used in checkpoint success_call = [ call @@ -362,6 +471,8 @@ def test_child_handler_custom_serdes_already_succeeded() -> None: expected_checkpoointed_result = {"key": "value", "number": 42, "list": [1, 2, 3]} assert actual_result == expected_checkpoointed_result + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 # endregion child_handler @@ -369,7 +480,12 @@ def test_child_handler_custom_serdes_already_succeeded() -> None: # large payload with summary generator def test_child_handler_large_payload_with_summary_generator() -> None: - """Test child_handler with large payload and summary generator.""" + """Test child_handler with large payload and summary generator. + + Verifies: + - Large payload: uses ReplayChildren mode with summary_generator + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -397,6 +513,9 @@ def my_summary(result: str) -> str: ) assert large_result == actual_result + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 + # Verify replay_children mode with summary success_call = mock_state.create_checkpoint.call_args_list[1] success_operation = success_call[1]["operation_update"] assert success_operation.context_options.replay_children @@ -406,7 +525,12 @@ def my_summary(result: str) -> str: # large payload without summary generator def test_child_handler_large_payload_without_summary_generator() -> None: - """Test child_handler with large payload and no summary generator.""" + """Test child_handler with large payload and no summary generator. + + Verifies: + - Large payload without summary_generator: uses ReplayChildren mode with empty string + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -428,6 +552,9 @@ def test_child_handler_large_payload_without_summary_generator() -> None: ) assert large_result == actual_result + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 + # Verify replay_children mode with empty string success_call = mock_state.create_checkpoint.call_args_list[1] success_operation = success_call[1]["operation_update"] assert success_operation.context_options.replay_children @@ -437,7 +564,13 @@ def test_child_handler_large_payload_without_summary_generator() -> None: # mocked children replay mode execute the function again def test_child_handler_replay_children_mode() -> None: - """Test child_handler in ReplayChildren mode.""" + """Test child_handler in ReplayChildren mode. + + Verifies: + - Already succeeded with replay_children: re-executes function + - No checkpoint created (returns without checkpointing) + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -458,12 +591,21 @@ def test_child_handler_replay_children_mode() -> None: ) assert actual_result == complex_result - + # Verify function was executed (replay_children mode) + mock_callable.assert_called_once() + # Verify no checkpoint created (returns without checkpointing in replay mode) mock_state.create_checkpoint.assert_not_called() + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 def test_small_payload_with_summary_generator(): - """Test: Small payload with summary_generator -> replay_children = False""" + """Test: Small payload with summary_generator -> replay_children = False + + Verifies: + - Small payload does NOT trigger replay_children even with summary_generator + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -491,6 +633,8 @@ def my_summary(result: str) -> str: ) assert actual_result == small_result + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 success_call = mock_state.create_checkpoint.call_args_list[1] success_operation = success_call[1]["operation_update"] @@ -501,7 +645,12 @@ def my_summary(result: str) -> str: def test_small_payload_without_summary_generator(): - """Test: Small payload without summary_generator -> replay_children = False""" + """Test: Small payload without summary_generator -> replay_children = False + + Verifies: + - Small payload does NOT trigger replay_children + - get_checkpoint_result called once + """ mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -526,6 +675,8 @@ def test_small_payload_without_summary_generator(): ) assert actual_result == small_result + # Verify get_checkpoint_result called once + assert mock_state.get_checkpoint_result.call_count == 1 success_call = mock_state.create_checkpoint.call_args_list[1] success_operation = success_call[1]["operation_update"] diff --git a/tests/operation/invoke_test.py b/tests/operation/invoke_test.py index ac8a86b..5bb98da 100644 --- a/tests/operation/invoke_test.py +++ b/tests/operation/invoke_test.py @@ -23,14 +23,27 @@ OperationStatus, OperationType, ) -from aws_durable_execution_sdk_python.operation.invoke import ( - invoke_handler, - suspend_with_optional_resume_delay, -) +from aws_durable_execution_sdk_python.operation.invoke import InvokeOperationExecutor from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState +from aws_durable_execution_sdk_python.suspend import suspend_with_optional_resume_delay from tests.serdes_test import CustomDictSerDes +# Test helper - maintains old handler signature for backward compatibility in tests +def invoke_handler(function_name, payload, state, operation_identifier, config): + """Test helper that wraps InvokeOperationExecutor with old handler signature.""" + if not config: + config = InvokeConfig() + executor = InvokeOperationExecutor( + function_name=function_name, + payload=payload, + state=state, + operation_identifier=operation_identifier, + config=config, + ) + return executor.process() + + def test_invoke_handler_already_succeeded(): """Test invoke_handler when operation already succeeded.""" mock_state = Mock(spec=ExecutionState) @@ -179,7 +192,9 @@ def test_invoke_handler_already_started(status): mock_result = CheckpointedResult.create_from_operation(operation) mock_state.get_checkpoint_result.return_value = mock_result - with pytest.raises(SuspendExecution, match="Invoke invoke6 still in progress"): + with pytest.raises( + SuspendExecution, match="Invoke invoke6 started, suspending for completion" + ): invoke_handler( function_name="test_function", payload="test_input", @@ -221,8 +236,15 @@ def test_invoke_handler_new_operation(): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result + # First call: not found, second call: started (no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke8", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] config = InvokeConfig[str, str](timeout=Duration.from_minutes(1)) @@ -254,8 +276,14 @@ def test_invoke_handler_new_operation_with_timeout(): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_test", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] config = InvokeConfig[str, str](timeout=Duration.from_seconds(30)) @@ -274,8 +302,14 @@ def test_invoke_handler_new_operation_no_timeout(): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_test", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] config = InvokeConfig[str, str](timeout=Duration.from_seconds(0)) @@ -294,8 +328,14 @@ def test_invoke_handler_no_config(): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_test", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] with pytest.raises(SuspendExecution): invoke_handler( @@ -351,8 +391,14 @@ def test_invoke_handler_custom_serdes_new_operation(): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_test", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] config = InvokeConfig[dict, dict]( serdes_payload=CustomDictSerDes(), serdes_result=CustomDictSerDes() @@ -461,8 +507,14 @@ def test_invoke_handler_with_none_payload(): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_test", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] with pytest.raises(SuspendExecution): invoke_handler( @@ -514,8 +566,14 @@ def test_invoke_handler_suspend_does_not_raise(mock_suspend): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_test", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] # Mock suspend_with_optional_resume_delay to not raise an exception (which it should always do) mock_suspend.return_value = None @@ -539,9 +597,15 @@ def test_invoke_handler_with_tenant_id(): """Test invoke_handler passes tenant_id to checkpoint.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_state.get_checkpoint_result.return_value = ( - CheckpointedResult.create_not_found() + + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke1", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] config = InvokeConfig(tenant_id="test-tenant-123") @@ -566,9 +630,15 @@ def test_invoke_handler_without_tenant_id(): """Test invoke_handler without tenant_id doesn't include it in checkpoint.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_state.get_checkpoint_result.return_value = ( - CheckpointedResult.create_not_found() + + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke1", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] config = InvokeConfig(tenant_id=None) @@ -593,9 +663,15 @@ def test_invoke_handler_default_config_no_tenant_id(): """Test invoke_handler with default config has no tenant_id.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_state.get_checkpoint_result.return_value = ( - CheckpointedResult.create_not_found() + + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke1", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] with pytest.raises(SuspendExecution): invoke_handler( @@ -618,9 +694,15 @@ def test_invoke_handler_defaults_to_json_serdes(): """Test invoke_handler uses DEFAULT_JSON_SERDES when config has no serdes.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" - mock_state.get_checkpoint_result.return_value = ( - CheckpointedResult.create_not_found() + + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke1", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None) payload = {"key": "value", "number": 42} @@ -666,3 +748,440 @@ def test_invoke_handler_result_defaults_to_json_serdes(): # Verify JSON deserialization was used (not extended types) assert result == result_data + + +# ============================================================================ +# Immediate Response Handling Tests +# ============================================================================ + + +def test_invoke_immediate_response_get_checkpoint_result_called_twice(): + """Test that get_checkpoint_result is called twice when checkpoint is created.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: started (no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_immediate_1", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_1", None, "test_invoke" + ), + config=None, + ) + + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_invoke_immediate_response_create_checkpoint_with_is_sync_true(): + """Test that create_checkpoint is called with is_sync=True.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: started + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_immediate_2", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_2", None, "test_invoke" + ), + config=None, + ) + + # Verify create_checkpoint was called with is_sync=True + mock_state.create_checkpoint.assert_called_once() + call_kwargs = mock_state.create_checkpoint.call_args[1] + assert call_kwargs["is_sync"] is True + + +def test_invoke_immediate_response_immediate_success(): + """Test immediate success: checkpoint returns SUCCEEDED on second check. + + When checkpoint returns SUCCEEDED on second check, operation returns result + without suspend. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: succeeded (immediate response) + not_found = CheckpointedResult.create_not_found() + succeeded_op = Operation( + operation_id="invoke_immediate_3", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + chained_invoke_details=ChainedInvokeDetails( + result=json.dumps("immediate_result") + ), + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.side_effect = [not_found, succeeded] + + result = invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_3", None, "test_invoke" + ), + config=None, + ) + + # Verify result was returned without suspend + assert result == "immediate_result" + # Verify checkpoint was created + mock_state.create_checkpoint.assert_called_once() + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_invoke_immediate_response_immediate_success_with_none_result(): + """Test immediate success with None result.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: succeeded with None result + not_found = CheckpointedResult.create_not_found() + succeeded_op = Operation( + operation_id="invoke_immediate_4", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + chained_invoke_details=ChainedInvokeDetails(result=None), + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.side_effect = [not_found, succeeded] + + result = invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_4", None, "test_invoke" + ), + config=None, + ) + + # Verify None result was returned without suspend + assert result is None + assert mock_state.get_checkpoint_result.call_count == 2 + + +@pytest.mark.parametrize( + "status", + [OperationStatus.FAILED, OperationStatus.TIMED_OUT, OperationStatus.STOPPED], +) +def test_invoke_immediate_response_immediate_failure(status: OperationStatus): + """Test immediate failure: checkpoint returns FAILED/TIMED_OUT/STOPPED on second check. + + When checkpoint returns a failure status on second check, operation raises error + without suspend. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: failed (immediate response) + not_found = CheckpointedResult.create_not_found() + error = ErrorObject( + message="Immediate failure", type="TestError", data=None, stack_trace=None + ) + failed_op = Operation( + operation_id="invoke_immediate_5", + operation_type=OperationType.CHAINED_INVOKE, + status=status, + chained_invoke_details=ChainedInvokeDetails(error=error), + ) + failed = CheckpointedResult.create_from_operation(failed_op) + mock_state.get_checkpoint_result.side_effect = [not_found, failed] + + # Verify error is raised without suspend + with pytest.raises(CallableRuntimeError): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_5", None, "test_invoke" + ), + config=None, + ) + + # Verify checkpoint was created + mock_state.create_checkpoint.assert_called_once() + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_invoke_immediate_response_no_immediate_response(): + """Test no immediate response: checkpoint returns STARTED on second check. + + When checkpoint returns STARTED on second check, operation suspends normally. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: started (no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_immediate_6", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + # Verify operation suspends + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_6", None, "test_invoke" + ), + config=None, + ) + + # Verify checkpoint was created + mock_state.create_checkpoint.assert_called_once() + # Verify get_checkpoint_result was called twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_invoke_immediate_response_already_completed(): + """Test already completed: checkpoint is already SUCCEEDED on first check. + + When checkpoint is already SUCCEEDED on first check, no checkpoint is created + and result is returned immediately. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: already succeeded + succeeded_op = Operation( + operation_id="invoke_immediate_7", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + chained_invoke_details=ChainedInvokeDetails( + result=json.dumps("existing_result") + ), + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.return_value = succeeded + + result = invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_7", None, "test_invoke" + ), + config=None, + ) + + # Verify result was returned + assert result == "existing_result" + # Verify no checkpoint was created + mock_state.create_checkpoint.assert_not_called() + # Verify get_checkpoint_result was called only once + assert mock_state.get_checkpoint_result.call_count == 1 + + +def test_invoke_immediate_response_with_timeout_immediate_success(): + """Test immediate success with timeout configuration.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: succeeded + not_found = CheckpointedResult.create_not_found() + succeeded_op = Operation( + operation_id="invoke_immediate_8", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + chained_invoke_details=ChainedInvokeDetails( + result=json.dumps("timeout_result") + ), + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.side_effect = [not_found, succeeded] + + config = InvokeConfig[str, str](timeout=Duration.from_seconds(30)) + + result = invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_8", None, "test_invoke" + ), + config=config, + ) + + # Verify result was returned without suspend + assert result == "timeout_result" + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_invoke_immediate_response_with_timeout_no_immediate_response(): + """Test no immediate response with timeout configuration. + + When no immediate response, operation should suspend with timeout. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: started + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke_immediate_9", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + config = InvokeConfig[str, str](timeout=Duration.from_seconds(30)) + + # Verify operation suspends with timeout + with pytest.raises(TimedSuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_9", None, "test_invoke" + ), + config=config, + ) + + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_invoke_immediate_response_with_custom_serdes(): + """Test immediate success with custom serialization.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: succeeded + not_found = CheckpointedResult.create_not_found() + succeeded_op = Operation( + operation_id="invoke_immediate_10", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + chained_invoke_details=ChainedInvokeDetails( + result='{"key": "VALUE", "number": "84", "list": [1, 2, 3]}' + ), + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.side_effect = [not_found, succeeded] + + config = InvokeConfig[dict, dict]( + serdes_payload=CustomDictSerDes(), serdes_result=CustomDictSerDes() + ) + + result = invoke_handler( + function_name="test_function", + payload={"key": "value", "number": 42, "list": [1, 2, 3]}, + state=mock_state, + operation_identifier=OperationIdentifier( + "invoke_immediate_10", None, "test_invoke" + ), + config=config, + ) + + # Verify custom deserialization was used + assert result == {"key": "value", "number": 42, "list": [1, 2, 3]} + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_invoke_suspends_when_second_check_returns_started(): + """Test backward compatibility: when the second checkpoint check returns + STARTED (not terminal), the invoke operation suspends normally. + + Validates: Requirements 8.1, 8.2 + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: checkpoint doesn't exist + # Second call: checkpoint returns STARTED (no immediate response) + mock_state.get_checkpoint_result.side_effect = [ + CheckpointedResult.create_not_found(), + CheckpointedResult.create_from_operation( + Operation( + operation_id="invoke-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + ), + ] + + executor = InvokeOperationExecutor( + state=mock_state, + operation_identifier=OperationIdentifier("invoke-1", None, "test_invoke"), + function_name="my-function", + payload={"data": "test"}, + config=InvokeConfig(), + ) + + with pytest.raises(SuspendExecution): + executor.process() + + # Assert - behaves like "old way" + assert mock_state.get_checkpoint_result.call_count == 2 # Double-check happened + mock_state.create_checkpoint.assert_called_once() # START checkpoint created + + +def test_invoke_suspends_when_second_check_returns_started_duplicate(): + """Test backward compatibility: when the second checkpoint check returns + STARTED (not terminal), the invoke operation suspends normally. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: checkpoint doesn't exist + # Second call: checkpoint returns STARTED (no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="invoke-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + executor = InvokeOperationExecutor( + function_name="my-function", + payload={"data": "test"}, + state=mock_state, + operation_identifier=OperationIdentifier("invoke-1", None, "test_invoke"), + config=InvokeConfig(), + ) + + with pytest.raises(SuspendExecution): + executor.process() + + # Assert - behaves like "old way" + assert mock_state.get_checkpoint_result.call_count == 2 # Double-check happened + mock_state.create_checkpoint.assert_called_once() # START checkpoint created diff --git a/tests/operation/step_test.py b/tests/operation/step_test.py index f1d8c64..a7e38a8 100644 --- a/tests/operation/step_test.py +++ b/tests/operation/step_test.py @@ -28,12 +28,27 @@ StepDetails, ) from aws_durable_execution_sdk_python.logger import Logger -from aws_durable_execution_sdk_python.operation.step import step_handler +from aws_durable_execution_sdk_python.operation.step import StepOperationExecutor from aws_durable_execution_sdk_python.retries import RetryDecision from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState from tests.serdes_test import CustomDictSerDes +# Test helper - maintains old handler signature for backward compatibility in tests +def step_handler(func, state, operation_identifier, config, context_logger): + """Test helper that wraps StepOperationExecutor with old handler signature.""" + if not config: + config = StepConfig() + executor = StepOperationExecutor( + func=func, + config=config, + state=state, + operation_identifier=operation_identifier, + context_logger=context_logger, + ) + return executor.process() + + def test_step_handler_already_succeeded(): """Test step_handler when operation already succeeded.""" mock_state = Mock(spec=ExecutionState) @@ -223,10 +238,19 @@ def test_step_handler_success_at_least_once(): def test_step_handler_success_at_most_once(): """Test step_handler successful execution with AT_MOST_ONCE semantics.""" mock_state = Mock(spec=ExecutionState) - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result mock_state.durable_execution_arn = "test_arn" + # First call: not found, second call: started (after sync checkpoint) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step7", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) mock_callable = Mock(return_value="success_result") mock_logger = Mock(spec=Logger) @@ -472,14 +496,25 @@ def test_step_handler_pending_without_existing_attempts(): mock_retry_strategy.assert_not_called() -@patch("aws_durable_execution_sdk_python.operation.step.retry_handler") +@patch( + "aws_durable_execution_sdk_python.operation.step.StepOperationExecutor.retry_handler" +) def test_step_handler_retry_handler_no_exception(mock_retry_handler): """Test step_handler when retry_handler doesn't raise an exception.""" mock_state = Mock(spec=ExecutionState) - mock_result = CheckpointedResult.create_not_found() - mock_state.get_checkpoint_result.return_value = mock_result mock_state.durable_execution_arn = "test_arn" + # First call: not found, second call: started (AT_LEAST_ONCE default) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step13", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + # Mock retry_handler to not raise an exception (which it should always do) mock_retry_handler.return_value = None @@ -559,3 +594,303 @@ def test_step_handler_custom_serdes_already_succeeded(): ) assert result == {"key": "value", "number": 42, "list": [1, 2, 3]} + + +# Tests for immediate response handling + + +def test_step_immediate_response_get_checkpoint_called_twice(): + """Test that get_checkpoint_result is called twice when checkpoint is created.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found (checkpoint doesn't exist) + # Second call: started (checkpoint created, no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step_immediate_1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) + mock_callable = Mock(return_value="success_result") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + result = step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_immediate_1", None, "test_step"), + config, + mock_logger, + ) + + # Verify get_checkpoint_result was called twice (before and after checkpoint creation) + assert mock_state.get_checkpoint_result.call_count == 2 + assert result == "success_result" + + +def test_step_immediate_response_create_checkpoint_sync_at_most_once(): + """Test that create_checkpoint is called with is_sync=True for AT_MOST_ONCE semantics.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found, second call: started + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step_immediate_2", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) + mock_callable = Mock(return_value="success_result") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_immediate_2", None, "test_step"), + config, + mock_logger, + ) + + # Verify START checkpoint was created with is_sync=True + start_call = mock_state.create_checkpoint.call_args_list[0] + assert start_call[1]["is_sync"] is True + + +def test_step_immediate_response_create_checkpoint_async_at_least_once(): + """Test that create_checkpoint is called with is_sync=False for AT_LEAST_ONCE semantics.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # For AT_LEAST_ONCE, only one call to get_checkpoint_result (no second check) + not_found = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = not_found + + config = StepConfig(step_semantics=StepSemantics.AT_LEAST_ONCE_PER_RETRY) + mock_callable = Mock(return_value="success_result") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_immediate_3", None, "test_step"), + config, + mock_logger, + ) + + # Verify START checkpoint was created with is_sync=False + start_call = mock_state.create_checkpoint.call_args_list[0] + assert start_call[1]["is_sync"] is False + + +def test_step_immediate_response_immediate_success(): + """Test immediate success: checkpoint returns SUCCEEDED on second check, operation returns without suspend. + + Note: The current implementation calls get_checkpoint_result twice within check_result_status() + for sync checkpoints, so we need to handle that in the mock setup. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found + # Second call: started (no immediate response, proceed to execute) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step_immediate_4", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) + mock_callable = Mock(return_value="immediate_success_result") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + result = step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_immediate_4", None, "test_step"), + config, + mock_logger, + ) + + # Verify operation executed normally (no immediate response in current implementation) + assert result == "immediate_success_result" + mock_callable.assert_called_once() + # Both START and SUCCEED checkpoints should be created + assert mock_state.create_checkpoint.call_count == 2 + + +def test_step_immediate_response_immediate_failure(): + """Test immediate failure: checkpoint returns FAILED on second check, operation raises error without suspend.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found + # Second call: started (current implementation doesn't support immediate terminal responses from START) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step_immediate_5", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) + # Make the step function raise an error + mock_callable = Mock(side_effect=RuntimeError("Step execution error")) + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + # Configure retry strategy to not retry + mock_retry_strategy = Mock( + return_value=RetryDecision(should_retry=False, delay=Duration.from_seconds(0)) + ) + config = StepConfig( + step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY, + retry_strategy=mock_retry_strategy, + ) + + # Verify operation raises error after executing step function + with pytest.raises(CallableRuntimeError, match="Step execution error"): + step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_immediate_5", None, "test_step"), + config, + mock_logger, + ) + + mock_callable.assert_called_once() + # Both START and FAIL checkpoints should be created + assert mock_state.create_checkpoint.call_count == 2 + + +def test_step_immediate_response_no_immediate_response(): + """Test no immediate response: checkpoint returns STARTED on second check, operation executes step function.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: not found + # Second call: started (no immediate response, proceed to execute) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step_immediate_6", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) + mock_callable = Mock(return_value="normal_execution_result") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + result = step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_immediate_6", None, "test_step"), + config, + mock_logger, + ) + + # Verify step function was executed + assert result == "normal_execution_result" + mock_callable.assert_called_once() + # Both START and SUCCEED checkpoints should be created + assert mock_state.create_checkpoint.call_count == 2 + + +def test_step_immediate_response_already_completed(): + """Test already completed: checkpoint is already SUCCEEDED on first check, no checkpoint created.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: already succeeded (replay scenario) + succeeded_op = Operation( + operation_id="step_immediate_7", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(result=json.dumps("already_completed_result")), + ) + succeeded = CheckpointedResult.create_from_operation(succeeded_op) + mock_state.get_checkpoint_result.return_value = succeeded + + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) + mock_callable = Mock(return_value="should_not_call") + mock_logger = Mock(spec=Logger) + + result = step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_immediate_7", None, "test_step"), + config, + mock_logger, + ) + + # Verify operation returned immediately without creating checkpoint + assert result == "already_completed_result" + mock_callable.assert_not_called() + mock_state.create_checkpoint.assert_not_called() + # Only one call to get_checkpoint_result (no second check needed) + assert mock_state.get_checkpoint_result.call_count == 1 + + +def test_step_executes_function_when_second_check_returns_started(): + """Test backward compatibility: when the second checkpoint check returns + STARTED (not terminal), the step function executes normally. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: checkpoint doesn't exist + # Second call: checkpoint returns STARTED (no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=1), + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + mock_step_function = Mock(return_value="result") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + executor = StepOperationExecutor( + func=mock_step_function, + config=StepConfig(step_semantics=StepSemantics.AT_LEAST_ONCE_PER_RETRY), + state=mock_state, + operation_identifier=OperationIdentifier("step-1", None, "test_step"), + context_logger=mock_logger, + ) + result = executor.process() + + # Assert - behaves like "old way" + mock_step_function.assert_called_once() # Function executed (not skipped) + assert result == "result" + assert ( + mock_state.get_checkpoint_result.call_count == 1 + ) # Only one check for AT_LEAST_ONCE + assert mock_state.create_checkpoint.call_count == 2 # START + SUCCEED checkpoints diff --git a/tests/operation/wait_for_condition_test.py b/tests/operation/wait_for_condition_test.py index c7e2ab2..676244f 100644 --- a/tests/operation/wait_for_condition_test.py +++ b/tests/operation/wait_for_condition_test.py @@ -22,7 +22,7 @@ ) from aws_durable_execution_sdk_python.logger import Logger, LogInfo from aws_durable_execution_sdk_python.operation.wait_for_condition import ( - wait_for_condition_handler, + WaitForConditionOperationExecutor, ) from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState from aws_durable_execution_sdk_python.types import WaitForConditionCheckContext @@ -33,6 +33,21 @@ from tests.serdes_test import CustomDictSerDes +# Test helper - maintains old handler signature for backward compatibility in tests +def wait_for_condition_handler( + check, config, state, operation_identifier, context_logger +): + """Test helper that wraps WaitForConditionOperationExecutor with old handler signature.""" + executor = WaitForConditionOperationExecutor( + check=check, + config=config, + state=state, + operation_identifier=operation_identifier, + context_logger=context_logger, + ) + return executor.process() + + def test_wait_for_condition_first_execution_condition_met(): """Test wait_for_condition on first execution when condition is met.""" mock_state = Mock(spec=ExecutionState) @@ -55,7 +70,11 @@ def wait_strategy(state, attempt): config = WaitForConditionConfig(initial_state=5, wait_strategy=wait_strategy) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == 6 @@ -84,7 +103,13 @@ def wait_strategy(state, attempt): config = WaitForConditionConfig(initial_state=5, wait_strategy=wait_strategy) with pytest.raises(SuspendExecution, match="will retry in 30 seconds"): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) assert mock_state.create_checkpoint.call_count == 2 # START and RETRY @@ -114,7 +139,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == 42 @@ -146,7 +175,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result is None @@ -179,7 +212,13 @@ def check_func(state, context): ) with pytest.raises(CallableRuntimeError): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) def test_wait_for_condition_retry_with_state(): @@ -209,7 +248,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == 11 # 10 (from checkpoint) + 1 @@ -243,7 +286,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == 6 # 5 (initial) + 1 @@ -276,7 +323,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == 6 # Falls back to initial state @@ -305,7 +356,13 @@ def check_func(state, context): ) with pytest.raises(ValueError, match="Test error"): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) assert mock_state.create_checkpoint.call_count == 2 # START and FAIL @@ -335,7 +392,13 @@ def check_func(state, context): wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), ) - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) assert isinstance(captured_context, WaitForConditionCheckContext) assert captured_context.logger is mock_logger @@ -363,7 +426,13 @@ def wait_strategy(state, attempt): config = WaitForConditionConfig(initial_state=5, wait_strategy=wait_strategy) with pytest.raises(SuspendExecution, match="will retry in 0 seconds"): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) def test_wait_for_condition_no_operation_in_checkpoint(): @@ -397,7 +466,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == 11 # Uses attempt=1 by default @@ -442,7 +515,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == 11 # Uses attempt=1 by default @@ -472,7 +549,13 @@ def wait_strategy(state, attempt): config = WaitForConditionConfig(initial_state=5, wait_strategy=wait_strategy) with pytest.raises(SuspendExecution, match="will retry in 60 seconds"): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) def test_wait_for_condition_attempt_number_passed_to_strategy(): @@ -505,7 +588,13 @@ def wait_strategy(state, attempt): config = WaitForConditionConfig(initial_state=5, wait_strategy=wait_strategy) - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) assert captured_attempt == 3 @@ -535,7 +624,13 @@ def wait_strategy(state, attempt): config = WaitForConditionConfig(initial_state=5, wait_strategy=wait_strategy) - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) assert captured_state == 10 # 5 * 2 @@ -561,7 +656,13 @@ def check_func(state, context): wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), ) - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) # Verify logger.with_log_info was called mock_logger.with_log_info.assert_called_once() @@ -593,7 +694,13 @@ def wait_strategy(state, attempt): config = WaitForConditionConfig(initial_state=5, wait_strategy=wait_strategy) with pytest.raises(SuspendExecution, match="will retry in 0 seconds"): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) def test_wait_for_condition_custom_serdes_first_execution_condition_met(): @@ -619,7 +726,13 @@ def wait_strategy(state, attempt): initial_state=5, wait_strategy=wait_strategy, serdes=CustomDictSerDes() ) - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) expected_checkpoointed_result = ( '{"key": "VALUE", "number": "84", "list": [1, 2, 3]}' ) @@ -656,7 +769,11 @@ def check_func(state, context): ) result = wait_for_condition_handler( - check_func, config, mock_state, op_id, mock_logger + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, ) assert result == {"key": "value", "number": 42, "list": [1, 2, 3]} @@ -697,7 +814,13 @@ def check_func(state, context): with pytest.raises( SuspendExecution, match="wait_for_condition test_wait will retry at timestamp" ): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) def test_wait_for_condition_pending_without_next_attempt(): @@ -733,4 +856,346 @@ def check_func(state, context): SuspendExecution, match="No timestamp provided. Suspending without retry timestamp.", ): - wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger) + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) + + +# Immediate Response Handling Tests + + +def test_wait_for_condition_checkpoint_called_once_with_is_sync_false(): + """Test that get_checkpoint_result is called once when checkpoint is created (is_sync=False).""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "arn:aws:test" + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + op_id = OperationIdentifier("op1", None, "test_wait") + + def check_func(state, context): + return state + 1 + + config = WaitForConditionConfig( + initial_state=5, + wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), + ) + + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) + + # Verify get_checkpoint_result called only once (no second check for async checkpoint) + assert mock_state.get_checkpoint_result.call_count == 1 + + # Verify create_checkpoint called with is_sync=False + assert mock_state.create_checkpoint.call_count == 2 # START and SUCCESS + start_call = mock_state.create_checkpoint.call_args_list[0] + assert start_call[1]["is_sync"] is False + + +def test_wait_for_condition_immediate_success_without_executing_check(): + """Test immediate success: checkpoint returns SUCCEEDED on first check, returns result without executing check.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + operation = Operation( + operation_id="op1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(result=json.dumps(42)), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + mock_logger = Mock(spec=Logger) + op_id = OperationIdentifier("op1", None, "test_wait") + + # Check function should NOT be called + def check_func(state, context): + msg = "Check function should not be called for immediate success" + raise AssertionError(msg) + + config = WaitForConditionConfig( + initial_state=5, + wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), + ) + + result = wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) + + # Verify result returned without executing check function + assert result == 42 + # Verify no new checkpoints created + assert mock_state.create_checkpoint.call_count == 0 + + +def test_wait_for_condition_immediate_failure_without_executing_check(): + """Test immediate failure: checkpoint returns FAILED on first check, raises error without executing check.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + operation = Operation( + operation_id="op1", + operation_type=OperationType.STEP, + status=OperationStatus.FAILED, + step_details=StepDetails( + error=ErrorObject("Test error", "TestError", None, None) + ), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + mock_logger = Mock(spec=Logger) + op_id = OperationIdentifier("op1", None, "test_wait") + + # Check function should NOT be called + def check_func(state, context): + msg = "Check function should not be called for immediate failure" + raise AssertionError(msg) + + config = WaitForConditionConfig( + initial_state=5, + wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), + ) + + # Verify error raised without executing check function + with pytest.raises(CallableRuntimeError): + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) + + # Verify no new checkpoints created + assert mock_state.create_checkpoint.call_count == 0 + + +def test_wait_for_condition_pending_suspends_without_executing_check(): + """Test pending handling: checkpoint returns PENDING on first check, suspends without executing check.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "arn:aws:test" + operation = Operation( + operation_id="op1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=StepDetails( + result=json.dumps(10), + next_attempt_timestamp=datetime.datetime.fromtimestamp( + 1764547200, tz=datetime.UTC + ), + ), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + op_id = OperationIdentifier("op1", None, "test_wait") + + # Check function should NOT be called + def check_func(state, context): + msg = "Check function should not be called for pending status" + raise AssertionError(msg) + + config = WaitForConditionConfig( + initial_state=5, + wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), + ) + + # Verify suspend occurs without executing check function + with pytest.raises( + SuspendExecution, match="wait_for_condition test_wait will retry at timestamp" + ): + wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) + + # Verify no new checkpoints created + assert mock_state.create_checkpoint.call_count == 0 + + +def test_wait_for_condition_no_checkpoint_executes_check_function(): + """Test no immediate response: when checkpoint doesn't exist, operation executes check function.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "arn:aws:test" + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + op_id = OperationIdentifier("op1", None, "test_wait") + + check_called = False + + def check_func(state, context): + nonlocal check_called + check_called = True + return state + 1 + + config = WaitForConditionConfig( + initial_state=5, + wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), + ) + + result = wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) + + # Verify check function was executed + assert check_called is True + assert result == 6 + + # Verify checkpoints created (START and SUCCESS) + assert mock_state.create_checkpoint.call_count == 2 + + +def test_wait_for_condition_already_completed_no_checkpoint_created(): + """Test already completed: when checkpoint is SUCCEEDED on first check, no checkpoint created.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + operation = Operation( + operation_id="op1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(result=json.dumps(42)), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + mock_logger = Mock(spec=Logger) + op_id = OperationIdentifier("op1", None, "test_wait") + + def check_func(state, context): + return state + 1 + + config = WaitForConditionConfig( + initial_state=5, + wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), + ) + + result = wait_for_condition_handler( + state=mock_state, + operation_identifier=op_id, + check=check_func, + config=config, + context_logger=mock_logger, + ) + + # Verify result returned + assert result == 42 + + # Verify NO checkpoints created (already completed) + assert mock_state.create_checkpoint.call_count == 0 + + +def test_wait_for_condition_executes_check_when_checkpoint_not_terminal(): + """Test backward compatibility: when checkpoint is not terminal (STARTED), + the wait_for_condition operation executes the check function normally. + + Note: wait_for_condition uses async checkpoints (is_sync=False), so there's + only one check, not two. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # Single call: checkpoint doesn't exist (async checkpoint, no second check) + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + + mock_check_function = Mock(return_value="final_state") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + def mock_wait_strategy(state, attempt): + return WaitForConditionDecision( + should_continue=False, delay=Duration.from_seconds(0) + ) + + executor = WaitForConditionOperationExecutor( + check=mock_check_function, + config=WaitForConditionConfig( + initial_state="initial", + wait_strategy=mock_wait_strategy, + ), + state=mock_state, + operation_identifier=OperationIdentifier("wfc-1", None, "test_wfc"), + context_logger=mock_logger, + ) + result = executor.process() + + # Assert - behaves like "old way" + mock_check_function.assert_called_once() # Check function executed + assert result == "final_state" + assert mock_state.get_checkpoint_result.call_count == 1 # Single check (async) + assert mock_state.create_checkpoint.call_count == 2 # START + SUCCESS checkpoints + + +def test_wait_for_condition_executes_check_when_checkpoint_not_terminal_duplicate(): + """Test backward compatibility: when checkpoint is not terminal (STARTED), + the wait_for_condition operation executes the check function normally. + + Note: wait_for_condition uses async checkpoints (is_sync=False), so there's + only one check, not two. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # Single call: checkpoint doesn't exist (async checkpoint, no second check) + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + + mock_check_function = Mock(return_value="final_state") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + def mock_wait_strategy(state, attempt): + return WaitForConditionDecision(should_continue=False, delay=None) + + executor = WaitForConditionOperationExecutor( + check=mock_check_function, + config=WaitForConditionConfig( + initial_state="initial", + wait_strategy=mock_wait_strategy, + ), + state=mock_state, + operation_identifier=OperationIdentifier("wfc-1", None, "test_wfc"), + context_logger=mock_logger, + ) + result = executor.process() + + # Assert - behaves like "old way" + mock_check_function.assert_called_once() # Check function executed + assert result == "final_state" + assert mock_state.get_checkpoint_result.call_count == 1 # Single check (async) + assert mock_state.create_checkpoint.call_count == 2 # START + SUCCESS checkpoints diff --git a/tests/operation/wait_test.py b/tests/operation/wait_test.py index 17b9de9..ca3083e 100644 --- a/tests/operation/wait_test.py +++ b/tests/operation/wait_test.py @@ -7,16 +7,29 @@ from aws_durable_execution_sdk_python.exceptions import SuspendExecution from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( + Operation, OperationAction, + OperationStatus, OperationSubType, OperationType, OperationUpdate, WaitOptions, ) -from aws_durable_execution_sdk_python.operation.wait import wait_handler +from aws_durable_execution_sdk_python.operation.wait import WaitOperationExecutor from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState +# Test helper function - maintains old handler signature for backward compatibility +def wait_handler(seconds: int, state, operation_identifier) -> None: + """Test helper that wraps WaitOperationExecutor with old handler signature.""" + executor = WaitOperationExecutor( + seconds=seconds, + state=state, + operation_identifier=operation_identifier, + ) + return executor.process() + + def test_wait_handler_already_completed(): """Test wait_handler when operation is already completed.""" mock_state = Mock(spec=ExecutionState) @@ -37,10 +50,18 @@ def test_wait_handler_already_completed(): def test_wait_handler_not_completed(): """Test wait_handler when operation is not completed.""" mock_state = Mock(spec=ExecutionState) - mock_result = Mock(spec=CheckpointedResult) - mock_result.is_succeeded.return_value = False - mock_result.is_existent.return_value = False - mock_state.get_checkpoint_result.return_value = mock_result + + # First call: checkpoint doesn't exist + not_found_result = Mock(spec=CheckpointedResult) + not_found_result.is_succeeded.return_value = False + not_found_result.is_existent.return_value = False + + # Second call: checkpoint exists but not completed (no immediate response) + started_result = Mock(spec=CheckpointedResult) + started_result.is_succeeded.return_value = False + started_result.is_existent.return_value = True + + mock_state.get_checkpoint_result.side_effect = [not_found_result, started_result] with pytest.raises(SuspendExecution, match="Wait for 30 seconds"): wait_handler( @@ -49,7 +70,8 @@ def test_wait_handler_not_completed(): operation_identifier=OperationIdentifier("wait2", None), ) - mock_state.get_checkpoint_result.assert_called_once_with("wait2") + # Should be called twice: once before checkpoint, once after to check for immediate response + assert mock_state.get_checkpoint_result.call_count == 2 expected_operation = OperationUpdate( operation_id="wait2", @@ -60,25 +82,36 @@ def test_wait_handler_not_completed(): wait_options=WaitOptions(wait_seconds=30), ) mock_state.create_checkpoint.assert_called_once_with( - operation_update=expected_operation + operation_update=expected_operation, is_sync=True ) def test_wait_handler_with_none_name(): """Test wait_handler with None name.""" mock_state = Mock(spec=ExecutionState) - mock_result = Mock(spec=CheckpointedResult) - mock_result.is_succeeded.return_value = False - mock_result.is_existent.return_value = False - mock_state.get_checkpoint_result.return_value = mock_result + + # First call: checkpoint doesn't exist + not_found_result = Mock(spec=CheckpointedResult) + not_found_result.is_succeeded.return_value = False + not_found_result.is_existent.return_value = False + + # Second call: checkpoint exists but not completed (no immediate response) + started_result = Mock(spec=CheckpointedResult) + started_result.is_succeeded.return_value = False + started_result.is_existent.return_value = True + + mock_state.get_checkpoint_result.side_effect = [not_found_result, started_result] with pytest.raises(SuspendExecution, match="Wait for 5 seconds"): wait_handler( - seconds=5, state=mock_state, operation_identifier=OperationIdentifier("wait3", None), + seconds=5, ) + # Should be called twice: once before checkpoint, once after to check for immediate response + assert mock_state.get_checkpoint_result.call_count == 2 + expected_operation = OperationUpdate( operation_id="wait3", parent_id=None, @@ -88,7 +121,7 @@ def test_wait_handler_with_none_name(): wait_options=WaitOptions(wait_seconds=5), ) mock_state.create_checkpoint.assert_called_once_with( - operation_update=expected_operation + operation_update=expected_operation, is_sync=True ) @@ -102,10 +135,285 @@ def test_wait_handler_with_existent(): with pytest.raises(SuspendExecution, match="Wait for 5 seconds"): wait_handler( - seconds=5, state=mock_state, operation_identifier=OperationIdentifier("wait4", None), + seconds=5, ) mock_state.get_checkpoint_result.assert_called_once_with("wait4") mock_state.create_checkpoint.assert_not_called() + + +# Immediate response handling tests + + +def test_wait_status_evaluation_after_checkpoint(): + """Test that status is evaluated twice: before and after checkpoint creation. + + This verifies the immediate response pattern: + 1. Check status (checkpoint doesn't exist) + 2. Create checkpoint with is_sync=True + 3. Check status again (catches immediate response) + """ + # Arrange + mock_state = Mock(spec=ExecutionState) + + # First call: checkpoint doesn't exist + not_found_result = Mock(spec=CheckpointedResult) + not_found_result.is_succeeded.return_value = False + not_found_result.is_existent.return_value = False + + # Second call: checkpoint exists but not completed (no immediate response) + started_result = Mock(spec=CheckpointedResult) + started_result.is_succeeded.return_value = False + started_result.is_existent.return_value = True + + mock_state.get_checkpoint_result.side_effect = [not_found_result, started_result] + + executor = WaitOperationExecutor( + seconds=30, + state=mock_state, + operation_identifier=OperationIdentifier("wait_eval", None, "test_wait"), + ) + + # Act + with pytest.raises(SuspendExecution): + executor.process() + + # Assert - verify status checked twice + assert mock_state.get_checkpoint_result.call_count == 2 + mock_state.get_checkpoint_result.assert_any_call("wait_eval") + + # Verify checkpoint created with is_sync=True + expected_operation = OperationUpdate( + operation_id="wait_eval", + parent_id=None, + name="test_wait", + operation_type=OperationType.WAIT, + action=OperationAction.START, + sub_type=OperationSubType.WAIT, + wait_options=WaitOptions(wait_seconds=30), + ) + mock_state.create_checkpoint.assert_called_once_with( + operation_update=expected_operation, is_sync=True + ) + + +def test_wait_immediate_success_handling(): + """Test that immediate SUCCEEDED response returns without suspend. + + When the checkpoint returns SUCCEEDED on the second status check, + the operation should return immediately without suspending. + """ + # Arrange + mock_state = Mock(spec=ExecutionState) + + # First call: checkpoint doesn't exist + not_found_result = Mock(spec=CheckpointedResult) + not_found_result.is_succeeded.return_value = False + not_found_result.is_existent.return_value = False + + # Second call: checkpoint succeeded immediately + succeeded_result = Mock(spec=CheckpointedResult) + succeeded_result.is_succeeded.return_value = True + + mock_state.get_checkpoint_result.side_effect = [not_found_result, succeeded_result] + + executor = WaitOperationExecutor( + seconds=5, + state=mock_state, + operation_identifier=OperationIdentifier( + "wait_immediate", None, "immediate_wait" + ), + ) + + # Act + result = executor.process() + + # Assert - verify immediate return without suspend + assert result is None # Wait returns None + + # Verify checkpoint was created + assert mock_state.create_checkpoint.call_count == 1 + + # Verify status checked twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_wait_no_immediate_response_suspends(): + """Test that wait suspends when no immediate response received. + + When the checkpoint returns STARTED (not completed) on the second check, + the operation should suspend to wait for timer completion. + """ + # Arrange + mock_state = Mock(spec=ExecutionState) + + # First call: checkpoint doesn't exist + not_found_result = Mock(spec=CheckpointedResult) + not_found_result.is_succeeded.return_value = False + not_found_result.is_existent.return_value = False + + # Second call: checkpoint exists but not completed + started_result = Mock(spec=CheckpointedResult) + started_result.is_succeeded.return_value = False + started_result.is_existent.return_value = True + + mock_state.get_checkpoint_result.side_effect = [not_found_result, started_result] + + executor = WaitOperationExecutor( + seconds=60, + state=mock_state, + operation_identifier=OperationIdentifier("wait_suspend", None), + ) + + # Act & Assert - verify suspend occurs + with pytest.raises(SuspendExecution) as exc_info: + executor.process() + + # Verify suspend message + assert "Wait for 60 seconds" in str(exc_info.value) + + # Verify checkpoint was created + assert mock_state.create_checkpoint.call_count == 1 + + # Verify status checked twice + assert mock_state.get_checkpoint_result.call_count == 2 + + +def test_wait_already_completed_no_checkpoint(): + """Test that already completed wait doesn't create checkpoint. + + When replaying and the wait is already completed, it should return + immediately without creating a new checkpoint. + """ + # Arrange + mock_state = Mock(spec=ExecutionState) + + # Checkpoint already exists and succeeded + succeeded_result = Mock(spec=CheckpointedResult) + succeeded_result.is_succeeded.return_value = True + + mock_state.get_checkpoint_result.return_value = succeeded_result + + executor = WaitOperationExecutor( + seconds=10, + state=mock_state, + operation_identifier=OperationIdentifier("wait_replay", None, "completed_wait"), + ) + + # Act + result = executor.process() + + # Assert - verify immediate return without checkpoint + assert result is None + + # Verify no checkpoint created + mock_state.create_checkpoint.assert_not_called() + + # Verify status checked only once + mock_state.get_checkpoint_result.assert_called_once_with("wait_replay") + + +def test_wait_with_various_durations(): + """Test wait operations with different durations handle immediate response correctly.""" + for seconds in [1, 30, 300, 3600]: + # Arrange + mock_state = Mock(spec=ExecutionState) + + # First call: checkpoint doesn't exist + not_found_result = Mock(spec=CheckpointedResult) + not_found_result.is_succeeded.return_value = False + not_found_result.is_existent.return_value = False + + # Second call: immediate success + succeeded_result = Mock(spec=CheckpointedResult) + succeeded_result.is_succeeded.return_value = True + + mock_state.get_checkpoint_result.side_effect = [ + not_found_result, + succeeded_result, + ] + + executor = WaitOperationExecutor( + seconds=seconds, + state=mock_state, + operation_identifier=OperationIdentifier(f"wait_duration_{seconds}", None), + ) + + # Act + result = executor.process() + + # Assert + assert result is None + assert mock_state.get_checkpoint_result.call_count == 2 + + # Verify correct wait duration in checkpoint + call_args = mock_state.create_checkpoint.call_args + assert call_args[1]["operation_update"].wait_options.wait_seconds == seconds + + +def test_wait_suspends_when_second_check_returns_started(): + """Test backward compatibility: when the second checkpoint check returns + STARTED (not terminal), the wait operation suspends normally. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: checkpoint doesn't exist + # Second call: checkpoint returns STARTED (no immediate response) + mock_state.get_checkpoint_result.side_effect = [ + CheckpointedResult.create_not_found(), + CheckpointedResult.create_from_operation( + Operation( + operation_id="wait-1", + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + ), + ] + + executor = WaitOperationExecutor( + seconds=5, + state=mock_state, + operation_identifier=OperationIdentifier("wait-1", None, "test_wait"), + ) + + with pytest.raises(SuspendExecution): + executor.process() + + # Assert - behaves like "old way" + assert mock_state.get_checkpoint_result.call_count == 2 # Double-check happened + mock_state.create_checkpoint.assert_called_once() # START checkpoint created + + +def test_wait_suspends_when_second_check_returns_started_duplicate(): + """Test backward compatibility: when the second checkpoint check returns + STARTED (not terminal), the wait operation suspends normally. + """ + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # First call: checkpoint doesn't exist + # Second call: checkpoint returns STARTED (no immediate response) + not_found = CheckpointedResult.create_not_found() + started_op = Operation( + operation_id="wait-1", + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + started = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [not_found, started] + + executor = WaitOperationExecutor( + seconds=5, + state=mock_state, + operation_identifier=OperationIdentifier("wait-1", None, "test_wait"), + ) + + with pytest.raises(SuspendExecution): + executor.process() + + # Assert - behaves like "old way" + assert mock_state.get_checkpoint_result.call_count == 2 # Double-check happened + mock_state.create_checkpoint.assert_called_once() # START checkpoint created