From af0522b3350e252ef72db8bc838823357d7be110 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Fri, 19 Sep 2025 17:12:34 -0700 Subject: [PATCH] chore: use updated sdk model --- .../lambda_service.py | 20 ++++++-- src/aws_durable_execution_sdk_python/state.py | 2 + tests/e2e/execution_int_test.py | 28 +++++++++-- tests/lambda_service_test.py | 48 +++++++++++++------ tests/state_test.py | 23 +++++++-- 5 files changed, 95 insertions(+), 26 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index b421c01..8f99f5d 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -780,13 +780,18 @@ class DurableServiceClient(Protocol): def checkpoint( self, + durable_execution_arn: str, checkpoint_token: str, updates: list[OperationUpdate], client_token: str | None, ) -> CheckpointOutput: ... # pragma: no cover def get_execution_state( - self, checkpoint_token: str, next_marker: str, max_items: int = 1000 + self, + durable_execution_arn: str, + checkpoint_token: str, + next_marker: str, + max_items: int = 1000, ) -> StateOutput: ... # pragma: no cover def stop( @@ -866,12 +871,14 @@ def initialize_from_env() -> LambdaClient: def checkpoint( self, + durable_execution_arn: str, checkpoint_token: str, updates: list[OperationUpdate], client_token: str | None, ) -> CheckpointOutput: try: params = { + "DurableExecutionArn": durable_execution_arn, "CheckpointToken": checkpoint_token, "Updates": [o.to_dict() for o in updates], } @@ -888,10 +895,17 @@ def checkpoint( raise CheckpointError(e) from e def get_execution_state( - self, checkpoint_token: str, next_marker: str, max_items: int = 1000 + self, + durable_execution_arn: str, + checkpoint_token: str, + next_marker: str, + max_items: int = 1000, ) -> StateOutput: result: MutableMapping[str, Any] = self.client.get_durable_execution_state( - CheckpointToken=checkpoint_token, Marker=next_marker, MaxItems=max_items + DurableExecutionArn=durable_execution_arn, + CheckpointToken=checkpoint_token, + Marker=next_marker, + MaxItems=max_items, ) return StateOutput.from_dict(result) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index dd101a0..ff43e87 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -162,6 +162,7 @@ def fetch_paginated_operations( ) while next_marker: output: StateOutput = self._service_client.get_execution_state( + durable_execution_arn=self.durable_execution_arn, checkpoint_token=checkpoint_token, next_marker=next_marker, ) @@ -227,6 +228,7 @@ def create_checkpoint( [operation_update] if operation_update is not None else [] ) output: CheckpointOutput = self._service_client.checkpoint( + durable_execution_arn=self.durable_execution_arn, checkpoint_token=self._current_checkpoint_token, updates=updates, client_token=None, diff --git a/tests/e2e/execution_int_test.py b/tests/e2e/execution_int_test.py index a236bdc..5ff8d64 100644 --- a/tests/e2e/execution_int_test.py +++ b/tests/e2e/execution_int_test.py @@ -65,7 +65,12 @@ def my_handler(event, context: DurableContext) -> list[str]: # Mock the checkpoint method to track calls checkpoint_calls = [] - def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107 + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): checkpoint_calls.append(updates) return CheckpointOutput( @@ -142,7 +147,12 @@ def my_handler(event, context: DurableContext): # Mock the checkpoint method to track calls checkpoint_calls = [] - def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107 + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): checkpoint_calls.append(updates) return CheckpointOutput( @@ -224,7 +234,12 @@ def my_handler(event, context): # Mock the checkpoint method to track calls checkpoint_calls = [] - def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107 + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): checkpoint_calls.append(updates) return CheckpointOutput( @@ -310,7 +325,12 @@ def my_handler(event: Any, context: DurableContext): # Mock the checkpoint method to track calls checkpoint_calls = [] - def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107 + def mock_checkpoint( + durable_execution_arn, + checkpoint_token, + updates, + client_token="token", # noqa: S107 + ): checkpoint_calls.append(updates) return CheckpointOutput( diff --git a/tests/lambda_service_test.py b/tests/lambda_service_test.py index 137b66d..328e536 100644 --- a/tests/lambda_service_test.py +++ b/tests/lambda_service_test.py @@ -838,10 +838,12 @@ def test_lambda_client_checkpoint(): action=OperationAction.START, ) - result = lambda_client.checkpoint("token123", [update], None) + result = lambda_client.checkpoint("arn123", "token123", [update], None) mock_client.checkpoint_durable_execution.assert_called_once_with( - CheckpointToken="token123", Updates=[update.to_dict()] + DurableExecutionArn="arn123", + CheckpointToken="token123", + Updates=[update.to_dict()], ) assert isinstance(result, CheckpointOutput) assert result.checkpoint_token == "new_token" # noqa: S105 @@ -862,9 +864,12 @@ def test_lambda_client_checkpoint_with_client_token(): action=OperationAction.START, ) - result = lambda_client.checkpoint("token123", [update], "client-token-123") + result = lambda_client.checkpoint( + "arn123", "token123", [update], "client-token-123" + ) mock_client.checkpoint_durable_execution.assert_called_once_with( + DurableExecutionArn="arn123", CheckpointToken="token123", Updates=[update.to_dict()], ClientToken="client-token-123", @@ -888,10 +893,12 @@ def test_lambda_client_checkpoint_with_explicit_none_client_token(): action=OperationAction.START, ) - result = lambda_client.checkpoint("token123", [update], None) + result = lambda_client.checkpoint("arn123", "token123", [update], None) mock_client.checkpoint_durable_execution.assert_called_once_with( - CheckpointToken="token123", Updates=[update.to_dict()] + DurableExecutionArn="arn123", + CheckpointToken="token123", + Updates=[update.to_dict()], ) assert isinstance(result, CheckpointOutput) assert result.checkpoint_token == "new_token" # noqa: S105 @@ -912,10 +919,13 @@ def test_lambda_client_checkpoint_with_empty_string_client_token(): action=OperationAction.START, ) - result = lambda_client.checkpoint("token123", [update], "") + result = lambda_client.checkpoint("arn123", "token123", [update], "") mock_client.checkpoint_durable_execution.assert_called_once_with( - CheckpointToken="token123", Updates=[update.to_dict()], ClientToken="" + DurableExecutionArn="arn123", + CheckpointToken="token123", + Updates=[update.to_dict()], + ClientToken="", ) assert isinstance(result, CheckpointOutput) assert result.checkpoint_token == "new_token" # noqa: S105 @@ -936,9 +946,10 @@ def test_lambda_client_checkpoint_with_string_value_client_token(): action=OperationAction.START, ) - result = lambda_client.checkpoint("token123", [update], "my-client-token") + result = lambda_client.checkpoint("arn123", "token123", [update], "my-client-token") mock_client.checkpoint_durable_execution.assert_called_once_with( + DurableExecutionArn="arn123", CheckpointToken="token123", Updates=[update.to_dict()], ClientToken="my-client-token", @@ -960,7 +971,7 @@ def test_lambda_client_checkpoint_with_exception(): ) with pytest.raises(CheckpointError): - lambda_client.checkpoint("token123", [update], None) + lambda_client.checkpoint("arn123", "token123", [update], None) def test_lambda_client_get_execution_state(): @@ -971,10 +982,13 @@ def test_lambda_client_get_execution_state(): } lambda_client = LambdaClient(mock_client) - result = lambda_client.get_execution_state("token123", "marker", 500) + result = lambda_client.get_execution_state("arn123", "token123", "marker", 500) mock_client.get_durable_execution_state.assert_called_once_with( - CheckpointToken="token123", Marker="marker", MaxItems=500 + DurableExecutionArn="arn123", + CheckpointToken="token123", + Marker="marker", + MaxItems=500, ) assert len(result.operations) == 1 @@ -1018,9 +1032,11 @@ def test_durable_service_client_protocol_checkpoint(): ) ] - result = mock_client.checkpoint("token", updates, "client_token") + result = mock_client.checkpoint("arn123", "token", updates, "client_token") - mock_client.checkpoint.assert_called_once_with("token", updates, "client_token") + mock_client.checkpoint.assert_called_once_with( + "arn123", "token", updates, "client_token" + ) assert result == mock_output @@ -1030,9 +1046,11 @@ def test_durable_service_client_protocol_get_execution_state(): mock_output = StateOutput(operations=[], next_marker="marker") mock_client.get_execution_state.return_value = mock_output - result = mock_client.get_execution_state("token", "marker", 1000) + result = mock_client.get_execution_state("arn123", "token", "marker", 1000) - mock_client.get_execution_state.assert_called_once_with("token", "marker", 1000) + mock_client.get_execution_state.assert_called_once_with( + "arn123", "token", "marker", 1000 + ) assert result == mock_output diff --git a/tests/state_test.py b/tests/state_test.py index e17aca6..6a952f1 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -384,6 +384,7 @@ def test_create_checkpoint(): # Verify the checkpoint was called mock_lambda_client.checkpoint.assert_called_once_with( + durable_execution_arn="test_arn", checkpoint_token="token123", # noqa: S106 updates=[operation_update], client_token=None, @@ -416,6 +417,7 @@ def test_create_checkpoint_with_none(): # Verify the checkpoint was called with empty updates mock_lambda_client.checkpoint.assert_called_once_with( + durable_execution_arn="test_arn", checkpoint_token="token123", # noqa: S106 updates=[], client_token=None, @@ -444,6 +446,7 @@ def test_create_checkpoint_with_no_args(): # Verify the checkpoint was called with empty updates mock_lambda_client.checkpoint.assert_called_once_with( + durable_execution_arn="test_arn", checkpoint_token="token123", # noqa: S106 updates=[], client_token=None, @@ -514,7 +517,7 @@ def test_checkpointed_result_is_timed_out_false_for_other_statuses(): def test_fetch_paginated_operations_with_marker(): mock_lambda_client = Mock(spec=LambdaClient) - def mock_get_execution_state(checkpoint_token, next_marker): + def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marker): resp = { "marker1": StateOutput( operations=[ @@ -573,9 +576,21 @@ def mock_get_execution_state(checkpoint_token, next_marker): assert mock_lambda_client.get_execution_state.call_count == 3 mock_lambda_client.get_execution_state.assert_has_calls( [ - call(checkpoint_token="test_token", next_marker="marker1"), # noqa: S106 - call(checkpoint_token="test_token", next_marker="marker2"), # noqa: S106 - call(checkpoint_token="test_token", next_marker="marker3"), # noqa: S106 + call( + durable_execution_arn="test_arn", + checkpoint_token="test_token", # noqa: S106 + next_marker="marker1", + ), + call( + durable_execution_arn="test_arn", + checkpoint_token="test_token", # noqa: S106 + next_marker="marker2", + ), + call( + durable_execution_arn="test_arn", + checkpoint_token="test_token", # noqa: S106 + next_marker="marker3", + ), ] )