Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 59 additions & 52 deletions src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from dataclasses import replace
from datetime import UTC, datetime
from threading import Lock
from typing import Any
from uuid import uuid4

Expand Down Expand Up @@ -46,11 +47,17 @@ def __init__(
self.updates: list[OperationUpdate] = []
self.used_tokens: set[str] = set()
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
self.token_sequence: int = 0
self._token_sequence: int = 0
self._state_lock: Lock = Lock()
self.is_complete: bool = False
self.result: DurableExecutionInvocationOutput | None = None
self.consecutive_failed_invocation_attempts: int = 0

@property
def token_sequence(self) -> int:
"""Get current token sequence value."""
return self._token_sequence

@staticmethod
def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002
# make a nicer arn
Expand All @@ -68,7 +75,7 @@ def to_dict(self) -> dict[str, Any]:
"Operations": [op.to_dict() for op in self.operations],
"Updates": [update.to_dict() for update in self.updates],
"UsedTokens": list(self.used_tokens),
"TokenSequence": self.token_sequence,
"TokenSequence": self._token_sequence,
"IsComplete": self.is_complete,
"Result": self.result.to_dict() if self.result else None,
"ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts,
Expand All @@ -95,7 +102,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
]
execution.used_tokens = set(data["UsedTokens"])
execution.token_sequence = data["TokenSequence"]
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
execution.is_complete = data["IsComplete"]
execution.result = (
DurableExecutionInvocationOutput.from_dict(data["Result"])
Expand All @@ -109,23 +116,23 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
return execution

def start(self) -> None:
# not thread safe, prob should be
if self.start_input.invocation_id is None:
msg: str = "invocation_id is required"
raise InvalidParameterValueException(msg)
self.operations.append(
Operation(
operation_id=self.start_input.invocation_id,
parent_id=None,
name=self.start_input.execution_name,
start_timestamp=datetime.now(UTC),
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=ExecutionDetails(
input_payload=json.dumps(self.start_input.input)
),
with self._state_lock:
self.operations.append(
Operation(
operation_id=self.start_input.invocation_id,
parent_id=None,
name=self.start_input.execution_name,
start_timestamp=datetime.now(UTC),
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=ExecutionDetails(
input_payload=json.dumps(self.start_input.input)
),
)
)
)

def get_operation_execution_started(self) -> Operation:
if not self.operations:
Expand All @@ -137,15 +144,16 @@ def get_operation_execution_started(self) -> Operation:

def get_new_checkpoint_token(self) -> str:
"""Generate a new checkpoint token with incremented sequence"""
# TODO: not thread safe and it should be
self.token_sequence += 1
new_token_sequence = self.token_sequence
token = CheckpointToken(
execution_arn=self.durable_execution_arn, token_sequence=new_token_sequence
)
token_str = token.to_str()
self.used_tokens.add(token_str)
return token_str
with self._state_lock:
self._token_sequence += 1
new_token_sequence = self._token_sequence
token = CheckpointToken(
execution_arn=self.durable_execution_arn,
token_sequence=new_token_sequence,
)
token_str = token.to_str()
self.used_tokens.add(token_str)
return token_str

def get_navigable_operations(self) -> list[Operation]:
"""Get list of operations, but exclude child operations where the parent has already completed."""
Expand Down Expand Up @@ -205,17 +213,16 @@ def complete_wait(self, operation_id: str) -> Operation:
)
raise IllegalStateException(msg_not_wait)

# TODO: make thread-safe. Increment sequence
self.token_sequence += 1

# Build and assign updated operation
self.operations[index] = replace(
operation,
status=OperationStatus.SUCCEEDED,
end_timestamp=datetime.now(UTC),
)

return self.operations[index]
# Thread-safe increment sequence and operation update
with self._state_lock:
self._token_sequence += 1
# Build and assign updated operation
self.operations[index] = replace(
operation,
status=OperationStatus.SUCCEEDED,
end_timestamp=datetime.now(UTC),
)
return self.operations[index]

def complete_retry(self, operation_id: str) -> Operation:
"""Complete STEP retry when timer fires."""
Expand All @@ -231,21 +238,21 @@ def complete_retry(self, operation_id: str) -> Operation:
)
raise IllegalStateException(msg_not_step)

# TODO: make thread-safe. Increment sequence
self.token_sequence += 1

# Build updated step_details with cleared next_attempt_timestamp
new_step_details = None
if operation.step_details:
new_step_details = replace(
operation.step_details, next_attempt_timestamp=None
# Thread-safe increment sequence and operation update
with self._state_lock:
self._token_sequence += 1
# Build updated step_details with cleared next_attempt_timestamp
new_step_details = None
if operation.step_details:
new_step_details = replace(
operation.step_details, next_attempt_timestamp=None
)

# Build updated operation
updated_operation = replace(
operation, status=OperationStatus.READY, step_details=new_step_details
)

# Build updated operation
updated_operation = replace(
operation, status=OperationStatus.READY, step_details=new_step_details
)

# Assign
self.operations[index] = updated_operation
return updated_operation
# Assign
self.operations[index] = updated_operation
return updated_operation
5 changes: 4 additions & 1 deletion src/aws_durable_execution_sdk_python_testing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
StepOptions,
WaitOptions,
)
from aws_durable_execution_sdk_python.types import (
LambdaContext as LambdaContextProtocol,
)

