diff --git a/src/aws_durable_execution_sdk_python_testing/execution.py b/src/aws_durable_execution_sdk_python_testing/execution.py index 51342fb..77aadbb 100644 --- a/src/aws_durable_execution_sdk_python_testing/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/execution.py @@ -3,6 +3,7 @@ import json from dataclasses import replace from datetime import UTC, datetime +from threading import Lock from typing import Any from uuid import uuid4 @@ -46,11 +47,17 @@ def __init__( self.updates: list[OperationUpdate] = [] self.used_tokens: set[str] = set() # TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store - self.token_sequence: int = 0 + self._token_sequence: int = 0 + self._state_lock: Lock = Lock() self.is_complete: bool = False self.result: DurableExecutionInvocationOutput | None = None self.consecutive_failed_invocation_attempts: int = 0 + @property + def token_sequence(self) -> int: + """Get current token sequence value.""" + return self._token_sequence + @staticmethod def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002 # make a nicer arn @@ -68,7 +75,7 @@ def to_dict(self) -> dict[str, Any]: "Operations": [op.to_dict() for op in self.operations], "Updates": [update.to_dict() for update in self.updates], "UsedTokens": list(self.used_tokens), - "TokenSequence": self.token_sequence, + "TokenSequence": self._token_sequence, "IsComplete": self.is_complete, "Result": self.result.to_dict() if self.result else None, "ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts, @@ -95,7 +102,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution: OperationUpdate.from_dict(update_data) for update_data in data["Updates"] ] execution.used_tokens = set(data["UsedTokens"]) - execution.token_sequence = data["TokenSequence"] + execution._token_sequence = data["TokenSequence"] # noqa: SLF001 execution.is_complete = data["IsComplete"] execution.result = ( DurableExecutionInvocationOutput.from_dict(data["Result"]) @@ -109,23 +116,23 @@ def from_dict(cls, data: dict[str, Any]) -> Execution: return execution def start(self) -> None: - # not thread safe, prob should be if self.start_input.invocation_id is None: msg: str = "invocation_id is required" raise InvalidParameterValueException(msg) - self.operations.append( - Operation( - operation_id=self.start_input.invocation_id, - parent_id=None, - name=self.start_input.execution_name, - start_timestamp=datetime.now(UTC), - operation_type=OperationType.EXECUTION, - status=OperationStatus.STARTED, - execution_details=ExecutionDetails( - input_payload=json.dumps(self.start_input.input) - ), + with self._state_lock: + self.operations.append( + Operation( + operation_id=self.start_input.invocation_id, + parent_id=None, + name=self.start_input.execution_name, + start_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails( + input_payload=json.dumps(self.start_input.input) + ), + ) ) - ) def get_operation_execution_started(self) -> Operation: if not self.operations: @@ -137,15 +144,16 @@ def get_operation_execution_started(self) -> Operation: def get_new_checkpoint_token(self) -> str: """Generate a new checkpoint token with incremented sequence""" - # TODO: not thread safe and it should be - self.token_sequence += 1 - new_token_sequence = self.token_sequence - token = CheckpointToken( - execution_arn=self.durable_execution_arn, token_sequence=new_token_sequence - ) - token_str = token.to_str() - self.used_tokens.add(token_str) - return token_str + with self._state_lock: + self._token_sequence += 1 + new_token_sequence = self._token_sequence + token = CheckpointToken( + execution_arn=self.durable_execution_arn, + token_sequence=new_token_sequence, + ) + token_str = token.to_str() + self.used_tokens.add(token_str) + return token_str def get_navigable_operations(self) -> list[Operation]: """Get list of operations, but exclude child operations where the parent has already completed.""" @@ -205,17 +213,16 @@ def complete_wait(self, operation_id: str) -> Operation: ) raise IllegalStateException(msg_not_wait) - # TODO: make thread-safe. Increment sequence - self.token_sequence += 1 - - # Build and assign updated operation - self.operations[index] = replace( - operation, - status=OperationStatus.SUCCEEDED, - end_timestamp=datetime.now(UTC), - ) - - return self.operations[index] + # Thread-safe increment sequence and operation update + with self._state_lock: + self._token_sequence += 1 + # Build and assign updated operation + self.operations[index] = replace( + operation, + status=OperationStatus.SUCCEEDED, + end_timestamp=datetime.now(UTC), + ) + return self.operations[index] def complete_retry(self, operation_id: str) -> Operation: """Complete STEP retry when timer fires.""" @@ -231,21 +238,21 @@ def complete_retry(self, operation_id: str) -> Operation: ) raise IllegalStateException(msg_not_step) - # TODO: make thread-safe. Increment sequence - self.token_sequence += 1 - - # Build updated step_details with cleared next_attempt_timestamp - new_step_details = None - if operation.step_details: - new_step_details = replace( - operation.step_details, next_attempt_timestamp=None + # Thread-safe increment sequence and operation update + with self._state_lock: + self._token_sequence += 1 + # Build updated step_details with cleared next_attempt_timestamp + new_step_details = None + if operation.step_details: + new_step_details = replace( + operation.step_details, next_attempt_timestamp=None + ) + + # Build updated operation + updated_operation = replace( + operation, status=OperationStatus.READY, step_details=new_step_details ) - # Build updated operation - updated_operation = replace( - operation, status=OperationStatus.READY, step_details=new_step_details - ) - - # Assign - self.operations[index] = updated_operation - return updated_operation + # Assign + self.operations[index] = updated_operation + return updated_operation diff --git a/src/aws_durable_execution_sdk_python_testing/model.py b/src/aws_durable_execution_sdk_python_testing/model.py index f13e47c..4adf96d 100644 --- a/src/aws_durable_execution_sdk_python_testing/model.py +++ b/src/aws_durable_execution_sdk_python_testing/model.py @@ -19,6 +19,9 @@ StepOptions, WaitOptions, ) +from aws_durable_execution_sdk_python.types import ( + LambdaContext as LambdaContextProtocol, +) from aws_durable_execution_sdk_python_testing.exceptions import ( InvalidParameterValueException, @@ -26,7 +29,7 @@ @dataclass(frozen=True) -class LambdaContext: +class LambdaContext(LambdaContextProtocol): """Lambda context for testing.""" aws_request_id: str diff --git a/src/aws_durable_execution_sdk_python_testing/stores/memory.py b/src/aws_durable_execution_sdk_python_testing/stores/memory.py index 482bef9..9dfc91d 100644 --- a/src/aws_durable_execution_sdk_python_testing/stores/memory.py +++ b/src/aws_durable_execution_sdk_python_testing/stores/memory.py @@ -2,6 +2,7 @@ from __future__ import annotations +from threading import Lock from typing import TYPE_CHECKING @@ -14,15 +15,20 @@ class InMemoryExecutionStore: def __init__(self) -> None: self._store: dict[str, Execution] = {} + self._lock: Lock = Lock() def save(self, execution: Execution) -> None: - self._store[execution.durable_execution_arn] = execution + with self._lock: + self._store[execution.durable_execution_arn] = execution def load(self, execution_arn: str) -> Execution: - return self._store[execution_arn] + with self._lock: + return self._store[execution_arn] def update(self, execution: Execution) -> None: - self._store[execution.durable_execution_arn] = execution + with self._lock: + self._store[execution.durable_execution_arn] = execution def list_all(self) -> list[Execution]: - return list(self._store.values()) + with self._lock: + return list(self._store.values()) diff --git a/tests/execution_concurrent_test.py b/tests/execution_concurrent_test.py new file mode 100644 index 0000000..6ea2bef --- /dev/null +++ b/tests/execution_concurrent_test.py @@ -0,0 +1,83 @@ +"""Concurrent access tests for Execution class.""" + +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput + + +def test_concurrent_token_generation(): + """Test concurrent checkpoint token generation.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-inv-id", + input='{"test": "data"}', + ) + execution = Execution.new(input_data) + tokens = [] + tokens_lock = threading.Lock() + + def generate_token(): + token = execution.get_new_checkpoint_token() + with tokens_lock: + tokens.append(token) + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(generate_token) for _ in range(20)] + + for future in as_completed(futures): + future.result() + + # All tokens should be unique and sequential + assert len(tokens) == 20 + assert len(set(tokens)) == 20 # All unique + assert execution.token_sequence == 20 + + +def test_concurrent_operations_modification(): + """Test concurrent operations list modifications.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-inv-id", + input='{"test": "data"}', + ) + execution = Execution.new(input_data) + results = [] + results_lock = threading.Lock() + + def start_execution(): + execution.start() + with results_lock: + results.append("started") + + def get_operations(): + ops = execution.get_navigable_operations() + with results_lock: + results.append(f"ops-{len(ops)}") + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + # One start operation + futures.append(executor.submit(start_execution)) + # Multiple read operations + futures.extend([executor.submit(get_operations) for _ in range(4)]) + + for future in as_completed(futures): + future.result() + + assert len(results) == 5 + assert "started" in results + # Should have at least one operation after start + final_ops = execution.get_navigable_operations() + assert len(final_ops) >= 1 diff --git a/tests/execution_wait_retry_test.py b/tests/execution_wait_retry_test.py new file mode 100644 index 0000000..b0c9db3 --- /dev/null +++ b/tests/execution_wait_retry_test.py @@ -0,0 +1,80 @@ +"""Additional concurrent tests for wait and retry operations.""" + +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import UTC, datetime + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationStatus, + OperationType, + StepDetails, +) + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput + + +def test_concurrent_wait_and_retry_completion(): + """Test concurrent complete_wait and complete_retry operations.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-inv-id", + input='{"test": "data"}', + ) + execution = Execution.new(input_data) + + # Add WAIT and STEP operations + wait_op = Operation( + operation_id="wait-1", + parent_id=None, + name="test-wait", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + + step_op = Operation( + operation_id="step-1", + parent_id=None, + name="test-step", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=StepDetails(), + ) + + execution.operations.extend([wait_op, step_op]) + + results = [] + results_lock = threading.Lock() + + def complete_wait(): + result = execution.complete_wait("wait-1") + with results_lock: + results.append(f"wait-completed-{result.status.value}") + + def complete_retry(): + result = execution.complete_retry("step-1") + with results_lock: + results.append(f"retry-completed-{result.status.value}") + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [] + futures.append(executor.submit(complete_wait)) + futures.append(executor.submit(complete_retry)) + + for future in as_completed(futures): + future.result() + + assert len(results) == 2 + assert "wait-completed-SUCCEEDED" in results + assert "retry-completed-READY" in results + + # Verify token sequence was incremented twice + assert execution.token_sequence == 2 diff --git a/tests/stores/concurrent_test.py b/tests/stores/concurrent_test.py new file mode 100644 index 0000000..bb06e77 --- /dev/null +++ b/tests/stores/concurrent_test.py @@ -0,0 +1,109 @@ +"""Concurrent access tests for InMemoryExecutionStore.""" + +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +from aws_durable_execution_sdk_python_testing.stores.memory import ( + InMemoryExecutionStore, +) + + +def test_concurrent_save_load(): + """Test concurrent save and load operations.""" + store = InMemoryExecutionStore() + results = [] + results_lock = threading.Lock() + + def save_execution(i: int): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + store.save(execution) + with results_lock: + results.append(f"saved-{i}") + + def load_execution(i: int): + try: + execution = store.load(f"arn-{i}") + with results_lock: + results.append(f"loaded-{execution.start_input.execution_name}") + except KeyError: + with results_lock: + results.append(f"not-found-{i}") + + with ThreadPoolExecutor(max_workers=10) as executor: + # Submit save operations first + futures = [executor.submit(save_execution, i) for i in range(5)] + # Wait for saves to complete + for future in as_completed(futures): + future.result() + + # Then submit load operations + futures = [] + for i in range(5): + futures.append(executor.submit(load_execution, i)) + # Wait for loads to complete + for future in as_completed(futures): + future.result() + + assert len(results) == 10 + + +def test_concurrent_update_list(): + """Test concurrent update and list operations.""" + store = InMemoryExecutionStore() + results = [] + results_lock = threading.Lock() + + # Pre-populate store + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + store.save(execution) + + def update_execution(i: int): + execution = store.load(f"arn-{i}") + execution.is_complete = True + store.update(execution) + with results_lock: + results.append(f"updated-{i}") + + def list_executions(): + executions = store.list_all() + with results_lock: + results.append(f"listed-{len(executions)}") + + with ThreadPoolExecutor(max_workers=6) as executor: + # Submit update operations + futures = [executor.submit(update_execution, i) for i in range(3)] + # Submit list operations + futures.extend([executor.submit(list_executions) for _ in range(3)]) + + # Wait for all operations to complete + for future in as_completed(futures): + future.result() + + assert len(results) == 6 + final_list = store.list_all() + assert len(final_list) == 3 diff --git a/tests/stores/filesystem_store_test.py b/tests/stores/filesystem_store_test.py index 1eb1538..b80c86f 100644 --- a/tests/stores/filesystem_store_test.py +++ b/tests/stores/filesystem_store_test.py @@ -76,7 +76,8 @@ def test_filesystem_execution_store_update(store, sample_execution): store.save(sample_execution) sample_execution.is_complete = True - sample_execution.token_sequence = 5 + for _ in range(5): + sample_execution.get_new_checkpoint_token() store.update(sample_execution) loaded_execution = store.load(sample_execution.durable_execution_arn) @@ -97,7 +98,8 @@ def test_filesystem_execution_store_update_overwrites(store, temp_storage_dir): execution1 = Execution.new(input_data) execution2 = Execution.new(input_data) execution2.durable_execution_arn = execution1.durable_execution_arn - execution2.token_sequence = 10 + for _ in range(10): + execution2.get_new_checkpoint_token() store.save(execution1) store.update(execution2)