Skip to content

Commit f01f6eb

Browse files
FullyTypedAstraea Quinn S
authored andcommitted
fix: make map/parallel child exec order invariant
Context: - Currently the operation_id for children in map and parallel is contingent on execution order due to internal method that always increments a counter for safety. - In map and parallel, depending on when the executor decides to run a particular branch, we may observe different ordering. Changes: - we skip using executor context's run_in_child_context to run a concurrent branch, and instead call child_handler directly. - we generate the appropriate operation_id before we call child_handler and use executable.index to make it deterministic.
1 parent 48a1c41 commit f01f6eb

File tree

7 files changed

+320
-232
lines changed

7 files changed

+320
-232
lines changed

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@
1919
SuspendExecution,
2020
TimedSuspendExecution,
2121
)
22+
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2223
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
24+
from aws_durable_execution_sdk_python.operation.child import child_handler
2325
from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol
2426

2527
if TYPE_CHECKING:
2628
from collections.abc import Callable
2729

2830
from aws_durable_execution_sdk_python.config import CompletionConfig
31+
from aws_durable_execution_sdk_python.context import DurableContext
2932
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
3033
from aws_durable_execution_sdk_python.serdes import SerDes
3134
from aws_durable_execution_sdk_python.state import ExecutionState
32-
from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator
35+
from aws_durable_execution_sdk_python.types import SummaryGenerator
3336

3437

