Skip to content

Commit 1d603c5

Browse files
Alex Wangwangyb-A
authored andcommitted
fix: fix callback serdes
- Add a new passthrough serdes - If the customer does not provide customized serdes for callback handler, use passthrough serdes for callback result because they are not created by sdk, instead, they are created by backend with customer data.
1 parent a04015e commit 1d603c5

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

src/aws_durable_execution_sdk_python/context.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
from aws_durable_execution_sdk_python.operation.wait_for_condition import (
3737
wait_for_condition_handler,
3838
)
39-
from aws_durable_execution_sdk_python.serdes import SerDes, deserialize
39+
from aws_durable_execution_sdk_python.serdes import (
40+
PassThroughSerDes,
41+
SerDes,
42+
deserialize,
43+
)
4044
from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001
4145
from aws_durable_execution_sdk_python.threading import OrderedCounter
4246
from aws_durable_execution_sdk_python.types import (
@@ -66,6 +70,8 @@
6670

6771
logger = logging.getLogger(__name__)
6872

73+
PASS_THROUGH_SERDES: SerDes[Any] = PassThroughSerDes()
74+
6975

7076
def durable_step(
7177
func: Callable[Concatenate[StepContext, Params], T],
@@ -144,7 +150,7 @@ def result(self) -> T | None:
144150
return None # type: ignore
145151

146152
return deserialize(
147-
serdes=self.serdes,
153+
serdes=self.serdes if self.serdes is not None else PASS_THROUGH_SERDES,
148154
data=checkpointed_result.result,
149155
operation_id=self.operation_id,
150156
durable_execution_arn=self.state.durable_execution_arn,

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,14 @@ def is_primitive(obj: Any) -> bool:
372372
return False
373373

374374

375+
class PassThroughSerDes(SerDes[T]):
376+
def serialize(self, value: T, _: SerDesContext) -> str: # noqa: PLR6301
377+
return value # type: ignore
378+
379+
def deserialize(self, data: str, _: SerDesContext) -> T: # noqa: PLR6301
380+
return data # type: ignore
381+
382+
375383
class JsonSerDes(SerDes[T]):
376384
def serialize(self, value: T, _: SerDesContext) -> str: # noqa: PLR6301
377385
return json.dumps(value)

tests/context_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,28 @@ def test_callback_result_succeeded():
7575
callback = Callback("callback1", "op1", mock_state)
7676
result = callback.result()
7777

78+
assert result == '"success_result"'
79+
mock_state.get_checkpoint_result.assert_called_once_with("op1")
80+
81+
82+
def test_callback_result_succeeded_with_plain_str():
83+
"""Test Callback.result() when operation succeeded."""
84+
mock_state = Mock(spec=ExecutionState)
85+
mock_state.durable_execution_arn = "test_arn"
86+
operation = Operation(
87+
operation_id="op1",
88+
operation_type=OperationType.CALLBACK,
89+
status=OperationStatus.SUCCEEDED,
90+
callback_details=CallbackDetails(
91+
callback_id="callback1", result="success_result"
92+
),
93+
)
94+
mock_result = CheckpointedResult.create_from_operation(operation)
95+
mock_state.get_checkpoint_result.return_value = mock_result
96+
97+
callback = Callback("callback1", "op1", mock_state)
98+
result = callback.result()
99+
78100
assert result == "success_result"
79101
mock_state.get_checkpoint_result.assert_called_once_with("op1")
80102

tests/serdes_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
EncodedValue,
2929
ExtendedTypeSerDes,
3030
JsonSerDes,
31+
PassThroughSerDes,
3132
PrimitiveCodec,
3233
SerDes,
3334
SerDesContext,
@@ -737,6 +738,18 @@ def test_extended_serdes_errors():
737738
# endregion
738739

739740

741+
def test_pass_through_serdes():
742+
serdes = PassThroughSerDes()
743+
744+
data = '"name": "test", "value": 123'
745+
serialized = serialize(serdes, data, "test-op", "test-arn")
746+
assert isinstance(serialized, str)
747+
assert serialized == '"name": "test", "value": 123'
748+
# Dict uses envelope format, so roundtrip through deserialize
749+
deserialized = deserialize(serdes, serialized, "test-op", "test-arn")
750+
assert deserialized == data
751+
752+
740753
# region EnvelopeSerDes Performance and Edge Cases
741754
def test_envelope_large_data_structure():
742755
"""Test with reasonably large data."""

0 commit comments

Comments
 (0)