@@ -272,21 +272,21 @@ def create_callback(
272272 if not config :
273273 config = CallbackConfig ()
274274 operation_id : str = self ._create_step_id ()
275- self .state .track_replay (operation_id = operation_id )
276275 callback_id : str = create_callback_handler (
277276 state = self .state ,
278277 operation_identifier = OperationIdentifier (
279278 operation_id = operation_id , parent_id = self ._parent_id , name = name
280279 ),
281280 config = config ,
282281 )
283-
284- return Callback (
282+ result : Callback = Callback (
285283 callback_id = callback_id ,
286284 operation_id = operation_id ,
287285 state = self .state ,
288286 serdes = config .serdes ,
289287 )
288+ self .state .track_replay (operation_id = operation_id )
289+ return result
290290
291291 def invoke (
292292 self ,
@@ -307,8 +307,7 @@ def invoke(
307307 The result of the invoked function
308308 """
309309 operation_id = self ._create_step_id ()
310- self .state .track_replay (operation_id = operation_id )
311- return invoke_handler (
310+ result : R = invoke_handler (
312311 function_name = function_name ,
313312 payload = payload ,
314313 state = self .state ,
@@ -319,6 +318,8 @@ def invoke(
319318 ),
320319 config = config ,
321320 )
321+ self .state .track_replay (operation_id = operation_id )
322+ return result
322323
323324 def map (
324325 self ,
@@ -331,7 +332,6 @@ def map(
331332 map_name : str | None = self ._resolve_step_name (name , func )
332333
333334 operation_id = self ._create_step_id ()
334- self .state .track_replay (operation_id = operation_id )
335335 operation_identifier = OperationIdentifier (
336336 operation_id = operation_id , parent_id = self ._parent_id , name = map_name
337337 )
@@ -351,7 +351,7 @@ def map_in_child_context() -> BatchResult[R]:
351351 operation_identifier = operation_identifier ,
352352 )
353353
354- return child_handler (
354+ result : BatchResult [ R ] = child_handler (
355355 func = map_in_child_context ,
356356 state = self .state ,
357357 operation_identifier = operation_identifier ,
@@ -364,6 +364,8 @@ def map_in_child_context() -> BatchResult[R]:
364364 item_serdes = None ,
365365 ),
366366 )
367+ self .state .track_replay (operation_id = operation_id )
368+ return result
367369
368370 def parallel (
369371 self ,
@@ -374,7 +376,6 @@ def parallel(
374376 """Execute multiple callables in parallel."""
375377 # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
376378 operation_id = self ._create_step_id ()
377- self .state .track_replay (operation_id = operation_id )
378379 parallel_context = self .create_child_context (parent_id = operation_id )
379380 operation_identifier = OperationIdentifier (
380381 operation_id = operation_id , parent_id = self ._parent_id , name = name
@@ -393,7 +394,7 @@ def parallel_in_child_context() -> BatchResult[T]:
393394 operation_identifier = operation_identifier ,
394395 )
395396
396- return child_handler (
397+ result : BatchResult [ T ] = child_handler (
397398 func = parallel_in_child_context ,
398399 state = self .state ,
399400 operation_identifier = operation_identifier ,
@@ -406,6 +407,8 @@ def parallel_in_child_context() -> BatchResult[T]:
406407 item_serdes = None ,
407408 ),
408409 )
410+ self .state .track_replay (operation_id = operation_id )
411+ return result
409412
410413 def run_in_child_context (
411414 self ,
@@ -428,19 +431,20 @@ def run_in_child_context(
428431 step_name : str | None = self ._resolve_step_name (name , func )
429432 # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
430433 operation_id = self ._create_step_id ()
431- self .state .track_replay (operation_id = operation_id )
432434
433435 def callable_with_child_context ():
434436 return func (self .create_child_context (parent_id = operation_id ))
435437
436- return child_handler (
438+ result : T = child_handler (
437439 func = callable_with_child_context ,
438440 state = self .state ,
439441 operation_identifier = OperationIdentifier (
440442 operation_id = operation_id , parent_id = self ._parent_id , name = step_name
441443 ),
442444 config = config ,
443445 )
446+ self .state .track_replay (operation_id = operation_id )
447+ return result
444448
445449 def step (
446450 self ,
@@ -451,9 +455,7 @@ def step(
451455 step_name = self ._resolve_step_name (name , func )
452456 logger .debug ("Step name: %s" , step_name )
453457 operation_id = self ._create_step_id ()
454- self .state .track_replay (operation_id = operation_id )
455-
456- return step_handler (
458+ result : T = step_handler (
457459 func = func ,
458460 config = config ,
459461 state = self .state ,
@@ -464,6 +466,8 @@ def step(
464466 ),
465467 context_logger = self .logger ,
466468 )
469+ self .state .track_replay (operation_id = operation_id )
470+ return result
467471
468472 def wait (self , duration : Duration , name : str | None = None ) -> None :
469473 """Wait for a specified amount of time.
@@ -477,7 +481,6 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
477481 msg = "duration must be at least 1 second"
478482 raise ValidationError (msg )
479483 operation_id = self ._create_step_id ()
480- self .state .track_replay (operation_id = operation_id )
481484 wait_handler (
482485 seconds = seconds ,
483486 state = self .state ,
@@ -487,6 +490,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
487490 name = name ,
488491 ),
489492 )
493+ self .state .track_replay (operation_id = operation_id )
490494
491495 def wait_for_callback (
492496 self ,
@@ -529,8 +533,7 @@ def wait_for_condition(
529533 raise ValidationError (msg )
530534
531535 operation_id = self ._create_step_id ()
532- self .state .track_replay (operation_id = operation_id )
533- return wait_for_condition_handler (
536+ result : T = wait_for_condition_handler (
534537 check = check ,
535538 config = config ,
536539 state = self .state ,
@@ -541,6 +544,8 @@ def wait_for_condition(
541544 ),
542545 context_logger = self .logger ,
543546 )
547+ self .state .track_replay (operation_id = operation_id )
548+ return result
544549
545550
546551# endregion Operations
0 commit comments