diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index a60a1ff..2e450e9 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -42,7 +42,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install Hatch - run: python -m pip install --upgrade hatch + run: python -m pip install hatch==1.15.0 - name: Setup and run Testing SDK working-directory: testing-sdk diff --git a/src/aws_durable_execution_sdk_python/concurrency/executor.py b/src/aws_durable_execution_sdk_python/concurrency/executor.py index c6e6b46..77e8529 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -381,12 +381,11 @@ def _execute_item_in_child_context( executor_context._parent_id, # noqa: SLF001 name, ) - child_context.state.track_replay(operation_id=operation_id) def run_in_child_handler(): return self.execute_item(child_context, executable) - return child_handler( + result: ResultType = child_handler( run_in_child_handler, child_context.state, operation_identifier=operation_identifier, @@ -396,6 +395,8 @@ def run_in_child_handler(): summary_generator=self.summary_generator, ), ) + child_context.state.track_replay(operation_id=operation_id) + return result def replay(self, execution_state: ExecutionState, executor_context: DurableContext): """ diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 20c3659..a12a47e 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -271,7 +271,6 @@ def create_callback( if not config: config = CallbackConfig() operation_id: str = self._create_step_id() - self.state.track_replay(operation_id=operation_id) callback_id: str = create_callback_handler( state=self.state, operation_identifier=OperationIdentifier( @@ -279,13 +278,14 @@ def create_callback( ), config=config, ) - - return Callback( + result: Callback = Callback( callback_id=callback_id, operation_id=operation_id, state=self.state, serdes=config.serdes, ) + self.state.track_replay(operation_id=operation_id) + return result def invoke( self, @@ -306,8 +306,7 @@ def invoke( The result of the invoked function """ operation_id = self._create_step_id() - self.state.track_replay(operation_id=operation_id) - return invoke_handler( + result: R = invoke_handler( function_name=function_name, payload=payload, state=self.state, @@ -318,6 +317,8 @@ def invoke( ), config=config, ) + self.state.track_replay(operation_id=operation_id) + return result def map( self, @@ -330,7 +331,6 @@ def map( map_name: str | None = self._resolve_step_name(name, func) operation_id = self._create_step_id() - self.state.track_replay(operation_id=operation_id) operation_identifier = OperationIdentifier( operation_id=operation_id, parent_id=self._parent_id, name=map_name ) @@ -350,7 +350,7 @@ def map_in_child_context() -> BatchResult[R]: operation_identifier=operation_identifier, ) - return child_handler( + result: BatchResult[R] = child_handler( func=map_in_child_context, state=self.state, operation_identifier=operation_identifier, @@ -363,6 +363,8 @@ def map_in_child_context() -> BatchResult[R]: item_serdes=None, ), ) + self.state.track_replay(operation_id=operation_id) + return result def parallel( self, @@ -373,7 +375,6 @@ def parallel( """Execute multiple callables in parallel.""" # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id operation_id = self._create_step_id() - self.state.track_replay(operation_id=operation_id) parallel_context = self.create_child_context(parent_id=operation_id) operation_identifier = OperationIdentifier( operation_id=operation_id, parent_id=self._parent_id, name=name @@ -392,7 +393,7 @@ def parallel_in_child_context() -> BatchResult[T]: operation_identifier=operation_identifier, ) - return child_handler( + result: BatchResult[T] = child_handler( func=parallel_in_child_context, state=self.state, operation_identifier=operation_identifier, @@ -405,6 +406,8 @@ def parallel_in_child_context() -> BatchResult[T]: item_serdes=None, ), ) + self.state.track_replay(operation_id=operation_id) + return result def run_in_child_context( self, @@ -427,12 +430,11 @@ def run_in_child_context( step_name: str | None = self._resolve_step_name(name, func) # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id operation_id = self._create_step_id() - self.state.track_replay(operation_id=operation_id) def callable_with_child_context(): return func(self.create_child_context(parent_id=operation_id)) - return child_handler( + result: T = child_handler( func=callable_with_child_context, state=self.state, operation_identifier=OperationIdentifier( @@ -440,6 +442,8 @@ def callable_with_child_context(): ), config=config, ) + self.state.track_replay(operation_id=operation_id) + return result def step( self, @@ -450,9 +454,7 @@ def step( step_name = self._resolve_step_name(name, func) logger.debug("Step name: %s", step_name) operation_id = self._create_step_id() - self.state.track_replay(operation_id=operation_id) - - return step_handler( + result: T = step_handler( func=func, config=config, state=self.state, @@ -463,6 +465,8 @@ def step( ), context_logger=self.logger, ) + self.state.track_replay(operation_id=operation_id) + return result def wait(self, duration: Duration, name: str | None = None) -> None: """Wait for a specified amount of time. @@ -476,7 +480,6 @@ def wait(self, duration: Duration, name: str | None = None) -> None: msg = "duration must be at least 1 second" raise ValidationError(msg) operation_id = self._create_step_id() - self.state.track_replay(operation_id=operation_id) wait_handler( seconds=seconds, state=self.state, @@ -486,6 +489,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None: name=name, ), ) + self.state.track_replay(operation_id=operation_id) def wait_for_callback( self, @@ -528,8 +532,7 @@ def wait_for_condition( raise ValidationError(msg) operation_id = self._create_step_id() - self.state.track_replay(operation_id=operation_id) - return wait_for_condition_handler( + result: T = wait_for_condition_handler( check=check, config=config, state=self.state, @@ -540,6 +543,8 @@ def wait_for_condition( ), context_logger=self.logger, ) + self.state.track_replay(operation_id=operation_id) + return result # endregion Operations diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 685664e..5174ce6 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -258,6 +258,7 @@ def __init__( self._parent_done_lock: Lock = Lock() self._replay_status: ReplayStatus = replay_status self._replay_status_lock: Lock = Lock() + self._visited_operations: set[str] = set() def fetch_paginated_operations( self, @@ -301,14 +302,20 @@ def track_replay(self, operation_id: str) -> None: """ with self._replay_status_lock: if self._replay_status == ReplayStatus.REPLAY: - operation = self.operations.get(operation_id) - # Transition if operation doesn't exist OR isn't in a completed state - if not operation or operation.status not in { - OperationStatus.SUCCEEDED, - OperationStatus.FAILED, - OperationStatus.CANCELLED, - OperationStatus.STOPPED, - }: + self._visited_operations.add(operation_id) + completed_ops = { + op_id + for op_id, op in self.operations.items() + if op.operation_type != OperationType.EXECUTION + and op.status + in { + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + } + } + if completed_ops.issubset(self._visited_operations): logger.debug( "Transitioning from REPLAY to NEW status at operation %s", operation_id, diff --git a/tests/logger_test.py b/tests/logger_test.py index 0cc03ea..f503538 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -381,22 +381,27 @@ def test_logger_replay_no_logging(): log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) mock_logger = Mock() logger = Logger.from_log_info(mock_logger, log_info) - replay_execution_state.track_replay(operation_id="op1") logger.info("logging info") + replay_execution_state.track_replay(operation_id="op1") mock_logger.info.assert_not_called() def test_logger_replay_then_new_logging(): - operation = Operation( + operation1 = Operation( operation_id="op1", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) + operation2 = Operation( + operation_id="op2", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation}, + operations={"op1": operation1, "op2": operation2}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) diff --git a/tests/state_test.py b/tests/state_test.py index 0831533..1e016d1 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -3246,21 +3246,25 @@ def test_create_checkpoint_sync_always_synchronous(): def test_state_replay_mode(): - operation = Operation( + operation1 = Operation( operation_id="op1", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) + operation2 = Operation( + operation_id="op2", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation}, + operations={"op1": operation1, "op2": operation2}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) - + assert execution_state.is_replaying() is True execution_state.track_replay(operation_id="op1") assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op2") assert execution_state.is_replaying() is False