Skip to content

Commit 900bb3d

Browse files
author
Rares Polenciuc
committed
feat: execution state pagination and token validation
- Add paging logic in checkpoint processor with next_marker support - Implement checkpoint token validation - Add token expiration checking with error responses - Handle missing token cases with context-appropriate validation - Add pagination metadata to responses with configurable max_items - Add test coverage for all validation scenarios
1 parent 6954160 commit 900bb3d

File tree

5 files changed

+222
-25
lines changed

5 files changed

+222
-25
lines changed

src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,35 @@ def process_checkpoint(
8888
def get_execution_state(
8989
self,
9090
checkpoint_token: str,
91-
next_marker: str, # noqa: ARG002
92-
max_items: int = 1000, # noqa: ARG002
91+
next_marker: str | None = None,
92+
max_items: int = 1000,
9393
) -> StateOutput:
94-
"""Get current execution state."""
94+
"""Get current execution state with batched checkpoint token validation and pagination."""
95+
if not checkpoint_token:
96+
msg: str = "Checkpoint token is required"
97+
raise InvalidParameterValueException(msg)
98+
9599
token: CheckpointToken = CheckpointToken.from_str(checkpoint_token)
96100
execution: Execution = self._store.load(token.execution_arn)
101+
token.validate_for_execution(execution)
102+
103+
# Get all operations
104+
all_operations = execution.get_navigable_operations()
105+
106+
# Apply pagination
107+
start_index = 0
108+
if next_marker:
109+
try:
110+
start_index = int(next_marker)
111+
except ValueError:
112+
start_index = 0
113+
114+
end_index = start_index + max_items
115+
paginated_operations = all_operations[start_index:end_index]
116+
117+
# Determine next marker
118+
next_marker_result = str(end_index) if end_index < len(all_operations) else None
97119

98-
# TODO: paging when size or max
99120
return StateOutput(
100-
operations=execution.get_navigable_operations(), next_marker=None
121+
operations=paginated_operations, next_marker=next_marker_result
101122
)

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Execution as ExecutionSummary,
4242
)
4343
from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver
44+
from aws_durable_execution_sdk_python_testing.token import CheckpointToken
4445

4546

4647
if TYPE_CHECKING:
@@ -359,9 +360,12 @@ def get_execution_state(
359360
"""
360361
execution = self.get_execution(execution_arn)
361362

362-
# TODO: Validate checkpoint token if provided
363-
if checkpoint_token and checkpoint_token not in execution.used_tokens:
364-
msg: str = f"Invalid checkpoint token: {checkpoint_token}"
363+
# Checkpoint token validation
364+
if checkpoint_token:
365+
token = CheckpointToken.from_str(checkpoint_token)
366+
token.validate_for_execution(execution)
367+
elif not execution.is_complete and marker is not None:
368+
msg = "Checkpoint token is required for paginated requests on active executions"
365369
raise InvalidParameterValueException(msg)
366370

367371
# Get operations (excluding the initial EXECUTION operation for state)
@@ -371,7 +375,6 @@ def get_execution_state(
371375
if max_items is None:
372376
max_items = 100
373377

374-
# Simple pagination - in real implementation would need proper marker handling
375378
start_index = 0
376379
if marker:
377380
try:
@@ -468,11 +471,16 @@ def checkpoint_execution(
468471
"""
469472
execution = self.get_execution(execution_arn)
470473

471-
# Validate checkpoint token
472-
if checkpoint_token not in execution.used_tokens:
473-
msg: str = f"Invalid checkpoint token: {checkpoint_token}"
474+
# Comprehensive checkpoint token validation
475+
from aws_durable_execution_sdk_python_testing.token import CheckpointToken
476+
477+
if not checkpoint_token:
478+
msg = "Checkpoint token is required for checkpoint operations"
474479
raise InvalidParameterValueException(msg)
475480

481+
token = CheckpointToken.from_str(checkpoint_token)
482+
token.validate_for_execution(execution)
483+
476484
# TODO: Process operation updates using the checkpoint processor
477485
# This would integrate with the existing checkpoint processing pipeline
478486

src/aws_durable_execution_sdk_python_testing/token.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import json
77
from dataclasses import dataclass
88

9+
from aws_durable_execution_sdk_python_testing.exceptions import (
10+
InvalidParameterValueException,
11+
)
12+
913

1014
@dataclass(frozen=True)
1115
class CheckpointToken:
@@ -27,6 +31,23 @@ def from_str(cls, token: str) -> CheckpointToken:
2731
data = json.loads(decoded)
2832
return cls(execution_arn=data["arn"], token_sequence=data["seq"])
2933

34+
def validate_for_execution(self, execution) -> None:
35+
"""Validate token against execution"""
36+
37+
if self.execution_arn != execution.durable_execution_arn:
38+
msg = "Checkpoint token does not match execution ARN"
39+
raise InvalidParameterValueException(msg)
40+
41+
if execution.is_complete or self.token_sequence > execution.token_sequence:
42+
msg = "Invalid or expired checkpoint token"
43+
raise InvalidParameterValueException(msg)
44+
45+
# Check if token has been used
46+
token_str = self.to_str()
47+
if token_str not in execution.used_tokens:
48+
msg = f"Invalid checkpoint token: {token_str}"
49+
raise InvalidParameterValueException(msg)
50+
3051

3152
@dataclass(frozen=True)
3253
class CallbackToken:

tests/executor_test.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
StartDurableExecutionInput,
3434
)
3535
from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver
36+
from aws_durable_execution_sdk_python_testing.token import CheckpointToken
3637

