@@ -271,21 +271,21 @@ def create_callback(
271271 if not config :
272272 config = CallbackConfig ()
273273 operation_id : str = self ._create_step_id ()
274- self .state .track_replay (operation_id = operation_id )
275274 callback_id : str = create_callback_handler (
276275 state = self .state ,
277276 operation_identifier = OperationIdentifier (
278277 operation_id = operation_id , parent_id = self ._parent_id , name = name
279278 ),
280279 config = config ,
281280 )
282-
283- return Callback (
281+ result : Callback = Callback (
284282 callback_id = callback_id ,
285283 operation_id = operation_id ,
286284 state = self .state ,
287285 serdes = config .serdes ,
288286 )
287+ self .state .track_replay (operation_id = operation_id )
288+ return result
289289
290290 def invoke (
291291 self ,
@@ -306,8 +306,7 @@ def invoke(
306306 The result of the invoked function
307307 """
308308 operation_id = self ._create_step_id ()
309- self .state .track_replay (operation_id = operation_id )
310- return invoke_handler (
309+ result : R = invoke_handler (
311310 function_name = function_name ,
312311 payload = payload ,
313312 state = self .state ,
@@ -318,6 +317,8 @@ def invoke(
318317 ),
319318 config = config ,
320319 )
320+ self .state .track_replay (operation_id = operation_id )
321+ return result
321322
322323 def map (
323324 self ,
@@ -330,7 +331,6 @@ def map(
330331 map_name : str | None = self ._resolve_step_name (name , func )
331332
332333 operation_id = self ._create_step_id ()
333- self .state .track_replay (operation_id = operation_id )
334334 operation_identifier = OperationIdentifier (
335335 operation_id = operation_id , parent_id = self ._parent_id , name = map_name
336336 )
@@ -350,7 +350,7 @@ def map_in_child_context() -> BatchResult[R]:
350350 operation_identifier = operation_identifier ,
351351 )
352352
353- return child_handler (
353+ result : BatchResult [ R ] = child_handler (
354354 func = map_in_child_context ,
355355 state = self .state ,
356356 operation_identifier = operation_identifier ,
@@ -363,6 +363,8 @@ def map_in_child_context() -> BatchResult[R]:
363363 item_serdes = None ,
364364 ),
365365 )
366+ self .state .track_replay (operation_id = operation_id )
367+ return result
366368
367369 def parallel (
368370 self ,
@@ -373,7 +375,6 @@ def parallel(
373375 """Execute multiple callables in parallel."""
374376 # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
375377 operation_id = self ._create_step_id ()
376- self .state .track_replay (operation_id = operation_id )
377378 parallel_context = self .create_child_context (parent_id = operation_id )
378379 operation_identifier = OperationIdentifier (
379380 operation_id = operation_id , parent_id = self ._parent_id , name = name
@@ -392,7 +393,7 @@ def parallel_in_child_context() -> BatchResult[T]:
392393 operation_identifier = operation_identifier ,
393394 )
394395
395- return child_handler (
396+ result : BatchResult [ T ] = child_handler (
396397 func = parallel_in_child_context ,
397398 state = self .state ,
398399 operation_identifier = operation_identifier ,
@@ -405,6 +406,8 @@ def parallel_in_child_context() -> BatchResult[T]:
405406 item_serdes = None ,
406407 ),
407408 )
409+ self .state .track_replay (operation_id = operation_id )
410+ return result
408411
409412 def run_in_child_context (
410413 self ,
@@ -427,19 +430,20 @@ def run_in_child_context(
427430 step_name : str | None = self ._resolve_step_name (name , func )
428431 # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
429432 operation_id = self ._create_step_id ()
430- self .state .track_replay (operation_id = operation_id )
431433
432434 def callable_with_child_context ():
433435 return func (self .create_child_context (parent_id = operation_id ))
434436
435- return child_handler (
437+ result : T = child_handler (
436438 func = callable_with_child_context ,
437439 state = self .state ,
438440 operation_identifier = OperationIdentifier (
439441 operation_id = operation_id , parent_id = self ._parent_id , name = step_name
440442 ),
441443 config = config ,
442444 )
445+ self .state .track_replay (operation_id = operation_id )
446+ return result
443447
444448 def step (
445449 self ,
@@ -450,9 +454,7 @@ def step(
450454 step_name = self ._resolve_step_name (name , func )
451455 logger .debug ("Step name: %s" , step_name )
452456 operation_id = self ._create_step_id ()
453- self .state .track_replay (operation_id = operation_id )
454-
455- return step_handler (
457+ result : T = step_handler (
456458 func = func ,
457459 config = config ,
458460 state = self .state ,
@@ -463,6 +465,8 @@ def step(
463465 ),
464466 context_logger = self .logger ,
465467 )
468+ self .state .track_replay (operation_id = operation_id )
469+ return result
466470
467471 def wait (self , duration : Duration , name : str | None = None ) -> None :
468472 """Wait for a specified amount of time.
@@ -476,7 +480,6 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
476480 msg = "duration must be at least 1 second"
477481 raise ValidationError (msg )
478482 operation_id = self ._create_step_id ()
479- self .state .track_replay (operation_id = operation_id )
480483 wait_handler (
481484 seconds = seconds ,
482485 state = self .state ,
@@ -486,6 +489,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
486489 name = name ,
487490 ),
488491 )
492+ self .state .track_replay (operation_id = operation_id )
489493
490494 def wait_for_callback (
491495 self ,
@@ -528,8 +532,7 @@ def wait_for_condition(
528532 raise ValidationError (msg )
529533
530534 operation_id = self ._create_step_id ()
531- self .state .track_replay (operation_id = operation_id )
532- return wait_for_condition_handler (
535+ result : T = wait_for_condition_handler (
533536 check = check ,
534537 config = config ,
535538 state = self .state ,
@@ -540,6 +543,8 @@ def wait_for_condition(
540543 ),
541544 context_logger = self .logger ,
542545 )
546+ self .state .track_replay (operation_id = operation_id )
547+ return result
543548
544549
545550# endregion Operations
0 commit comments