Skip to content

Commit bca5796

Browse files
author
Rares Polenciuc
committed
feat: add callback examples
1 parent 4e8ecb8 commit bca5796

File tree

9 files changed

+369
-1
lines changed

9 files changed

+369
-1
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import threading
2+
import queue
3+
from enum import StrEnum
4+
from time import sleep
5+
from typing import Callable, Optional
6+
7+
8+
class RunnerMode(StrEnum):
9+
"""Runner mode for local or cloud execution."""
10+
11+
LOCAL = "local"
12+
CLOUD = "cloud"
13+
14+
15+
class ExternalSystem:
16+
_instance = None
17+
_lock = threading.Lock()
18+
19+
def __new__(cls):
20+
if cls._instance is None:
21+
with cls._lock:
22+
if cls._instance is None:
23+
cls._instance = super().__new__(cls)
24+
cls._instance._initialized = False
25+
return cls._instance
26+
27+
def __init__(self):
28+
if self._initialized:
29+
return
30+
self._call_queue = queue.Queue()
31+
self._worker_thread = None
32+
self._mode = RunnerMode.CLOUD
33+
self._success_handler = self._cloud_success_handler
34+
self._failure_handler = self._cloud_failure_handler
35+
self._heartbeat_handler = self._cloud_heartbeat_handler
36+
self._initialized = True
37+
self._start_worker()
38+
39+
@property
40+
def mode(self) -> RunnerMode:
41+
return self._mode
42+
43+
def activate_local_mode(
44+
self,
45+
success_handler: Optional[Callable[[str, bytes], None]] = None,
46+
failure_handler: Optional[Callable[[str, Exception], None]] = None,
47+
heartbeat_handler: Optional[Callable[[str], None]] = None,
48+
):
49+
"""Activate local mode with custom handlers."""
50+
self._mode = RunnerMode.LOCAL
51+
self._success_handler = success_handler
52+
self._failure_handler = failure_handler
53+
self._heartbeat_handler = heartbeat_handler
54+
55+
def activate_cloud_mode(self):
56+
"""Activate cloud mode with boto3 handlers."""
57+
self._mode = RunnerMode.CLOUD
58+
self._success_handler = self._cloud_success_handler
59+
self._failure_handler = self._cloud_failure_handler
60+
self._heartbeat_handler = self._cloud_heartbeat_handler
61+
62+
def send_success(self, callback_id: str, msg: bytes):
63+
"""Send success callback."""
64+
self._call_queue.put(("success", callback_id, msg), timeout=0.5)
65+
66+
def send_failure(self, callback_id: str, error: Exception):
67+
"""Send failure callback."""
68+
self._call_queue.put(("failure", callback_id, error), timeout=0.5)
69+
70+
def send_heartbeat(self, callback_id: str):
71+
"""Send heartbeat callback."""
72+
self._call_queue.put(("heartbeat", callback_id, None), timeout=0.5)
73+
74+
def _start_worker(self):
75+
if self._worker_thread is None:
76+
self._worker_thread = threading.Thread(target=self._worker, daemon=True)
77+
self._worker_thread.start()
78+
79+
def _worker(self):
80+
"""Background worker that processes callbacks."""
81+
while True:
82+
try:
83+
operation_type, callback_id, data = self._call_queue.get(timeout=2)
84+
85+
if operation_type == "success" and self._success_handler:
86+
self._success_handler(callback_id, data)
87+
elif operation_type == "failure" and self._failure_handler:
88+
self._failure_handler(callback_id, data)
89+
elif operation_type == "heartbeat" and self._heartbeat_handler:
90+
self._heartbeat_handler(callback_id)
91+
92+
self._call_queue.task_done()
93+
except queue.Empty:
94+
continue
95+
96+
def _cloud_success_handler(self, callback_id: str, msg: bytes):
97+
"""Default cloud success handler using boto3."""
98+
try:
99+
import boto3
100+
import os
101+
102+
client = boto3.client(
103+
"lambdainternal",
104+
endpoint_url=os.environ.get("LAMBDA_ENDPOINT"),
105+
region_name=os.environ.get("AWS_REGION", "us-west-2"),
106+
)
107+
108+
client.send_durable_execution_callback_success(
109+
CallbackId=callback_id, Result=msg.decode("utf-8") if msg else None
110+
)
111+
except Exception:
112+
pass # Fail silently in cloud mode
113+
114+
def _cloud_failure_handler(self, callback_id: str, error: Exception):
115+
"""Default cloud failure handler using boto3."""
116+
try:
117+
import boto3
118+
import os
119+
120+
client = boto3.client(
121+
"lambdainternal",
122+
endpoint_url=os.environ.get("LAMBDA_ENDPOINT"),
123+
region_name=os.environ.get("AWS_REGION", "us-west-2"),
124+
)
125+
126+
client.send_durable_execution_callback_failure(
127+
CallbackId=callback_id, Error=str(error)
128+
)
129+
except Exception:
130+
pass # Fail silently in cloud mode
131+
132+
def _cloud_heartbeat_handler(self, callback_id: str):
133+
"""Default cloud heartbeat handler using boto3."""
134+
try:
135+
import boto3
136+
import os
137+
138+
client = boto3.client(
139+
"lambdainternal",
140+
endpoint_url=os.environ.get("LAMBDA_ENDPOINT"),
141+
region_name=os.environ.get("AWS_REGION", "us-west-2"),
142+
)
143+
144+
client.send_durable_execution_callback_heartbeat(CallbackId=callback_id)
145+
except Exception:
146+
pass # Fail silently in cloud mode
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Any
2+
3+
from aws_durable_execution_sdk_python import DurableContext, durable_execution
4+
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig, Duration
5+
from .external_system import ExternalSystem # noqa: TID252
6+
7+
external_system = ExternalSystem() # Singleton instance
8+
9+
10+
@durable_execution
11+
def handler(_event: Any, context: DurableContext) -> str:
12+
name = "Callback Failure"
13+
config = WaitForCallbackConfig(timeout=Duration(10), retry_strategy=None)
14+
15+
def submitter(callback_id: str) -> None:
16+
"""Submitter function that triggers failure."""
17+
try:
18+
raise Exception("Callback failed")
19+
except Exception as e:
20+
external_system.send_failure(callback_id, e)
21+
22+
try:
23+
context.wait_for_callback(submitter, name=name, config=config)
24+
return "OK"
25+
except Exception as e:
26+
result = str(e)
27+
return result
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Any
2+
3+
from aws_durable_execution_sdk_python import DurableContext, durable_execution
4+
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig, Duration
5+
from .external_system import ExternalSystem # noqa: TID252
6+
7+
external_system = ExternalSystem() # Singleton instance
8+
9+
10+
@durable_execution
11+
def handler(_event: Any, context: DurableContext) -> str:
12+
name = "Callback Heartbeat"
13+
config = WaitForCallbackConfig(timeout=Duration(30), retry_strategy=None)
14+
15+
def submitter(callback_id: str) -> None:
16+
"""Submitter function that sends heartbeat then succeeds."""
17+
external_system.send_heartbeat(callback_id)
18+
external_system.send_success(callback_id, b"")
19+
20+
context.wait_for_callback(submitter, name=name, config=config)
21+
return "OK"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Any
2+
3+
from aws_durable_execution_sdk_python import DurableContext, durable_execution
4+
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig, Duration
5+
from .external_system import ExternalSystem # noqa: TID252
6+
7+
external_system = ExternalSystem() # Singleton instance
8+
9+
10+
@durable_execution
11+
def handler(_event: Any, context: DurableContext) -> str:
12+
name = "Callback Waiting"
13+
config = WaitForCallbackConfig(timeout=Duration(10), retry_strategy=None)
14+
15+
def submitter(callback_id: str) -> None:
16+
"""Submitter function."""
17+
external_system.send_success(callback_id, b"")
18+
19+
context.wait_for_callback(submitter, name=name, config=config)
20+
21+
return "OK"

