Skip to content

Commit 13876ab

Browse files
author
Alex Wang
committed
feat: Implement callback for web runner
- Implement async run, send callback for web runner - Unit tests
1 parent a3cda91 commit 13876ab

File tree

5 files changed

+755
-23
lines changed

5 files changed

+755
-23
lines changed

examples/src/wait_for_callback/wait_for_callback.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig
44
from aws_durable_execution_sdk_python.context import DurableContext
55
from aws_durable_execution_sdk_python.execution import durable_execution
6+
from aws_durable_execution_sdk_python.config import Duration
67

78

89
def external_system_call(_callback_id: str) -> None:
@@ -13,7 +14,9 @@ def external_system_call(_callback_id: str) -> None:
1314

1415
@durable_execution
1516
def handler(_event: Any, context: DurableContext) -> str:
16-
config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=60)
17+
config = WaitForCallbackConfig(
18+
timeout=Duration.from_seconds(120), heartbeat_timeout=Duration.from_seconds(60)
19+
)
1720

1821
result = context.wait_for_callback(
1922
external_system_call, name="external_call", config=config

examples/test/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ def run(
105105
"""Execute the durable function and return results."""
106106
return self._runner.run(input=input, timeout=timeout)
107107

108+
def run_async(
109+
self,
110+
input: str | None = None, # noqa: A002
111+
timeout: int = 60,
112+
) -> str:
113+
return self._runner.run_async(input=input, timeout=timeout)
114+
115+
def send_callback_success(self, callback_id: str) -> None:
116+
self._runner.send_callback_success(callback_id=callback_id)
117+
118+
def send_callback_failure(self, callback_id: str) -> None:
119+
self._runner.send_callback_failure(callback_id=callback_id)
120+
121+
def send_callback_heartbeat(self, callback_id: str) -> None:
122+
self._runner.send_callback_heartbeat(callback_id=callback_id)
123+
124+
def wait_for_result(
125+
self, execution_arn: str, timeout: int = 60
126+
) -> DurableFunctionTestResult:
127+
return self._runner.wait_for_result(
128+
execution_arn=execution_arn, timeout=timeout
129+
)
130+
131+
def wait_for_callback(
132+
self, execution_arn: str, name: str | None = None, timeout: int = 60
133+
) -> str:
134+
return self._runner.wait_for_callback(
135+
execution_arn=execution_arn, name=name, timeout=timeout
136+
)
137+
108138
@property
109139
def mode(self) -> str:
110140
"""Get the runner mode (local or cloud)."""

src/aws_durable_execution_sdk_python_testing/runner.py

Lines changed: 205 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import aws_durable_execution_sdk_python
2020
import boto3 # type: ignore
21+
from botocore.exceptions import ClientError # type: ignore
2122
from aws_durable_execution_sdk_python.execution import (
2223
InvocationStatus,
2324
durable_execution,
@@ -75,6 +76,8 @@
7576

7677
from aws_durable_execution_sdk_python_testing.execution import Execution
7778
from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
79+
from aws_durable_execution_sdk_python_testing.model import Event
80+
7881

7982
logger = logging.getLogger(__name__)
8083

@@ -792,9 +795,9 @@ def run(
792795
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
793796
raise DurableFunctionsTestError(msg) from e
794797

795-
# Check HTTP status code (200 for RequestResponse, 202 for Event, 204 for DryRun)
798+
# Check HTTP status code, 200 for RequestResponse
796799
status_code = response.get("StatusCode")
797-
if status_code not in (200, 202, 204):
800+
if status_code != 200:
798801
error_payload = response["Payload"].read().decode("utf-8")
799802
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
800803
raise DurableFunctionsTestError(msg)
@@ -819,17 +822,129 @@ def run(
819822
)
820823
raise DurableFunctionsTestError(msg)
821824

822-
# Poll for completion
823-
execution_response = self._wait_for_completion(execution_arn, timeout)
825+
return self.wait_for_result(execution_arn=execution_arn, timeout=timeout)
824826

825-
# Get execution history
826-
history_response = self._get_execution_history(execution_arn)
827+
def run_async(
828+
self,
829+
input: str | None = None, # noqa: A002
830+
timeout: int = 60,
831+
) -> str:
832+
"""Execute function on AWS Lambda asynchronously"""
833+
logger.info(
834+
"Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout
835+
)
836+
payload = json.dumps(input)
837+
try:
838+
response = self.lambda_client.invoke(
839+
FunctionName=self.function_name,
840+
InvocationType="Event",
841+
Payload=payload,
842+
)
843+
except Exception as e:
844+
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
845+
raise DurableFunctionsTestError(msg) from e
827846

828-
# Build test result from execution history
829-
return DurableFunctionTestResult.from_execution_history(
830-
execution_response, history_response
847+
# Check HTTP status code, 202 for Event
848+
status_code = response.get("StatusCode")
849+
if status_code != 202:
850+
error_payload = response["Payload"].read().decode("utf-8")
851+
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
852+
raise DurableFunctionsTestError(msg)
853+
854+
return response.get("DurableExecutionArn")
855+
856+
def _get_callback_id_from_events(
857+
self, events: list[Event], name: str | None = None
858+
) -> str | None:
859+
"""
860+
Get callback ID from execution history for callbacks that haven't completed.
861+
862+
Args:
863+
execution_arn: The ARN of the execution to query.
864+
name: Optional callback name to search for. If not provided, returns the latest callback.
865+
866+
Returns:
867+
The callback ID string for a non-completed callback, or None if not found.
868+
869+
Raises:
870+
DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
871+
"""
872+
callback_started_events = [
873+
event for event in events if event.event_type == "CallbackStarted"
874+
]
875+
876+
if not callback_started_events:
877+
return None
878+
879+
completed_callback_ids = {
880+
event.event_id
881+
for event in events
882+
if event.event_type
883+
in ["CallbackSucceeded", "CallbackFailed", "CallbackTimedOut"]
884+
}
885+
886+
if name is not None:
887+
for event in callback_started_events:
888+
if event.name == name:
889+
callback_id = event.event_id
890+
if callback_id in completed_callback_ids:
891+
raise DurableFunctionsTestError(
892+
f"Callback {name} has already completed (succeeded/failed/timed out)"
893+
)
894+
return (
895+
event.callback_started_details.callback_id
896+
if event.callback_started_details
897+
else None
898+
)
899+
return None
900+
901+
# If name is not provided, find the latest non-completed callback event
902+
active_callbacks = [
903+
event
904+
for event in callback_started_events
905+
if event.event_id not in completed_callback_ids
906+
]
907+
908+
if not active_callbacks:
909+
return None
910+
911+
latest_event = active_callbacks[-1]
912+
return (
913+
latest_event.callback_started_details.callback_id
914+
if latest_event.callback_started_details
915+
else None
916+
)
917+
918+
def send_callback_success(self, callback_id: str) -> None:
919+
self._send_callback(
920+
"success",
921+
callback_id,
922+
self.lambda_client.send_durable_execution_callback_success,
831923
)
832924

925+
def send_callback_failure(self, callback_id: str) -> None:
926+
self._send_callback(
927+
"failure",
928+
callback_id,
929+
self.lambda_client.send_durable_execution_callback_failure,
930+
)
931+
932+
def send_callback_heartbeat(self, callback_id: str) -> None:
933+
self._send_callback(
934+
"heartbeat",
935+
callback_id,
936+
self.lambda_client.send_durable_execution_callback_heartbeat,
937+
)
938+
939+
def _send_callback(self, operation: str, callback_id: str, api_method) -> None:
940+
"""Helper method to send callback operations."""
941+
method_name = f"send_durable_execution_callback_{operation}"
942+
try:
943+
api_method(CallbackId=callback_id)
944+
except Exception as e:
945+
msg = f"Failed to send callback {operation} for {self.function_name}, callback_id {callback_id}: {e}"
946+
raise DurableFunctionsTestError(msg) from e
947+
833948
def _wait_for_completion(
834949
self, execution_arn: str, timeout: int
835950
) -> GetDurableExecutionResponse:
@@ -886,7 +1001,81 @@ def _wait_for_completion(
8861001
)
8871002
raise TimeoutError(msg)
8881003

889-
def _get_execution_history(
1004+
def wait_for_result(
1005+
self, execution_arn: str, timeout: int = 60
1006+
) -> DurableFunctionTestResult:
1007+
# Poll for completion
1008+
execution_response = self._wait_for_completion(execution_arn, timeout)
1009+
1010+
try:
1011+
history_response = self._fetch_execution_history(execution_arn)
1012+
except Exception as e:
1013+
msg = f"Failed to fetch execution history: {e}"
1014+
raise DurableFunctionsTestError(msg) from e
1015+
1016+
# Build test result from execution history
1017+
return DurableFunctionTestResult.from_execution_history(
1018+
execution_response, history_response
1019+
)
1020+
1021+
def wait_for_callback(
1022+
self, execution_arn: str, name: str | None = None, timeout: int = 60
1023+
) -> str:
1024+
"""
1025+
Wait for and retrieve a callback ID from a Step Functions execution.
1026+
1027+
Polls the execution history at regular intervals until a callback ID is found
1028+
or the timeout is reached.
1029+
1030+
Args:
1031+
execution_arn: Execution Arn
1032+
name: Specific callback name, default to None
1033+
timeout: Maximum time in seconds to wait for callback. Defaults to 60.
1034+
1035+
Returns:
1036+
str: The callback ID/token retrieved from the execution history
1037+
1038+
Raises:
1039+
TimeoutError: If callback is not found within the specified timeout period
1040+
DurableFunctionsTestError: If there's an error fetching execution history
1041+
(excluding retryable errors)
1042+
"""
1043+
start_time = time.time()
1044+
1045+
while time.time() - start_time < timeout:
1046+
try:
1047+
history_response = self._fetch_execution_history(execution_arn)
1048+
callback_id = self._get_callback_id_from_events(
1049+
events=history_response.events, name=name
1050+
)
1051+
if callback_id:
1052+
return callback_id
1053+
except ClientError as e:
1054+
error_code = e.response["Error"]["Code"]
1055+
# retryable error, the execution may not start yet in async invoke situation
1056+
if error_code in ["ResourceNotFoundException"]:
1057+
pass
1058+
else:
1059+
msg = f"Failed to fetch execution history: {e}"
1060+
raise DurableFunctionsTestError(msg) from e
1061+
except DurableFunctionsTestError as e:
1062+
raise e
1063+
except Exception as e:
1064+
msg = f"Failed to fetch execution history: {e}"
1065+
raise DurableFunctionsTestError(msg) from e
1066+
1067+
# Wait before next poll
1068+
time.sleep(self.poll_interval)
1069+
1070+
# Timeout reached
1071+
elapsed = time.time() - start_time
1072+
msg = (
1073+
f"Callback did not available within {timeout}s "
1074+
f"(elapsed: {elapsed:.1f}s."
1075+
)
1076+
raise TimeoutError(msg)
1077+
1078+
def _fetch_execution_history(
8901079
self, execution_arn: str
8911080
) -> GetDurableExecutionHistoryResponse:
8921081
"""Retrieve execution history from Lambda service.
@@ -898,19 +1087,13 @@ def _get_execution_history(
8981087
GetDurableExecutionHistoryResponse with typed Event objects
8991088
9001089
Raises:
901-
DurableFunctionsTestError: If history retrieval fails
1090+
ClientError: If lambda client encounter error
9021091
"""
903-
try:
904-
history_dict = self.lambda_client.get_durable_execution_history(
905-
DurableExecutionArn=execution_arn,
906-
IncludeExecutionData=True,
907-
)
908-
history_response = GetDurableExecutionHistoryResponse.from_dict(
909-
history_dict
910-
)
911-
except Exception as e:
912-
msg = f"Failed to get execution history: {e}"
913-
raise DurableFunctionsTestError(msg) from e
1092+
history_dict = self.lambda_client.get_durable_execution_history(
1093+
DurableExecutionArn=execution_arn,
1094+
IncludeExecutionData=True,
1095+
)
1096+
history_response = GetDurableExecutionHistoryResponse.from_dict(history_dict)
9141097

9151098
logger.info("Retrieved %d events from history", len(history_response.events))
9161099

0 commit comments

Comments
 (0)