diff --git a/.gitignore b/.gitignore index 7d2b20a..bfc52e9 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,6 @@ dist/ .kiro/ /examples/build/* -/examples/*.zip \ No newline at end of file +/examples/*.zip + +.env \ No newline at end of file diff --git a/examples/examples-catalog.json b/examples/examples-catalog.json index e80e3ba..fb9ab78 100644 --- a/examples/examples-catalog.json +++ b/examples/examples-catalog.json @@ -580,6 +580,28 @@ "ApplicationLogLevel": "DEBUG", "LogFormat": "JSON" } - } + }, + { + "name": "Map with Item Namer", + "description": "Map operation with custom item_namer for iteration naming", + "handler": "map_with_item_namer.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/map/map_with_item_namer.py" + }, + { + "name": "Parallel with Named Branches", + "description": "Parallel operation with named branches using ParallelBranch", + "handler": "parallel_with_named_branches.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/parallel/parallel_with_named_branches.py" + } ] } diff --git a/examples/src/map/map_with_item_namer.py b/examples/src/map/map_with_item_namer.py new file mode 100644 index 0000000..331faab --- /dev/null +++ b/examples/src/map/map_with_item_namer.py @@ -0,0 +1,30 @@ +"""Example demonstrating map operations with custom iteration naming.""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import MapConfig +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> list[str]: + """Process orders using context.map() with custom iteration names.""" + orders = [ + {"id": "order-101", "amount": 25}, + {"id": "order-102", "amount": 50}, + {"id": "order-103", "amount": 75}, + ] + + return context.map( + inputs=orders, + func=lambda ctx, order, index, _: ctx.step( + lambda _: f"processed-{order['id']}-${order['amount']}", + name=f"process_{order['id']}", + ), + name="process_orders", + config=MapConfig( + max_concurrency=2, + item_namer=lambda order, index: f"order-{order['id']}", + ), + ).get_results() diff --git a/examples/src/parallel/parallel_with_named_branches.py b/examples/src/parallel/parallel_with_named_branches.py new file mode 100644 index 0000000..78340a1 --- /dev/null +++ b/examples/src/parallel/parallel_with_named_branches.py @@ -0,0 +1,35 @@ +"""Example demonstrating parallel operations with named branches.""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import ParallelBranch, ParallelConfig +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> list[str]: + """Execute named parallel branches using ParallelBranch.""" + + return context.parallel( + functions=[ + ParallelBranch( + func=lambda ctx: ctx.step( + lambda _: "user-data-loaded", name="load_user" + ), + name="fetch-user-data", + ), + ParallelBranch( + func=lambda ctx: ctx.step( + lambda _: "orders-loaded", name="load_orders" + ), + name="fetch-order-history", + ), + ParallelBranch( + func=lambda ctx: ctx.step(lambda _: "prefs-loaded", name="load_prefs"), + name="fetch-preferences", + ), + ], + name="load_all_data", + config=ParallelConfig(max_concurrency=3), + ).get_results() diff --git a/examples/template.yaml b/examples/template.yaml index 0a9dcb9..2854e72 100644 --- a/examples/template.yaml +++ b/examples/template.yaml @@ -941,6 +941,42 @@ "ExecutionTimeout": 300 } } + }, + "MapWithItemNamer": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "map_with_item_namer.handler", + "Description": "Map operation with custom item_namer for iteration naming", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } + }, + "ParallelWithNamedBranches": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "parallel_with_named_branches.handler", + "Description": "Parallel operation with named branches using ParallelBranch", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } } } } \ No newline at end of file diff --git a/examples/test/map/test_map_with_item_namer.py b/examples/test/map/test_map_with_item_namer.py new file mode 100644 index 0000000..11997d8 --- /dev/null +++ b/examples/test/map/test_map_with_item_namer.py @@ -0,0 +1,39 @@ +"""Tests for map_with_item_namer example.""" + +import pytest +from src.map import map_with_item_namer +from test.conftest import deserialize_operation_payload + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, +) + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=map_with_item_namer.handler, + lambda_function_name="map with item namer", +) +def test_map_with_item_namer(durable_runner): + """Test map example with custom item_namer for iteration naming.""" + with durable_runner: + result = durable_runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == [ + "processed-order-101-$25", + "processed-order-102-$50", + "processed-order-103-$75", + ] + + # Get the map operation + map_op = result.get_context("process_orders") + assert map_op is not None + assert map_op.status is OperationStatus.SUCCEEDED + + # Verify custom iteration names from item_namer + assert len(map_op.child_operations) == 3 + child_names = {op.name for op in map_op.child_operations} + expected_names = {"order-order-101", "order-order-102", "order-order-103"} + assert child_names == expected_names diff --git a/examples/test/parallel/test_parallel_with_named_branches.py b/examples/test/parallel/test_parallel_with_named_branches.py new file mode 100644 index 0000000..61a8ee9 --- /dev/null +++ b/examples/test/parallel/test_parallel_with_named_branches.py @@ -0,0 +1,45 @@ +"""Tests for parallel_with_named_branches example.""" + +import pytest +from src.parallel import parallel_with_named_branches +from test.conftest import deserialize_operation_payload + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, +) + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=parallel_with_named_branches.handler, + lambda_function_name="parallel with named branches", +) +def test_parallel_with_named_branches(durable_runner): + """Test parallel example with named branches using ParallelBranch.""" + with durable_runner: + result = durable_runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == [ + "user-data-loaded", + "orders-loaded", + "prefs-loaded", + ] + + # Get the parallel operation + parallel_op = result.get_context("load_all_data") + assert parallel_op is not None + assert parallel_op.status is OperationStatus.SUCCEEDED + + # Verify custom branch names from ParallelBranch + assert len(parallel_op.child_operations) == 3 + child_names = {op.name for op in parallel_op.child_operations} + expected_names = {"fetch-user-data", "fetch-order-history", "fetch-preferences"} + assert child_names == expected_names + + # Verify all children succeeded + for child in parallel_op.child_operations: + assert child.operation_type == OperationType.CONTEXT + assert child.status is OperationStatus.SUCCEEDED diff --git a/src/aws_durable_execution_sdk_python/__init__.py b/src/aws_durable_execution_sdk_python/__init__.py index 23a85cd..fc2a657 100644 --- a/src/aws_durable_execution_sdk_python/__init__.py +++ b/src/aws_durable_execution_sdk_python/__init__.py @@ -7,6 +7,7 @@ # Helper decorators - commonly used for step functions # Concurrency from aws_durable_execution_sdk_python.concurrency.models import BatchResult +from aws_durable_execution_sdk_python.config import ParallelBranch from aws_durable_execution_sdk_python.context import ( DurableContext, durable_step, @@ -27,11 +28,13 @@ # Essential context types - passed to user functions from aws_durable_execution_sdk_python.types import StepContext + __all__ = [ "BatchResult", "DurableContext", "DurableExecutionsError", "InvocationError", + "ParallelBranch", "StepContext", "ValidationError", "__version__", diff --git a/src/aws_durable_execution_sdk_python/concurrency/executor.py b/src/aws_durable_execution_sdk_python/concurrency/executor.py index 24c7657..3a7ab13 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -194,6 +194,14 @@ def execute_item( """Execute a single executable in a child context and return the result.""" raise NotImplementedError + def get_iteration_name(self, index: int) -> str: + """Get the display name for an iteration/branch at the given index. + + Subclasses can override this to provide custom naming (e.g., from item_namer + or branch names). The default returns "{name_prefix}{index}". + """ + return f"{self.name_prefix}{index}" + def execute( self, execution_state: ExecutionState, executor_context: DurableContext ) -> BatchResult[ResultType]: @@ -410,7 +418,7 @@ def _execute_item_in_child_context( operation_id: str = executor_context._create_step_id_for_logical_step( # noqa: SLF001 executable.index ) - name: str = f"{self.name_prefix}{executable.index}" + name: str = self.get_iteration_name(executable.index) is_virtual: bool = self.nesting_type is NestingType.FLAT child_context: DurableContext = executor_context.create_child_context( diff --git a/src/aws_durable_execution_sdk_python/config.py b/src/aws_durable_execution_sdk_python/config.py index 980786d..e8c0eb4 100644 --- a/src/aws_durable_execution_sdk_python/config.py +++ b/src/aws_durable_execution_sdk_python/config.py @@ -9,6 +9,7 @@ from aws_durable_execution_sdk_python.exceptions import ValidationError + P = TypeVar("P") # Payload type R = TypeVar("R") # Result type T = TypeVar("T") @@ -245,6 +246,41 @@ class ParallelConfig: nesting_type: NestingType = NestingType.NESTED +@dataclass(frozen=True) +class ParallelBranch(Generic[T]): + """A named branch for parallel execution. + + Use this to provide custom names for parallel branches, improving + observability in execution history. + + Type Parameters: + T: The return type of the branch function. + + Args: + func: The callable to execute in this branch. Receives a DurableContext. + name: Optional custom name for this branch. When provided, replaces + the default "parallel-branch-{index}" naming in execution history. + This affects observability but not replay determinism. + + Example: + context.parallel( + functions=[ + ParallelBranch(func=lambda ctx: fetch_user(ctx), name="fetch-user-data"), + ParallelBranch(func=lambda ctx: fetch_orders(ctx), name="fetch-order-history"), + ], + name="load-data", + config=ParallelConfig(max_concurrency=2), + ) + """ + + func: Callable + name: str | None = None + + def __call__(self, *args, **kwargs): + """Delegate to the wrapped function, making ParallelBranch itself callable.""" + return self.func(*args, **kwargs) + + class StepSemantics(Enum): AT_MOST_ONCE_PER_RETRY = "AT_MOST_ONCE_PER_RETRY" AT_LEAST_ONCE_PER_RETRY = "AT_LEAST_ONCE_PER_RETRY" @@ -354,12 +390,15 @@ class ItemBatcher(Generic[T]): @dataclass(frozen=True) -class MapConfig: +class MapConfig(Generic[T]): """Configuration options for map operations over collections. This class configures how map operations process collections of items, including concurrency, batching, completion criteria, and serialization. + Type Parameters: + T: The type of items being processed in the map operation. + Args: max_concurrency: Maximum number of items to process concurrently. If None, no limit is imposed and all items are processed concurrently. @@ -402,6 +441,12 @@ class MapConfig: - NESTED: Each item runs in its own isolated context (default) - FLAT: All items share the same parent context + item_namer: Optional callable to generate custom names for each map iteration. + When provided, replaces the default "map-item-{index}" naming scheme. + Receives the item and its index, and returns a string name for that iteration. + This affects observability (execution history names) but not replay determinism. + If None, uses the default naming: "map-item-{index}". + Example: # Process 5 items at a time, batch by count, require all to succeed config = MapConfig( @@ -409,6 +454,12 @@ class MapConfig: item_batcher=ItemBatcher(max_items_per_batch=10), completion_config=CompletionConfig.all_successful() ) + + # With custom iteration names + config = MapConfig( + max_concurrency=5, + item_namer=lambda item, index: f"process-order-{item.id}" + ) """ max_concurrency: int | None = None @@ -418,6 +469,7 @@ class MapConfig: item_serdes: SerDes | None = None summary_generator: SummaryGenerator | None = None nesting_type: NestingType = NestingType.NESTED + item_namer: Callable[[T, int], str] | None = None @dataclass(frozen=True) diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index bfded98..370e0d4 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -12,6 +12,7 @@ Duration, InvokeConfig, MapConfig, + ParallelBranch, ParallelConfig, StepConfig, WaitForCallbackConfig, @@ -55,6 +56,7 @@ WaitForConditionCheckContext, ) + if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -496,7 +498,7 @@ def map_in_child_context() -> BatchResult[R]: def parallel( self, - functions: Sequence[Callable[[DurableContext], T]], + functions: Sequence[Callable[[DurableContext], T] | ParallelBranch[T]], name: str | None = None, config: ParallelConfig | None = None, ) -> BatchResult[T]: diff --git a/src/aws_durable_execution_sdk_python/operation/map.py b/src/aws_durable_execution_sdk_python/operation/map.py index 8e9fb6a..f201efc 100644 --- a/src/aws_durable_execution_sdk_python/operation/map.py +++ b/src/aws_durable_execution_sdk_python/operation/map.py @@ -15,6 +15,7 @@ from aws_durable_execution_sdk_python.config import MapConfig, NestingType from aws_durable_execution_sdk_python.lambda_service import OperationSubType + if TYPE_CHECKING: from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -47,6 +48,7 @@ def __init__( summary_generator: SummaryGenerator | None = None, item_serdes: SerDes | None = None, nesting_type: NestingType = NestingType.NESTED, + item_namer: Callable[[T, int], str] | None = None, ): super().__init__( executables=executables, @@ -61,13 +63,14 @@ def __init__( nesting_type=nesting_type, ) self.items = items + self._item_namer = item_namer @classmethod def from_items( cls, items: Sequence[T], func: Callable, - config: MapConfig, + config: MapConfig[T], ) -> MapExecutor[T, R]: """Create MapExecutor from items and a callable.""" executables: list[Executable[Callable]] = [ @@ -86,8 +89,15 @@ def from_items( summary_generator=config.summary_generator, item_serdes=config.item_serdes, nesting_type=config.nesting_type, + item_namer=config.item_namer, ) + def get_iteration_name(self, index: int) -> str: + """Return custom item name if item_namer is provided, otherwise default.""" + if self._item_namer is not None: + return self._item_namer(self.items[index], index) + return super().get_iteration_name(index) + def execute_item(self, child_context, executable: Executable[Callable]) -> R: logger.debug("πŸ—ΊοΈ Processing map item: %s", executable.index) item = self.items[executable.index] diff --git a/src/aws_durable_execution_sdk_python/operation/parallel.py b/src/aws_durable_execution_sdk_python/operation/parallel.py index 4d7094a..76fc16f 100644 --- a/src/aws_durable_execution_sdk_python/operation/parallel.py +++ b/src/aws_durable_execution_sdk_python/operation/parallel.py @@ -9,9 +9,14 @@ from aws_durable_execution_sdk_python.concurrency.executor import ConcurrentExecutor from aws_durable_execution_sdk_python.concurrency.models import Executable -from aws_durable_execution_sdk_python.config import ParallelConfig, NestingType +from aws_durable_execution_sdk_python.config import ( + NestingType, + ParallelBranch, + ParallelConfig, +) from aws_durable_execution_sdk_python.lambda_service import OperationSubType + if TYPE_CHECKING: from aws_durable_execution_sdk_python.concurrency.models import BatchResult from aws_durable_execution_sdk_python.context import DurableContext @@ -56,13 +61,19 @@ def __init__( @classmethod def from_callables( cls, - callables: Sequence[Callable], + callables: Sequence[Callable | ParallelBranch], config: ParallelConfig, ) -> ParallelExecutor: - """Create ParallelExecutor from a sequence of callables.""" + """Create ParallelExecutor from a sequence of callables or ParallelBranch instances. + + Since ParallelBranch is callable, it is stored directly as the func in + each Executable. The get_iteration_name method inspects the func to + extract the branch name when available. + """ executables: list[Executable[Callable]] = [ Executable(index=i, func=func) for i, func in enumerate(callables) ] + return cls( executables=executables, max_concurrency=config.max_concurrency, @@ -76,6 +87,13 @@ def from_callables( nesting_type=config.nesting_type, ) + def get_iteration_name(self, index: int) -> str: + """Return custom branch name if the callable is a ParallelBranch with a name.""" + func = self.executables[index].func + if isinstance(func, ParallelBranch) and func.name is not None: + return func.name + return super().get_iteration_name(index) + def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301 logger.debug("πŸ”€ Processing parallel branch: %s", executable.index) result: R = executable.func(child_context) @@ -84,7 +102,7 @@ def execute_item(self, child_context, executable: Executable[Callable]) -> R: # def parallel_handler( - callables: Sequence[Callable], + callables: Sequence[Callable | ParallelBranch], config: ParallelConfig | None, execution_state: ExecutionState, parallel_context: DurableContext, diff --git a/src/aws_durable_execution_sdk_python/types.py b/src/aws_durable_execution_sdk_python/types.py index 9181be9..90080b0 100644 --- a/src/aws_durable_execution_sdk_python/types.py +++ b/src/aws_durable_execution_sdk_python/types.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar + if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence @@ -15,6 +16,7 @@ ChildConfig, Duration, MapConfig, + ParallelBranch, ParallelConfig, StepConfig, ) @@ -124,7 +126,7 @@ def map( @abstractmethod def parallel( self, - functions: Sequence[Callable[[DurableContext], T]], + functions: Sequence[Callable[[DurableContext], T] | ParallelBranch[T]], name: str | None = None, config: ParallelConfig | None = None, ) -> BatchResult[T]: diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index b3c979d..c7a653f 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -1178,3 +1178,122 @@ def map_func(ctx, item, idx, items): assert result.total_count == 0 assert result.success_count == 0 assert result.failure_count == 0 + + +# region item_namer tests + + +def test_map_executor_get_iteration_name_default(): + """Without item_namer, iterations use default 'map-item-{index}' naming.""" + items = ["a", "b", "c"] + config = MapConfig(max_concurrency=2) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "map-item-0" + assert executor.get_iteration_name(1) == "map-item-1" + assert executor.get_iteration_name(2) == "map-item-2" + + +def test_map_executor_get_iteration_name_with_item_namer(): + """With item_namer, iterations use custom names.""" + items = [{"id": "order-1"}, {"id": "order-2"}, {"id": "order-3"}] + config = MapConfig( + max_concurrency=2, + item_namer=lambda item, index: f"process-{item['id']}", + ) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "process-order-1" + assert executor.get_iteration_name(1) == "process-order-2" + assert executor.get_iteration_name(2) == "process-order-3" + + +def test_map_executor_item_namer_receives_item_and_index(): + """item_namer receives both the item and its index.""" + items = ["alpha", "beta", "gamma"] + received_args: list[tuple] = [] + + def namer(item, index): + received_args.append((item, index)) + return f"item-{index}-{item}" + + config = MapConfig(item_namer=namer) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + executor.get_iteration_name(0) + executor.get_iteration_name(2) + + assert received_args == [("alpha", 0), ("gamma", 2)] + + +def test_map_executor_item_namer_uses_index(): + """item_namer can use the index to generate names.""" + items = [10, 20, 30] + config = MapConfig(item_namer=lambda item, index: f"step-{index + 1}") + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "step-1" + assert executor.get_iteration_name(1) == "step-2" + assert executor.get_iteration_name(2) == "step-3" + + +def test_map_executor_item_namer_none_falls_back_to_default(): + """Explicitly passing item_namer=None uses default naming.""" + items = ["x", "y"] + config = MapConfig(item_namer=None) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "map-item-0" + assert executor.get_iteration_name(1) == "map-item-1" + + +def test_map_executor_from_items_passes_item_namer(): + """MapExecutor.from_items correctly passes item_namer from config.""" + namer = lambda item, index: f"custom-{index}" # noqa: E731 + config = MapConfig(item_namer=namer) + + executor = MapExecutor.from_items( + items=["a"], + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor._item_namer is namer + + +def test_map_config_generic_with_item_namer(): + """MapConfig can be parameterized with a type and use item_namer.""" + config: MapConfig[dict] = MapConfig( + item_namer=lambda item, index: f"item-{item['name']}", + ) + + assert config.item_namer is not None + assert config.item_namer({"name": "test"}, 0) == "item-test" + + +# endregion diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 1922207..cf7c736 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -20,8 +20,8 @@ ) from aws_durable_execution_sdk_python.config import ( CompletionConfig, - ParallelConfig, NestingType, + ParallelConfig, ) from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -1118,3 +1118,123 @@ def create_id(self, i): assert parent_call[1]["serdes"] is custom_serdes assert isinstance(parent_call[1]["value"], BatchResult) assert parent_call[1]["value"] is result + + +# region ParallelBranch and branch naming tests + + +def test_parallel_branch_is_callable(): + """ParallelBranch instances are callable.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda x: x * 2, name="double") + assert callable(branch) + + +def test_parallel_branch_delegates_to_func(): + """Calling ParallelBranch delegates to the wrapped func.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda x, y: x + y, name="add") + assert branch(3, 4) == 7 + + +def test_parallel_branch_passes_kwargs(): + """ParallelBranch passes keyword arguments to func.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda ctx, flag=False: flag, name="test") + assert branch("ctx", flag=True) is True + + +def test_parallel_branch_frozen(): + """ParallelBranch is immutable (frozen dataclass).""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda: None, name="test") + with pytest.raises(AttributeError): + branch.name = "changed" # type: ignore[misc] + + +def test_parallel_executor_get_iteration_name_default(): + """Plain callables use default 'parallel-branch-{index}' naming.""" + callables = [lambda ctx: "a", lambda ctx: "b", lambda ctx: "c"] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(callables, config) + + assert executor.get_iteration_name(0) == "parallel-branch-0" + assert executor.get_iteration_name(1) == "parallel-branch-1" + assert executor.get_iteration_name(2) == "parallel-branch-2" + + +def test_parallel_executor_get_iteration_name_with_named_branches(): + """ParallelBranch with name uses the custom name.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branches = [ + ParallelBranch(func=lambda ctx: "user", name="fetch-user-data"), + ParallelBranch(func=lambda ctx: "orders", name="fetch-order-history"), + ] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(branches, config) + + assert executor.get_iteration_name(0) == "fetch-user-data" + assert executor.get_iteration_name(1) == "fetch-order-history" + + +def test_parallel_executor_get_iteration_name_mixed(): + """Mix of ParallelBranch (with/without name) and plain callables.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branches = [ + ParallelBranch(func=lambda ctx: "a", name="named-branch"), + lambda ctx: "b", + ParallelBranch(func=lambda ctx: "c"), + ] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(branches, config) + + assert executor.get_iteration_name(0) == "named-branch" + assert executor.get_iteration_name(1) == "parallel-branch-1" + assert executor.get_iteration_name(2) == "parallel-branch-2" + + +def test_parallel_executor_get_iteration_name_none_name(): + """ParallelBranch with name=None falls back to default naming.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branches = [ + ParallelBranch(func=lambda ctx: "x", name=None), + ] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(branches, config) + + assert executor.get_iteration_name(0) == "parallel-branch-0" + + +def test_parallel_branch_execute_item(): + """ParallelBranch works correctly in execute_item.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda ctx: f"result-{ctx}", name="my-branch") + executable = Executable(index=0, func=branch) + + executor = ParallelExecutor( + executables=[executable], + max_concurrency=None, + completion_config=CompletionConfig.all_successful(), + top_level_sub_type=OperationSubType.PARALLEL, + iteration_sub_type=OperationSubType.PARALLEL_BRANCH, + name_prefix="parallel-branch-", + serdes=None, + ) + + result = executor.execute_item("test-ctx", executable) + assert result == "result-test-ctx" + + +# endregion