examples/test/conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from typing import Any
1111

1212
import pytest
13-
from aws_durable_execution_sdk_python.lambda_service import OperationPayload
13+
from aws_durable_execution_sdk_python.lambda_service import (
14+
ErrorObject,
15+
OperationPayload,
16+
)
1417
from aws_durable_execution_sdk_python.serdes import ExtendedTypeSerDes
1518

1619
from aws_durable_execution_sdk_python_testing.runner import (
@@ -122,6 +125,29 @@ def __exit__(self, exc_type, exc_val, exc_tb):
122125
return self._runner.__exit__(exc_type, exc_val, exc_tb)
123126
return None
124127

128+
def succeed_callback(self, callback_id: str, result: bytes) -> None:
129+
"""Send callback success response."""
130+
if self.mode == RunnerMode.LOCAL:
131+
self._runner.send_callback_success(callback_id=callback_id, result=result)
132+
else:
133+
logger.warning("Current runner does not support callback success")
134+
135+
def fail_callback(self, callback_id: str, error: Exception | None = None) -> None:
136+
"""Send callback failure response."""
137+
if self.mode == RunnerMode.LOCAL:
138+
error_obj = ErrorObject.from_exception(error) if error else None
139+
self._runner.send_callback_failure(callback_id=callback_id, error=error_obj)
140+
else:
141+
logger.warning("Current runner does not support callback failure")
142+
143+
def heartbeat_callback(self, callback_id: str) -> None:
144+
"""Send callback heartbeat to keep callback alive."""
145+
146+
if self.mode == RunnerMode.LOCAL:
147+
self._runner.send_callback_heartbeat(callback_id=callback_id)
148+
else:
149+
logger.warning("Current runner does not support callback heartbeat")
150+
125151

126152
@pytest.fixture
127153
def durable_runner(request):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Tests for callback failure example."""
2+
3+
from asyncio import sleep
4+
5+
import pytest
6+
from aws_durable_execution_sdk_python.execution import InvocationStatus
7+
8+
from wait_for_callback import wait_for_callback_failure
9+
from test.conftest import deserialize_operation_payload
10+
11+
12+
@pytest.mark.example
13+
@pytest.mark.durable_execution(
14+
handler=wait_for_callback_failure.handler,
15+
lambda_function_name="wait for callback failure",
16+
)
17+
def test_callback_failure(durable_runner):
18+
"""Test callback failure handling."""
19+
20+
with durable_runner:
21+
# Configure external system for local mode if needed
22+
if durable_runner.mode == "local":
23+
24+
def failure_handler(callback_id: str, error: Exception):
25+
sleep(0.5) # Simulate async work
26+
durable_runner.fail_callback(callback_id, str(error))
27+
28+
def success_handler(callback_id: str, msg: bytes):
29+
durable_runner.succeed_callback(callback_id, msg)
30+
31+
wait_for_callback_failure.external_system.activate_local_mode(
32+
success_handler=success_handler, failure_handler=failure_handler
33+
)
34+
result = durable_runner.run(input="test", timeout=10)
35+
36+
# Should handle the failure gracefully
37+
assert result.status is InvocationStatus.SUCCEEDED
38+
assert result.result != "OK"
39+
assert result.result == '"Callback failed"'
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Tests for callback heartbeat example."""
2+
3+
from asyncio import sleep
4+
5+
import pytest
6+
from aws_durable_execution_sdk_python.execution import InvocationStatus
7+
8+
from wait_for_callback import wait_for_callback_heartbeat
9+
from test.conftest import deserialize_operation_payload
10+
11+
12+
@pytest.mark.example
13+
@pytest.mark.durable_execution(
14+
handler=wait_for_callback_heartbeat.handler,
15+
lambda_function_name="wait for callback heartbeat",
16+
)
17+
def test_callback_heartbeat(durable_runner):
18+
"""Test callback heartbeat functionality."""
19+
20+
with durable_runner:
21+
# Configure external system for local mode if needed
22+
if durable_runner.mode == "local":
23+
24+
def heartbeat_handler(callback_id: str):
25+
sleep(0.1) # Simulate async work
26+
durable_runner.heartbeat_callback(callback_id)
27+
28+
def success_handler(callback_id: str, msg: bytes):
29+
sleep(0.5)
30+
durable_runner.succeed_callback(callback_id, msg)
31+
32+
wait_for_callback_heartbeat.external_system.activate_local_mode(
33+
success_handler=success_handler, heartbeat_handler=heartbeat_handler
34+
)
35+
36+
result = durable_runner.run(input="test", timeout=30)
37+
38+
assert result.status is InvocationStatus.SUCCEEDED
39+
assert deserialize_operation_payload(result.result) == "OK"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Tests for run_in_child_context example."""
2+
3+
from asyncio import sleep
4+
5+
import pytest
6+
from aws_durable_execution_sdk_python.execution import InvocationStatus
7+
8+
from wait_for_callback import wait_for_callback_success
9+
from test.conftest import deserialize_operation_payload
10+
11+
12+
@pytest.mark.example
13+
@pytest.mark.durable_execution(
14+
handler=wait_for_callback_success.handler,
15+
lambda_function_name="wait for callback success",
16+
)
17+
def test_callback_success(durable_runner):
18+
"""Test run_in_child_context example."""
19+
20+
with durable_runner:
21+
# Configure external system for local mode if needed
22+
if durable_runner.mode == "local":
23+
24+
def success_handler(callback_id: str, msg: bytes):
25+
durable_runner.succeed_callback(callback_id, msg)
26+
27+
wait_for_callback_success.external_system.activate_local_mode(
28+
success_handler=success_handler
29+
)
30+
31+
result = durable_runner.run(input="test", timeout=10)
32+
33+
assert result.status is InvocationStatus.SUCCEEDED
34+
assert deserialize_operation_payload(result.result) == "OK"
35+
36+
# Verify child context operation exists
37+
context_ops = [
38+
op for op in result.operations if op.operation_type.value == "CONTEXT"
39+
]
40+
assert len(context_ops) >= 1

src/aws_durable_execution_sdk_python_testing/runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,15 @@ def run(
558558
execution: Execution = self._store.load(output.execution_arn)
559559
return DurableFunctionTestResult.create(execution=execution)
560560

561+
def send_callback_success(self, callback_id: str, result: bytes):
562+
self._executor.send_callback_success(callback_id=callback_id, result=result)
563+
564+
def send_callback_failure(self, callback_id: str, error: ErrorObject | None):
565+
self._executor.send_callback_failure(callback_id=callback_id, error=error)
566+
567+
def send_callback_heartbeat(self, callback_id: str):
568+
self._executor.send_callback_heartbeat(callback_id=callback_id)
569+
561570

562571
class DurableChildContextTestRunner(DurableFunctionTestRunner):
563572
"""Test a durable block, annotated with @durable_with_child_context, in isolation."""

0 commit comments

Comments
 (0)