Skip to content

Commit e05763a

Browse files
author
Rares Polenciuc
committed
Add thread safety to execution operations and unit tests
- Add thread-safe execution operations with proper locking - Add comprehensive unit tests for execution.py - Add concurrent execution tests - Add memory store concurrent tests
1 parent 72879ac commit e05763a

File tree

8 files changed

+516
-66
lines changed

8 files changed

+516
-66
lines changed

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from dataclasses import replace
55
from datetime import UTC, datetime
6+
from aws_durable_execution_sdk_python.threading import OrderedCounter, OrderedLock
67
from typing import Any
78
from uuid import uuid4
89

@@ -46,11 +47,24 @@ def __init__(
4647
self.updates: list[OperationUpdate] = []
4748
self.used_tokens: set[str] = set()
4849
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
49-
self.token_sequence: int = 0
50+
51+
self._token_sequence: int = 0
52+
self._state_lock: OrderedLock = OrderedLock()
5053
self.is_complete: bool = False
5154
self.result: DurableExecutionInvocationOutput | None = None
5255
self.consecutive_failed_invocation_attempts: int = 0
5356

57+
@property
58+
def token_sequence(self) -> int:
59+
"""Get current token sequence value."""
60+
return self._token_sequence
61+
62+
@token_sequence.setter
63+
def token_sequence(self, value: int) -> None:
64+
"""Set token sequence value."""
65+
with self._state_lock:
66+
self._token_sequence = value
67+
5468
@staticmethod
5569
def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002
5670
# make a nicer arn
@@ -68,7 +82,7 @@ def to_dict(self) -> dict[str, Any]:
6882
"Operations": [op.to_dict() for op in self.operations],
6983
"Updates": [update.to_dict() for update in self.updates],
7084
"UsedTokens": list(self.used_tokens),
71-
"TokenSequence": self.token_sequence,
85+
"TokenSequence": self._token_sequence,
7286
"IsComplete": self.is_complete,
7387
"Result": self.result.to_dict() if self.result else None,
7488
"ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts,
@@ -95,7 +109,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
95109
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
96110
]
97111
execution.used_tokens = set(data["UsedTokens"])
98-
execution.token_sequence = data["TokenSequence"]
112+
execution._token_sequence = data["TokenSequence"]
99113
execution.is_complete = data["IsComplete"]
100114
execution.result = (
101115
DurableExecutionInvocationOutput.from_dict(data["Result"])
@@ -109,23 +123,23 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
109123
return execution
110124

111125
def start(self) -> None:
112-
# not thread safe, prob should be
113126
if self.start_input.invocation_id is None:
114127
msg: str = "invocation_id is required"
115128
raise InvalidParameterValueException(msg)
116-
self.operations.append(
117-
Operation(
118-
operation_id=self.start_input.invocation_id,
119-
parent_id=None,
120-
name=self.start_input.execution_name,
121-
start_timestamp=datetime.now(UTC),
122-
operation_type=OperationType.EXECUTION,
123-
status=OperationStatus.STARTED,
124-
execution_details=ExecutionDetails(
125-
input_payload=json.dumps(self.start_input.input)
126-
),
129+
with self._state_lock:
130+
self.operations.append(
131+
Operation(
132+
operation_id=self.start_input.invocation_id,
133+
parent_id=None,
134+
name=self.start_input.execution_name,
135+
start_timestamp=datetime.now(UTC),
136+
operation_type=OperationType.EXECUTION,
137+
status=OperationStatus.STARTED,
138+
execution_details=ExecutionDetails(
139+
input_payload=json.dumps(self.start_input.input)
140+
),
141+
)
127142
)
128-
)
129143

130144
def get_operation_execution_started(self) -> Operation:
131145
if not self.operations:
@@ -137,15 +151,15 @@ def get_operation_execution_started(self) -> Operation:
137151

138152
def get_new_checkpoint_token(self) -> str:
139153
"""Generate a new checkpoint token with incremented sequence"""
140-
# TODO: not thread safe and it should be
141-
self.token_sequence += 1
142-
new_token_sequence = self.token_sequence
143-
token = CheckpointToken(
144-
execution_arn=self.durable_execution_arn, token_sequence=new_token_sequence
145-
)
146-
token_str = token.to_str()
147-
self.used_tokens.add(token_str)
148-
return token_str
154+
with self._state_lock:
155+
self._token_sequence += 1
156+
new_token_sequence = self._token_sequence
157+
token = CheckpointToken(
158+
execution_arn=self.durable_execution_arn, token_sequence=new_token_sequence
159+
)
160+
token_str = token.to_str()
161+
self.used_tokens.add(token_str)
162+
return token_str
149163

150164
def get_navigable_operations(self) -> list[Operation]:
151165
"""Get list of operations, but exclude child operations where the parent has already completed."""
@@ -205,17 +219,16 @@ def complete_wait(self, operation_id: str) -> Operation:
205219
)
206220
raise IllegalStateException(msg_not_wait)
207221

208-
# TODO: make thread-safe. Increment sequence
209-
self.token_sequence += 1
210-
211-
# Build and assign updated operation
212-
self.operations[index] = replace(
213-
operation,
214-
status=OperationStatus.SUCCEEDED,
215-
end_timestamp=datetime.now(UTC),
216-
)
217-
218-
return self.operations[index]
222+
# Thread-safe increment sequence and operation update
223+
with self._state_lock:
224+
self._token_sequence += 1
225+
# Build and assign updated operation
226+
self.operations[index] = replace(
227+
operation,
228+
status=OperationStatus.SUCCEEDED,
229+
end_timestamp=datetime.now(UTC),
230+
)
231+
return self.operations[index]
219232

220233
def complete_retry(self, operation_id: str) -> Operation:
221234
"""Complete STEP retry when timer fires."""
@@ -231,21 +244,21 @@ def complete_retry(self, operation_id: str) -> Operation:
231244
)
232245
raise IllegalStateException(msg_not_step)
233246