3538
logger = logging.getLogger(__name__)
@@ -615,12 +618,7 @@ def execute_item(
615618
raise NotImplementedError
616619

617620
def execute(
618-
self,
619-
execution_state: ExecutionState,
620-
run_in_child_context: Callable[
621-
[Callable[[DurableContext], ResultType], str | None, ChildConfig | None],
622-
ResultType,
623-
],
621+
self, execution_state: ExecutionState, executor_context: DurableContext
624622
) -> BatchResult[ResultType]:
625623
"""Execute items concurrently with event-driven state management."""
626624
logger.debug(
@@ -649,7 +647,7 @@ def submit_task(executable_with_state: ExecutableWithState) -> None:
649647
"""Submit task to the thread executor and mark its state as started."""
650648
future = thread_executor.submit(
651649
self._execute_item_in_child_context,
652-
run_in_child_context,
650+
executor_context,
653651
executable_with_state.executable,
654652
)
655653
executable_with_state.run(future)
@@ -784,21 +782,42 @@ def _create_result(self) -> BatchResult[ResultType]:
784782

785783
def _execute_item_in_child_context(
786784
self,
787-
run_in_child_context: Callable[
788-
[Callable[[DurableContext], ResultType], str | None, ChildConfig | None],
789-
ResultType,
790-
],
785+
executor_context: DurableContext,
791786
executable: Executable[CallableType],
792787
) -> ResultType:
793-
"""Execute a single item in a child context."""
788+
"""
789+
Execute a single item in a derived child context.
790+
791+
instead of relying on `executor_context.run_in_child_context`
792+
we generate an operation_id for the child, and then call `child_handler`
793+
directly. This avoids the hidden mutation of the context's internal counter.
794+
we can do this because we explicitly control the generation of step_id and do it
795+
using executable.index.
796+
797+
798+
invariant: `operation_id` for a given executable is deterministic,
799+
and execution order invariant.
800+
"""
801+
802+
operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001
803+
executable.index
804+
)
805+
name = f"{self.name_prefix}{executable.index}"
806+
child_context = executor_context.create_child_context(operation_id)
807+
operation_identifier = OperationIdentifier(
808+
operation_id,
809+
executor_context._parent_id, # noqa: SLF001
810+
name,
811+
)
794812

795-
def execute_in_child_context(child_context: DurableContext) -> ResultType:
813+
def run_in_child_handler():
796814
return self.execute_item(child_context, executable)
797815

798-
return run_in_child_context(
799-
execute_in_child_context,
800-
f"{self.name_prefix}{executable.index}",
801-
ChildConfig(
816+
return child_handler(
817+
run_in_child_handler,
818+
child_context.state,
819+
operation_identifier=operation_identifier,
820+
config=ChildConfig(
802821
serdes=self.item_serdes or self.serdes,
803822
sub_type=self.sub_type_iteration,
804823
summary_generator=self.summary_generator,

src/aws_durable_execution_sdk_python/context.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,17 +222,23 @@ def set_logger(self, new_logger: LoggerInterface):
222222
info=self._log_info,
223223
)
224224

225+
def _create_step_id_for_logical_step(self, step: int) -> str:
226+
"""
227+
Generate a step_id based on the given logical step.
228+
This allows us to recover operation ids or even look
229+
forward without changing the internal state of this context.
230+
"""
231+
step_id = f"{self._parent_id}-{step}" if self._parent_id else str(step)
232+
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]
233+
225234
def _create_step_id(self) -> str:
226235
"""Generate a thread-safe step id, incrementing in order of invocation.
227236
228237
This method is an internal implementation detail. Do not rely the exact format of
229238
the id generated by this method. It is subject to change without notice.
230239
"""
231240
new_counter: int = self._step_counter.increment()
232-
step_id = (
233-
f"{self._parent_id}-{new_counter}" if self._parent_id else str(new_counter)
234-
)
235-
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]
241+
return self._create_step_id_for_logical_step(new_counter)
236242

237243
# region Operations
238244

@@ -311,13 +317,17 @@ def map(
311317
"""Execute a callable for each item in parallel."""
312318
map_name: str | None = self._resolve_step_name(name, func)
313319

314-
def map_in_child_context(child_context) -> BatchResult[R]:
320+
def map_in_child_context(map_context) -> BatchResult[R]:
321+
# map_context is a child_context of the context upon which `.map`
322+
# was called. We are calling it `map_context` to make it explicit
323+
# that any operations happening from hereon are done on the context
324+
# that owns the branches
315325
return map_handler(
316326
items=inputs,
317327
func=func,
318328
config=config,
319329
execution_state=self.state,
320-
run_in_child_context=child_context.run_in_child_context,
330+
map_context=map_context,
321331
)
322332

323333
return self.run_in_child_context(
@@ -337,12 +347,16 @@ def parallel(
337347
) -> BatchResult[T]:
338348
"""Execute multiple callables in parallel."""
339349

340-
def parallel_in_child_context(child_context) -> BatchResult[T]:
350+
def parallel_in_child_context(parallel_context) -> BatchResult[T]:
351+
# parallel_context is a child_context of the context upon which `.map`
352+
# was called. We are calling it `parallel_context` to make it explicit
353+
# that any operations happening from hereon are done on the context
354+
# that owns the branches
341355
return parallel_handler(
342356
callables=functions,
343357
config=config,
344358
execution_state=self.state,
345-
run_in_child_context=child_context.run_in_child_context,
359+
parallel_context=parallel_context,
346360
)
347361

348362
return self.run_in_child_context(

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
1717

1818
if TYPE_CHECKING:
19-
from aws_durable_execution_sdk_python.config import ChildConfig
19+
from aws_durable_execution_sdk_python.context import DurableContext
2020
from aws_durable_execution_sdk_python.serdes import SerDes
2121
from aws_durable_execution_sdk_python.state import ExecutionState
22-
from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator
22+
from aws_durable_execution_sdk_python.types import SummaryGenerator
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -93,9 +93,7 @@ def map_handler(
9393
func: Callable,
9494
config: MapConfig | None,
9595
execution_state: ExecutionState,
96-
run_in_child_context: Callable[
97-
[Callable[[DurableContext], R], str | None, ChildConfig | None], R
98-
],
96+
map_context: DurableContext,
9997
) -> BatchResult[R]:
10098
"""Execute a callable for each item in parallel."""
10199
# Summary Generator Construction (matches TypeScript implementation):
@@ -109,7 +107,8 @@ def map_handler(
109107
func=func,
110108
config=config or MapConfig(summary_generator=MapSummaryGenerator()),
111109
)
112-
return executor.execute(execution_state, run_in_child_context)
110+
# we are making it explicit that we are now executing within the map_context
111+
return executor.execute(execution_state, executor_context=map_context)
113112

114113

115114
class MapSummaryGenerator:

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313

1414
if TYPE_CHECKING:
1515
from aws_durable_execution_sdk_python.concurrency import BatchResult
16-
from aws_durable_execution_sdk_python.config import ChildConfig
16+
from aws_durable_execution_sdk_python.context import DurableContext
1717
from aws_durable_execution_sdk_python.serdes import SerDes
1818
from aws_durable_execution_sdk_python.state import ExecutionState
19-
from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator
19+
from aws_durable_execution_sdk_python.types import SummaryGenerator
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -81,9 +81,7 @@ def parallel_handler(
8181
callables: Sequence[Callable],
8282
config: ParallelConfig | None,
8383
execution_state: ExecutionState,
84-
run_in_child_context: Callable[
85-
[Callable[[DurableContext], R], str | None, ChildConfig | None], R
86-
],
84+
parallel_context: DurableContext,
8785
) -> BatchResult[R]:
8886
"""Execute multiple operations in parallel."""
8987
# Summary Generator Construction (matches TypeScript implementation):
@@ -96,7 +94,7 @@ def parallel_handler(
9694
callables,
9795
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
9896
)
99-
return executor.execute(execution_state, run_in_child_context)
97+
return executor.execute(execution_state, executor_context=parallel_context)
10098

10199

102100
class ParallelSummaryGenerator:

0 commit comments

Comments
 (0)