Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions examples/src/wait_for_callback/external_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import threading
import queue
from enum import StrEnum
from time import sleep
from typing import Callable, Optional


class RunnerMode(StrEnum):
"""Runner mode for local or cloud execution."""

LOCAL = "local"
CLOUD = "cloud"


class ExternalSystem:
_instance = None
_lock = threading.Lock()

def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance

def __init__(self):
if self._initialized:
return
self._call_queue = queue.Queue()
self._worker_thread = None
self._shutdown_flag = threading.Event()

self._mode = RunnerMode.CLOUD
self._success_handler = self._cloud_success_handler
self._failure_handler = self._cloud_failure_handler
self._heartbeat_handler = self._cloud_heartbeat_handler
self._initialized = True

@property
def mode(self) -> RunnerMode:
return self._mode

def activate_local_mode(
self,
success_handler: Optional[Callable[[str, bytes], None]] = None,
failure_handler: Optional[Callable[[str, Exception], None]] = None,
heartbeat_handler: Optional[Callable[[str], None]] = None,
):
"""Activate local mode with custom handlers."""
self._mode = RunnerMode.LOCAL
self._success_handler = success_handler
self._failure_handler = failure_handler
self._heartbeat_handler = heartbeat_handler

def activate_cloud_mode(self):
"""Activate cloud mode with boto3 handlers."""
self._mode = RunnerMode.CLOUD
self._success_handler = self._cloud_success_handler
self._failure_handler = self._cloud_failure_handler
self._heartbeat_handler = self._cloud_heartbeat_handler

def send_success(self, callback_id: str, msg: bytes):
"""Send success callback."""
self._call_queue.put(("success", callback_id, msg), timeout=0.5)

def send_failure(self, callback_id: str, error: Exception):
"""Send failure callback."""
self._call_queue.put(("failure", callback_id, error), timeout=0.5)

def send_heartbeat(self, callback_id: str):
"""Send heartbeat callback."""
self._call_queue.put(("heartbeat", callback_id, None), timeout=0.5)

def start(self):
if self._worker_thread is None or not self._worker_thread.is_alive():
self._worker_thread = threading.Thread(target=self._worker, daemon=True)
self._worker_thread.start()

def _worker(self):
"""Background worker that processes callbacks."""
while not self._shutdown_flag.is_set():
try:
operation_type, callback_id, data = self._call_queue.get(timeout=0.5)

if operation_type == "success" and self._success_handler:
self._success_handler(callback_id, data)
elif operation_type == "failure" and self._failure_handler:
self._failure_handler(callback_id, data)
elif operation_type == "heartbeat" and self._heartbeat_handler:
self._heartbeat_handler(callback_id)

self._call_queue.task_done()
except queue.Empty:
continue

def reset(self):
"""Reset the external system state."""
# Clear the queue
while not self._call_queue.empty():
try:
self._call_queue.get_nowait()
self._call_queue.task_done()
except queue.Empty:
break

def shutdown(self):
"""Shutdown the worker thread."""
self._shutdown_flag.set()

# Clear the queue
while not self._call_queue.empty():
try:
self._call_queue.get_nowait()
self._call_queue.task_done()
except queue.Empty:
break

# Wait for thread to finish
if self._worker_thread and self._worker_thread.is_alive():
self._worker_thread.join(timeout=1)

# Reset for next test
self._worker_thread = None
self._shutdown_flag.clear()

@classmethod
def reset_instance(cls):
"""Reset the singleton instance."""
with cls._lock:
if cls._instance:
cls._instance.shutdown()
cls._instance = None

def _cloud_success_handler(self, callback_id: str, msg: bytes):
"""Default cloud success handler using boto3."""
try:
import boto3
import os

client = boto3.client(
"lambdainternal",
endpoint_url=os.environ.get("LAMBDA_ENDPOINT"),
region_name=os.environ.get("AWS_REGION", "us-west-2"),
)

client.send_durable_execution_callback_success(
CallbackId=callback_id, Result=msg.decode("utf-8") if msg else None
)
except Exception:
pass # Fail silently in cloud mode

def _cloud_failure_handler(self, callback_id: str, error: Exception):
"""Default cloud failure handler using boto3."""
try:
import boto3
import os

client = boto3.client(
"lambdainternal",
endpoint_url=os.environ.get("LAMBDA_ENDPOINT"),
region_name=os.environ.get("AWS_REGION", "us-west-2"),
)

client.send_durable_execution_callback_failure(
CallbackId=callback_id, Error=str(error)
)
except Exception:
pass # Fail silently in cloud mode

def _cloud_heartbeat_handler(self, callback_id: str):
"""Default cloud heartbeat handler using boto3."""
try:
import boto3
import os

client = boto3.client(
"lambdainternal",
endpoint_url=os.environ.get("LAMBDA_ENDPOINT"),
region_name=os.environ.get("AWS_REGION", "us-west-2"),
)

client.send_durable_execution_callback_heartbeat(CallbackId=callback_id)
except Exception:
pass # Fail silently in cloud mode
28 changes: 28 additions & 0 deletions examples/src/wait_for_callback/wait_for_callback_failure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any

from aws_durable_execution_sdk_python import DurableContext, durable_execution
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig, Duration
from .external_system import ExternalSystem # noqa: TID252

external_system = ExternalSystem() # Singleton instance


@durable_execution
def handler(_event: Any, context: DurableContext) -> str:
name = "Callback Failure"
config = WaitForCallbackConfig(timeout=Duration(10), retry_strategy=None)

def submitter(callback_id: str) -> None:
"""Submitter function that triggers failure."""
try:
raise Exception("Callback failed")
except Exception as e:
external_system.send_failure(callback_id, e)
external_system.start()

