diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index d2ff0c6..3df4ed8 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -360,7 +360,7 @@ def get_execution_state( execution = self.get_execution(execution_arn) # TODO: Validate checkpoint token if provided - if checkpoint_token and checkpoint_token not in execution.used_tokens: + if checkpoint_token and checkpoint_token in execution.used_tokens: msg: str = f"Invalid checkpoint token: {checkpoint_token}" raise InvalidParameterValueException(msg) @@ -469,7 +469,7 @@ def checkpoint_execution( execution = self.get_execution(execution_arn) # Validate checkpoint token - if checkpoint_token not in execution.used_tokens: + if checkpoint_token in execution.used_tokens: msg: str = f"Invalid checkpoint token: {checkpoint_token}" raise InvalidParameterValueException(msg) diff --git a/tests/executor_test.py b/tests/executor_test.py index 78a067e..a53797a 100644 --- a/tests/executor_test.py +++ b/tests/executor_test.py @@ -1916,7 +1916,7 @@ def test_get_execution_state(executor, mock_store): mock_store.load.return_value = mock_execution - result = executor.get_execution_state("test-arn", checkpoint_token="token1") # noqa: S106 + result = executor.get_execution_state("test-arn", checkpoint_token="token3") # noqa: S106 assert len(result.operations) == 2 assert result.next_marker is None @@ -1932,7 +1932,7 @@ def test_get_execution_state_invalid_token(executor, mock_store): with pytest.raises( InvalidParameterValueException, match="Invalid checkpoint token" ): - executor.get_execution_state("test-arn", checkpoint_token="invalid-token") # noqa: S106 + executor.get_execution_state("test-arn", checkpoint_token="token1") # noqa: S106 def test_get_execution_history(executor, mock_store): @@ -1954,7 +1954,7 @@ def test_checkpoint_execution(executor, mock_store): mock_execution.get_new_checkpoint_token.return_value = "new-token" mock_store.load.return_value = mock_execution - result = executor.checkpoint_execution("test-arn", "token1") + result = executor.checkpoint_execution("test-arn", "new-token") assert result.checkpoint_token == "new-token" # noqa: S105 assert result.new_execution_state is None @@ -1971,7 +1971,7 @@ def test_checkpoint_execution_invalid_token(executor, mock_store): with pytest.raises( InvalidParameterValueException, match="Invalid checkpoint token" ): - executor.checkpoint_execution("test-arn", "invalid-token") + executor.checkpoint_execution("test-arn", "token1") # Callback method tests