Skip to content

Commit 3338948

Browse files
committed
feat: add WaitForCallbackContext to submitter
BREAKING CHANGE: wait_for_callback submitter signature changed from submitter(callback_id: str) to submitter(callback_id: str, context: WaitForCallbackContext) The WaitForCallbackContext provides access to a logger, enabling submitter functions to log operations consistently with other SDK operations like step and wait_for_condition. This change aligns the wait_for_callback API with other context-aware operations in the SDK, improving consistency and extensibility. - Add WaitForCallbackContext type with logger field - Update wait_for_callback_handler to pass context to submitter - Update all callback tests to use new submitter signature - Add test coverage for context parameter validation
1 parent 4e28a5e commit 3338948

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

src/aws_durable_execution_sdk_python/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
BatchResult,
4848
LoggerInterface,
4949
StepContext,
50+
WaitForCallbackContext,
5051
WaitForConditionCheckContext,
5152
)
5253
from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol
@@ -489,7 +490,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
489490

490491
def wait_for_callback(
491492
self,
492-
submitter: Callable[[str], None],
493+
submitter: Callable[[str, WaitForCallbackContext], None],
493494
name: str | None = None,
494495
config: WaitForCallbackConfig | None = None,
495496
) -> Any:

src/aws_durable_execution_sdk_python/operation/callback.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
CallbackOptions,
1111
OperationUpdate,
1212
)
13+
from aws_durable_execution_sdk_python.types import WaitForCallbackContext
1314

1415
if TYPE_CHECKING:
1516
from collections.abc import Callable
@@ -23,7 +24,11 @@
2324
CheckpointedResult,
2425
ExecutionState,
2526
)
26-
from aws_durable_execution_sdk_python.types import Callback, DurableContext
27+
from aws_durable_execution_sdk_python.types import (
28+
Callback,
29+
DurableContext,
30+
StepContext,
31+
)
2732

2833

