Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Install Hatch
run: python -m pip install --upgrade hatch
run: python -m pip install hatch==1.15.0

- name: Setup and run Testing SDK
working-directory: testing-sdk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,11 @@ def _execute_item_in_child_context(
executor_context._parent_id, # noqa: SLF001
name,
)
child_context.state.track_replay(operation_id=operation_id)

def run_in_child_handler():
return self.execute_item(child_context, executable)

return child_handler(
result: ResultType = child_handler(
run_in_child_handler,
child_context.state,
operation_identifier=operation_identifier,
Expand All @@ -396,6 +395,8 @@ def run_in_child_handler():
summary_generator=self.summary_generator,
),
)
child_context.state.track_replay(operation_id=operation_id)
return result

def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
"""
Expand Down
39 changes: 22 additions & 17 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,21 @@ def create_callback(
if not config:
config = CallbackConfig()
operation_id: str = self._create_step_id()
self.state.track_replay(operation_id=operation_id)
callback_id: str = create_callback_handler(
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=name
),
config=config,
)

return Callback(
result: Callback = Callback(
callback_id=callback_id,
operation_id=operation_id,
state=self.state,
serdes=config.serdes,
)
self.state.track_replay(operation_id=operation_id)
return result

def invoke(
self,
Expand All @@ -306,8 +306,7 @@ def invoke(
The result of the invoked function
"""
operation_id = self._create_step_id()
self.state.track_replay(operation_id=operation_id)
return invoke_handler(
result: R = invoke_handler(
function_name=function_name,
payload=payload,
state=self.state,
Expand All @@ -318,6 +317,8 @@ def invoke(
),
config=config,
)
self.state.track_replay(operation_id=operation_id)
return result

def map(
self,
Expand All @@ -330,7 +331,6 @@ def map(
map_name: str | None = self._resolve_step_name(name, func)

operation_id = self._create_step_id()
self.state.track_replay(operation_id=operation_id)
operation_identifier = OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=map_name
)
Expand All @@ -350,7 +350,7 @@ def map_in_child_context() -> BatchResult[R]:
operation_identifier=operation_identifier,
)

return child_handler(
result: BatchResult[R] = child_handler(
func=map_in_child_context,
state=self.state,
operation_identifier=operation_identifier,
Expand All @@ -363,6 +363,8 @@ def map_in_child_context() -> BatchResult[R]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def parallel(
self,
Expand All @@ -373,7 +375,6 @@ def parallel(
"""Execute multiple callables in parallel."""
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()
self.state.track_replay(operation_id=operation_id)
parallel_context = self.create_child_context(parent_id=operation_id)
operation_identifier = OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=name
Expand All @@ -392,7 +393,7 @@ def parallel_in_child_context() -> BatchResult[T]:
operation_identifier=operation_identifier,
)

return child_handler(
result: BatchResult[T] = child_handler(
func=parallel_in_child_context,
state=self.state,
operation_identifier=operation_identifier,
Expand All @@ -405,6 +406,8 @@ def parallel_in_child_context() -> BatchResult[T]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def run_in_child_context(
self,
Expand All @@ -427,19 +430,20 @@ def run_in_child_context(
step_name: str | None = self._resolve_step_name(name, func)
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()
self.state.track_replay(operation_id=operation_id)

def callable_with_child_context():
return func(self.create_child_context(parent_id=operation_id))

return child_handler(
result: T = child_handler(
func=callable_with_child_context,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=step_name
),
config=config,
)
self.state.track_replay(operation_id=operation_id)
return result

def step(
self,
Expand All @@ -450,9 +454,7 @@ def step(
step_name = self._resolve_step_name(name, func)
logger.debug("Step name: %s", step_name)
operation_id = self._create_step_id()
self.state.track_replay(operation_id=operation_id)

return step_handler(
result: T = step_handler(
func=func,
config=config,
state=self.state,
Expand All @@ -463,6 +465,8 @@ def step(
),
context_logger=self.logger,
)
self.state.track_replay(operation_id=operation_id)
return result

def wait(self, duration: Duration, name: str | None = None) -> None:
"""Wait for a specified amount of time.
Expand All @@ -476,7 +480,6 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
msg = "duration must be at least 1 second"
raise ValidationError(msg)
operation_id = self._create_step_id()
self.state.track_replay(operation_id=operation_id)
wait_handler(
seconds=seconds,
state=self.state,
Expand All @@ -486,6 +489,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
name=name,
),
)
self.state.track_replay(operation_id=operation_id)

def wait_for_callback(
self,
Expand Down Expand Up @@ -528,8 +532,7 @@ def wait_for_condition(
raise ValidationError(msg)

operation_id = self._create_step_id()
self.state.track_replay(operation_id=operation_id)
return wait_for_condition_handler(
result: T = wait_for_condition_handler(
check=check,
config=config,
state=self.state,
Expand All @@ -540,6 +543,8 @@ def wait_for_condition(
),
context_logger=self.logger,
)
self.state.track_replay(operation_id=operation_id)
return result


# endregion Operations
23 changes: 15 additions & 8 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(
self._parent_done_lock: Lock = Lock()
self._replay_status: ReplayStatus = replay_status
self._replay_status_lock: Lock = Lock()
self._visited_operations: set[str] = set()

def fetch_paginated_operations(
self,
Expand Down Expand Up @@ -301,14 +302,20 @@ def track_replay(self, operation_id: str) -> None:
"""
with self._replay_status_lock:
if self._replay_status == ReplayStatus.REPLAY:
operation = self.operations.get(operation_id)
# Transition if operation doesn't exist OR isn't in a completed state
if not operation or operation.status not in {
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
}:
self._visited_operations.add(operation_id)
completed_ops = {
op_id
for op_id, op in self.operations.items()
if op.operation_type != OperationType.EXECUTION
and op.status
in {
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
}
}
if completed_ops.issubset(self._visited_operations):
logger.debug(
"Transitioning from REPLAY to NEW status at operation %s",
operation_id,
Expand Down
11 changes: 8 additions & 3 deletions tests/logger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,22 +381,27 @@ def test_logger_replay_no_logging():
log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5)
mock_logger = Mock()
logger = Logger.from_log_info(mock_logger, log_info)
replay_execution_state.track_replay(operation_id="op1")
logger.info("logging info")
replay_execution_state.track_replay(operation_id="op1")

mock_logger.info.assert_not_called()


def test_logger_replay_then_new_logging():
operation = Operation(
operation1 = Operation(
operation_id="op1",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)
operation2 = Operation(
operation_id="op2",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)
execution_state = ExecutionState(
durable_execution_arn="arn:aws:test",
initial_checkpoint_token="test_token", # noqa: S106
operations={"op1": operation},
operations={"op1": operation1, "op2": operation2},
service_client=Mock(),
replay_status=ReplayStatus.REPLAY,
)
Expand Down
12 changes: 8 additions & 4 deletions tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3246,21 +3246,25 @@ def test_create_checkpoint_sync_always_synchronous():


def test_state_replay_mode():
operation = Operation(
operation1 = Operation(
operation_id="op1",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)
operation2 = Operation(
operation_id="op2",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)
execution_state = ExecutionState(
durable_execution_arn="arn:aws:test",
initial_checkpoint_token="test_token", # noqa: S106
operations={"op1": operation},
operations={"op1": operation1, "op2": operation2},
service_client=Mock(),
replay_status=ReplayStatus.REPLAY,
)

assert execution_state.is_replaying() is True
execution_state.track_replay(operation_id="op1")
assert execution_state.is_replaying() is True

execution_state.track_replay(operation_id="op2")
assert execution_state.is_replaying() is False