Skip to content

Commit 341336c

Browse files
Alex Wangyaythomas
authored andcommitted
fix: update track_replay logic
- Move track_replay after each operation, instead of before
1 parent 3338948 commit 341336c

File tree

6 files changed

+57
-35
lines changed

6 files changed

+57
-35
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
python-version: ${{ matrix.python-version }}
4343

4444
- name: Install Hatch
45-
run: python -m pip install --upgrade hatch
45+
run: python -m pip install hatch==1.15.0
4646

4747
- name: Setup and run Testing SDK
4848
working-directory: testing-sdk

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,11 @@ def _execute_item_in_child_context(
381381
executor_context._parent_id, # noqa: SLF001
382382
name,
383383
)
384-
child_context.state.track_replay(operation_id=operation_id)
385384

386385
def run_in_child_handler():
387386
return self.execute_item(child_context, executable)
388387

389-
return child_handler(
388+
result: ResultType = child_handler(
390389
run_in_child_handler,
391390
child_context.state,
392391
operation_identifier=operation_identifier,
@@ -396,6 +395,8 @@ def run_in_child_handler():
396395
summary_generator=self.summary_generator,
397396
),
398397
)
398+
child_context.state.track_replay(operation_id=operation_id)
399+
return result
399400

400401
def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
401402
"""

src/aws_durable_execution_sdk_python/context.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,21 @@ def create_callback(
272272
if not config:
273273
config = CallbackConfig()
274274
operation_id: str = self._create_step_id()
275-
self.state.track_replay(operation_id=operation_id)
276275
callback_id: str = create_callback_handler(
277276
state=self.state,
278277
operation_identifier=OperationIdentifier(
279278
operation_id=operation_id, parent_id=self._parent_id, name=name
280279
),
281280
config=config,
282281
)
283-
284-
return Callback(
282+
result: Callback = Callback(
285283
callback_id=callback_id,
286284
operation_id=operation_id,
287285
state=self.state,
288286
serdes=config.serdes,
289287
)
288+
self.state.track_replay(operation_id=operation_id)
289+
return result
290290

291291
def invoke(
292292
self,
@@ -307,8 +307,7 @@ def invoke(
307307
The result of the invoked function
308308
"""
309309
operation_id = self._create_step_id()
310-
self.state.track_replay(operation_id=operation_id)
311-
return invoke_handler(
310+
result: R = invoke_handler(
312311
function_name=function_name,
313312
payload=payload,
314313
state=self.state,
@@ -319,6 +318,8 @@ def invoke(
319318
),
320319
config=config,
321320
)
321+
self.state.track_replay(operation_id=operation_id)
322+
return result
322323

323324
def map(
324325
self,
@@ -331,7 +332,6 @@ def map(
331332
map_name: str | None = self._resolve_step_name(name, func)
332333

333334
operation_id = self._create_step_id()
334-
self.state.track_replay(operation_id=operation_id)
335335
operation_identifier = OperationIdentifier(
336336
operation_id=operation_id, parent_id=self._parent_id, name=map_name
337337
)
@@ -351,7 +351,7 @@ def map_in_child_context() -> BatchResult[R]:
351351
operation_identifier=operation_identifier,
352352
)
353353

354-
return child_handler(
354+
result: BatchResult[R] = child_handler(
355355
func=map_in_child_context,
356356
state=self.state,
357357
operation_identifier=operation_identifier,
@@ -364,6 +364,8 @@ def map_in_child_context() -> BatchResult[R]:
364364
item_serdes=None,
365365
),
366366
)
367+
self.state.track_replay(operation_id=operation_id)
368+
return result
367369

368370
def parallel(
369371
self,
@@ -374,7 +376,6 @@ def parallel(
374376
"""Execute multiple callables in parallel."""
375377
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
376378
operation_id = self._create_step_id()
377-
self.state.track_replay(operation_id=operation_id)
378379
parallel_context = self.create_child_context(parent_id=operation_id)
379380
operation_identifier = OperationIdentifier(
380381
operation_id=operation_id, parent_id=self._parent_id, name=name
@@ -393,7 +394,7 @@ def parallel_in_child_context() -> BatchResult[T]:
393394
operation_identifier=operation_identifier,
394395
)
395396

396-
return child_handler(
397+
result: BatchResult[T] = child_handler(
397398
func=parallel_in_child_context,
398399
state=self.state,
399400
operation_identifier=operation_identifier,
@@ -406,6 +407,8 @@ def parallel_in_child_context() -> BatchResult[T]:
406407
item_serdes=None,
407408
),
408409
)
410+
self.state.track_replay(operation_id=operation_id)
411+
return result
409412

410413
def run_in_child_context(
411414
self,
@@ -428,19 +431,20 @@ def run_in_child_context(
428431
step_name: str | None = self._resolve_step_name(name, func)
429432
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
430433
operation_id = self._create_step_id()
431-
self.state.track_replay(operation_id=operation_id)
432434

433435
def callable_with_child_context():
434436
return func(self.create_child_context(parent_id=operation_id))
435437

436-
return child_handler(
438+
result: T = child_handler(
437439
func=callable_with_child_context,
438440
state=self.state,
439441
operation_identifier=OperationIdentifier(
440442
operation_id=operation_id, parent_id=self._parent_id, name=step_name
441443
),
442444
config=config,
443445
)
446+
self.state.track_replay(operation_id=operation_id)
447+
return result
444448

445449
def step(
446450
self,
@@ -451,9 +455,7 @@ def step(
451455
step_name = self._resolve_step_name(name, func)
452456
logger.debug("Step name: %s", step_name)
453457
operation_id = self._create_step_id()
454-
self.state.track_replay(operation_id=operation_id)
455-
456-
return step_handler(
458+
result: T = step_handler(
457459
func=func,
458460
config=config,
459461
state=self.state,
@@ -464,6 +466,8 @@ def step(
464466
),
465467
context_logger=self.logger,
466468
)
469+
self.state.track_replay(operation_id=operation_id)
470+
return result
467471

468472
def wait(self, duration: Duration, name: str | None = None) -> None:
469473
"""Wait for a specified amount of time.
@@ -477,7 +481,6 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
477481
msg = "duration must be at least 1 second"
478482
raise ValidationError(msg)
479483
operation_id = self._create_step_id()
480-
self.state.track_replay(operation_id=operation_id)
481484
wait_handler(
482485
seconds=seconds,
483486
state=self.state,
@@ -487,6 +490,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
487490
name=name,
488491
),
489492
)
493+
self.state.track_replay(operation_id=operation_id)
490494

491495
def wait_for_callback(
492496
self,
@@ -529,8 +533,7 @@ def wait_for_condition(
529533
raise ValidationError(msg)
530534

531535
operation_id = self._create_step_id()
532-
self.state.track_replay(operation_id=operation_id)
533-
return wait_for_condition_handler(
536+
result: T = wait_for_condition_handler(
534537
check=check,
535538
config=config,
536539
state=self.state,
@@ -541,6 +544,8 @@ def wait_for_condition(
541544
),
542545
context_logger=self.logger,
543546
)
547+
self.state.track_replay(operation_id=operation_id)
548+
return result
544549

545550

546551
# endregion Operations

src/aws_durable_execution_sdk_python/state.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def __init__(
258258
self._parent_done_lock: Lock = Lock()
259259
self._replay_status: ReplayStatus = replay_status
260260
self._replay_status_lock: Lock = Lock()
261+
self._visited_operations: set[str] = set()
261262

262263
def fetch_paginated_operations(
263264
self,
@@ -301,14 +302,20 @@ def track_replay(self, operation_id: str) -> None:
301302
"""
302303
with self._replay_status_lock:
303304
if self._replay_status == ReplayStatus.REPLAY:
304-
operation = self.operations.get(operation_id)
305-
# Transition if operation doesn't exist OR isn't in a completed state
306-
if not operation or operation.status not in {
307-
OperationStatus.SUCCEEDED,
308-
OperationStatus.FAILED,
309-
OperationStatus.CANCELLED,
310-
OperationStatus.STOPPED,
311-
}:
305+
self._visited_operations.add(operation_id)
306+
completed_ops = {
307+
op_id
308+
for op_id, op in self.operations.items()
309+
if op.operation_type != OperationType.EXECUTION
310+
and op.status
311+
in {
312+
OperationStatus.SUCCEEDED,
313+
OperationStatus.FAILED,
314+
OperationStatus.CANCELLED,
315+
OperationStatus.STOPPED,
316+
}
317+
}
318+
if completed_ops.issubset(self._visited_operations):
312319
logger.debug(
313320
"Transitioning from REPLAY to NEW status at operation %s",
314321
operation_id,

tests/logger_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,22 +381,27 @@ def test_logger_replay_no_logging():
381381
log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5)
382382
mock_logger = Mock()
383383
logger = Logger.from_log_info(mock_logger, log_info)
384-
replay_execution_state.track_replay(operation_id="op1")
385384
logger.info("logging info")
385+
replay_execution_state.track_replay(operation_id="op1")
386386

387387
mock_logger.info.assert_not_called()
388388

389389

390390
def test_logger_replay_then_new_logging():
391-
operation = Operation(
391+
operation1 = Operation(
392392
operation_id="op1",
393393
operation_type=OperationType.STEP,
394394
status=OperationStatus.SUCCEEDED,
395395
)
396+
operation2 = Operation(
397+
operation_id="op2",
398+
operation_type=OperationType.STEP,
399+
status=OperationStatus.SUCCEEDED,
400+
)
396401
execution_state = ExecutionState(
397402
durable_execution_arn="arn:aws:test",
398403
initial_checkpoint_token="test_token", # noqa: S106
399-
operations={"op1": operation},
404+
operations={"op1": operation1, "op2": operation2},
400405
service_client=Mock(),
401406
replay_status=ReplayStatus.REPLAY,
402407
)

tests/state_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,21 +3246,25 @@ def test_create_checkpoint_sync_always_synchronous():
32463246

32473247

32483248
def test_state_replay_mode():
3249-
operation = Operation(
3249+
operation1 = Operation(
32503250
operation_id="op1",
32513251
operation_type=OperationType.STEP,
32523252
status=OperationStatus.SUCCEEDED,
32533253
)
3254+
operation2 = Operation(
3255+
operation_id="op2",
3256+
operation_type=OperationType.STEP,
3257+
status=OperationStatus.SUCCEEDED,
3258+
)
32543259
execution_state = ExecutionState(
32553260
durable_execution_arn="arn:aws:test",
32563261
initial_checkpoint_token="test_token", # noqa: S106
3257-
operations={"op1": operation},
3262+
operations={"op1": operation1, "op2": operation2},
32583263
service_client=Mock(),
32593264
replay_status=ReplayStatus.REPLAY,
32603265
)
3261-
3266+
assert execution_state.is_replaying() is True
32623267
execution_state.track_replay(operation_id="op1")
32633268
assert execution_state.is_replaying() is True
3264-
32653269
execution_state.track_replay(operation_id="op2")
32663270
assert execution_state.is_replaying() is False

0 commit comments

Comments
 (0)