Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/aws_durable_execution_sdk_python_testing/web/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass
from urllib.parse import unquote

from aws_durable_execution_sdk_python_testing.exceptions import (
UnknownRouteError,
Expand Down Expand Up @@ -444,7 +445,7 @@ def from_route(cls, route: Route) -> CallbackSuccessRoute:
return cls(
raw_path=route.raw_path,
segments=route.segments,
callback_id=route.segments[2],
callback_id=unquote(route.segments[2]),
)


Expand Down Expand Up @@ -487,7 +488,7 @@ def from_route(cls, route: Route) -> CallbackFailureRoute:
return cls(
raw_path=route.raw_path,
segments=route.segments,
callback_id=route.segments[2],
callback_id=unquote(route.segments[2]),
)


Expand Down Expand Up @@ -530,7 +531,7 @@ def from_route(cls, route: Route) -> CallbackHeartbeatRoute:
return cls(
raw_path=route.raw_path,
segments=route.segments,
callback_id=route.segments[2],
callback_id=unquote(route.segments[2]),
)


Expand Down
58 changes: 58 additions & 0 deletions tests/web/routes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import threading
import time
from urllib.parse import quote

import pytest

Expand Down Expand Up @@ -1112,3 +1113,60 @@ def worker(worker_id: int):
# Check results
assert len(errors) == 0, f"Thread safety test failed with errors: {errors}"
assert len(results) == 5, f"Expected 5 successful workers, got {len(results)}"


def test_callback_routes_url_decoding():
"""Test that callback routes properly URL-decode callback IDs."""
# Test callback ID with special characters that need URL encoding
callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ=="
encoded_callback_id = quote(callback_id, safe="")

# Test CallbackSuccessRoute
base_route = Route.from_string(
f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed"
)
success_route = CallbackSuccessRoute.from_route(base_route)
assert success_route.callback_id == callback_id # Should be decoded

# Test CallbackFailureRoute
base_route = Route.from_string(
f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail"
)
failure_route = CallbackFailureRoute.from_route(base_route)
assert failure_route.callback_id == callback_id # Should be decoded

# Test CallbackHeartbeatRoute
base_route = Route.from_string(
f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat"
)
heartbeat_route = CallbackHeartbeatRoute.from_route(base_route)
assert heartbeat_route.callback_id == callback_id # Should be decoded


def test_router_callback_routes_url_decoding():
"""Test Router properly handles URL-encoded callback IDs."""
router = Router()
callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ=="
encoded_callback_id = quote(callback_id, safe="")

# Test success route
route = router.find_route(
f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed", "POST"
)
assert isinstance(route, CallbackSuccessRoute)
assert route.callback_id == callback_id # Should be decoded

# Test failure route
route = router.find_route(
f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail", "POST"
)
assert isinstance(route, CallbackFailureRoute)
assert route.callback_id == callback_id # Should be decoded

# Test heartbeat route
route = router.find_route(
f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat",
"POST",
)
assert isinstance(route, CallbackHeartbeatRoute)
assert route.callback_id == callback_id # Should be decoded