from aws_durable_execution_sdk_python_testing.exceptions import (
InvalidParameterValueException,
)


@dataclass(frozen=True)
class LambdaContext:
class LambdaContext(LambdaContextProtocol):
"""Lambda context for testing."""

aws_request_id: str
Expand Down
14 changes: 10 additions & 4 deletions src/aws_durable_execution_sdk_python_testing/stores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from threading import Lock
from typing import TYPE_CHECKING


Expand All @@ -14,15 +15,20 @@ class InMemoryExecutionStore:

def __init__(self) -> None:
self._store: dict[str, Execution] = {}
self._lock: Lock = Lock()

def save(self, execution: Execution) -> None:
self._store[execution.durable_execution_arn] = execution
with self._lock:
self._store[execution.durable_execution_arn] = execution

def load(self, execution_arn: str) -> Execution:
return self._store[execution_arn]
with self._lock:
return self._store[execution_arn]

def update(self, execution: Execution) -> None:
self._store[execution.durable_execution_arn] = execution
with self._lock:
self._store[execution.durable_execution_arn] = execution

def list_all(self) -> list[Execution]:
return list(self._store.values())
with self._lock:
return list(self._store.values())
83 changes: 83 additions & 0 deletions tests/execution_concurrent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Concurrent access tests for Execution class."""

import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

from aws_durable_execution_sdk_python_testing.execution import Execution
from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput


def test_concurrent_token_generation():
"""Test concurrent checkpoint token generation."""
input_data = StartDurableExecutionInput(
account_id="123456789012",
function_name="test-function",
function_qualifier="$LATEST",
execution_name="test-execution",
execution_timeout_seconds=300,
execution_retention_period_days=7,
invocation_id="test-inv-id",
input='{"test": "data"}',
)
execution = Execution.new(input_data)
tokens = []
tokens_lock = threading.Lock()

def generate_token():
token = execution.get_new_checkpoint_token()
with tokens_lock:
tokens.append(token)

with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(generate_token) for _ in range(20)]

for future in as_completed(futures):
future.result()

# All tokens should be unique and sequential
assert len(tokens) == 20
assert len(set(tokens)) == 20 # All unique
assert execution.token_sequence == 20


def test_concurrent_operations_modification():
"""Test concurrent operations list modifications."""
input_data = StartDurableExecutionInput(
account_id="123456789012",
function_name="test-function",
function_qualifier="$LATEST",
execution_name="test-execution",
execution_timeout_seconds=300,
execution_retention_period_days=7,
invocation_id="test-inv-id",
input='{"test": "data"}',
)
execution = Execution.new(input_data)
results = []
results_lock = threading.Lock()

def start_execution():
execution.start()
with results_lock:
results.append("started")

def get_operations():
ops = execution.get_navigable_operations()
with results_lock:
results.append(f"ops-{len(ops)}")

with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
# One start operation
futures.append(executor.submit(start_execution))
# Multiple read operations
futures.extend([executor.submit(get_operations) for _ in range(4)])

for future in as_completed(futures):
future.result()

assert len(results) == 5
assert "started" in results
# Should have at least one operation after start
final_ops = execution.get_navigable_operations()
assert len(final_ops) >= 1
80 changes: 80 additions & 0 deletions tests/execution_wait_retry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Additional concurrent tests for wait and retry operations."""

import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import UTC, datetime

from aws_durable_execution_sdk_python.lambda_service import (
Operation,
OperationStatus,
OperationType,
StepDetails,
)

from aws_durable_execution_sdk_python_testing.execution import Execution
from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput


def test_concurrent_wait_and_retry_completion():
"""Test concurrent complete_wait and complete_retry operations."""
input_data = StartDurableExecutionInput(
account_id="123456789012",
function_name="test-function",
function_qualifier="$LATEST",
execution_name="test-execution",
execution_timeout_seconds=300,
execution_retention_period_days=7,
invocation_id="test-inv-id",
input='{"test": "data"}',
)
execution = Execution.new(input_data)

# Add WAIT and STEP operations
wait_op = Operation(
operation_id="wait-1",
parent_id=None,
name="test-wait",
start_timestamp=datetime.now(UTC),
operation_type=OperationType.WAIT,
status=OperationStatus.STARTED,
)

step_op = Operation(
operation_id="step-1",
parent_id=None,
name="test-step",
start_timestamp=datetime.now(UTC),
operation_type=OperationType.STEP,
status=OperationStatus.PENDING,
step_details=StepDetails(),
)

execution.operations.extend([wait_op, step_op])

results = []
results_lock = threading.Lock()

def complete_wait():
result = execution.complete_wait("wait-1")
with results_lock:
results.append(f"wait-completed-{result.status.value}")

def complete_retry():
result = execution.complete_retry("step-1")
with results_lock:
results.append(f"retry-completed-{result.status.value}")

with ThreadPoolExecutor(max_workers=2) as executor:
futures = []
futures.append(executor.submit(complete_wait))
futures.append(executor.submit(complete_retry))

for future in as_completed(futures):
future.result()

assert len(results) == 2
assert "wait-completed-SUCCEEDED" in results
assert "retry-completed-READY" in results

# Verify token sequence was incremented twice
assert execution.token_sequence == 2
Loading