2934
def create_callback_handler(
@@ -85,7 +90,7 @@ def create_callback_handler(
8590

8691
def wait_for_callback_handler(
8792
context: DurableContext,
88-
submitter: Callable[[str], None],
93+
submitter: Callable[[str, WaitForCallbackContext], None],
8994
name: str | None = None,
9095
config: WaitForCallbackConfig | None = None,
9196
) -> Any:
@@ -98,8 +103,10 @@ def wait_for_callback_handler(
98103
name=f"{name_with_space}create callback id", config=config
99104
)
100105

101-
def submitter_step(step_context): # noqa: ARG001
102-
return submitter(callback.callback_id)
106+
def submitter_step(step_context: StepContext):
107+
return submitter(
108+
callback.callback_id, WaitForCallbackContext(logger=step_context.logger)
109+
)
103110

104111
step_config = (
105112
StepConfig(

src/aws_durable_execution_sdk_python/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class StepContext(OperationContext):
5757
pass
5858

5959

60+
@dataclass(frozen=True)
61+
class WaitForCallbackContext(OperationContext):
62+
"""Context provided to waitForCallback submitter functions."""
63+
64+
6065
@dataclass(frozen=True)
6166
class WaitForConditionCheckContext(OperationContext):
6267
pass

tests/operation/callback_test.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,18 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id():
303303
def capture_step_call(func, name, config=None):
304304
# Execute the step callable to verify submitter is called correctly
305305
step_context = Mock(spec=StepContext)
306+
step_context.logger = Mock()
306307
func(step_context)
307308

308309
mock_context.step.side_effect = capture_step_call
309310

310311
wait_for_callback_handler(mock_context, mock_submitter, "test")
311312

312-
mock_submitter.assert_called_once_with("callback_test_id")
313+
# Verify submitter was called with callback_id and WaitForCallbackContext
314+
assert mock_submitter.call_count == 1
315+
call_args = mock_submitter.call_args[0]
316+
assert call_args[0] == "callback_test_id"
317+
assert hasattr(call_args[1], "logger")
313318

314319

315320
def test_create_callback_handler_with_none_operation_in_result():
@@ -350,14 +355,19 @@ def test_wait_for_callback_handler_with_none_callback_id():
350355

351356
def execute_step(func, name, config=None):
352357
step_context = Mock(spec=StepContext)
358+
step_context.logger = Mock()
353359
return func(step_context)
354360

355361
mock_context.step.side_effect = execute_step
356362

357363
result = wait_for_callback_handler(mock_context, mock_submitter, "test")
358364

359365
assert result == "result_with_none_id"
360-
mock_submitter.assert_called_once_with(None)
366+
# Verify submitter was called with None callback_id and WaitForCallbackContext
367+
assert mock_submitter.call_count == 1
368+
call_args = mock_submitter.call_args[0]
369+
assert call_args[0] is None
370+
assert hasattr(call_args[1], "logger")
361371

362372

363373
def test_wait_for_callback_handler_with_empty_string_callback_id():
@@ -371,14 +381,19 @@ def test_wait_for_callback_handler_with_empty_string_callback_id():
371381

372382
def execute_step(func, name, config=None):
373383
step_context = Mock(spec=StepContext)
384+
step_context.logger = Mock()
374385
return func(step_context)
375386

376387
mock_context.step.side_effect = execute_step
377388

378389
result = wait_for_callback_handler(mock_context, mock_submitter, "test")
379390

380391
assert result == "result_with_empty_id"
381-
mock_submitter.assert_called_once_with("")
392+
# Verify submitter was called with empty string callback_id and WaitForCallbackContext
393+
assert mock_submitter.call_count == 1
394+
call_args = mock_submitter.call_args[0]
395+
assert call_args[0] == "" # noqa: PLC1901 - explicitly testing empty string, not just falsey
396+
assert hasattr(call_args[1], "logger")
382397

383398

384399
def test_wait_for_callback_handler_with_large_data():
@@ -585,12 +600,13 @@ def test_wait_for_callback_handler_submitter_exception_handling():
585600
mock_callback.result.return_value = "exception_result"
586601
mock_context.create_callback.return_value = mock_callback
587602

588-
def failing_submitter(callback_id):
603+
def failing_submitter(callback_id, context):
589604
msg = "Submitter failed"
590605
raise ValueError(msg)
591606

592607
def step_side_effect(func, name, config=None):
593608
step_context = Mock(spec=StepContext)
609+
step_context.logger = Mock()
594610
func(step_context)
595611

596612
mock_context.step.side_effect = step_side_effect
@@ -775,12 +791,14 @@ def test_callback_lifecycle_complete_flow():
775791

776792
assert callback_id == "lifecycle_cb123"
777793

778-
def mock_submitter(cb_id):
794+
def mock_submitter(cb_id, context):
779795
assert cb_id == "lifecycle_cb123"
796+
assert hasattr(context, "logger")
780797
return "submitted"
781798

782799
def execute_step(func, name, config=None):
783800
step_context = Mock(spec=StepContext)
801+
step_context.logger = Mock()
784802
return func(step_context)
785803

786804
mock_context.step.side_effect = execute_step
@@ -889,7 +907,7 @@ def test_callback_with_complex_submitter():
889907

890908
submission_log = []
891909

892-
def complex_submitter(callback_id):
910+
def complex_submitter(callback_id, context):
893911
submission_log.append(f"received_id: {callback_id}")
894912
if callback_id == "complex_cb789":
895913
submission_log.append("api_call_success")
@@ -901,6 +919,7 @@ def complex_submitter(callback_id):
901919

902920
def execute_step(func, name, config):
903921
step_context = Mock(spec=StepContext)
922+
step_context.logger = Mock()
904923
return func(step_context)
905924

906925
mock_context.step.side_effect = execute_step

0 commit comments

Comments
 (0)