Skip to content

Commit 71c01e4

Browse files
author
Alex Wang
committed
feat: Implement callback for web runner
- Implement async run, send callback for web runner - Unit tests
1 parent a3cda91 commit 71c01e4

File tree

4 files changed

+740
-22
lines changed

4 files changed

+740
-22
lines changed

examples/test/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ def run(
105105
"""Execute the durable function and return results."""
106106
return self._runner.run(input=input, timeout=timeout)
107107

108+
def run_async(
109+
self,
110+
input: str | None = None, # noqa: A002
111+
timeout: int = 60,
112+
) -> str:
113+
return self._runner.run_async(input=input, timeout=timeout)
114+
115+
def send_callback_success(self, callback_id: str) -> None:
116+
self._runner.send_callback_success(callback_id=callback_id)
117+
118+
def send_callback_failure(self, callback_id: str) -> None:
119+
self._runner.send_callback_failure(callback_id=callback_id)
120+
121+
def send_callback_heartbeat(self, callback_id: str) -> None:
122+
self._runner.send_callback_heartbeat(callback_id=callback_id)
123+
124+
def wait_and_fetch_test_result(
125+
self, execution_arn: str, timeout: int = 60
126+
) -> DurableFunctionTestResult:
127+
return self._runner.wait_and_fetch_test_result(
128+
execution_arn=execution_arn, timeout=timeout
129+
)
130+
131+
def wait_and_fetch_callback_id(
132+
self, execution_arn: str, name: str | None = None, timeout: int = 60
133+
) -> str:
134+
return self._runner.wait_and_fetch_callback_id(
135+
execution_arn=execution_arn, name=name, timeout=timeout
136+
)
137+
108138
@property
109139
def mode(self) -> str:
110140
"""Get the runner mode (local or cloud)."""

src/aws_durable_execution_sdk_python_testing/runner.py

Lines changed: 196 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import aws_durable_execution_sdk_python
2020
import boto3 # type: ignore
21+
from botocore.exceptions import ClientError # type: ignore
2122
from aws_durable_execution_sdk_python.execution import (
2223
InvocationStatus,
2324
durable_execution,
@@ -75,6 +76,8 @@
7576

7677
from aws_durable_execution_sdk_python_testing.execution import Execution
7778
from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
79+
from aws_durable_execution_sdk_python_testing.model import Event
80+
7881

7982
logger = logging.getLogger(__name__)
8083

@@ -792,9 +795,9 @@ def run(
792795
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
793796
raise DurableFunctionsTestError(msg) from e
794797

795-
# Check HTTP status code (200 for RequestResponse, 202 for Event, 204 for DryRun)
798+
# Check HTTP status code, 200 for RequestResponse
796799
status_code = response.get("StatusCode")
797-
if status_code not in (200, 202, 204):
800+
if status_code != 200:
798801
error_payload = response["Payload"].read().decode("utf-8")
799802
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
800803
raise DurableFunctionsTestError(msg)
@@ -819,17 +822,120 @@ def run(
819822
)
820823
raise DurableFunctionsTestError(msg)
821824

822-
# Poll for completion
823-
execution_response = self._wait_for_completion(execution_arn, timeout)
825+
return self.wait_and_fetch_test_result(
826+
execution_arn=execution_arn, timeout=timeout
827+
)
824828

825-
# Get execution history
826-
history_response = self._get_execution_history(execution_arn)
829+
def run_async(
830+
self,
831+
input: str | None = None, # noqa: A002
832+
timeout: int = 60,
833+
) -> str:
834+
"""Execute function on AWS Lambda asynchronously"""
835+
logger.info(
836+
"Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout
837+
)
838+
payload = json.dumps(input)
839+
try:
840+
response = self.lambda_client.invoke(
841+
FunctionName=self.function_name,
842+
InvocationType="Event",
843+
Payload=payload,
844+
)
845+
except Exception as e:
846+
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
847+
raise DurableFunctionsTestError(msg) from e
827848

828-
# Build test result from execution history
829-
return DurableFunctionTestResult.from_execution_history(
830-
execution_response, history_response
849+
# Check HTTP status code, 202 for Event
850+
status_code = response.get("StatusCode")
851+
if status_code != 202:
852+
error_payload = response["Payload"].read().decode("utf-8")
853+
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
854+
raise DurableFunctionsTestError(msg)
855+
856+
return response.get("DurableExecutionArn")
857+
858+
def _get_callback_id_from_events(
859+
self, events: list[Event], name: str | None = None
860+
) -> str | None:
861+
"""
862+
Get callback ID from execution history for callbacks that haven't completed.
863+
864+
Args:
865+
execution_arn: The ARN of the execution to query.
866+
name: Optional callback name to search for. If not provided, returns the latest callback.
867+
868+
Returns:
869+
The callback ID string for a non-completed callback, or None if not found.
870+
871+
Raises:
872+
DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
873+
"""
874+
callback_started_events = [
875+
event for event in events if event.event_type == "CallbackStarted"
876+
]
877+
878+
if not callback_started_events:
879+
return None
880+
881+
completed_callback_ids = {
882+
event.event_id
883+
for event in events
884+
if event.event_type
885+
in ["CallbackSucceeded", "CallbackFailed", "CallbackTimedOut"]
886+
}
887+
888+
if name is not None:
889+
for event in callback_started_events:
890+
if event.name == name:
891+
callback_id = event.event_id
892+
if callback_id in completed_callback_ids:
893+
raise DurableFunctionsTestError(
894+
f"Callback {name} has already completed (succeeded/failed/timed out)"
895+
)
896+
return (
897+
event.callback_started_details.callback_id
898+
if event.callback_started_details
899+
else None
900+
)
901+
return None
902+
903+
# If name is not provided, find the latest non-completed callback event
904+
active_callbacks = [
905+
event
906+
for event in callback_started_events
907+
if event.event_id not in completed_callback_ids
908+
]
909+
910+
if not active_callbacks:
911+
return None
912+
913+
latest_event = active_callbacks[-1]
914+
return (
915+
latest_event.callback_started_details.callback_id
916+
if latest_event.callback_started_details
917+
else None
831918
)
832919

920+
def send_callback_success(self, callback_id: str) -> None:
921+
self._send_callback("success", callback_id)
922+
923+
def send_callback_failure(self, callback_id: str) -> None:
924+
self._send_callback("failure", callback_id)
925+
926+
def send_callback_heartbeat(self, callback_id: str) -> None:
927+
self._send_callback("heartbeat", callback_id)
928+
929+
def _send_callback(self, operation: str, callback_id: str) -> None:
930+
"""Helper method to send callback operations."""
931+
method_name = f"send_durable_execution_callback_{operation}"
932+
try:
933+
api_method = getattr(self.lambda_client, method_name)
934+
api_method(CallbackId=callback_id)
935+
except Exception as e:
936+
msg = f"Failed to send callback {operation} for {self.function_name}, callback_id {callback_id}: {e}"
937+
raise DurableFunctionsTestError(msg) from e
938+
833939
def _wait_for_completion(
834940
self, execution_arn: str, timeout: int
835941
) -> GetDurableExecutionResponse:
@@ -886,7 +992,81 @@ def _wait_for_completion(
886992
)
887993
raise TimeoutError(msg)
888994

889-
def _get_execution_history(
995+
def wait_and_fetch_test_result(
996+
self, execution_arn: str, timeout: int = 60
997+
) -> DurableFunctionTestResult:
998+
# Poll for completion
999+
execution_response = self._wait_for_completion(execution_arn, timeout)
1000+
1001+
try:
1002+
history_response = self._fetch_execution_history(execution_arn)
1003+
except Exception as e:
1004+
msg = f"Failed to fetch execution history: {e}"
1005+
raise DurableFunctionsTestError(msg) from e
1006+
1007+
# Build test result from execution history
1008+
return DurableFunctionTestResult.from_execution_history(
1009+
execution_response, history_response
1010+
)
1011+
1012+
def wait_and_fetch_callback_id(
1013+
self, execution_arn: str, name: str | None = None, timeout: int = 60
1014+
) -> str:
1015+
"""
1016+
Wait for and retrieve a callback ID from a Step Functions execution.
1017+
1018+
Polls the execution history at regular intervals until a callback ID is found
1019+
or the timeout is reached.
1020+
1021+
Args:
1022+
execution_arn: Execution Arn
1023+
name: Specific callback name, default to None
1024+
timeout: Maximum time in seconds to wait for callback. Defaults to 60.
1025+
1026+
Returns:
1027+
str: The callback ID/token retrieved from the execution history
1028+
1029+
Raises:
1030+
TimeoutError: If callback is not found within the specified timeout period
1031+
DurableFunctionsTestError: If there's an error fetching execution history
1032+
(excluding retryable errors)
1033+
"""
1034+
start_time = time.time()
1035+
1036+
while time.time() - start_time < timeout:
1037+
try:
1038+
history_response = self._fetch_execution_history(execution_arn)
1039+
callback_id = self._get_callback_id_from_events(
1040+
events=history_response.events, name=name
1041+
)
1042+
if callback_id:
1043+
return callback_id
1044+
except ClientError as e:
1045+
error_code = e.response["Error"]["Code"]
1046+
# retryable error, the execution may not start yet in async invoke situation
1047+
if error_code in ["ResourceNotFoundException"]:
1048+
pass
1049+
else:
1050+
msg = f"Failed to fetch execution history: {e}"
1051+
raise DurableFunctionsTestError(msg) from e
1052+
except DurableFunctionsTestError as e:
1053+
raise e
1054+
except Exception as e:
1055+
msg = f"Failed to fetch execution history: {e}"
1056+
raise DurableFunctionsTestError(msg) from e
1057+
1058+
# Wait before next poll
1059+
time.sleep(self.poll_interval)
1060+
1061+
# Timeout reached
1062+
elapsed = time.time() - start_time
1063+
msg = (
1064+
f"Callback did not available within {timeout}s "
1065+
f"(elapsed: {elapsed:.1f}s."
1066+
)
1067+
raise TimeoutError(msg)
1068+
1069+
def _fetch_execution_history(
8901070
self, execution_arn: str
8911071
) -> GetDurableExecutionHistoryResponse:
8921072
"""Retrieve execution history from Lambda service.
@@ -898,19 +1078,13 @@ def _get_execution_history(
8981078
GetDurableExecutionHistoryResponse with typed Event objects
8991079
9001080
Raises:
901-
DurableFunctionsTestError: If history retrieval fails
1081+
ClientError: If lambda client encounter error
9021082
"""
903-
try:
904-
history_dict = self.lambda_client.get_durable_execution_history(
905-
DurableExecutionArn=execution_arn,
906-
IncludeExecutionData=True,
907-
)
908-
history_response = GetDurableExecutionHistoryResponse.from_dict(
909-
history_dict
910-
)
911-
except Exception as e:
912-
msg = f"Failed to get execution history: {e}"
913-
raise DurableFunctionsTestError(msg) from e
1083+
history_dict = self.lambda_client.get_durable_execution_history(
1084+
DurableExecutionArn=execution_arn,
1085+
IncludeExecutionData=True,
1086+
)
1087+
history_response = GetDurableExecutionHistoryResponse.from_dict(history_dict)
9141088

9151089
logger.info("Retrieved %d events from history", len(history_response.events))
9161090

0 commit comments

Comments
 (0)