Skip to content

Commit ecc8daf

Browse files
author
Alex Wang
committed
feat: Implement callback for web runner
- Implement async run, send callback for web runner - Unit tests
1 parent a3cda91 commit ecc8daf

File tree

5 files changed

+748
-23
lines changed

5 files changed

+748
-23
lines changed

examples/src/wait_for_callback/wait_for_callback.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig
44
from aws_durable_execution_sdk_python.context import DurableContext
55
from aws_durable_execution_sdk_python.execution import durable_execution
6+
from aws_durable_execution_sdk_python.config import Duration
67

78

89
def external_system_call(_callback_id: str) -> None:
@@ -13,7 +14,9 @@ def external_system_call(_callback_id: str) -> None:
1314

1415
@durable_execution
1516
def handler(_event: Any, context: DurableContext) -> str:
16-
config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=60)
17+
config = WaitForCallbackConfig(
18+
timeout=Duration.from_seconds(120), heartbeat_timeout=Duration.from_seconds(60)
19+
)
1720

1821
result = context.wait_for_callback(
1922
external_system_call, name="external_call", config=config

examples/test/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ def run(
105105
"""Execute the durable function and return results."""
106106
return self._runner.run(input=input, timeout=timeout)
107107

108+
def run_async(
109+
self,
110+
input: str | None = None, # noqa: A002
111+
timeout: int = 60,
112+
) -> str:
113+
return self._runner.run_async(input=input, timeout=timeout)
114+
115+
def send_callback_success(self, callback_id: str) -> None:
116+
self._runner.send_callback_success(callback_id=callback_id)
117+
118+
def send_callback_failure(self, callback_id: str) -> None:
119+
self._runner.send_callback_failure(callback_id=callback_id)
120+
121+
def send_callback_heartbeat(self, callback_id: str) -> None:
122+
self._runner.send_callback_heartbeat(callback_id=callback_id)
123+
124+
def wait_for_result(
125+
self, execution_arn: str, timeout: int = 60
126+
) -> DurableFunctionTestResult:
127+
return self._runner.wait_for_result(
128+
execution_arn=execution_arn, timeout=timeout
129+
)
130+
131+
def wait_for_callback(
132+
self, execution_arn: str, name: str | None = None, timeout: int = 60
133+
) -> str:
134+
return self._runner.wait_for_callback(
135+
execution_arn=execution_arn, name=name, timeout=timeout
136+
)
137+
108138
@property
109139
def mode(self) -> str:
110140
"""Get the runner mode (local or cloud)."""

src/aws_durable_execution_sdk_python_testing/runner.py

Lines changed: 202 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import aws_durable_execution_sdk_python
2020
import boto3 # type: ignore
21+
from botocore.exceptions import ClientError # type: ignore
2122
from aws_durable_execution_sdk_python.execution import (
2223
InvocationStatus,
2324
durable_execution,
@@ -75,6 +76,8 @@
7576

7677
from aws_durable_execution_sdk_python_testing.execution import Execution
7778
from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
79+
from aws_durable_execution_sdk_python_testing.model import Event
80+
7881

7982
logger = logging.getLogger(__name__)
8083

@@ -792,9 +795,9 @@ def run(
792795
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
793796
raise DurableFunctionsTestError(msg) from e
794797

795-
# Check HTTP status code (200 for RequestResponse, 202 for Event, 204 for DryRun)
798+
# Check HTTP status code, 200 for RequestResponse
796799
status_code = response.get("StatusCode")
797-
if status_code not in (200, 202, 204):
800+
if status_code != 200:
798801
error_payload = response["Payload"].read().decode("utf-8")
799802
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
800803
raise DurableFunctionsTestError(msg)
@@ -819,17 +822,126 @@ def run(
819822
)
820823
raise DurableFunctionsTestError(msg)
821824

822-
# Poll for completion
823-
execution_response = self._wait_for_completion(execution_arn, timeout)
825+
return self.wait_for_result(execution_arn=execution_arn, timeout=timeout)
824826

825-
# Get execution history
826-
history_response = self._get_execution_history(execution_arn)
827+
def run_async(
828+
self,
829+
input: str | None = None, # noqa: A002
830+
timeout: int = 60,
831+
) -> str:
832+
"""Execute function on AWS Lambda asynchronously"""
833+
logger.info(
834+
"Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout
835+
)
836+
payload = json.dumps(input)
837+
try:
838+
response = self.lambda_client.invoke(
839+
FunctionName=self.function_name,
840+
InvocationType="Event",
841+
Payload=payload,
842+
)
843+
except Exception as e:
844+
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
845+
raise DurableFunctionsTestError(msg) from e
827846

828-
# Build test result from execution history
829-
return DurableFunctionTestResult.from_execution_history(
830-
execution_response, history_response
847+
# Check HTTP status code, 202 for Event
848+
status_code = response.get("StatusCode")
849+
if status_code != 202:
850+
error_payload = response["Payload"].read().decode("utf-8")
851+
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
852+
raise DurableFunctionsTestError(msg)
853+
854+
return response.get("DurableExecutionArn")
855+
856+
def _get_callback_id_from_events(
857+
self, events: list[Event], name: str | None = None
858+
) -> str | None:
859+
"""
860+
Get callback ID from execution history for callbacks that haven't completed.
861+
862+
Args:
863+
execution_arn: The ARN of the execution to query.
864+
name: Optional callback name to search for. If not provided, returns the latest callback.
865+
866+
Returns:
867+
The callback ID string for a non-completed callback, or None if not found.
868+
869+
Raises:
870+
DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
871+
"""
872+
callback_started_events = [
873+
event for event in events if event.event_type == "CallbackStarted"
874+
]
875+
876+
if not callback_started_events:
877+
return None
878+
879+
completed_callback_ids = {
880+
event.event_id
881+
for event in events
882+
if event.event_type
883+
in ["CallbackSucceeded", "CallbackFailed", "CallbackTimedOut"]
884+
}
885+
886+
if name is not None:
887+
for event in callback_started_events:
888+
if event.name == name:
889+
callback_id = event.event_id
890+
if callback_id in completed_callback_ids:
891+
raise DurableFunctionsTestError(
892+
f"Callback {name} has already completed (succeeded/failed/timed out)"
893+
)
894+
return (
895+
event.callback_started_details.callback_id
896+
if event.callback_started_details
897+
else None
898+
)
899+
return None
900+
901+
# If name is not provided, find the latest non-completed callback event
902+
active_callbacks = [
903+
event
904+
for event in callback_started_events
905+
if event.event_id not in completed_callback_ids
906+
]
907+
908+
if not active_callbacks:
909+
return None
910+
911+
latest_event = active_callbacks[-1]
912+
return (
913+
latest_event.callback_started_details.callback_id
914+
if latest_event.callback_started_details
915+
else None
831916
)
832917

918+
def send_callback_success(self, callback_id: str) -> None:
919+
try:
920+
self.lambda_client.send_durable_execution_callback_success(
921+
CallbackId=callback_id
922+
)
923+
except Exception as e:
924+
msg = f"Failed to send callback success for {self.function_name}, callback_id {callback_id}: {e}"
925+
raise DurableFunctionsTestError(msg) from e
926+
927+
def send_callback_failure(self, callback_id: str) -> None:
928+
try:
929+
self.lambda_client.send_durable_execution_callback_failure(
930+
CallbackId=callback_id
931+
)
932+
except Exception as e:
933+
msg = f"Failed to send callback failure for {self.function_name}, callback_id {callback_id}: {e}"
934+
raise DurableFunctionsTestError(msg) from e
935+
936+
def send_callback_heartbeat(self, callback_id: str) -> None:
937+
try:
938+
self.lambda_client.send_durable_execution_callback_heartbeat(
939+
CallbackId=callback_id
940+
)
941+
except Exception as e:
942+
msg = f"Failed to send callback heartbeat for {self.function_name}, callback_id {callback_id}: {e}"
943+
raise DurableFunctionsTestError(msg) from e
944+
833945
def _wait_for_completion(
834946
self, execution_arn: str, timeout: int
835947
) -> GetDurableExecutionResponse:
@@ -886,7 +998,81 @@ def _wait_for_completion(
886998
)
887999
raise TimeoutError(msg)
8881000

889-
def _get_execution_history(
1001+
def wait_for_result(
1002+
self, execution_arn: str, timeout: int = 60
1003+
) -> DurableFunctionTestResult:
1004+
# Poll for completion
1005+
execution_response = self._wait_for_completion(execution_arn, timeout)
1006+
1007+
try:
1008+
history_response = self._fetch_execution_history(execution_arn)
1009+
except Exception as e:
1010+
msg = f"Failed to fetch execution history: {e}"
1011+
raise DurableFunctionsTestError(msg) from e
1012+
1013+
# Build test result from execution history
1014+
return DurableFunctionTestResult.from_execution_history(
1015+
execution_response, history_response
1016+
)
1017+
1018+
def wait_for_callback(
1019+
self, execution_arn: str, name: str | None = None, timeout: int = 60
1020+
) -> str:
1021+
"""
1022+
Wait for and retrieve a callback ID from a Step Functions execution.
1023+
1024+
Polls the execution history at regular intervals until a callback ID is found
1025+
or the timeout is reached.
1026+
1027+
Args:
1028+
execution_arn: Execution Arn
1029+
name: Specific callback name, default to None
1030+
timeout: Maximum time in seconds to wait for callback. Defaults to 60.
1031+
1032+
Returns:
1033+
str: The callback ID/token retrieved from the execution history
1034+
1035+
Raises:
1036+
TimeoutError: If callback is not found within the specified timeout period
1037+
DurableFunctionsTestError: If there's an error fetching execution history
1038+
(excluding retryable errors)
1039+
"""
1040+
start_time = time.time()
1041+
1042+
while time.time() - start_time < timeout:
1043+
try:
1044+
history_response = self._fetch_execution_history(execution_arn)
1045+
callback_id = self._get_callback_id_from_events(
1046+
events=history_response.events, name=name
1047+
)
1048+
if callback_id:
1049+
return callback_id
1050+
except ClientError as e:
1051+
error_code = e.response["Error"]["Code"]
1052+
# retryable error, the execution may not start yet in async invoke situation
1053+
if error_code in ["ResourceNotFoundException"]:
1054+
pass
1055+
else:
1056+
msg = f"Failed to fetch execution history: {e}"
1057+
raise DurableFunctionsTestError(msg) from e
1058+
except DurableFunctionsTestError as e:
1059+
raise e
1060+
except Exception as e:
1061+
msg = f"Failed to fetch execution history: {e}"
1062+
raise DurableFunctionsTestError(msg) from e
1063+
1064+
# Wait before next poll
1065+
time.sleep(self.poll_interval)
1066+
1067+
# Timeout reached
1068+
elapsed = time.time() - start_time
1069+
msg = (
1070+
f"Callback did not available within {timeout}s "
1071+
f"(elapsed: {elapsed:.1f}s."
1072+
)
1073+
raise TimeoutError(msg)
1074+
1075+
def _fetch_execution_history(
8901076
self, execution_arn: str
8911077
) -> GetDurableExecutionHistoryResponse:
8921078
"""Retrieve execution history from Lambda service.
@@ -898,19 +1084,13 @@ def _get_execution_history(
8981084
GetDurableExecutionHistoryResponse with typed Event objects
8991085
9001086
Raises:
901-
DurableFunctionsTestError: If history retrieval fails
1087+
ClientError: If lambda client encounter error
9021088
"""
903-
try:
904-
history_dict = self.lambda_client.get_durable_execution_history(
905-
DurableExecutionArn=execution_arn,
906-
IncludeExecutionData=True,
907-
)
908-
history_response = GetDurableExecutionHistoryResponse.from_dict(
909-
history_dict
910-
)
911-
except Exception as e:
912-
msg = f"Failed to get execution history: {e}"
913-
raise DurableFunctionsTestError(msg) from e
1089+
history_dict = self.lambda_client.get_durable_execution_history(
1090+
DurableExecutionArn=execution_arn,
1091+
IncludeExecutionData=True,
1092+
)
1093+
history_response = GetDurableExecutionHistoryResponse.from_dict(history_dict)
9141094

9151095
logger.info("Retrieved %d events from history", len(history_response.events))
9161096

0 commit comments

Comments
 (0)