Skip to content

Commit d539e2b

Browse files
author
Rares Polenciuc
committed
Add thread safety to execution operations and unit tests
- Added thread-safe synchronization to Execution and InMemoryExecutionStore classes using standard threading.Lock - Execution.py: Added _state_lock (Lock) to synchronize all state modifications including token sequence increments, operations list updates, and used tokens modifications in methods like start(), get_new_checkpoint_token(), complete_wait(), and complete_retry() - stores/memory.py: Added _lock (Lock) to ensure atomic operations for save(), load(), update(), and list_all() methods - Added unit tests for execution operations: - tests/execution_concurrent_test.py: Added concurrent access tests verifying thread-safe operations under multi-threaded scenarios - tests/execution_wait_retry_test.py: Added tests for wait and retry operations - tests/stores/concurrent_test.py: Added concurrent access tests for InMemoryExecutionStore verifying thread-safe operations - Made token_sequence a read-only property without setter to encapsulate mutations within the class
1 parent be4af22 commit d539e2b

File tree

5 files changed

+341
-56
lines changed

5 files changed

+341
-56
lines changed

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from dataclasses import replace
55
from datetime import UTC, datetime
6+
from threading import Lock
67
from typing import Any
78
from uuid import uuid4
89

@@ -46,11 +47,17 @@ def __init__(
4647
self.updates: list[OperationUpdate] = []
4748
self.used_tokens: set[str] = set()
4849
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
49-
self.token_sequence: int = 0
50+
self._token_sequence: int = 0
51+
self._state_lock: Lock = Lock()
5052
self.is_complete: bool = False
5153
self.result: DurableExecutionInvocationOutput | None = None
5254
self.consecutive_failed_invocation_attempts: int = 0
5355

56+
@property
57+
def token_sequence(self) -> int:
58+
"""Get current token sequence value."""
59+
return self._token_sequence
60+
5461
@staticmethod
5562
def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002
5663
# make a nicer arn
@@ -68,7 +75,7 @@ def to_dict(self) -> dict[str, Any]:
6875
"Operations": [op.to_dict() for op in self.operations],
6976
"Updates": [update.to_dict() for update in self.updates],
7077
"UsedTokens": list(self.used_tokens),
71-
"TokenSequence": self.token_sequence,
78+
"TokenSequence": self._token_sequence,
7279
"IsComplete": self.is_complete,
7380
"Result": self.result.to_dict() if self.result else None,
7481
"ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts,
@@ -95,7 +102,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
95102
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
96103
]
97104
execution.used_tokens = set(data["UsedTokens"])
98-
execution.token_sequence = data["TokenSequence"]
105+
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
99106
execution.is_complete = data["IsComplete"]
100107
execution.result = (
101108
DurableExecutionInvocationOutput.from_dict(data["Result"])
@@ -109,23 +116,23 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
109116
return execution
110117

111118
def start(self) -> None:
112-
# not thread safe, prob should be
113119
if self.start_input.invocation_id is None:
114120
msg: str = "invocation_id is required"
115121
raise InvalidParameterValueException(msg)
116-
self.operations.append(
117-
Operation(
118-
operation_id=self.start_input.invocation_id,
119-
parent_id=None,
120-
name=self.start_input.execution_name,
121-
start_timestamp=datetime.now(UTC),
122-
operation_type=OperationType.EXECUTION,
123-
status=OperationStatus.STARTED,
124-
execution_details=ExecutionDetails(
125-
input_payload=json.dumps(self.start_input.input)
126-
),
122+
with self._state_lock:
123+
self.operations.append(
124+
Operation(
125+
operation_id=self.start_input.invocation_id,
126+
parent_id=None,
127+
name=self.start_input.execution_name,
128+
start_timestamp=datetime.now(UTC),
129+
operation_type=OperationType.EXECUTION,
130+
status=OperationStatus.STARTED,
131+
execution_details=ExecutionDetails(
132+
input_payload=json.dumps(self.start_input.input)
133+
),
134+
)
127135
)
128-
)
129136

130137
def get_operation_execution_started(self) -> Operation:
131138
if not self.operations:
@@ -137,15 +144,16 @@ def get_operation_execution_started(self) -> Operation:
137144

138145
def get_new_checkpoint_token(self) -> str:
139146
"""Generate a new checkpoint token with incremented sequence"""
140-
# TODO: not thread safe and it should be
141-
self.token_sequence += 1
142-
new_token_sequence = self.token_sequence
143-
token = CheckpointToken(
144-
execution_arn=self.durable_execution_arn, token_sequence=new_token_sequence
145-
)
146-
token_str = token.to_str()
147-
self.used_tokens.add(token_str)
148-
return token_str
147+
with self._state_lock:
148+
self._token_sequence += 1
149+
new_token_sequence = self._token_sequence
150+
token = CheckpointToken(
151+
execution_arn=self.durable_execution_arn,
152+
token_sequence=new_token_sequence,
153+
)
154+
token_str = token.to_str()
155+
self.used_tokens.add(token_str)
156+
return token_str
149157

150158
def get_navigable_operations(self) -> list[Operation]:
151159
"""Get list of operations, but exclude child operations where the parent has already completed."""
@@ -205,17 +213,16 @@ def complete_wait(self, operation_id: str) -> Operation:
205213
)
206214
raise IllegalStateException(msg_not_wait)
207215

208-
# TODO: make thread-safe. Increment sequence
209-
self.token_sequence += 1
210-
211-
# Build and assign updated operation
212-
self.operations[index] = replace(
213-
operation,
214-
status=OperationStatus.SUCCEEDED,
215-
end_timestamp=datetime.now(UTC),
216-
)
217-
218-
return self.operations[index]
216+
# Thread-safe increment sequence and operation update
217+
with self._state_lock:
218+
self._token_sequence += 1
219+
# Build and assign updated operation
220+
self.operations[index] = replace(
221+
operation,
222+
status=OperationStatus.SUCCEEDED,
223+
end_timestamp=datetime.now(UTC),
224+
)
225+
return self.operations[index]
219226

220227
def complete_retry(self, operation_id: str) -> Operation:
221228
"""Complete STEP retry when timer fires."""
@@ -231,21 +238,21 @@ def complete_retry(self, operation_id: str) -> Operation:
231238
)
232239
raise IllegalStateException(msg_not_step)
233240

