From 93cb06302a014d80452f6ad2f6f11994f6196f8f Mon Sep 17 00:00:00 2001 From: yaythomas Date: Tue, 9 Dec 2025 12:06:29 -0800 Subject: [PATCH] feat: add early exit for concurrency - Add OrphanedChildException (BaseException) to terminate orphaned children when parent completes early - Modify ThreadPoolExecutor to shutdown without waiting (wait=False) when completion criteria met - Raise exception when orphaned children attempt to checkpoint, preventing subsequent operations from executing - Update state.py to reject orphaned child checkpoints with exception instead of silent return - Add comprehensive tests for early exit behavior and orphaned child handling When min_successful or error threshold is reached in parallel/map operations, the parent now returns immediately without waiting for remaining branches to complete. Orphaned branches are terminated on their next checkpoint attempt, preventing wasted work and ensuring correct semantics for completion criteria. --- .../concurrency/executor.py | 75 ++++--- .../exceptions.py | 29 +++ src/aws_durable_execution_sdk_python/state.py | 9 +- tests/concurrency_test.py | 195 +++++++++++++++++- tests/exceptions_test.py | 42 ++++ tests/state_test.py | 25 +-- 6 files changed, 329 insertions(+), 46 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/concurrency/executor.py b/src/aws_durable_execution_sdk_python/concurrency/executor.py index 77e8529..da1a5cd 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -22,6 +22,7 @@ ) from aws_durable_execution_sdk_python.config import ChildConfig from aws_durable_execution_sdk_python.exceptions import ( + OrphanedChildException, SuspendExecution, TimedSuspendExecution, ) @@ -198,42 +199,47 @@ def resubmitter(executable_with_state: ExecutableWithState) -> None: execution_state.create_checkpoint() submit_task(executable_with_state) - with ( - TimerScheduler(resubmitter) as scheduler, - ThreadPoolExecutor(max_workers=max_workers) as thread_executor, - ): - - def submit_task(executable_with_state: ExecutableWithState) -> Future: - """Submit task to the thread executor and mark its state as started.""" - future = thread_executor.submit( - self._execute_item_in_child_context, - executor_context, - executable_with_state.executable, - ) - executable_with_state.run(future) + thread_executor = ThreadPoolExecutor(max_workers=max_workers) + try: + with TimerScheduler(resubmitter) as scheduler: + + def submit_task(executable_with_state: ExecutableWithState) -> Future: + """Submit task to the thread executor and mark its state as started.""" + future = thread_executor.submit( + self._execute_item_in_child_context, + executor_context, + executable_with_state.executable, + ) + executable_with_state.run(future) - def on_done(future: Future) -> None: - self._on_task_complete(executable_with_state, future, scheduler) + def on_done(future: Future) -> None: + self._on_task_complete(executable_with_state, future, scheduler) - future.add_done_callback(on_done) - return future + future.add_done_callback(on_done) + return future - # Submit initial tasks - futures = [ - submit_task(exe_state) for exe_state in self.executables_with_state - ] + # Submit initial tasks + futures = [ + submit_task(exe_state) for exe_state in self.executables_with_state + ] - # Wait for completion - self._completion_event.wait() + # Wait for completion + self._completion_event.wait() - # Cancel remaining futures so - # that we don't wait for them to join. - for future in futures: - future.cancel() + # Cancel futures that haven't started yet + for future in futures: + future.cancel() - # Suspend execution if everything done and at least one of the tasks raised a suspend exception. - if self._suspend_exception: - raise self._suspend_exception + # Suspend execution if everything done and at least one of the tasks raised a suspend exception. + if self._suspend_exception: + raise self._suspend_exception + + finally: + # Shutdown without waiting for running threads for early return when + # completion criteria are met (e.g., min_successful). + # Running threads will continue in background but they raise OrphanedChildException + # on the next attempt to checkpoint. + thread_executor.shutdown(wait=False, cancel_futures=True) # Build final result return self._create_result() @@ -291,6 +297,15 @@ def _on_task_complete( result = future.result() exe_state.complete(result) self.counters.complete_task() + except OrphanedChildException: + # Parent already completed and returned. + # State is already RUNNING, which _create_result() marked as STARTED + # Just log and exit - no state change needed + logger.debug( + "Terminating orphaned branch %s without error because parent has completed already", + exe_state.index, + ) + return except TimedSuspendExecution as tse: exe_state.suspend_with_timeout(tse.scheduled_timestamp) scheduler.schedule_resume(exe_state, tse.scheduled_timestamp) diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index 9b37db6..72f0aa0 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -372,3 +372,32 @@ def __str__(self) -> str: class SerDesError(DurableExecutionsError): """Raised when serialization fails.""" + + +class OrphanedChildException(BaseException): + """Raised when a child operation attempts to checkpoint after its parent context has completed. + + This exception inherits from BaseException (not Exception) so that user-space doesn't + accidentally catch it with broad exception handlers like 'except Exception'. + + This exception will happen when a parallel branch or map item tries to create a checkpoint + after its parent context (i.e the parallel/map operation) has already completed due to meeting + completion criteria (e.g., min_successful reached, failure tolerance exceeded). + + Although you cannot cancel running futures in user-space, this will at least terminate the + child operation on the next checkpoint attempt, preventing subsequent operations in the + child scope from executing. + + Attributes: + operation_id: Operation ID of the orphaned child + """ + + def __init__(self, message: str, operation_id: str): + """Initialize OrphanedChildException. + + Args: + message: Human-readable error message + operation_id: Operation ID of the orphaned child (required) + """ + super().__init__(message) + self.operation_id = operation_id diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 5174ce6..a6fc0c7 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -16,6 +16,7 @@ BackgroundThreadError, CallableRuntimeError, DurableExecutionsError, + OrphanedChildException, ) from aws_durable_execution_sdk_python.lambda_service import ( CheckpointOutput, @@ -449,7 +450,13 @@ def create_checkpoint( "Rejecting checkpoint for operation %s - parent is done", operation_update.operation_id, ) - return + error_msg = ( + "Parent context completed, child operation cannot checkpoint" + ) + raise OrphanedChildException( + error_msg, + operation_id=operation_update.operation_id, + ) # Check if background checkpointing has failed if self._checkpointing_failed.is_set(): diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index 563c143..cb2f0ba 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -32,7 +32,9 @@ SuspendExecution, TimedSuspendExecution, ) -from aws_durable_execution_sdk_python.lambda_service import ErrorObject +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, +) from aws_durable_execution_sdk_python.operation.map import MapExecutor @@ -2838,3 +2840,194 @@ def task_func(ctx, item, idx, items): assert ( sum(1 for item in result.all if item.status == BatchItemStatus.SUCCEEDED) < 98 ) + + +def test_executor_exits_early_with_min_successful(): + """Test that parallel exits immediately when min_successful is reached without waiting for other branches.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return executable.func() + + execution_times = [] + + def fast_branch(): + execution_times.append(("fast", time.time())) + return "fast_result" + + def slow_branch(): + execution_times.append(("slow_start", time.time())) + time.sleep(2) # Long sleep + execution_times.append(("slow_end", time.time())) + return "slow_result" + + executables = [ + Executable(0, fast_branch), + Executable(1, slow_branch), + ] + + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + execution_state.create_checkpoint = Mock() + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001 + executor_context._parent_id = "parent" # noqa: SLF001 + + def create_child_context(op_id): + child = Mock() + child.state = execution_state + return child + + executor_context.create_child_context = create_child_context + + start_time = time.time() + result = executor.execute(execution_state, executor_context) + elapsed_time = time.time() - start_time + + # Should complete in less than 1.5 second (not wait for 2-second sleep) + assert elapsed_time < 1.5, f"Took {elapsed_time}s, expected < 1.5s" + + # Result should show MIN_SUCCESSFUL_REACHED + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Fast branch should succeed + assert result.all[0].status == BatchItemStatus.SUCCEEDED + assert result.all[0].result == "fast_result" + + # Slow branch should be marked as STARTED (incomplete) + assert result.all[1].status == BatchItemStatus.STARTED + + # Verify counts + assert result.success_count == 1 + assert result.failure_count == 0 + assert result.started_count == 1 + assert result.total_count == 2 + + +def test_executor_returns_with_incomplete_branches(): + """Test that executor returns when min_successful is reached, leaving other branches incomplete.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return executable.func() + + operation_tracker = Mock() + + def fast_branch(): + operation_tracker.fast_executed() + return "fast_result" + + def slow_branch(): + operation_tracker.slow_started() + time.sleep(2) # Long sleep + operation_tracker.slow_completed() + return "slow_result" + + executables = [ + Executable(0, fast_branch), + Executable(1, slow_branch), + ] + + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + execution_state.create_checkpoint = Mock() + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001 + executor_context._parent_id = "parent" # noqa: SLF001 + executor_context.create_child_context = lambda op_id: Mock(state=execution_state) + + result = executor.execute(execution_state, executor_context) + + # Verify fast branch executed + assert operation_tracker.fast_executed.call_count == 1 + + # Slow branch may or may not have started (depends on thread scheduling) + # but it definitely should not have completed + assert ( + operation_tracker.slow_completed.call_count == 0 + ), "Executor should return before slow branch completes" + + # Result should show MIN_SUCCESSFUL_REACHED + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Verify counts - one succeeded, one incomplete + assert result.success_count == 1 + assert result.failure_count == 0 + assert result.started_count == 1 + assert result.total_count == 2 + + +def test_executor_returns_before_slow_branch_completes(): + """Test that executor returns immediately when min_successful is reached, not waiting for slow branches.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return executable.func() + + slow_branch_mock = Mock() + + def fast_func(): + return "fast" + + def slow_func(): + time.sleep(3) # Sleep + slow_branch_mock.completed() # Should not be called before executor returns + return "slow" + + executables = [Executable(0, fast_func), Executable(1, slow_func)] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + execution_state.create_checkpoint = Mock() + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001 + executor_context._parent_id = "parent" # noqa: SLF001 + executor_context.create_child_context = lambda op_id: Mock(state=execution_state) + + result = executor.execute(execution_state, executor_context) + + # Executor should have returned before slow branch completed + assert ( + not slow_branch_mock.completed.called + ), "Executor should return before slow branch completes" + + # Result should show MIN_SUCCESSFUL_REACHED + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Verify counts + assert result.success_count == 1 + assert result.failure_count == 0 + assert result.started_count == 1 + assert result.total_count == 2 diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index 410022c..f3ed213 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -15,6 +15,7 @@ ExecutionError, InvocationError, OrderedLockError, + OrphanedChildException, StepInterruptedError, SuspendExecution, TerminationReason, @@ -332,3 +333,44 @@ def test_execution_error_with_custom_termination_reason(): error = ExecutionError("custom error", TerminationReason.SERIALIZATION_ERROR) assert str(error) == "custom error" assert error.termination_reason == TerminationReason.SERIALIZATION_ERROR + + +def test_orphaned_child_exception_is_base_exception(): + """Test that OrphanedChildException is a BaseException, not Exception.""" + assert issubclass(OrphanedChildException, BaseException) + assert not issubclass(OrphanedChildException, Exception) + + +def test_orphaned_child_exception_bypasses_user_exception_handler(): + """Test that OrphanedChildException cannot be caught by user's except Exception handler.""" + caught_by_exception = False + caught_by_base_exception = False + exception_instance = None + + try: + msg = "test message" + raise OrphanedChildException(msg, operation_id="test_op_123") + except Exception: # noqa: BLE001 + caught_by_exception = True + except BaseException as e: # noqa: BLE001 + caught_by_base_exception = True + exception_instance = e + + expected_msg = "OrphanedChildException should not be caught by except Exception" + assert not caught_by_exception, expected_msg + expected_base_msg = ( + "OrphanedChildException should be caught by except BaseException" + ) + assert caught_by_base_exception, expected_base_msg + + # Verify operation_id is preserved + assert isinstance(exception_instance, OrphanedChildException) + assert exception_instance.operation_id == "test_op_123" + assert str(exception_instance) == "test message" + + +def test_orphaned_child_exception_with_operation_id(): + """Test OrphanedChildException stores operation_id correctly.""" + exception = OrphanedChildException("parent completed", operation_id="child_op_456") + assert exception.operation_id == "child_op_456" + assert str(exception) == "parent completed" diff --git a/tests/state_test.py b/tests/state_test.py index 1e016d1..d997abf 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -16,6 +16,7 @@ from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, CallableRuntimeError, + OrphanedChildException, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( @@ -1091,20 +1092,18 @@ def test_rejection_of_operations_from_completed_parents(): ) state.create_checkpoint(parent_complete, is_sync=False) - # Get initial queue size - initial_queue_size = state._checkpoint_queue.qsize() - - # Try to checkpoint child operation (should be rejected) + # Try to checkpoint child operation (should raise OrphanedChildException) child_checkpoint = OperationUpdate( operation_id="child_1", operation_type=OperationType.STEP, action=OperationAction.SUCCEED, parent_id="parent_1", ) - state.create_checkpoint(child_checkpoint, is_sync=False) + with pytest.raises(OrphanedChildException) as exc_info: + state.create_checkpoint(child_checkpoint, is_sync=False) - # Verify operation was rejected (queue size unchanged) - assert state._checkpoint_queue.qsize() == initial_queue_size + # Verify exception contains operation_id + assert exc_info.value.operation_id == "child_1" def test_nested_parallel_operations_deep_hierarchy(): @@ -1474,20 +1473,18 @@ def process_sync_checkpoint(): state.create_checkpoint(parent_complete, is_sync=True) processor.join(timeout=1.0) - # Get queue size before attempting to checkpoint orphaned child - initial_queue_size = state._checkpoint_queue.qsize() - - # Try to checkpoint child (should be rejected) + # Try to checkpoint child (should raise OrphanedChildException) child_checkpoint = OperationUpdate( operation_id="child_1", operation_type=OperationType.STEP, action=OperationAction.SUCCEED, parent_id="parent_1", ) - state.create_checkpoint(child_checkpoint, is_sync=True) + with pytest.raises(OrphanedChildException) as exc_info: + state.create_checkpoint(child_checkpoint, is_sync=True) - # Verify operation was rejected (queue size unchanged) - assert state._checkpoint_queue.qsize() == initial_queue_size + # Verify exception contains operation_id + assert exc_info.value.operation_id == "child_1" def test_mark_orphans_handles_cycles():