Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def __init__(
self.start_input: StartDurableExecutionInput = start_input
self.operations: list[Operation] = operations
self.updates: list[OperationUpdate] = []
self.invocation_completions: list[
tuple[float, float, str]
] = [] # (start_ts, end_ts, request_id)
self.used_tokens: set[str] = set()
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
self._token_sequence: int = 0
Expand Down Expand Up @@ -101,6 +104,10 @@ def to_dict(self) -> dict[str, Any]:
"StartInput": self.start_input.to_dict(),
"Operations": [op.to_dict() for op in self.operations],
"Updates": [update.to_dict() for update in self.updates],
"InvocationCompletions": [
[start, end, req_id]
for start, end, req_id in self.invocation_completions
],
"UsedTokens": list(self.used_tokens),
"TokenSequence": self._token_sequence,
"IsComplete": self.is_complete,
Expand Down Expand Up @@ -129,6 +136,9 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
execution.updates = [
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
]
execution.invocation_completions = [
tuple(item) for item in data.get("InvocationCompletions", [])
]
execution.used_tokens = set(data["UsedTokens"])
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
execution.is_complete = data["IsComplete"]
Expand Down Expand Up @@ -215,6 +225,12 @@ def has_pending_operations(self, execution: Execution) -> bool:
return True
return False

def record_invocation_completion(
self, start_timestamp: float, end_timestamp: float, request_id: str
) -> None:
"""Record an invocation completion event."""
self.invocation_completions.append((start_timestamp, end_timestamp, request_id))

def complete_success(self, result: str | None) -> None:
"""Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
self.result = DurableExecutionInvocationOutput(
Expand Down
26 changes: 25 additions & 1 deletion src/aws_durable_execution_sdk_python_testing/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import time
import uuid
from datetime import UTC, datetime
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -413,6 +414,20 @@ def get_execution_history(
updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates}
durable_execution_arn: str = execution.durable_execution_arn

# Add InvocationCompleted events
for start_ts, end_ts, request_id in execution.invocation_completions:
invocation_event = HistoryEvent(
event_id=0, # Temporary, will be reassigned
event_type="InvocationCompleted",
event_timestamp=datetime.fromtimestamp(end_ts, tz=UTC),
invocation_completed_details={
"StartTimestamp": start_ts,
"EndTimestamp": end_ts,
"RequestId": request_id,
},
)
all_events.append(invocation_event)

# Generate all events first (without final event IDs)
for op in ops:
operation_update: OperationUpdate | None = updates_dict.get(
Expand Down Expand Up @@ -769,14 +784,23 @@ async def invoke() -> None:

self._store.save(execution)

response: DurableExecutionInvocationOutput = self._invoker.invoke(
invocation_start = time.time()
response, request_id = self._invoker.invoke(
execution.start_input.function_name,
invocation_input,
execution.start_input.lambda_endpoint,
)
invocation_end = time.time()

# Reload execution after invocation in case it was completed via checkpoint
execution = self._store.load(execution_arn)

# Record invocation completion and save immediately
execution.record_invocation_completion(
invocation_start, invocation_end, request_id
)
self._store.save(execution)

if execution.is_complete:
logger.info(
"[%s] Execution completed during invocation, ignoring result",
Expand Down
23 changes: 17 additions & 6 deletions src/aws_durable_execution_sdk_python_testing/invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from threading import Lock
from typing import TYPE_CHECKING, Any, Protocol
from uuid import uuid4

import boto3 # type: ignore
from aws_durable_execution_sdk_python.execution import (
Expand Down Expand Up @@ -65,7 +66,7 @@ def invoke(
function_name: str,
input: DurableExecutionInvocationInput,
endpoint_url: str | None = None,
) -> DurableExecutionInvocationOutput: ... # pragma: no cover
) -> tuple[DurableExecutionInvocationOutput, str]: ... # pragma: no cover

def update_endpoint(
self, endpoint_url: str, region_name: str
Expand Down Expand Up @@ -96,14 +97,15 @@ def invoke(
function_name: str, # noqa: ARG002
input: DurableExecutionInvocationInput,
endpoint_url: str | None = None, # noqa: ARG002
) -> DurableExecutionInvocationOutput:
) -> tuple[DurableExecutionInvocationOutput, str]:
# TODO: reasses if function_name will be used in future
input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input(
input, self.service_client
)
context = create_test_lambda_context()
response_dict = self.handler(input_with_client, context)
return DurableExecutionInvocationOutput.from_dict(response_dict)
output = DurableExecutionInvocationOutput.from_dict(response_dict)
return output, context.aws_request_id

def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
"""No-op for in-process invoker."""
Expand Down Expand Up @@ -192,7 +194,7 @@ def invoke(
function_name: str,
input: DurableExecutionInvocationInput,
endpoint_url: str | None = None,
) -> DurableExecutionInvocationOutput:
) -> tuple[DurableExecutionInvocationOutput, str]:
"""Invoke AWS Lambda function and return durable execution result.

Args:
Expand All @@ -201,7 +203,7 @@ def invoke(
endpoint_url: Lambda endpoint url

Returns:
DurableExecutionInvocationOutput: Result of the function execution
tuple: (DurableExecutionInvocationOutput, request_id)

Raises:
ResourceNotFoundException: If function does not exist
Expand Down Expand Up @@ -247,8 +249,17 @@ def invoke(
response_payload = response["Payload"].read().decode("utf-8")
response_dict = json.loads(response_payload)

# Extract request ID from response headers (x-amzn-RequestId or x-amzn-request-id)
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
request_id = (
headers.get("x-amzn-RequestId")
or headers.get("x-amzn-request-id")
or f"local-{uuid4()}"
)

# Convert to DurableExecutionInvocationOutput
return DurableExecutionInvocationOutput.from_dict(response_dict)
output = DurableExecutionInvocationOutput.from_dict(response_dict)
return output, request_id

except client.exceptions.ResourceNotFoundException as e:
msg = f"Function not found: {function_name}"
Expand Down
6 changes: 6 additions & 0 deletions src/aws_durable_execution_sdk_python_testing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,7 @@ class Event:
callback_succeeded_details: CallbackSucceededDetails | None = None
callback_failed_details: CallbackFailedDetails | None = None
callback_timed_out_details: CallbackTimedOutDetails | None = None
invocation_completed_details: dict[str, Any] | None = None

@classmethod
def from_dict(cls, data: dict) -> Event:
Expand Down Expand Up @@ -1447,6 +1448,8 @@ def from_dict(cls, data: dict) -> Event:
if details_data := data.get("CallbackTimedOutDetails"):
callback_timed_out_details = CallbackTimedOutDetails.from_dict(details_data)

invocation_completed_details = data.get("InvocationCompletedDetails")

return cls(
event_type=data["EventType"],
event_timestamp=data["EventTimestamp"],
Expand Down Expand Up @@ -1479,6 +1482,7 @@ def from_dict(cls, data: dict) -> Event:
callback_succeeded_details=callback_succeeded_details,
callback_failed_details=callback_failed_details,
callback_timed_out_details=callback_timed_out_details,
invocation_completed_details=invocation_completed_details,
)

def to_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -1563,6 +1567,8 @@ def to_dict(self) -> dict[str, Any]:
result["CallbackTimedOutDetails"] = (
self.callback_timed_out_details.to_dict()
)
if self.invocation_completed_details is not None:
result["InvocationCompletedDetails"] = self.invocation_completed_details
return result

# region execution
Expand Down
42 changes: 27 additions & 15 deletions tests/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_should_complete_workflow_with_error_when_invocation_fails(
failed_response = DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, error=ErrorObject.from_message("Test error")
)
mock_invoker.invoke.return_value = failed_response
mock_invoker.invoke.return_value = (failed_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_should_complete_workflow_with_result_when_invocation_succeeds(
success_response = DurableExecutionInvocationOutput(
status=InvocationStatus.SUCCEEDED, result="success result"
)
mock_invoker.invoke.return_value = success_response
mock_invoker.invoke.return_value = (success_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -372,7 +372,7 @@ def test_should_handle_pending_status_when_operations_exist(
mock_invocation_input = Mock()
mock_invoker.create_invocation_input.return_value = mock_invocation_input
pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING)
mock_invoker.invoke.return_value = pending_response
mock_invoker.invoke.return_value = (pending_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -409,8 +409,9 @@ def test_should_ignore_response_when_execution_already_complete(

# Mock invoker - this shouldn't be called since execution is complete
mock_invoker.create_invocation_input.return_value = Mock()
mock_invoker.invoke.return_value = DurableExecutionInvocationOutput(
status=InvocationStatus.SUCCEEDED
mock_invoker.invoke.return_value = (
DurableExecutionInvocationOutput(status=InvocationStatus.SUCCEEDED),
"test-request-id",
)

# Mock execution creation and store behavior
Expand Down Expand Up @@ -452,7 +453,7 @@ def test_should_retry_when_response_has_no_status(
mock_invocation_input = Mock()
mock_invoker.create_invocation_input.return_value = mock_invocation_input
no_status_response = DurableExecutionInvocationOutput(status=None)
mock_invoker.invoke.return_value = no_status_response
mock_invoker.invoke.return_value = (no_status_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -495,7 +496,7 @@ def test_should_retry_when_failed_response_has_result(
invalid_response = DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, result="should not have result"
)
mock_invoker.invoke.return_value = invalid_response
mock_invoker.invoke.return_value = (invalid_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -539,7 +540,7 @@ def test_should_retry_when_success_response_has_error(
status=InvocationStatus.SUCCEEDED,
error=ErrorObject.from_message("should not have error"),
)
mock_invoker.invoke.return_value = invalid_response
mock_invoker.invoke.return_value = (invalid_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -581,7 +582,7 @@ def test_should_retry_when_pending_response_has_no_operations(
mock_invocation_input = Mock()
mock_invoker.create_invocation_input.return_value = mock_invocation_input
pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING)
mock_invoker.invoke.return_value = pending_response
mock_invoker.invoke.return_value = (pending_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -622,7 +623,7 @@ def test_invoke_handler_success(
mock_response = DurableExecutionInvocationOutput(
status=InvocationStatus.SUCCEEDED, result="test"
)
mock_invoker.invoke.return_value = mock_response
mock_invoker.invoke.return_value = (mock_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -694,7 +695,7 @@ def test_invoke_handler_execution_completed_during_invocation(
mock_invocation_input = Mock()
mock_invoker.create_invocation_input.return_value = mock_invocation_input
mock_response = Mock()
mock_invoker.invoke.return_value = mock_response
mock_invoker.invoke.return_value = (mock_response, "test-request-id")

# Create a completed execution mock
completed_execution = Mock()
Expand Down Expand Up @@ -1037,7 +1038,10 @@ def test_should_retry_invocation_when_under_limit_through_public_api(
success_response = DurableExecutionInvocationOutput(
status=InvocationStatus.SUCCEEDED, result="final success"
)
mock_invoker.invoke.side_effect = [invalid_response, success_response]
mock_invoker.invoke.side_effect = [
(invalid_response, "test-request-id-1"),
(success_response, "test-request-id-2"),
]

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -1435,7 +1439,7 @@ def test_should_retry_when_response_has_unexpected_status(
mock_invoker.create_invocation_input.return_value = mock_invocation_input
unexpected_response = Mock()
unexpected_response.status = "UNKNOWN_STATUS"
mock_invoker.invoke.return_value = unexpected_response
mock_invoker.invoke.return_value = (unexpected_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -1480,7 +1484,7 @@ def test_invoke_handler_execution_completed_during_invocation_async(
mock_invocation_input = Mock()
mock_invoker.create_invocation_input.return_value = mock_invocation_input
mock_response = Mock()
mock_invoker.invoke.return_value = mock_response
mock_invoker.invoke.return_value = (mock_response, "test-request-id")

# Mock execution creation
with patch(
Expand Down Expand Up @@ -1566,7 +1570,7 @@ def test_invoke_handler_general_exception_async(
success_response = DurableExecutionInvocationOutput(
status=InvocationStatus.SUCCEEDED, result="success"
)
mock_invoker.invoke.return_value = success_response
mock_invoker.invoke.return_value = (success_response, "test-request-id")

# Mock execution creation and store behavior
with patch(
Expand Down Expand Up @@ -2094,6 +2098,7 @@ def test_get_execution_history(executor, mock_store):
mock_execution = Mock()
mock_execution.operations = [] # Empty operations list
mock_execution.updates = []
mock_execution.invocation_completions = []
mock_execution.durable_execution_arn = ""
mock_execution.start_input = Mock()
mock_execution.result = Mock()
Expand Down Expand Up @@ -2123,6 +2128,7 @@ def test_get_execution_history_with_events(executor, mock_store):
mock_execution = Mock()
mock_execution.operations = [op1]
mock_execution.updates = []
mock_execution.invocation_completions = []
mock_execution.durable_execution_arn = ""
mock_execution.start_input = Mock()
mock_execution.result = Mock()
Expand All @@ -2148,6 +2154,7 @@ def test_get_execution_history_reverse_order(executor, mock_store):
mock_execution = Mock()
mock_execution.operations = [op1]
mock_execution.updates = []
mock_execution.invocation_completions = []
mock_execution.durable_execution_arn = ""
mock_execution.start_input = Mock()
mock_execution.result = Mock()
Expand Down Expand Up @@ -2178,6 +2185,7 @@ def test_get_execution_history_pagination(executor, mock_store):
mock_execution = Mock()
mock_execution.operations = operations
mock_execution.updates = []
mock_execution.invocation_completions = []
mock_execution.durable_execution_arn = ""
mock_execution.start_input = Mock()
mock_execution.result = Mock()
Expand Down Expand Up @@ -2206,6 +2214,7 @@ def test_get_execution_history_pagination_with_marker(executor, mock_store):
mock_execution = Mock()
mock_execution.operations = operations
mock_execution.updates = []
mock_execution.invocation_completions = []
mock_execution.durable_execution_arn = ""
mock_execution.start_input = Mock()
mock_execution.result = Mock()
Expand All @@ -2223,6 +2232,7 @@ def test_get_execution_history_invalid_marker(executor, mock_store):
mock_execution = Mock()
mock_execution.operations = []
mock_execution.updates = []
mock_execution.invocation_completions = []
mock_execution.durable_execution_arn = ""
mock_execution.start_input = Mock()
mock_execution.result = Mock()
Expand Down Expand Up @@ -2399,6 +2409,7 @@ def test_send_callback_heartbeat(executor, mock_store):
mock_operation.status = OperationStatus.STARTED
mock_execution.find_callback_operation.return_value = (0, mock_operation)
mock_execution.updates = [] # No callback options to reset timeout
mock_execution.invocation_completions = []
mock_store.load.return_value = mock_execution

result = executor.send_callback_heartbeat(callback_id)
Expand Down Expand Up @@ -2651,6 +2662,7 @@ def test_schedule_callback_timeouts_no_callback_options(executor, mock_store):
mock_execution = Mock()
mock_execution.find_operation.return_value = (0, operation)
mock_execution.updates = [] # No updates with callback options
mock_execution.invocation_completions = []
mock_store.load.return_value = mock_execution

# Should return early without scheduling
Expand Down
Loading
Loading