234-
# TODO: make thread-safe. Increment sequence
235-
self.token_sequence += 1
236-
237-
# Build updated step_details with cleared next_attempt_timestamp
238-
new_step_details = None
239-
if operation.step_details:
240-
new_step_details = replace(
241-
operation.step_details, next_attempt_timestamp=None
241+
# Thread-safe increment sequence and operation update
242+
with self._state_lock:
243+
self._token_sequence += 1
244+
# Build updated step_details with cleared next_attempt_timestamp
245+
new_step_details = None
246+
if operation.step_details:
247+
new_step_details = replace(
248+
operation.step_details, next_attempt_timestamp=None
249+
)
250+
251+
# Build updated operation
252+
updated_operation = replace(
253+
operation, status=OperationStatus.READY, step_details=new_step_details
242254
)
243255

244-
# Build updated operation
245-
updated_operation = replace(
246-
operation, status=OperationStatus.READY, step_details=new_step_details
247-
)
248-
249-
# Assign
250-
self.operations[index] = updated_operation
251-
return updated_operation
256+
# Assign
257+
self.operations[index] = updated_operation
258+
return updated_operation

src/aws_durable_execution_sdk_python_testing/stores/memory.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from threading import Lock
56
from typing import TYPE_CHECKING
67

78

@@ -14,15 +15,20 @@ class InMemoryExecutionStore:
1415

1516
def __init__(self) -> None:
1617
self._store: dict[str, Execution] = {}
18+
self._lock: Lock = Lock()
1719

1820
def save(self, execution: Execution) -> None:
19-
self._store[execution.durable_execution_arn] = execution
21+
with self._lock:
22+
self._store[execution.durable_execution_arn] = execution
2023

2124
def load(self, execution_arn: str) -> Execution:
22-
return self._store[execution_arn]
25+
with self._lock:
26+
return self._store[execution_arn]
2327

2428
def update(self, execution: Execution) -> None:
25-
self._store[execution.durable_execution_arn] = execution
29+
with self._lock:
30+
self._store[execution.durable_execution_arn] = execution
2631

2732
def list_all(self) -> list[Execution]:
28-
return list(self._store.values())
33+
with self._lock:
34+
return list(self._store.values())

tests/execution_concurrent_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Concurrent access tests for Execution class."""
2+
3+
import threading
4+
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
6+
from aws_durable_execution_sdk_python_testing.execution import Execution
7+
from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
8+
9+
10+
def test_concurrent_token_generation():
11+
"""Test concurrent checkpoint token generation."""
12+
input_data = StartDurableExecutionInput(
13+
account_id="123456789012",
14+
function_name="test-function",
15+
function_qualifier="$LATEST",
16+
execution_name="test-execution",
17+
execution_timeout_seconds=300,
18+
execution_retention_period_days=7,
19+
invocation_id="test-inv-id",
20+
input='{"test": "data"}',
21+
)
22+
execution = Execution.new(input_data)
23+
tokens = []
24+
tokens_lock = threading.Lock()
25+
26+
def generate_token():
27+
token = execution.get_new_checkpoint_token()
28+
with tokens_lock:
29+
tokens.append(token)
30+
31+
with ThreadPoolExecutor(max_workers=10) as executor:
32+
futures = [executor.submit(generate_token) for _ in range(20)]
33+
34+
for future in as_completed(futures):
35+
future.result()
36+
37+
# All tokens should be unique and sequential
38+
assert len(tokens) == 20
39+
assert len(set(tokens)) == 20 # All unique
40+
assert execution.token_sequence == 20
41+
42+
43+
def test_concurrent_operations_modification():
44+
"""Test concurrent operations list modifications."""
45+
input_data = StartDurableExecutionInput(
46+
account_id="123456789012",
47+
function_name="test-function",
48+
function_qualifier="$LATEST",
49+
execution_name="test-execution",
50+
execution_timeout_seconds=300,
51+
execution_retention_period_days=7,
52+
invocation_id="test-inv-id",
53+
input='{"test": "data"}',
54+
)
55+
execution = Execution.new(input_data)
56+
results = []
57+
results_lock = threading.Lock()
58+
59+
def start_execution():
60+
execution.start()
61+
with results_lock:
62+
results.append("started")
63+
64+
def get_operations():
65+
ops = execution.get_navigable_operations()
66+
with results_lock:
67+
results.append(f"ops-{len(ops)}")
68+
69+
with ThreadPoolExecutor(max_workers=5) as executor:
70+
futures = []
71+
# One start operation
72+
futures.append(executor.submit(start_execution))
73+
# Multiple read operations
74+
futures.extend([executor.submit(get_operations) for _ in range(4)])
75+
76+
for future in as_completed(futures):
77+
future.result()
78+
79+
assert len(results) == 5
80+
assert "started" in results
81+
# Should have at least one operation after start
82+
final_ops = execution.get_navigable_operations()
83+
assert len(final_ops) >= 1

tests/execution_wait_retry_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Additional concurrent tests for wait and retry operations."""
2+
3+
import threading
4+
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
from datetime import UTC, datetime
6+
7+
from aws_durable_execution_sdk_python.lambda_service import (
8+
Operation,
9+
OperationStatus,
10+
OperationType,
11+
StepDetails,
12+
)
13+
14+
from aws_durable_execution_sdk_python_testing.execution import Execution
15+
from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
16+
17+
18+
def test_concurrent_wait_and_retry_completion():
19+
"""Test concurrent complete_wait and complete_retry operations."""
20+
input_data = StartDurableExecutionInput(
21+
account_id="123456789012",
22+
function_name="test-function",
23+
function_qualifier="$LATEST",
24+
execution_name="test-execution",
25+
execution_timeout_seconds=300,
26+
execution_retention_period_days=7,
27+
invocation_id="test-inv-id",
28+
input='{"test": "data"}',
29+
)
30+
execution = Execution.new(input_data)
31+
32+
# Add WAIT and STEP operations
33+
wait_op = Operation(
34+
operation_id="wait-1",
35+
parent_id=None,
36+
name="test-wait",
37+
start_timestamp=datetime.now(UTC),
38+
operation_type=OperationType.WAIT,
39+
status=OperationStatus.STARTED,
40+
)
41+
42+
step_op = Operation(
43+
operation_id="step-1",
44+
parent_id=None,
45+
name="test-step",
46+
start_timestamp=datetime.now(UTC),
47+
operation_type=OperationType.STEP,
48+
status=OperationStatus.PENDING,
49+
step_details=StepDetails(),
50+
)
51+
52+
execution.operations.extend([wait_op, step_op])
53+
54+
results = []
55+
results_lock = threading.Lock()
56+
57+
def complete_wait():
58+
result = execution.complete_wait("wait-1")
59+
with results_lock:
60+
results.append(f"wait-completed-{result.status.value}")
61+
62+
def complete_retry():
63+
result = execution.complete_retry("step-1")
64+
with results_lock:
65+
results.append(f"retry-completed-{result.status.value}")
66+
67+
with ThreadPoolExecutor(max_workers=2) as executor:
68+
futures = []
69+
futures.append(executor.submit(complete_wait))
70+
futures.append(executor.submit(complete_retry))
71+
72+
for future in as_completed(futures):
73+
future.result()
74+
75+
assert len(results) == 2
76+
assert "wait-completed-SUCCEEDED" in results
77+
assert "retry-completed-READY" in results
78+
79+
# Verify token sequence was incremented twice
80+
assert execution.token_sequence == 2

0 commit comments

Comments
 (0)