try:
context.wait_for_callback(submitter, name=name, config=config)
return "OK"
except Exception as e:
result = str(e)
return result
22 changes: 22 additions & 0 deletions examples/src/wait_for_callback/wait_for_callback_heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any

from aws_durable_execution_sdk_python import DurableContext, durable_execution
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig, Duration
from .external_system import ExternalSystem # noqa: TID252

external_system = ExternalSystem() # Singleton instance


@durable_execution
def handler(_event: Any, context: DurableContext) -> str:
name = "Callback Heartbeat"
config = WaitForCallbackConfig(timeout=Duration(30), retry_strategy=None)

def submitter(callback_id: str) -> None:
"""Submitter function that sends heartbeat then succeeds."""
external_system.send_heartbeat(callback_id)
external_system.send_success(callback_id, b"")
external_system.start()

context.wait_for_callback(submitter, name=name, config=config)
return "OK"
22 changes: 22 additions & 0 deletions examples/src/wait_for_callback/wait_for_callback_success.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any

from aws_durable_execution_sdk_python import DurableContext, durable_execution
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig, Duration
from .external_system import ExternalSystem # noqa: TID252

external_system = ExternalSystem() # Singleton instance


@durable_execution
def handler(_event: Any, context: DurableContext) -> str:
name = "Callback Waiting"
config = WaitForCallbackConfig(timeout=Duration(30), retry_strategy=None)

def submitter(callback_id: str) -> None:
"""Submitter function."""
external_system.send_success(callback_id, b"")
external_system.start()

context.wait_for_callback(submitter, name=name, config=config)

return "OK"
28 changes: 27 additions & 1 deletion examples/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from typing import Any

import pytest
from aws_durable_execution_sdk_python.lambda_service import OperationPayload
from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
OperationPayload,
)
from aws_durable_execution_sdk_python.serdes import ExtendedTypeSerDes

from aws_durable_execution_sdk_python_testing.runner import (
Expand Down Expand Up @@ -122,6 +125,29 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self._runner.__exit__(exc_type, exc_val, exc_tb)
return None

def succeed_callback(self, callback_id: str, result: bytes) -> None:
"""Send callback success response."""
if self.mode == RunnerMode.LOCAL:
self._runner.send_callback_success(callback_id=callback_id, result=result)
else:
logger.warning("Current runner does not support callback success")

def fail_callback(self, callback_id: str, error: Exception | None = None) -> None:
"""Send callback failure response."""
if self.mode == RunnerMode.LOCAL:
error_obj = ErrorObject.from_exception(error) if error else None
self._runner.send_callback_failure(callback_id=callback_id, error=error_obj)
else:
logger.warning("Current runner does not support callback failure")

def heartbeat_callback(self, callback_id: str) -> None:
"""Send callback heartbeat to keep callback alive."""

if self.mode == RunnerMode.LOCAL:
self._runner.send_callback_heartbeat(callback_id=callback_id)
else:
logger.warning("Current runner does not support callback heartbeat")


@pytest.fixture
def durable_runner(request):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Tests for callback failure example."""

from asyncio import sleep

import pytest
from aws_durable_execution_sdk_python.execution import InvocationStatus

from wait_for_callback import wait_for_callback_failure
from test.conftest import deserialize_operation_payload
from wait_for_callback.external_system import ExternalSystem


@pytest.mark.example
@pytest.mark.durable_execution(
handler=wait_for_callback_failure.handler,
lambda_function_name="wait for callback failure",
)
def test_callback_failure(durable_runner):
"""Test callback failure handling."""

with durable_runner:
external_system = ExternalSystem()
# Configure external system for local mode if needed
if durable_runner.mode == "local":

def failure_handler(callback_id: str, error: Exception):
sleep(0.5) # Simulate async work
durable_runner.fail_callback(callback_id, str(error))

def success_handler(callback_id: str, msg: bytes):
durable_runner.succeed_callback(callback_id, msg)

external_system.activate_local_mode(
success_handler=success_handler, failure_handler=failure_handler
)

result = durable_runner.run(input="test", timeout=10)
external_system.shutdown()

# Should handle the failure gracefully
assert result.status is InvocationStatus.SUCCEEDED
assert result.result != "OK"
assert result.result == '"Callback failed"'
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for callback heartbeat example."""

from asyncio import sleep

import pytest
from aws_durable_execution_sdk_python.execution import InvocationStatus

from wait_for_callback import wait_for_callback_heartbeat
from test.conftest import deserialize_operation_payload
from wait_for_callback.external_system import ExternalSystem


@pytest.mark.example
@pytest.mark.durable_execution(
handler=wait_for_callback_heartbeat.handler,
lambda_function_name="wait for callback heartbeat",
)
def test_callback_heartbeat(durable_runner):
"""Test callback heartbeat functionality."""

with durable_runner:
external_system = ExternalSystem()
# Configure external system for local mode if needed
if durable_runner.mode == "local":

def heartbeat_handler(callback_id: str):
sleep(0.1) # Simulate async work
# durable_runner.heartbeat_callback(callback_id)

def success_handler(callback_id: str, msg: bytes):
sleep(0.5)
durable_runner.succeed_callback(callback_id, msg)

external_system.activate_local_mode(
success_handler=success_handler, heartbeat_handler=heartbeat_handler
)
Comment on lines +34 to +36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/aws/aws-durable-execution-sdk-js/blob/development/packages/aws-durable-execution-sdk-js-examples/src/examples/wait-for-callback/failures/wait-for-callback-failures.test.ts#L18-L27

I think we want to implement this in a similar way, where the test is able to capture the callback operation and respond to it directly.


result = durable_runner.run(input="test", timeout=30)
external_system.shutdown()

assert result.status is InvocationStatus.SUCCEEDED
assert deserialize_operation_payload(result.result) == "OK"
Loading