3738

3839
class MockExecutionObserver(ExecutionObserver):
@@ -1889,9 +1890,15 @@ def test_stop_execution_with_custom_error(executor, mock_store):
18891890

18901891
def test_get_execution_state(executor, mock_store):
18911892
"""Test get_execution_state method."""
1892-
18931893
mock_execution = Mock()
1894-
mock_execution.used_tokens = {"token1", "token2"}
1894+
mock_execution.durable_execution_arn = "test-arn"
1895+
mock_execution.token_sequence = 5
1896+
mock_execution.is_complete = False
1897+
1898+
# Create valid token and add to used_tokens
1899+
token = CheckpointToken("test-arn", 3)
1900+
valid_token = token.to_str()
1901+
mock_execution.used_tokens = {valid_token}
18951902

18961903
# Create mock operations
18971904
operations = [
@@ -1916,7 +1923,7 @@ def test_get_execution_state(executor, mock_store):
19161923

19171924
mock_store.load.return_value = mock_execution
19181925

1919-
result = executor.get_execution_state("test-arn", checkpoint_token="token1") # noqa: S106
1926+
result = executor.get_execution_state("test-arn", checkpoint_token=valid_token)
19201927

19211928
assert len(result.operations) == 2
19221929
assert result.next_marker is None
@@ -1928,11 +1935,13 @@ def test_get_execution_state_invalid_token(executor, mock_store):
19281935
mock_execution = Mock()
19291936
mock_execution.used_tokens = {"token1", "token2"}
19301937
mock_store.load.return_value = mock_execution
1931-
1938+
token = CheckpointToken("invalid-token", 3)
1939+
invalid_token = token.to_str()
19321940
with pytest.raises(
1933-
InvalidParameterValueException, match="Invalid checkpoint token"
1941+
InvalidParameterValueException,
1942+
match="Checkpoint token does not match execution ARN",
19341943
):
1935-
executor.get_execution_state("test-arn", checkpoint_token="invalid-token") # noqa: S106
1944+
executor.get_execution_state("test-arn", checkpoint_token=invalid_token)
19361945

19371946

19381947
def test_get_execution_history(executor, mock_store):
@@ -1950,13 +1959,22 @@ def test_get_execution_history(executor, mock_store):
19501959
def test_checkpoint_execution(executor, mock_store):
19511960
"""Test checkpoint_execution method."""
19521961
mock_execution = Mock()
1953-
mock_execution.used_tokens = {"token1", "token2"}
1954-
mock_execution.get_new_checkpoint_token.return_value = "new-token"
1962+
mock_execution.durable_execution_arn = "test-arn"
1963+
mock_execution.token_sequence = 5
1964+
mock_execution.is_complete = False
1965+
new_token = "new-token" # noqa:S105
1966+
mock_execution.get_new_checkpoint_token.return_value = new_token
1967+
1968+
# Create valid token and add to used_tokens
1969+
token = CheckpointToken("test-arn", 3)
1970+
valid_token = token.to_str()
1971+
mock_execution.used_tokens = {valid_token}
1972+
19551973
mock_store.load.return_value = mock_execution
19561974

1957-
result = executor.checkpoint_execution("test-arn", "token1")
1975+
result = executor.checkpoint_execution("test-arn", valid_token)
19581976

1959-
assert result.checkpoint_token == "new-token" # noqa: S105
1977+
assert result.checkpoint_token == new_token
19601978
assert result.new_execution_state is None
19611979
mock_store.load.assert_called_once_with("test-arn")
19621980
mock_execution.get_new_checkpoint_token.assert_called_once()
@@ -1967,11 +1985,13 @@ def test_checkpoint_execution_invalid_token(executor, mock_store):
19671985
mock_execution = Mock()
19681986
mock_execution.used_tokens = {"token1", "token2"}
19691987
mock_store.load.return_value = mock_execution
1970-
1988+
token = CheckpointToken("invalid-token", 3)
1989+
invalid_token = token.to_str()
19711990
with pytest.raises(
1972-
InvalidParameterValueException, match="Invalid checkpoint token"
1991+
InvalidParameterValueException,
1992+
match="Checkpoint token does not match execution ARN",
19731993
):
1974-
executor.checkpoint_execution("test-arn", "invalid-token")
1994+
executor.checkpoint_execution("test-arn", invalid_token)
19751995

19761996

19771997
# Callback method tests
@@ -2050,3 +2070,35 @@ def test_send_callback_heartbeat_none_callback_id(executor):
20502070
"""Test send_callback_heartbeat with None callback_id."""
20512071
with pytest.raises(InvalidParameterValueException, match="callback_id is required"):
20522072
executor.send_callback_heartbeat(None)
2073+
2074+
2075+
def test_get_execution_state_no_token_with_marker_active_execution(
2076+
mock_store, mock_scheduler, mock_invoker
2077+
):
2078+
"""Test get_execution_state fails when no token provided with marker on active execution."""
2079+
executor = Executor(mock_store, mock_scheduler, mock_invoker)
2080+
execution_arn = "test-arn"
2081+
2082+
# Create an active execution
2083+
execution = Execution(execution_arn, "test-function", {})
2084+
execution.is_complete = False
2085+
mock_store.load.return_value = execution
2086+
2087+
with pytest.raises(
2088+
InvalidParameterValueException, match="Checkpoint token is required"
2089+
):
2090+
executor.get_execution_state(execution_arn, marker="some-marker")
2091+
2092+
2093+
def test_checkpoint_execution_no_token(mock_store, mock_scheduler, mock_invoker):
2094+
"""Test checkpoint_execution fails when no token provided."""
2095+
executor = Executor(mock_store, mock_scheduler, mock_invoker)
2096+
execution_arn = "test-arn"
2097+
2098+
execution = Execution(execution_arn, "test-function", {})
2099+
mock_store.load.return_value = execution
2100+
2101+
with pytest.raises(
2102+
InvalidParameterValueException, match="Checkpoint token is required"
2103+
):
2104+
executor.checkpoint_execution(execution_arn, "", [], "client-token")

tests/token_test.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
import base64
44
import json
5+
from unittest.mock import Mock
56

67
import pytest
78

9+
from aws_durable_execution_sdk_python_testing.exceptions import (
10+
InvalidParameterValueException,
11+
)
812
from aws_durable_execution_sdk_python_testing.token import (
913
CallbackToken,
1014
CheckpointToken,
@@ -130,3 +134,94 @@ def test_callback_token_frozen_dataclass():
130134

131135
with pytest.raises(AttributeError):
132136
token.operation_id = "new-op"
137+
138+
139+
def test_checkpoint_token_validate_for_execution_success():
140+
"""Test successful token validation."""
141+
token = CheckpointToken("test-arn", 5)
142+
execution = Mock()
143+
execution.durable_execution_arn = "test-arn"
144+
execution.token_sequence = 10
145+
execution.is_complete = False
146+
execution.used_tokens = {token.to_str()}
147+
148+
token.validate_for_execution(execution)
149+
150+
151+
def test_checkpoint_token_validate_for_execution_arn_mismatch():
152+
"""Test token validation fails when ARN doesn't match."""
153+
token = CheckpointToken("test-arn", 5)
154+
execution = Mock()
155+
execution.durable_execution_arn = "different-arn"
156+
execution.token_sequence = 10
157+
execution.is_complete = False
158+
159+
with pytest.raises(
160+
InvalidParameterValueException, match="does not match execution ARN"
161+
):
162+
token.validate_for_execution(execution)
163+
164+
165+
def test_checkpoint_token_validate_for_execution_completed():
166+
"""Test token validation fails when execution is complete."""
167+
token = CheckpointToken("test-arn", 5)
168+
execution = Mock()
169+
execution.durable_execution_arn = "test-arn"
170+
execution.token_sequence = 10
171+
execution.is_complete = True
172+
173+
with pytest.raises(InvalidParameterValueException, match="Invalid or expired"):
174+
token.validate_for_execution(execution)
175+
176+
177+
def test_checkpoint_token_validate_for_execution_future_sequence():
178+
"""Test token validation fails when token sequence is from future."""
179+
token = CheckpointToken("test-arn", 15)
180+
execution = Mock()
181+
execution.durable_execution_arn = "test-arn"
182+
execution.token_sequence = 10
183+
execution.is_complete = False
184+
185+
with pytest.raises(InvalidParameterValueException, match="Invalid or expired"):
186+
token.validate_for_execution(execution)
187+
188+
189+
def test_checkpoint_token_validate_for_execution_equal_sequence():
190+
"""Test token validation succeeds when sequences are equal."""
191+
token = CheckpointToken("test-arn", 10)
192+
execution = Mock()
193+
execution.durable_execution_arn = "test-arn"
194+
execution.token_sequence = 10
195+
execution.is_complete = False
196+
execution.used_tokens = {token.to_str()}
197+
198+
token.validate_for_execution(execution)
199+
200+
201+
def test_checkpoint_token_validate_for_execution_not_in_used_tokens():
202+
"""Test token validation fails when token not in used_tokens."""
203+
token = CheckpointToken("test-arn", 5)
204+
execution = Mock()
205+
execution.durable_execution_arn = "test-arn"
206+
execution.token_sequence = 10
207+
execution.is_complete = False
208+
execution.used_tokens = {"other-token"}
209+
210+
with pytest.raises(
211+
InvalidParameterValueException, match="Invalid checkpoint token"
212+
):
213+
token.validate_for_execution(execution)
214+
215+
216+
def test_checkpoint_token_validate_for_execution_in_used_tokens():
217+
"""Test token validation succeeds when token is in used_tokens."""
218+
token = CheckpointToken("test-arn", 5)
219+
execution = Mock()
220+
execution.durable_execution_arn = "test-arn"
221+
execution.token_sequence = 10
222+
execution.is_complete = False
223+
# Mock the token string that would be generated
224+
token_str = token.to_str()
225+
execution.used_tokens = {token_str}
226+
227+
token.validate_for_execution(execution)

0 commit comments

Comments
 (0)