diff --git a/src/aws_durable_execution_sdk_python/concurrency/executor.py b/src/aws_durable_execution_sdk_python/concurrency/executor.py index 77e8529..da1a5cd 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -22,6 +22,7 @@ ) from aws_durable_execution_sdk_python.config import ChildConfig from aws_durable_execution_sdk_python.exceptions import ( + OrphanedChildException, SuspendExecution, TimedSuspendExecution, ) @@ -198,42 +199,47 @@ def resubmitter(executable_with_state: ExecutableWithState) -> None: execution_state.create_checkpoint() submit_task(executable_with_state) - with ( - TimerScheduler(resubmitter) as scheduler, - ThreadPoolExecutor(max_workers=max_workers) as thread_executor, - ): - - def submit_task(executable_with_state: ExecutableWithState) -> Future: - """Submit task to the thread executor and mark its state as started.""" - future = thread_executor.submit( - self._execute_item_in_child_context, - executor_context, - executable_with_state.executable, - ) - executable_with_state.run(future) + thread_executor = ThreadPoolExecutor(max_workers=max_workers) + try: + with TimerScheduler(resubmitter) as scheduler: + + def submit_task(executable_with_state: ExecutableWithState) -> Future: + """Submit task to the thread executor and mark its state as started.""" + future = thread_executor.submit( + self._execute_item_in_child_context, + executor_context, + executable_with_state.executable, + ) + executable_with_state.run(future) - def on_done(future: Future) -> None: - self._on_task_complete(executable_with_state, future, scheduler) + def on_done(future: Future) -> None: + self._on_task_complete(executable_with_state, future, scheduler) - future.add_done_callback(on_done) - return future + future.add_done_callback(on_done) + return future - # Submit initial tasks - futures = [ - submit_task(exe_state) for exe_state in self.executables_with_state - ] + # Submit initial tasks + futures = [ + submit_task(exe_state) for exe_state in self.executables_with_state + ] - # Wait for completion - self._completion_event.wait() + # Wait for completion + self._completion_event.wait() - # Cancel remaining futures so - # that we don't wait for them to join. - for future in futures: - future.cancel() + # Cancel futures that haven't started yet + for future in futures: + future.cancel() - # Suspend execution if everything done and at least one of the tasks raised a suspend exception. - if self._suspend_exception: - raise self._suspend_exception + # Suspend execution if everything done and at least one of the tasks raised a suspend exception. + if self._suspend_exception: + raise self._suspend_exception + + finally: + # Shutdown without waiting for running threads for early return when + # completion criteria are met (e.g., min_successful). + # Running threads will continue in background but they raise OrphanedChildException + # on the next attempt to checkpoint. + thread_executor.shutdown(wait=False, cancel_futures=True) # Build final result return self._create_result() @@ -291,6 +297,15 @@ def _on_task_complete( result = future.result() exe_state.complete(result) self.counters.complete_task() + except OrphanedChildException: + # Parent already completed and returned. + # State is already RUNNING, which _create_result() marked as STARTED + # Just log and exit - no state change needed + logger.debug( + "Terminating orphaned branch %s without error because parent has completed already", + exe_state.index, + ) + return except TimedSuspendExecution as tse: exe_state.suspend_with_timeout(tse.scheduled_timestamp) scheduler.schedule_resume(exe_state, tse.scheduled_timestamp) diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index 9b37db6..72f0aa0 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -372,3 +372,32 @@ def __str__(self) -> str: class SerDesError(DurableExecutionsError): """Raised when serialization fails.""" + + +class OrphanedChildException(BaseException): + """Raised when a child operation attempts to checkpoint after its parent context has completed. + + This exception inherits from BaseException (not Exception) so that user-space doesn't + accidentally catch it with broad exception handlers like 'except Exception'. + + This exception will happen when a parallel branch or map item tries to create a checkpoint + after its parent context (i.e the parallel/map operation) has already completed due to meeting + completion criteria (e.g., min_successful reached, failure tolerance exceeded). + + Although you cannot cancel running futures in user-space, this will at least terminate the + child operation on the next checkpoint attempt, preventing subsequent operations in the + child scope from executing. + + Attributes: + operation_id: Operation ID of the orphaned child + """ + + def __init__(self, message: str, operation_id: str): + """Initialize OrphanedChildException. + + Args: + message: Human-readable error message + operation_id: Operation ID of the orphaned child (required) + """ + super().__init__(message) + self.operation_id = operation_id diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 5174ce6..a6fc0c7 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -16,6 +16,7 @@ BackgroundThreadError, CallableRuntimeError, DurableExecutionsError, + OrphanedChildException, ) from aws_durable_execution_sdk_python.lambda_service import ( CheckpointOutput, @@ -449,7 +450,13 @@ def create_checkpoint( "Rejecting checkpoint for operation %s - parent is done", operation_update.operation_id, ) - return + error_msg = ( + "Parent context completed, child operation cannot checkpoint" + ) + raise OrphanedChildException( + error_msg, + operation_id=operation_update.operation_id, + ) # Check if background checkpointing has failed if self._checkpointing_failed.is_set(): diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index 563c143..cb2f0ba 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -32,7 +32,9 @@ SuspendExecution, TimedSuspendExecution, ) -from aws_durable_execution_sdk_python.lambda_service import ErrorObject +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, +) from aws_durable_execution_sdk_python.operation.map import MapExecutor @@ -2838,3 +2840,194 @@ def task_func(ctx, item, idx, items): assert ( sum(1 for item in result.all if item.status == BatchItemStatus.SUCCEEDED) < 98 ) + + +def test_executor_exits_early_with_min_successful(): + """Test that parallel exits immediately when min_successful is reached without waiting for other branches.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return executable.func() + + execution_times = [] + + def fast_branch(): + execution_times.append(("fast", time.time())) + return "fast_result" + + def slow_branch(): + execution_times.append(("slow_start", time.time())) + time.sleep(2) # Long sleep + execution_times.append(("slow_end", time.time())) + return "slow_result" + + executables = [ + Executable(0, fast_branch), + Executable(1, slow_branch), + ] + + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + execution_state.create_checkpoint = Mock() + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001 + executor_context._parent_id = "parent" # noqa: SLF001 + + def create_child_context(op_id): + child = Mock() + child.state = execution_state + return child + + executor_context.create_child_context = create_child_context + + start_time = time.time() + result = executor.execute(execution_state, executor_context) + elapsed_time = time.time() - start_time + + # Should complete in less than 1.5 second (not wait for 2-second sleep) + assert elapsed_time < 1.5, f"Took {elapsed_time}s, expected < 1.5s" + + # Result should show MIN_SUCCESSFUL_REACHED + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Fast branch should succeed + assert result.all[0].status == BatchItemStatus.SUCCEEDED + assert result.all[0].result == "fast_result" + + # Slow branch should be marked as STARTED (incomplete) + assert result.all[1].status == BatchItemStatus.STARTED + + # Verify counts + assert result.success_count == 1 + assert result.failure_count == 0 + assert result.started_count == 1 + assert result.total_count == 2 + + +def test_executor_returns_with_incomplete_branches(): + """Test that executor returns when min_successful is reached, leaving other branches incomplete.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return executable.func() + + operation_tracker = Mock() + + def fast_branch(): + operation_tracker.fast_executed() + return "fast_result" + + def slow_branch(): + operation_tracker.slow_started() + time.sleep(2) # Long sleep + operation_tracker.slow_completed() + return "slow_result" + + executables = [ + Executable(0, fast_branch), + Executable(1, slow_branch), + ] + + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + execution_state.create_checkpoint = Mock() + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001 + executor_context._parent_id = "parent" # noqa: SLF001 + executor_context.create_child_context = lambda op_id: Mock(state=execution_state) + + result = executor.execute(execution_state, executor_context) + + # Verify fast branch executed + assert operation_tracker.fast_executed.call_count == 1 + + # Slow branch may or may not have started (depends on thread scheduling) + # but it definitely should not have completed + assert ( + operation_tracker.slow_completed.call_count == 0 + ), "Executor should return before slow branch completes" + + # Result should show MIN_SUCCESSFUL_REACHED + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Verify counts - one succeeded, one incomplete + assert result.success_count == 1 + assert result.failure_count == 0 + assert result.started_count == 1 + assert result.total_count == 2 + + +def test_executor_returns_before_slow_branch_completes(): + """Test that executor returns immediately when min_successful is reached, not waiting for slow branches.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return executable.func() + + slow_branch_mock = Mock() + + def fast_func(): + return "fast" + + def slow_func(): + time.sleep(3) # Sleep + slow_branch_mock.completed() # Should not be called before executor returns + return "slow" + + executables = [Executable(0, fast_func), Executable(1, slow_func)] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + execution_state.create_checkpoint = Mock() + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001 + executor_context._parent_id = "parent" # noqa: SLF001 + executor_context.create_child_context = lambda op_id: Mock(state=execution_state) + + result = executor.execute(execution_state, executor_context) + + # Executor should have returned before slow branch completed + assert ( + not slow_branch_mock.completed.called + ), "Executor should return before slow branch completes" + + # Result should show MIN_SUCCESSFUL_REACHED + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Verify counts + assert result.success_count == 1 + assert result.failure_count == 0 + assert result.started_count == 1 + assert result.total_count == 2 diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index 410022c..f3ed213 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -15,6 +15,7 @@ ExecutionError, InvocationError, OrderedLockError, + OrphanedChildException, StepInterruptedError, SuspendExecution, TerminationReason, @@ -332,3 +333,44 @@ def test_execution_error_with_custom_termination_reason(): error = ExecutionError("custom error", TerminationReason.SERIALIZATION_ERROR) assert str(error) == "custom error" assert error.termination_reason == TerminationReason.SERIALIZATION_ERROR + + +def test_orphaned_child_exception_is_base_exception(): + """Test that OrphanedChildException is a BaseException, not Exception.""" + assert issubclass(OrphanedChildException, BaseException) + assert not issubclass(OrphanedChildException, Exception) + + +def test_orphaned_child_exception_bypasses_user_exception_handler(): + """Test that OrphanedChildException cannot be caught by user's except Exception handler.""" + caught_by_exception = False + caught_by_base_exception = False + exception_instance = None + + try: + msg = "test message" + raise OrphanedChildException(msg, operation_id="test_op_123") + except Exception: # noqa: BLE001 + caught_by_exception = True + except BaseException as e: # noqa: BLE001 + caught_by_base_exception = True + exception_instance = e + + expected_msg = "OrphanedChildException should not be caught by except Exception" + assert not caught_by_exception, expected_msg + expected_base_msg = ( + "OrphanedChildException should be caught by except BaseException" + ) + assert caught_by_base_exception, expected_base_msg + + # Verify operation_id is preserved + assert isinstance(exception_instance, OrphanedChildException) + assert exception_instance.operation_id == "test_op_123" + assert str(exception_instance) == "test message" + + +def test_orphaned_child_exception_with_operation_id(): + """Test OrphanedChildException stores operation_id correctly.""" + exception = OrphanedChildException("parent completed", operation_id="child_op_456") + assert exception.operation_id == "child_op_456" + assert str(exception) == "parent completed" diff --git a/tests/state_test.py b/tests/state_test.py index 1e016d1..d997abf 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -16,6 +16,7 @@ from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, CallableRuntimeError, + OrphanedChildException, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( @@ -1091,20 +1092,18 @@ def test_rejection_of_operations_from_completed_parents(): ) state.create_checkpoint(parent_complete, is_sync=False) - # Get initial queue size - initial_queue_size = state._checkpoint_queue.qsize() - - # Try to checkpoint child operation (should be rejected) + # Try to checkpoint child operation (should raise OrphanedChildException) child_checkpoint = OperationUpdate( operation_id="child_1", operation_type=OperationType.STEP, action=OperationAction.SUCCEED, parent_id="parent_1", ) - state.create_checkpoint(child_checkpoint, is_sync=False) + with pytest.raises(OrphanedChildException) as exc_info: + state.create_checkpoint(child_checkpoint, is_sync=False) - # Verify operation was rejected (queue size unchanged) - assert state._checkpoint_queue.qsize() == initial_queue_size + # Verify exception contains operation_id + assert exc_info.value.operation_id == "child_1" def test_nested_parallel_operations_deep_hierarchy(): @@ -1474,20 +1473,18 @@ def process_sync_checkpoint(): state.create_checkpoint(parent_complete, is_sync=True) processor.join(timeout=1.0) - # Get queue size before attempting to checkpoint orphaned child - initial_queue_size = state._checkpoint_queue.qsize() - - # Try to checkpoint child (should be rejected) + # Try to checkpoint child (should raise OrphanedChildException) child_checkpoint = OperationUpdate( operation_id="child_1", operation_type=OperationType.STEP, action=OperationAction.SUCCEED, parent_id="parent_1", ) - state.create_checkpoint(child_checkpoint, is_sync=True) + with pytest.raises(OrphanedChildException) as exc_info: + state.create_checkpoint(child_checkpoint, is_sync=True) - # Verify operation was rejected (queue size unchanged) - assert state._checkpoint_queue.qsize() == initial_queue_size + # Verify exception contains operation_id + assert exc_info.value.operation_id == "child_1" def test_mark_orphans_handles_cycles():