Skip to content

Commit b758457

Browse files
authored
fix: add InvocationCompleted event support (#168)
1 parent 2b03bfb commit b758457

File tree

6 files changed

+212
-36
lines changed

6 files changed

+212
-36
lines changed

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
# Import AWS exceptions
3030
from aws_durable_execution_sdk_python_testing.model import (
31+
InvocationCompletedDetails,
3132
StartDurableExecutionInput,
3233
)
3334
from aws_durable_execution_sdk_python_testing.token import (
@@ -60,6 +61,7 @@ def __init__(
6061
self.start_input: StartDurableExecutionInput = start_input
6162
self.operations: list[Operation] = operations
6263
self.updates: list[OperationUpdate] = []
64+
self.invocation_completions: list[InvocationCompletedDetails] = []
6365
self.used_tokens: set[str] = set()
6466
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
6567
self._token_sequence: int = 0
@@ -101,6 +103,9 @@ def to_dict(self) -> dict[str, Any]:
101103
"StartInput": self.start_input.to_dict(),
102104
"Operations": [op.to_dict() for op in self.operations],
103105
"Updates": [update.to_dict() for update in self.updates],
106+
"InvocationCompletions": [
107+
completion.to_dict() for completion in self.invocation_completions
108+
],
104109
"UsedTokens": list(self.used_tokens),
105110
"TokenSequence": self._token_sequence,
106111
"IsComplete": self.is_complete,
@@ -129,6 +134,10 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
129134
execution.updates = [
130135
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
131136
]
137+
execution.invocation_completions = [
138+
InvocationCompletedDetails.from_dict(item)
139+
for item in data.get("InvocationCompletions", [])
140+
]
132141
execution.used_tokens = set(data["UsedTokens"])
133142
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
134143
execution.is_complete = data["IsComplete"]
@@ -215,6 +224,18 @@ def has_pending_operations(self, execution: Execution) -> bool:
215224
return True
216225
return False
217226

227+
def record_invocation_completion(
228+
self, start_timestamp: datetime, end_timestamp: datetime, request_id: str
229+
) -> None:
230+
"""Record an invocation completion event."""
231+
self.invocation_completions.append(
232+
InvocationCompletedDetails(
233+
start_timestamp=start_timestamp,
234+
end_timestamp=end_timestamp,
235+
request_id=request_id,
236+
)
237+
)
238+
218239
def complete_success(self, result: str | None) -> None:
219240
"""Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
220241
self.result = DurableExecutionInvocationOutput(

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
import time
67
import uuid
78
from datetime import UTC, datetime
89
from typing import TYPE_CHECKING
@@ -32,6 +33,8 @@
3233
from aws_durable_execution_sdk_python_testing.model import (
3334
CheckpointDurableExecutionResponse,
3435
CheckpointUpdatedExecutionState,
36+
EventCreationContext,
37+
EventType,
3538
GetDurableExecutionHistoryResponse,
3639
GetDurableExecutionResponse,
3740
GetDurableExecutionStateResponse,
@@ -44,7 +47,6 @@
4447
StartDurableExecutionOutput,
4548
StopDurableExecutionResponse,
4649
TERMINAL_STATUSES,
47-
EventCreationContext,
4850
)
4951
from aws_durable_execution_sdk_python_testing.model import (
5052
Event as HistoryEvent,
@@ -413,6 +415,17 @@ def get_execution_history(
413415
updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates}
414416
durable_execution_arn: str = execution.durable_execution_arn
415417

418+
# Add InvocationCompleted events
419+
for completion in execution.invocation_completions:
420+
invocation_event = HistoryEvent.create_invocation_completed(
421+
event_id=0, # Temporary, will be reassigned
422+
event_timestamp=completion.end_timestamp,
423+
start_timestamp=completion.start_timestamp,
424+
end_timestamp=completion.end_timestamp,
425+
request_id=completion.request_id,
426+
)
427+
all_events.append(invocation_event)
428+
416429
# Generate all events first (without final event IDs)
417430
for op in ops:
418431
operation_update: OperationUpdate | None = updates_dict.get(
@@ -769,14 +782,23 @@ async def invoke() -> None:
769782

770783
self._store.save(execution)
771784

772-
response: DurableExecutionInvocationOutput = self._invoker.invoke(
785+
invocation_start = datetime.now(UTC)
786+
invoke_response = self._invoker.invoke(
773787
execution.start_input.function_name,
774788
invocation_input,
775789
execution.start_input.lambda_endpoint,
776790
)
791+
invocation_end = datetime.now(UTC)
777792

778793
# Reload execution after invocation in case it was completed via checkpoint
779794
execution = self._store.load(execution_arn)
795+
796+
# Record invocation completion and save immediately
797+
execution.record_invocation_completion(
798+
invocation_start, invocation_end, invoke_response.request_id
799+
)
800+
self._store.save(execution)
801+
780802
if execution.is_complete:
781803
logger.info(
782804
"[%s] Execution completed during invocation, ignoring result",
@@ -785,6 +807,7 @@ async def invoke() -> None:
785807
return
786808

787809
# Process successful received response - validate status and handle accordingly
810+
response = invoke_response.invocation_output
788811
try:
789812
self._validate_invocation_response_and_store(
790813
execution_arn, response, execution

src/aws_durable_execution_sdk_python_testing/invoker.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import json
4+
from dataclasses import dataclass
45
from threading import Lock
56
from typing import TYPE_CHECKING, Any, Protocol
7+
from uuid import uuid4
68

79
import boto3 # type: ignore
810
from aws_durable_execution_sdk_python.execution import (
@@ -26,6 +28,14 @@
2628
from aws_durable_execution_sdk_python_testing.execution import Execution
2729

2830

31+
@dataclass(frozen=True)
32+
class InvokeResponse:
33+
"""Response from invoking a durable function."""
34+
35+
invocation_output: DurableExecutionInvocationOutput
36+
request_id: str
37+
38+
2939
def create_test_lambda_context() -> LambdaContext:
3040
# Create client context as a dictionary, not as objects
3141
# LambdaContext.__init__ expects dictionaries and will create the objects internally
@@ -65,7 +75,7 @@ def invoke(
6575
function_name: str,
6676
input: DurableExecutionInvocationInput,
6777
endpoint_url: str | None = None,
68-
) -> DurableExecutionInvocationOutput: ... # pragma: no cover
78+
) -> InvokeResponse: ... # pragma: no cover
6979

7080
def update_endpoint(
7181
self, endpoint_url: str, region_name: str
@@ -96,14 +106,17 @@ def invoke(
96106
function_name: str, # noqa: ARG002
97107
input: DurableExecutionInvocationInput,
98108
endpoint_url: str | None = None, # noqa: ARG002
99-
) -> DurableExecutionInvocationOutput:
109+
) -> InvokeResponse:
100110
# TODO: reasses if function_name will be used in future
101111
input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input(
102112
input, self.service_client
103113
)
104114
context = create_test_lambda_context()
105115
response_dict = self.handler(input_with_client, context)
106-
return DurableExecutionInvocationOutput.from_dict(response_dict)
116+
output = DurableExecutionInvocationOutput.from_dict(response_dict)
117+
return InvokeResponse(
118+
invocation_output=output, request_id=context.aws_request_id
119+
)
107120

108121
def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
109122
"""No-op for in-process invoker."""
@@ -192,7 +205,7 @@ def invoke(
192205
function_name: str,
193206
input: DurableExecutionInvocationInput,
194207
endpoint_url: str | None = None,
195-
) -> DurableExecutionInvocationOutput:
208+
) -> InvokeResponse:
196209
"""Invoke AWS Lambda function and return durable execution result.
197210
198211
Args:
@@ -201,7 +214,7 @@ def invoke(
201214
endpoint_url: Lambda endpoint url
202215
203216
Returns:
204-
DurableExecutionInvocationOutput: Result of the function execution
217+
InvokeResponse: Response containing invocation output and request ID
205218
206219
Raises:
207220
ResourceNotFoundException: If function does not exist
@@ -247,8 +260,17 @@ def invoke(
247260
response_payload = response["Payload"].read().decode("utf-8")
248261
response_dict = json.loads(response_payload)
249262

263+
# Extract request ID from response headers (x-amzn-RequestId or x-amzn-request-id)
264+
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
265+
request_id = (
266+
headers.get("x-amzn-RequestId")
267+
or headers.get("x-amzn-request-id")
268+
or f"local-{uuid4()}"
269+
)
270+
250271
# Convert to DurableExecutionInvocationOutput
251-
return DurableExecutionInvocationOutput.from_dict(response_dict)
272+
output = DurableExecutionInvocationOutput.from_dict(response_dict)
273+
return InvokeResponse(invocation_output=output, request_id=request_id)
252274

253275
except client.exceptions.ResourceNotFoundException as e:
254276
msg = f"Function not found: {function_name}"

src/aws_durable_execution_sdk_python_testing/model.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class EventType(Enum):
6868
CALLBACK_SUCCEEDED = "CallbackSucceeded"
6969
CALLBACK_FAILED = "CallbackFailed"
7070
CALLBACK_TIMED_OUT = "CallbackTimedOut"
71+
INVOCATION_COMPLETED = "InvocationCompleted"
7172

7273

7374
TERMINAL_STATUSES: set[OperationStatus] = {
@@ -1222,6 +1223,30 @@ def to_dict(self) -> dict[str, Any]:
12221223
return result
12231224

12241225

1226+
@dataclass(frozen=True)
1227+
class InvocationCompletedDetails:
1228+
"""Invocation completed event details."""
1229+
1230+
start_timestamp: datetime.datetime
1231+
end_timestamp: datetime.datetime
1232+
request_id: str
1233+
1234+
@classmethod
1235+
def from_dict(cls, data: dict) -> InvocationCompletedDetails:
1236+
return cls(
1237+
start_timestamp=data["StartTimestamp"],
1238+
end_timestamp=data["EndTimestamp"],
1239+
request_id=data["RequestId"],
1240+
)
1241+
1242+
def to_dict(self) -> dict[str, Any]:
1243+
return {
1244+
"StartTimestamp": self.start_timestamp,
1245+
"EndTimestamp": self.end_timestamp,
1246+
"RequestId": self.request_id,
1247+
}
1248+
1249+
12251250
# endregion event_structures
12261251

12271252

@@ -1329,6 +1354,7 @@ class Event:
13291354
callback_succeeded_details: CallbackSucceededDetails | None = None
13301355
callback_failed_details: CallbackFailedDetails | None = None
13311356
callback_timed_out_details: CallbackTimedOutDetails | None = None
1357+
invocation_completed_details: InvocationCompletedDetails | None = None
13321358

13331359
@classmethod
13341360
def from_dict(cls, data: dict) -> Event:
@@ -1447,6 +1473,12 @@ def from_dict(cls, data: dict) -> Event:
14471473
if details_data := data.get("CallbackTimedOutDetails"):
14481474
callback_timed_out_details = CallbackTimedOutDetails.from_dict(details_data)
14491475

1476+
invocation_completed_details = None
1477+
if details_data := data.get("InvocationCompletedDetails"):
1478+
invocation_completed_details = InvocationCompletedDetails.from_dict(
1479+
details_data
1480+
)
1481+
14501482
return cls(
14511483
event_type=data["EventType"],
14521484
event_timestamp=data["EventTimestamp"],
@@ -1479,6 +1511,7 @@ def from_dict(cls, data: dict) -> Event:
14791511
callback_succeeded_details=callback_succeeded_details,
14801512
callback_failed_details=callback_failed_details,
14811513
callback_timed_out_details=callback_timed_out_details,
1514+
invocation_completed_details=invocation_completed_details,
14821515
)
14831516

14841517
def to_dict(self) -> dict[str, Any]:
@@ -1563,6 +1596,10 @@ def to_dict(self) -> dict[str, Any]:
15631596
result["CallbackTimedOutDetails"] = (
15641597
self.callback_timed_out_details.to_dict()
15651598
)
1599+
if self.invocation_completed_details is not None:
1600+
result["InvocationCompletedDetails"] = (
1601+
self.invocation_completed_details.to_dict()
1602+
)
15661603
return result
15671604

15681605
# region execution
@@ -2218,6 +2255,30 @@ def create_callback_event(cls, context: EventCreationContext) -> Event:
22182255

22192256
# endregion callback
22202257

2258+
# region invocation_completed
2259+
@classmethod
2260+
def create_invocation_completed(
2261+
cls,
2262+
event_id: int,
2263+
event_timestamp: datetime.datetime,
2264+
start_timestamp: datetime.datetime,
2265+
end_timestamp: datetime.datetime,
2266+
request_id: str,
2267+
) -> Event:
2268+
"""Create invocation completed event."""
2269+
return cls(
2270+
event_type=EventType.INVOCATION_COMPLETED.value,
2271+
event_timestamp=event_timestamp,
2272+
event_id=event_id,
2273+
invocation_completed_details=InvocationCompletedDetails(
2274+
start_timestamp=start_timestamp,
2275+
end_timestamp=end_timestamp,
2276+
request_id=request_id,
2277+
),
2278+
)
2279+
2280+
# endregion invocation_completed
2281+
22212282
@classmethod
22222283
def create_event_started(cls, context: EventCreationContext) -> Event:
22232284
"""Convert operation to started event."""

0 commit comments

Comments
 (0)