33import json
44from dataclasses import replace
55from datetime import UTC , datetime
6+ from aws_durable_execution_sdk_python .threading import OrderedCounter , OrderedLock
67from typing import Any
78from uuid import uuid4
89
@@ -46,11 +47,24 @@ def __init__(
4647 self .updates : list [OperationUpdate ] = []
4748 self .used_tokens : set [str ] = set ()
4849 # TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
49- self .token_sequence : int = 0
50+
51+ self ._token_sequence : int = 0
52+ self ._state_lock : OrderedLock = OrderedLock ()
5053 self .is_complete : bool = False
5154 self .result : DurableExecutionInvocationOutput | None = None
5255 self .consecutive_failed_invocation_attempts : int = 0
5356
57+ @property
58+ def token_sequence (self ) -> int :
59+ """Get current token sequence value."""
60+ return self ._token_sequence
61+
62+ @token_sequence .setter
63+ def token_sequence (self , value : int ) -> None :
64+ """Set token sequence value."""
65+ with self ._state_lock :
66+ self ._token_sequence = value
67+
5468 @staticmethod
5569 def new (input : StartDurableExecutionInput ) -> Execution : # noqa: A002
5670 # make a nicer arn
@@ -68,7 +82,7 @@ def to_dict(self) -> dict[str, Any]:
6882 "Operations" : [op .to_dict () for op in self .operations ],
6983 "Updates" : [update .to_dict () for update in self .updates ],
7084 "UsedTokens" : list (self .used_tokens ),
71- "TokenSequence" : self .token_sequence ,
85+ "TokenSequence" : self ._token_sequence ,
7286 "IsComplete" : self .is_complete ,
7387 "Result" : self .result .to_dict () if self .result else None ,
7488 "ConsecutiveFailedInvocationAttempts" : self .consecutive_failed_invocation_attempts ,
@@ -95,7 +109,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
95109 OperationUpdate .from_dict (update_data ) for update_data in data ["Updates" ]
96110 ]
97111 execution .used_tokens = set (data ["UsedTokens" ])
98- execution .token_sequence = data ["TokenSequence" ]
112+ execution ._token_sequence = data ["TokenSequence" ]
99113 execution .is_complete = data ["IsComplete" ]
100114 execution .result = (
101115 DurableExecutionInvocationOutput .from_dict (data ["Result" ])
@@ -109,23 +123,23 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
109123 return execution
110124
111125 def start (self ) -> None :
112- # not thread safe, prob should be
113126 if self .start_input .invocation_id is None :
114127 msg : str = "invocation_id is required"
115128 raise InvalidParameterValueException (msg )
116- self .operations .append (
117- Operation (
118- operation_id = self .start_input .invocation_id ,
119- parent_id = None ,
120- name = self .start_input .execution_name ,
121- start_timestamp = datetime .now (UTC ),
122- operation_type = OperationType .EXECUTION ,
123- status = OperationStatus .STARTED ,
124- execution_details = ExecutionDetails (
125- input_payload = json .dumps (self .start_input .input )
126- ),
129+ with self ._state_lock :
130+ self .operations .append (
131+ Operation (
132+ operation_id = self .start_input .invocation_id ,
133+ parent_id = None ,
134+ name = self .start_input .execution_name ,
135+ start_timestamp = datetime .now (UTC ),
136+ operation_type = OperationType .EXECUTION ,
137+ status = OperationStatus .STARTED ,
138+ execution_details = ExecutionDetails (
139+ input_payload = json .dumps (self .start_input .input )
140+ ),
141+ )
127142 )
128- )
129143
130144 def get_operation_execution_started (self ) -> Operation :
131145 if not self .operations :
@@ -137,15 +151,15 @@ def get_operation_execution_started(self) -> Operation:
137151
138152 def get_new_checkpoint_token (self ) -> str :
139153 """Generate a new checkpoint token with incremented sequence"""
140- # TODO: not thread safe and it should be
141- self .token_sequence += 1
142- new_token_sequence = self .token_sequence
143- token = CheckpointToken (
144- execution_arn = self .durable_execution_arn , token_sequence = new_token_sequence
145- )
146- token_str = token .to_str ()
147- self .used_tokens .add (token_str )
148- return token_str
154+ with self . _state_lock :
155+ self ._token_sequence += 1
156+ new_token_sequence = self ._token_sequence
157+ token = CheckpointToken (
158+ execution_arn = self .durable_execution_arn , token_sequence = new_token_sequence
159+ )
160+ token_str = token .to_str ()
161+ self .used_tokens .add (token_str )
162+ return token_str
149163
150164 def get_navigable_operations (self ) -> list [Operation ]:
151165 """Get list of operations, but exclude child operations where the parent has already completed."""
@@ -205,17 +219,16 @@ def complete_wait(self, operation_id: str) -> Operation:
205219 )
206220 raise IllegalStateException (msg_not_wait )
207221
208- # TODO: make thread-safe. Increment sequence
209- self .token_sequence += 1
210-
211- # Build and assign updated operation
212- self .operations [index ] = replace (
213- operation ,
214- status = OperationStatus .SUCCEEDED ,
215- end_timestamp = datetime .now (UTC ),
216- )
217-
218- return self .operations [index ]
222+ # Thread-safe increment sequence and operation update
223+ with self ._state_lock :
224+ self ._token_sequence += 1
225+ # Build and assign updated operation
226+ self .operations [index ] = replace (
227+ operation ,
228+ status = OperationStatus .SUCCEEDED ,
229+ end_timestamp = datetime .now (UTC ),
230+ )
231+ return self .operations [index ]
219232
220233 def complete_retry (self , operation_id : str ) -> Operation :
221234 """Complete STEP retry when timer fires."""
@@ -231,21 +244,21 @@ def complete_retry(self, operation_id: str) -> Operation:
231244 )
232245 raise IllegalStateException (msg_not_step )
233246
234- # TODO: make thread-safe. Increment sequence
235- self .token_sequence += 1
236-
237- # Build updated step_details with cleared next_attempt_timestamp
238- new_step_details = None
239- if operation .step_details :
240- new_step_details = replace (
241- operation .step_details , next_attempt_timestamp = None
247+ # Thread-safe increment sequence and operation update
248+ with self ._state_lock :
249+ self ._token_sequence += 1
250+ # Build updated step_details with cleared next_attempt_timestamp
251+ new_step_details = None
252+ if operation .step_details :
253+ new_step_details = replace (
254+ operation .step_details , next_attempt_timestamp = None
255+ )
256+
257+ # Build updated operation
258+ updated_operation = replace (
259+ operation , status = OperationStatus .READY , step_details = new_step_details
242260 )
243261
244- # Build updated operation
245- updated_operation = replace (
246- operation , status = OperationStatus .READY , step_details = new_step_details
247- )
248-
249- # Assign
250- self .operations [index ] = updated_operation
251- return updated_operation
262+ # Assign
263+ self .operations [index ] = updated_operation
264+ return updated_operation
0 commit comments