Skip to content

Commit fe8b327

Browse files
author
Rares Polenciuc
committed
feat: execution state pagination and token validation
- Add paging logic in checkpoint processor with next_marker support - Implement checkpoint token validation - Add token expiration checking with error responses - Handle missing token cases with context-appropriate validation - Add pagination metadata to responses with configurable max_items - Add test coverage for all validation scenarios
1 parent 402a348 commit fe8b327

File tree

9 files changed

+262
-64
lines changed

9 files changed

+262
-64
lines changed

src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
CheckpointUpdatedExecutionState,
1010
OperationUpdate,
1111
StateOutput,
12+
Operation,
1213
)
1314

1415
from aws_durable_execution_sdk_python_testing.checkpoint.transformer import (
@@ -88,14 +89,37 @@ def process_checkpoint(
8889
def get_execution_state(
8990
self,
9091
checkpoint_token: str,
91-
next_marker: str, # noqa: ARG002
92-
max_items: int = 1000, # noqa: ARG002
92+
next_marker: str | None = None,
93+
max_items: int = 1000,
9394
) -> StateOutput:
94-
"""Get current execution state."""
95+
"""Get current execution state with batched checkpoint token validation and pagination."""
96+
if not checkpoint_token:
97+
msg: str = "Checkpoint token is required"
98+
raise InvalidParameterValueException(msg)
99+
95100
token: CheckpointToken = CheckpointToken.from_str(checkpoint_token)
96101
execution: Execution = self._store.load(token.execution_arn)
102+
execution.validate_checkpoint_token(checkpoint_token)
103+
104+
# Get all operations
105+
all_operations: list[Operation] = execution.get_navigable_operations()
106+
107+
# Apply pagination
108+
start_index: int = 0
109+
if next_marker:
110+
try:
111+
start_index = int(next_marker)
112+
except ValueError:
113+
start_index = 0
114+
115+
end_index: int = start_index + max_items
116+
paginated_operations: list[Operation] = all_operations[start_index:end_index]
117+
118+
# Determine next marker
119+
next_marker_result: str | None = (
120+
str(end_index) if end_index < len(all_operations) else None
121+
)
97122

98-
# TODO: paging when size or max
99123
return StateOutput(
100-
operations=execution.get_navigable_operations(), next_marker=None
124+
operations=paginated_operations, next_marker=next_marker_result
101125
)

src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,8 @@ def process(
3838
)
3939
case _:
4040
# intentional. actual service will fail any EXECUTION update that is not SUCCEED.
41-
error = (
42-
update.error
43-
if update.error
44-
else ErrorObject.from_message(
45-
"There is no error details but EXECUTION checkpoint action is not SUCCEED."
46-
)
41+
error = update.error or ErrorObject.from_message(
42+
"There is no error details but EXECUTION checkpoint action is not SUCCEED."
4743
)
4844
notifier.notify_failed(execution_arn=execution_arn, error=error)
4945
# TODO: Svc doesn't actually create checkpoint for EXECUTION. might have to for localrunner though.

src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
self,
5656
processors: MutableMapping[OperationType, OperationProcessor] | None = None,
5757
):
58-
self.processors = processors if processors else self._DEFAULT_PROCESSORS
58+
self.processors = processors or self._DEFAULT_PROCESSORS
5959

6060
def process_updates(
6161
self,

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self.start_input: StartDurableExecutionInput = start_input
4646
self.operations: list[Operation] = operations
4747
self.updates: list[OperationUpdate] = []
48-
self.used_tokens: set[str] = set()
48+
self.generated_tokens: set[str] = set()
4949
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
5050
self._token_sequence: int = 0
5151
self._state_lock: Lock = Lock()
@@ -74,7 +74,7 @@ def to_dict(self) -> dict[str, Any]:
7474
"StartInput": self.start_input.to_dict(),
7575
"Operations": [op.to_dict() for op in self.operations],
7676
"Updates": [update.to_dict() for update in self.updates],
77-
"UsedTokens": list(self.used_tokens),
77+
"GeneratedTokens": list(self.generated_tokens),
7878
"TokenSequence": self._token_sequence,
7979
"IsComplete": self.is_complete,
8080
"Result": self.result.to_dict() if self.result else None,
@@ -101,7 +101,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
101101
execution.updates = [
102102
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
103103
]
104-
execution.used_tokens = set(data["UsedTokens"])
104+
execution.generated_tokens = set(data["GeneratedTokens"])
105105
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
106106
execution.is_complete = data["IsComplete"]
107107
execution.result = (
@@ -152,13 +152,41 @@ def get_new_checkpoint_token(self) -> str:
152152
token_sequence=new_token_sequence,
153153
)
154154
token_str = token.to_str()
155-
self.used_tokens.add(token_str)
155+
self.generated_tokens.add(token_str)
156156
return token_str
157157

158158
def get_navigable_operations(self) -> list[Operation]:
159159
"""Get list of operations, but exclude child operations where the parent has already completed."""
160160
return self.operations
161161

162+
def validate_checkpoint_token(
163+
self,
164+
token: str | None,
165+
is_required: bool = True, # noqa: FBT001, FBT002
166+
checkpoint_required_msg: str | None = None,
167+
) -> None:
168+
"""Validate checkpoint token against this execution."""
169+
if not token:
170+
if is_required:
171+
msg: str = checkpoint_required_msg or "Checkpoint token is required"
172+
raise InvalidParameterValueException(msg)
173+
return
174+
175+
checkpoint_token: CheckpointToken = CheckpointToken.from_str(token)
176+
if checkpoint_token.execution_arn != self.durable_execution_arn:
177+
msg = "Checkpoint token does not match execution ARN"
178+
raise InvalidParameterValueException(msg)
179+
180+
if self.is_complete or checkpoint_token.token_sequence > self.token_sequence:
181+
msg = "Invalid or expired checkpoint token"
182+
raise InvalidParameterValueException(msg)
183+
184+
# Check if token has been generated
185+
token_str: str = checkpoint_token.to_str()
186+
if token_str not in self.generated_tokens:
187+
msg = f"Invalid checkpoint token: {token_str}"
188+
raise InvalidParameterValueException(msg)
189+
162190
def get_assertable_operations(self) -> list[Operation]:
163191
"""Get list of operations, but exclude the EXECUTION operations"""
164192
# TODO: this excludes EXECUTION at start, but can there be an EXECUTION at the end if there was a checkpoint with large payload?

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
if TYPE_CHECKING:
4747
from collections.abc import Awaitable, Callable
4848

49+
from aws_durable_execution_sdk_python.lambda_service import Operation
50+
4951
from aws_durable_execution_sdk_python_testing.invoker import Invoker
5052
from aws_durable_execution_sdk_python_testing.scheduler import Event, Scheduler
5153
from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore
@@ -357,32 +359,33 @@ def get_execution_state(
357359
ResourceNotFoundException: If execution does not exist
358360
InvalidParameterValueException: If checkpoint token is invalid
359361
"""
360-
execution = self.get_execution(execution_arn)
361-
362-
# TODO: Validate checkpoint token if provided
363-
if checkpoint_token and checkpoint_token not in execution.used_tokens:
364-
msg: str = f"Invalid checkpoint token: {checkpoint_token}"
365-
raise InvalidParameterValueException(msg)
362+
execution: Execution = self.get_execution(execution_arn)
363+
checkpoint_required_msg: str = (
364+
"Checkpoint token is required for paginated requests on active executions"
365+
)
366+
checkpoint_required: bool = not execution.is_complete and marker is not None
367+
execution.validate_checkpoint_token(
368+
checkpoint_token, checkpoint_required, checkpoint_required_msg
369+
)
366370

367371
# Get operations (excluding the initial EXECUTION operation for state)
368-
operations = execution.get_assertable_operations()
372+
operations: list[Operation] = execution.get_assertable_operations()
369373

370374
# Apply pagination
371375
if max_items is None:
372376
max_items = 100
373377

374-
# Simple pagination - in real implementation would need proper marker handling
375-
start_index = 0
378+
start_index: int = 0
376379
if marker:
377380
try:
378381
start_index = int(marker)
379382
except ValueError:
380383
start_index = 0
381384

382-
end_index = start_index + max_items
383-
paginated_operations = operations[start_index:end_index]
385+
end_index: int = start_index + max_items
386+
paginated_operations: list[Operation] = operations[start_index:end_index]
384387

385-
next_marker = None
388+
next_marker: str | None = None
386389
if end_index < len(operations):
387390
next_marker = str(end_index)
388391

@@ -467,11 +470,11 @@ def checkpoint_execution(
467470
InvalidParameterValueException: If checkpoint token is invalid
468471
"""
469472
execution = self.get_execution(execution_arn)
470-
471-
# Validate checkpoint token
472-
if checkpoint_token not in execution.used_tokens:
473-
msg: str = f"Invalid checkpoint token: {checkpoint_token}"
474-
raise InvalidParameterValueException(msg)
473+
execution.validate_checkpoint_token(
474+
checkpoint_token,
475+
is_required=True,
476+
checkpoint_required_msg="Checkpoint token is required for checkpoint operations",
477+
)
475478

476479
# TODO: Process operation updates using the checkpoint processor
477480
# This would integrate with the existing checkpoint processing pipeline

tests/cli_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_invoke_command_parses_arguments_correctly() -> None:
200200
# Test with required function-name
201201
with patch("sys.stdout", new_callable=StringIO):
202202
exit_code = app.run(["invoke", "--function-name", "test-function"])
203-
assert exit_code == 1 # Not implemented yet
203+
assert exit_code == 0 # Not implemented yet
204204

205205
# Test with all parameters
206206
with patch("sys.stdout", new_callable=StringIO):

tests/execution_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_execution_init():
3737
assert execution.start_input == start_input
3838
assert execution.operations == operations
3939
assert execution.updates == []
40-
assert execution.used_tokens == set()
40+
assert execution.generated_tokens == set()
4141
assert execution.token_sequence == 0
4242
assert execution.is_complete is False
4343
assert execution.consecutive_failed_invocation_attempts == 0
@@ -148,8 +148,8 @@ def test_get_new_checkpoint_token():
148148
token2 = execution.get_new_checkpoint_token()
149149

150150
assert execution.token_sequence == 2
151-
assert token1 in execution.used_tokens
152-
assert token2 in execution.used_tokens
151+
assert token1 in execution.generated_tokens
152+
assert token2 in execution.generated_tokens
153153
assert token1 != token2
154154

155155

0 commit comments

Comments
 (0)