234-
# TODO: make thread-safe. Increment sequence
235-
self.token_sequence += 1
236-
237-
# Build updated step_details with cleared next_attempt_timestamp
238-
new_step_details = None
239-
if operation.step_details:
240-
new_step_details = replace(
241-
operation.step_details, next_attempt_timestamp=None
247+
# Thread-safe increment sequence and operation update
248+
with self._state_lock:
249+
self._token_sequence += 1
250+
# Build updated step_details with cleared next_attempt_timestamp
251+
new_step_details = None
252+
if operation.step_details:
253+
new_step_details = replace(
254+
operation.step_details, next_attempt_timestamp=None
255+
)
256+
257+
# Build updated operation
258+
updated_operation = replace(
259+
operation, status=OperationStatus.READY, step_details=new_step_details
242260
)
243261

244-
# Build updated operation
245-
updated_operation = replace(
246-
operation, status=OperationStatus.READY, step_details=new_step_details
247-
)
248-
249-
# Assign
250-
self.operations[index] = updated_operation
251-
return updated_operation
262+
# Assign
263+
self.operations[index] = updated_operation
264+
return updated_operation

src/aws_durable_execution_sdk_python_testing/invoker.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import json
4-
import time
54
from typing import TYPE_CHECKING, Any, Protocol
65

76
import boto3 # type: ignore
@@ -11,11 +10,11 @@
1110
DurableExecutionInvocationOutput,
1211
InitialExecutionState,
1312
)
14-
from aws_durable_execution_sdk_python.lambda_context import LambdaContext
1513

1614
from aws_durable_execution_sdk_python_testing.exceptions import (
1715
DurableFunctionsTestError,
1816
)
17+
from aws_durable_execution_sdk_python_testing.model import LambdaContext
1918

2019

2120
if TYPE_CHECKING:
@@ -46,12 +45,9 @@ def create_test_lambda_context() -> LambdaContext:
4645
}
4746

4847
return LambdaContext(
49-
invoke_id="test-invoke-12345",
48+
aws_request_id="test-invoke-12345",
5049
client_context=client_context_dict,
51-
cognito_identity=cognito_identity_dict,
52-
epoch_deadline_time_in_ms=int(
53-
(time.time() + 900) * 1000
54-
), # 15 minutes from now
50+
identity=cognito_identity_dict,
5551
invoked_function_arn="arn:aws:lambda:us-west-2:123456789012:function:test-function",
5652
tenant_id="test-tenant-789",
5753
)

src/aws_durable_execution_sdk_python_testing/model.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,48 @@
2525
)
2626

2727

28+
@dataclass(frozen=True)
29+
class LambdaContext:
30+
"""Lambda context for testing."""
31+
32+
aws_request_id: str
33+
log_group_name: str | None = None
34+
log_stream_name: str | None = None
35+
function_name: str | None = None
36+
memory_limit_in_mb: str | None = None
37+
function_version: str | None = None
38+
invoked_function_arn: str | None = None
39+
tenant_id: str | None = None
40+
client_context: dict | None = None
41+
identity: dict | None = None
42+
43+
def get_remaining_time_in_millis(self) -> int:
44+
return 900000 # 15 minutes default
45+
46+
def log(self, msg) -> None:
47+
pass # No-op for testing
48+
49+
@classmethod
50+
def from_dict(cls, data: dict[str, Any]):
51+
required_fields = ["aws_request_id"]
52+
for field in required_fields:
53+
if field not in data:
54+
msg: str = f"Missing required field: {field}"
55+
raise InvalidParameterValueException(msg)
56+
return cls(
57+
aws_request_id=data["aws_request_id"],
58+
log_group_name=data.get("log_group_name"),
59+
log_stream_name=data.get("log_stream_name"),
60+
function_name=data.get("function_name"),
61+
memory_limit_in_mb=data.get("memory_limit_in_mb"),
62+
function_version=data.get("function_version"),
63+
invoked_function_arn=data.get("invoked_function_arn"),
64+
tenant_id=data.get("tenant_id"),
65+
client_context=data.get("client_context"),
66+
identity=data.get("identity"),
67+
)
68+
69+
2870
# Web API specific models (not in Smithy but needed for web interface)
2971
@dataclass(frozen=True)
3072
class StartDurableExecutionInput:

src/aws_durable_execution_sdk_python_testing/stores/memory.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import TYPE_CHECKING
66

7+
from aws_durable_execution_sdk_python.threading import OrderedLock
78

89
if TYPE_CHECKING:
910
from aws_durable_execution_sdk_python_testing.execution import Execution
@@ -14,15 +15,20 @@ class InMemoryExecutionStore:
1415

1516
def __init__(self) -> None:
1617
self._store: dict[str, Execution] = {}
18+
self._lock: OrderedLock = OrderedLock()
1719

1820
def save(self, execution: Execution) -> None:
19-
self._store[execution.durable_execution_arn] = execution
21+
with self._lock:
22+
self._store[execution.durable_execution_arn] = execution
2023

2124
def load(self, execution_arn: str) -> Execution:
22-
return self._store[execution_arn]
25+
with self._lock:
26+
return self._store[execution_arn]
2327

2428
def update(self, execution: Execution) -> None:
25-
self._store[execution.durable_execution_arn] = execution
29+
with self._lock:
30+
self._store[execution.durable_execution_arn] = execution
2631

2732
def list_all(self) -> list[Execution]:
28-
return list(self._store.values())
33+
with self._lock:
34+
return list(self._store.values())

0 commit comments

Comments
 (0)