Skip to content

Commit 6a15466

Browse files
author
Alex Wang
committed
fix: update track_replay logic
- Move track_replay after each operation, instead of before
1 parent 4e28a5e commit 6a15466

File tree

5 files changed

+56
-34
lines changed

5 files changed

+56
-34
lines changed

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,11 @@ def _execute_item_in_child_context(
381381
executor_context._parent_id, # noqa: SLF001
382382
name,
383383
)
384-
child_context.state.track_replay(operation_id=operation_id)
385384

386385
def run_in_child_handler():
387386
return self.execute_item(child_context, executable)
388387

389-
return child_handler(
388+
result: ResultType = child_handler(
390389
run_in_child_handler,
391390
child_context.state,
392391
operation_identifier=operation_identifier,
@@ -396,6 +395,8 @@ def run_in_child_handler():
396395
summary_generator=self.summary_generator,
397396
),
398397
)
398+
child_context.state.track_replay(operation_id=operation_id)
399+
return result
399400

400401
def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
401402
"""

src/aws_durable_execution_sdk_python/context.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -271,21 +271,21 @@ def create_callback(
271271
if not config:
272272
config = CallbackConfig()
273273
operation_id: str = self._create_step_id()
274-
self.state.track_replay(operation_id=operation_id)
275274
callback_id: str = create_callback_handler(
276275
state=self.state,
277276
operation_identifier=OperationIdentifier(
278277
operation_id=operation_id, parent_id=self._parent_id, name=name
279278
),
280279
config=config,
281280
)
282-
283-
return Callback(
281+
result: Callback = Callback(
284282
callback_id=callback_id,
285283
operation_id=operation_id,
286284
state=self.state,
287285
serdes=config.serdes,
288286
)
287+
self.state.track_replay(operation_id=operation_id)
288+
return result
289289

290290
def invoke(
291291
self,
@@ -306,8 +306,7 @@ def invoke(
306306
The result of the invoked function
307307
"""
308308
operation_id = self._create_step_id()
309-
self.state.track_replay(operation_id=operation_id)
310-
return invoke_handler(
309+
result: R = invoke_handler(
311310
function_name=function_name,
312311
payload=payload,
313312
state=self.state,
@@ -318,6 +317,8 @@ def invoke(
318317
),
319318
config=config,
320319
)
320+
self.state.track_replay(operation_id=operation_id)
321+
return result
321322

322323
def map(
323324
self,
@@ -330,7 +331,6 @@ def map(
330331
map_name: str | None = self._resolve_step_name(name, func)
331332

332333
operation_id = self._create_step_id()
333-
self.state.track_replay(operation_id=operation_id)
334334
operation_identifier = OperationIdentifier(
335335
operation_id=operation_id, parent_id=self._parent_id, name=map_name
336336
)
@@ -350,7 +350,7 @@ def map_in_child_context() -> BatchResult[R]:
350350
operation_identifier=operation_identifier,
351351
)
352352

353-
return child_handler(
353+
result: BatchResult[R] = child_handler(
354354
func=map_in_child_context,
355355
state=self.state,
356356
operation_identifier=operation_identifier,
@@ -363,6 +363,8 @@ def map_in_child_context() -> BatchResult[R]:
363363
item_serdes=None,
364364
),
365365
)
366+
self.state.track_replay(operation_id=operation_id)
367+
return result
366368

367369
def parallel(
368370
self,
@@ -373,7 +375,6 @@ def parallel(
373375
"""Execute multiple callables in parallel."""
374376
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
375377
operation_id = self._create_step_id()
376-
self.state.track_replay(operation_id=operation_id)
377378
parallel_context = self.create_child_context(parent_id=operation_id)
378379
operation_identifier = OperationIdentifier(
379380
operation_id=operation_id, parent_id=self._parent_id, name=name
@@ -392,7 +393,7 @@ def parallel_in_child_context() -> BatchResult[T]:
392393
operation_identifier=operation_identifier,
393394
)
394395

395-
return child_handler(
396+
result: BatchResult[T] = child_handler(
396397
func=parallel_in_child_context,
397398
state=self.state,
398399
operation_identifier=operation_identifier,
@@ -405,6 +406,8 @@ def parallel_in_child_context() -> BatchResult[T]:
405406
item_serdes=None,
406407
),
407408
)
409+
self.state.track_replay(operation_id=operation_id)
410+
return result
408411

409412
def run_in_child_context(
410413
self,
@@ -427,19 +430,20 @@ def run_in_child_context(
427430
step_name: str | None = self._resolve_step_name(name, func)
428431
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
429432
operation_id = self._create_step_id()
430-
self.state.track_replay(operation_id=operation_id)
431433

432434
def callable_with_child_context():
433435
return func(self.create_child_context(parent_id=operation_id))
434436

435-
return child_handler(
437+
result: T = child_handler(
436438
func=callable_with_child_context,
437439
state=self.state,
438440
operation_identifier=OperationIdentifier(
439441
operation_id=operation_id, parent_id=self._parent_id, name=step_name
440442
),
441443
config=config,
442444
)
445+
self.state.track_replay(operation_id=operation_id)
446+
return result
443447

444448
def step(
445449
self,
@@ -450,9 +454,7 @@ def step(
450454
step_name = self._resolve_step_name(name, func)
451455
logger.debug("Step name: %s", step_name)
452456
operation_id = self._create_step_id()
453-
self.state.track_replay(operation_id=operation_id)
454-
455-
return step_handler(
457+
result: T = step_handler(
456458
func=func,
457459
config=config,
458460
state=self.state,
@@ -463,6 +465,8 @@ def step(
463465
),
464466
context_logger=self.logger,
465467
)
468+
self.state.track_replay(operation_id=operation_id)
469+
return result
466470

467471
def wait(self, duration: Duration, name: str | None = None) -> None:
468472
"""Wait for a specified amount of time.
@@ -476,7 +480,6 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
476480
msg = "duration must be at least 1 second"
477481
raise ValidationError(msg)
478482
operation_id = self._create_step_id()
479-
self.state.track_replay(operation_id=operation_id)
480483
wait_handler(
481484
seconds=seconds,
482485
state=self.state,
@@ -486,6 +489,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
486489
name=name,
487490
),
488491
)
492+
self.state.track_replay(operation_id=operation_id)
489493

490494
def wait_for_callback(
491495
self,
@@ -528,8 +532,7 @@ def wait_for_condition(
528532
raise ValidationError(msg)
529533

530534
operation_id = self._create_step_id()
531-
self.state.track_replay(operation_id=operation_id)
532-
return wait_for_condition_handler(
535+
result: T = wait_for_condition_handler(
533536
check=check,
534537
config=config,
535538
state=self.state,
@@ -540,6 +543,8 @@ def wait_for_condition(
540543
),
541544
context_logger=self.logger,
542545
)
546+
self.state.track_replay(operation_id=operation_id)
547+
return result
543548

544549

545550
# endregion Operations

src/aws_durable_execution_sdk_python/state.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def __init__(
258258
self._parent_done_lock: Lock = Lock()
259259
self._replay_status: ReplayStatus = replay_status
260260
self._replay_status_lock: Lock = Lock()
261+
self._visited_operations: set[str] = set()
261262

262263
def fetch_paginated_operations(
263264
self,
@@ -301,14 +302,20 @@ def track_replay(self, operation_id: str) -> None:
301302
"""
302303
with self._replay_status_lock:
303304
if self._replay_status == ReplayStatus.REPLAY:
304-
operation = self.operations.get(operation_id)
305-
# Transition if operation doesn't exist OR isn't in a completed state
306-
if not operation or operation.status not in {
307-
OperationStatus.SUCCEEDED,
308-
OperationStatus.FAILED,
309-
OperationStatus.CANCELLED,
310-
OperationStatus.STOPPED,
311-
}:
305+
self._visited_operations.add(operation_id)
306+
completed_ops = {
307+
op_id
308+
for op_id, op in self.operations.items()
309+
if op.operation_type != OperationType.EXECUTION
310+
and op.status
311+
in {
312+
OperationStatus.SUCCEEDED,
313+
OperationStatus.FAILED,
314+
OperationStatus.CANCELLED,
315+
OperationStatus.STOPPED,
316+
}
317+
}
318+
if completed_ops.issubset(self._visited_operations):
312319
logger.debug(
313320
"Transitioning from REPLAY to NEW status at operation %s",
314321
operation_id,

tests/logger_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,22 +381,27 @@ def test_logger_replay_no_logging():
381381
log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5)
382382
mock_logger = Mock()
383383
logger = Logger.from_log_info(mock_logger, log_info)
384-
replay_execution_state.track_replay(operation_id="op1")
385384
logger.info("logging info")
385+
replay_execution_state.track_replay(operation_id="op1")
386386

387387
mock_logger.info.assert_not_called()
388388

389389

390390
def test_logger_replay_then_new_logging():
391-
operation = Operation(
391+
operation1 = Operation(
392392
operation_id="op1",
393393
operation_type=OperationType.STEP,
394394
status=OperationStatus.SUCCEEDED,
395395
)
396+
operation2 = Operation(
397+
operation_id="op2",
398+
operation_type=OperationType.STEP,
399+
status=OperationStatus.SUCCEEDED,
400+
)
396401
execution_state = ExecutionState(
397402
durable_execution_arn="arn:aws:test",
398403
initial_checkpoint_token="test_token", # noqa: S106
399-
operations={"op1": operation},
404+
operations={"op1": operation1, "op2": operation2},
400405
service_client=Mock(),
401406
replay_status=ReplayStatus.REPLAY,
402407
)

tests/state_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,21 +3246,25 @@ def test_create_checkpoint_sync_always_synchronous():
32463246

32473247

32483248
def test_state_replay_mode():
3249-
operation = Operation(
3249+
operation1 = Operation(
32503250
operation_id="op1",
32513251
operation_type=OperationType.STEP,
32523252
status=OperationStatus.SUCCEEDED,
32533253
)
3254+
operation2 = Operation(
3255+
operation_id="op2",
3256+
operation_type=OperationType.STEP,
3257+
status=OperationStatus.SUCCEEDED,
3258+
)
32543259
execution_state = ExecutionState(
32553260
durable_execution_arn="arn:aws:test",
32563261
initial_checkpoint_token="test_token", # noqa: S106
3257-
operations={"op1": operation},
3262+
operations={"op1": operation1, "op2": operation2},
32583263
service_client=Mock(),
32593264
replay_status=ReplayStatus.REPLAY,
32603265
)
3261-
3266+
assert execution_state.is_replaying() is True
32623267
execution_state.track_replay(operation_id="op1")
32633268
assert execution_state.is_replaying() is True
3264-
32653269
execution_state.track_replay(operation_id="op2")
32663270
assert execution_state.is_replaying() is False

0 commit comments

Comments
 (0)