Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 45 additions & 30 deletions src/aws_durable_execution_sdk_python/concurrency/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from aws_durable_execution_sdk_python.config import ChildConfig
from aws_durable_execution_sdk_python.exceptions import (
OrphanedChildException,
SuspendExecution,
TimedSuspendExecution,
)
Expand Down Expand Up @@ -198,42 +199,47 @@ def resubmitter(executable_with_state: ExecutableWithState) -> None:
execution_state.create_checkpoint()
submit_task(executable_with_state)

with (
TimerScheduler(resubmitter) as scheduler,
ThreadPoolExecutor(max_workers=max_workers) as thread_executor,
):

def submit_task(executable_with_state: ExecutableWithState) -> Future:
"""Submit task to the thread executor and mark its state as started."""
future = thread_executor.submit(
self._execute_item_in_child_context,
executor_context,
executable_with_state.executable,
)
executable_with_state.run(future)
thread_executor = ThreadPoolExecutor(max_workers=max_workers)
try:
with TimerScheduler(resubmitter) as scheduler:

def submit_task(executable_with_state: ExecutableWithState) -> Future:
"""Submit task to the thread executor and mark its state as started."""
future = thread_executor.submit(
self._execute_item_in_child_context,
executor_context,
executable_with_state.executable,
)
executable_with_state.run(future)

def on_done(future: Future) -> None:
self._on_task_complete(executable_with_state, future, scheduler)
def on_done(future: Future) -> None:
self._on_task_complete(executable_with_state, future, scheduler)

future.add_done_callback(on_done)
return future
future.add_done_callback(on_done)
return future

# Submit initial tasks
futures = [
submit_task(exe_state) for exe_state in self.executables_with_state
]
# Submit initial tasks
futures = [
submit_task(exe_state) for exe_state in self.executables_with_state
]

# Wait for completion
self._completion_event.wait()
# Wait for completion
self._completion_event.wait()

# Cancel remaining futures so
# that we don't wait for them to join.
for future in futures:
future.cancel()
# Cancel futures that haven't started yet
for future in futures:
future.cancel()

# Suspend execution if everything done and at least one of the tasks raised a suspend exception.
if self._suspend_exception:
raise self._suspend_exception
# Suspend execution if everything done and at least one of the tasks raised a suspend exception.
if self._suspend_exception:
raise self._suspend_exception

finally:
# Shutdown without waiting for running threads for early return when
# completion criteria are met (e.g., min_successful).
# Running threads will continue in background but they raise OrphanedChildException
# on the next attempt to checkpoint.
thread_executor.shutdown(wait=False, cancel_futures=True)

# Build final result
return self._create_result()
Expand Down Expand Up @@ -291,6 +297,15 @@ def _on_task_complete(
result = future.result()
exe_state.complete(result)
self.counters.complete_task()
except OrphanedChildException:
# Parent already completed and returned.
# State is already RUNNING, which _create_result() marked as STARTED
# Just log and exit - no state change needed
logger.debug(
"Terminating orphaned branch %s without error because parent has completed already",
exe_state.index,
)
return
except TimedSuspendExecution as tse:
exe_state.suspend_with_timeout(tse.scheduled_timestamp)
scheduler.schedule_resume(exe_state, tse.scheduled_timestamp)
Expand Down
29 changes: 29 additions & 0 deletions src/aws_durable_execution_sdk_python/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,32 @@ def __str__(self) -> str:

class SerDesError(DurableExecutionsError):
"""Raised when serialization fails."""


class OrphanedChildException(BaseException):
"""Raised when a child operation attempts to checkpoint after its parent context has completed.

This exception inherits from BaseException (not Exception) so that user-space doesn't
accidentally catch it with broad exception handlers like 'except Exception'.

This exception will happen when a parallel branch or map item tries to create a checkpoint
after its parent context (i.e the parallel/map operation) has already completed due to meeting
completion criteria (e.g., min_successful reached, failure tolerance exceeded).

Although you cannot cancel running futures in user-space, this will at least terminate the
child operation on the next checkpoint attempt, preventing subsequent operations in the
child scope from executing.

Attributes:
operation_id: Operation ID of the orphaned child
"""

def __init__(self, message: str, operation_id: str):
"""Initialize OrphanedChildException.

Args:
message: Human-readable error message
operation_id: Operation ID of the orphaned child (required)
"""
super().__init__(message)
self.operation_id = operation_id
9 changes: 8 additions & 1 deletion src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BackgroundThreadError,
CallableRuntimeError,
DurableExecutionsError,
OrphanedChildException,
)
from aws_durable_execution_sdk_python.lambda_service import (
CheckpointOutput,
Expand Down Expand Up @@ -449,7 +450,13 @@ def create_checkpoint(
"Rejecting checkpoint for operation %s - parent is done",
operation_update.operation_id,
)
return
error_msg = (
"Parent context completed, child operation cannot checkpoint"
)
raise OrphanedChildException(
error_msg,
operation_id=operation_update.operation_id,
)

# Check if background checkpointing has failed
if self._checkpointing_failed.is_set():
Expand Down
195 changes: 194 additions & 1 deletion tests/concurrency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
SuspendExecution,
TimedSuspendExecution,
)
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
)
from aws_durable_execution_sdk_python.operation.map import MapExecutor


