Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CheckpointUpdatedExecutionState,
OperationUpdate,
StateOutput,
Operation,
)

from aws_durable_execution_sdk_python_testing.checkpoint.transformer import (
Expand Down Expand Up @@ -88,14 +89,37 @@ def process_checkpoint(
def get_execution_state(
self,
checkpoint_token: str,
next_marker: str, # noqa: ARG002
max_items: int = 1000, # noqa: ARG002
next_marker: str | None = None,
max_items: int = 1000,
) -> StateOutput:
"""Get current execution state."""
"""Get current execution state with batched checkpoint token validation and pagination."""
if not checkpoint_token:
msg: str = "Checkpoint token is required"
raise InvalidParameterValueException(msg)

token: CheckpointToken = CheckpointToken.from_str(checkpoint_token)
execution: Execution = self._store.load(token.execution_arn)
execution.validate_checkpoint_token(checkpoint_token)

# Get all operations
all_operations: list[Operation] = execution.get_navigable_operations()

# Apply pagination
start_index: int = 0
if next_marker:
try:
start_index = int(next_marker)
except ValueError:
start_index = 0

end_index: int = start_index + max_items
paginated_operations: list[Operation] = all_operations[start_index:end_index]

# Determine next marker
next_marker_result: str | None = (
str(end_index) if end_index < len(all_operations) else None
)

# TODO: paging when size or max
return StateOutput(
operations=execution.get_navigable_operations(), next_marker=None
operations=paginated_operations, next_marker=next_marker_result
)
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,8 @@ def process(
)
case _:
# intentional. actual service will fail any EXECUTION update that is not SUCCEED.
error = (
update.error
if update.error
else ErrorObject.from_message(
"There is no error details but EXECUTION checkpoint action is not SUCCEED."
)
error = update.error or ErrorObject.from_message(
"There is no error details but EXECUTION checkpoint action is not SUCCEED."
)
# All EXECUTION failures go through normal fail path
# Timeout/Stop status is set by executor based on the operation that caused it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self,
processors: MutableMapping[OperationType, OperationProcessor] | None = None,
):
self.processors = processors if processors else self._DEFAULT_PROCESSORS
self.processors = processors or self._DEFAULT_PROCESSORS

def process_updates(
self,
Expand Down
33 changes: 29 additions & 4 deletions src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.start_input: StartDurableExecutionInput = start_input
self.operations: list[Operation] = operations
self.updates: list[OperationUpdate] = []
self.used_tokens: set[str] = set()
self.generated_tokens: set[str] = set()
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
self._token_sequence: int = 0
self._state_lock: Lock = Lock()
Expand Down Expand Up @@ -101,7 +101,7 @@ def to_dict(self) -> dict[str, Any]:
"StartInput": self.start_input.to_dict(),
"Operations": [op.to_dict() for op in self.operations],
"Updates": [update.to_dict() for update in self.updates],
"UsedTokens": list(self.used_tokens),
"GeneratedTokens": list(self.generated_tokens),
"TokenSequence": self._token_sequence,
"IsComplete": self.is_complete,
"Result": self.result.to_dict() if self.result else None,
Expand Down Expand Up @@ -129,7 +129,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
execution.updates = [
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
]
execution.used_tokens = set(data["UsedTokens"])
execution.generated_tokens = set(data["GeneratedTokens"])
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
execution.is_complete = data["IsComplete"]
execution.result = (
Expand Down Expand Up @@ -184,13 +184,38 @@ def get_new_checkpoint_token(self) -> str:
token_sequence=new_token_sequence,
)
token_str = token.to_str()
self.used_tokens.add(token_str)
self.generated_tokens.add(token_str)
return token_str

def get_navigable_operations(self) -> list[Operation]:
"""Get list of operations, but exclude child operations where the parent has already completed."""
return self.operations

def validate_checkpoint_token(
self,
token: str | None,
checkpoint_required_msg: str | None = None,
) -> None:
"""Validate checkpoint token against this execution."""
if not token:
msg: str = checkpoint_required_msg or "Checkpoint token is required"
raise InvalidParameterValueException(msg)

checkpoint_token: CheckpointToken = CheckpointToken.from_str(token)
if checkpoint_token.execution_arn != self.durable_execution_arn:
msg = "Checkpoint token does not match execution ARN"
raise InvalidParameterValueException(msg)

if self.is_complete or checkpoint_token.token_sequence > self.token_sequence:
msg = "Invalid or expired checkpoint token"
raise InvalidParameterValueException(msg)

# Check if token has been generated
token_str: str = checkpoint_token.to_str()
if token_str not in self.generated_tokens:
msg = f"Invalid checkpoint token: {token_str}"
raise InvalidParameterValueException(msg)

def get_assertable_operations(self) -> list[Operation]:
"""Get list of operations, but exclude the EXECUTION operations"""
# TODO: this excludes EXECUTION at start, but can there be an EXECUTION at the end if there was a checkpoint with large payload?
Expand Down
34 changes: 18 additions & 16 deletions src/aws_durable_execution_sdk_python_testing/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable

from aws_durable_execution_sdk_python.lambda_service import Operation

from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
CheckpointProcessor,
)
Expand Down Expand Up @@ -347,32 +349,33 @@ def get_execution_state(
ResourceNotFoundException: If execution does not exist
InvalidParameterValueException: If checkpoint token is invalid
"""
execution = self.get_execution(execution_arn)
execution: Execution = self.get_execution(execution_arn)
is_checkpoint_required: bool = not execution.is_complete and marker is not None

# TODO: Validate checkpoint token if provided
if checkpoint_token and checkpoint_token not in execution.used_tokens:
msg: str = f"Invalid checkpoint token: {checkpoint_token}"
raise InvalidParameterValueException(msg)
if is_checkpoint_required or checkpoint_token:
checkpoint_required_msg: str = "Checkpoint token is required for paginated requests on active executions"
execution.validate_checkpoint_token(
checkpoint_token, checkpoint_required_msg
)

# Get operations (excluding the initial EXECUTION operation for state)
operations = execution.get_assertable_operations()
operations: list[Operation] = execution.get_assertable_operations()

# Apply pagination
if max_items is None:
max_items = 100

# Simple pagination - in real implementation would need proper marker handling
start_index = 0
start_index: int = 0
if marker:
try:
start_index = int(marker)
except ValueError:
start_index = 0

end_index = start_index + max_items
paginated_operations = operations[start_index:end_index]
end_index: int = start_index + max_items
paginated_operations: list[Operation] = operations[start_index:end_index]

next_marker = None
next_marker: str | None = None
if end_index < len(operations):
next_marker = str(end_index)

Expand Down Expand Up @@ -541,11 +544,10 @@ def checkpoint_execution(
InvalidParameterValueException: If checkpoint token is invalid
"""
execution = self.get_execution(execution_arn)

# Validate checkpoint token
if checkpoint_token not in execution.used_tokens:
msg: str = f"Invalid checkpoint token: {checkpoint_token}"
raise InvalidParameterValueException(msg)
execution.validate_checkpoint_token(
checkpoint_token,
checkpoint_required_msg="Checkpoint token is required for checkpoint operations",
)

if updates:
checkpoint_output = self._checkpoint_processor.process_checkpoint(
Expand Down
8 changes: 4 additions & 4 deletions tests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_execution_init():
assert execution.start_input == start_input
assert execution.operations == operations
assert execution.updates == []
assert execution.used_tokens == set()
assert execution.generated_tokens == set()
assert execution.token_sequence == 0
assert execution.is_complete is False
assert execution.consecutive_failed_invocation_attempts == 0
Expand Down Expand Up @@ -154,8 +154,8 @@ def test_get_new_checkpoint_token():
token2 = execution.get_new_checkpoint_token()

assert execution.token_sequence == 2
assert token1 in execution.used_tokens
assert token2 in execution.used_tokens
assert token1 in execution.generated_tokens
assert token2 in execution.generated_tokens
assert token1 != token2


Expand Down Expand Up @@ -801,7 +801,7 @@ def test_from_dict_with_none_result():
"StartInput": {"function_name": "test"},
"Operations": [],
"Updates": [],
"UsedTokens": [],
"GeneratedTokens": [],
"TokenSequence": 0,
"IsComplete": False,
"Result": None, # None result
Expand Down
Loading
Loading