Expand Down Expand Up @@ -2838,3 +2840,194 @@ def task_func(ctx, item, idx, items):
assert (
sum(1 for item in result.all if item.status == BatchItemStatus.SUCCEEDED) < 98
)


def test_executor_exits_early_with_min_successful():
"""Test that parallel exits immediately when min_successful is reached without waiting for other branches."""

class TestExecutor(ConcurrentExecutor):
def execute_item(self, child_context, executable):
return executable.func()

execution_times = []

def fast_branch():
execution_times.append(("fast", time.time()))
return "fast_result"

def slow_branch():
execution_times.append(("slow_start", time.time()))
time.sleep(2) # Long sleep
execution_times.append(("slow_end", time.time()))
return "slow_result"

executables = [
Executable(0, fast_branch),
Executable(1, slow_branch),
]

completion_config = CompletionConfig(min_successful=1)

executor = TestExecutor(
executables=executables,
max_concurrency=2,
completion_config=completion_config,
sub_type_top="TOP",
sub_type_iteration="ITER",
name_prefix="test_",
serdes=None,
)

execution_state = Mock()
execution_state.create_checkpoint = Mock()
executor_context = Mock()
executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001
executor_context._parent_id = "parent" # noqa: SLF001

def create_child_context(op_id):
child = Mock()
child.state = execution_state
return child

executor_context.create_child_context = create_child_context

start_time = time.time()
result = executor.execute(execution_state, executor_context)
elapsed_time = time.time() - start_time

# Should complete in less than 1.5 second (not wait for 2-second sleep)
assert elapsed_time < 1.5, f"Took {elapsed_time}s, expected < 1.5s"

# Result should show MIN_SUCCESSFUL_REACHED
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED

# Fast branch should succeed
assert result.all[0].status == BatchItemStatus.SUCCEEDED
assert result.all[0].result == "fast_result"

# Slow branch should be marked as STARTED (incomplete)
assert result.all[1].status == BatchItemStatus.STARTED

# Verify counts
assert result.success_count == 1
assert result.failure_count == 0
assert result.started_count == 1
assert result.total_count == 2


def test_executor_returns_with_incomplete_branches():
"""Test that executor returns when min_successful is reached, leaving other branches incomplete."""

class TestExecutor(ConcurrentExecutor):
def execute_item(self, child_context, executable):
return executable.func()

operation_tracker = Mock()

def fast_branch():
operation_tracker.fast_executed()
return "fast_result"

def slow_branch():
operation_tracker.slow_started()
time.sleep(2) # Long sleep
operation_tracker.slow_completed()
return "slow_result"

executables = [
Executable(0, fast_branch),
Executable(1, slow_branch),
]

completion_config = CompletionConfig(min_successful=1)

executor = TestExecutor(
executables=executables,
max_concurrency=2,
completion_config=completion_config,
sub_type_top="TOP",
sub_type_iteration="ITER",
name_prefix="test_",
serdes=None,
)

execution_state = Mock()
execution_state.create_checkpoint = Mock()
executor_context = Mock()
executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001
executor_context._parent_id = "parent" # noqa: SLF001
executor_context.create_child_context = lambda op_id: Mock(state=execution_state)

result = executor.execute(execution_state, executor_context)

# Verify fast branch executed
assert operation_tracker.fast_executed.call_count == 1

# Slow branch may or may not have started (depends on thread scheduling)
# but it definitely should not have completed
assert (
operation_tracker.slow_completed.call_count == 0
), "Executor should return before slow branch completes"

# Result should show MIN_SUCCESSFUL_REACHED
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED

# Verify counts - one succeeded, one incomplete
assert result.success_count == 1
assert result.failure_count == 0
assert result.started_count == 1
assert result.total_count == 2


def test_executor_returns_before_slow_branch_completes():
"""Test that executor returns immediately when min_successful is reached, not waiting for slow branches."""

class TestExecutor(ConcurrentExecutor):
def execute_item(self, child_context, executable):
return executable.func()

slow_branch_mock = Mock()

def fast_func():
return "fast"

def slow_func():
time.sleep(3) # Sleep
slow_branch_mock.completed() # Should not be called before executor returns
return "slow"

executables = [Executable(0, fast_func), Executable(1, slow_func)]
completion_config = CompletionConfig(min_successful=1)

executor = TestExecutor(
executables=executables,
max_concurrency=2,
completion_config=completion_config,
sub_type_top="TOP",
sub_type_iteration="ITER",
name_prefix="test_",
serdes=None,
)

execution_state = Mock()
execution_state.create_checkpoint = Mock()
executor_context = Mock()
executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001
executor_context._parent_id = "parent" # noqa: SLF001
executor_context.create_child_context = lambda op_id: Mock(state=execution_state)

result = executor.execute(execution_state, executor_context)

# Executor should have returned before slow branch completed
assert (
not slow_branch_mock.completed.called
), "Executor should return before slow branch completes"

# Result should show MIN_SUCCESSFUL_REACHED
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED

# Verify counts
assert result.success_count == 1
assert result.failure_count == 0
assert result.started_count == 1
assert result.total_count == 2
Loading