diff --git a/src/aws_durable_execution_sdk_python_testing/web/routes.py b/src/aws_durable_execution_sdk_python_testing/web/routes.py index 5a106b9..672f8bc 100644 --- a/src/aws_durable_execution_sdk_python_testing/web/routes.py +++ b/src/aws_durable_execution_sdk_python_testing/web/routes.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +from urllib.parse import unquote from aws_durable_execution_sdk_python_testing.exceptions import ( UnknownRouteError, @@ -444,7 +445,7 @@ def from_route(cls, route: Route) -> CallbackSuccessRoute: return cls( raw_path=route.raw_path, segments=route.segments, - callback_id=route.segments[2], + callback_id=unquote(route.segments[2]), ) @@ -487,7 +488,7 @@ def from_route(cls, route: Route) -> CallbackFailureRoute: return cls( raw_path=route.raw_path, segments=route.segments, - callback_id=route.segments[2], + callback_id=unquote(route.segments[2]), ) @@ -530,7 +531,7 @@ def from_route(cls, route: Route) -> CallbackHeartbeatRoute: return cls( raw_path=route.raw_path, segments=route.segments, - callback_id=route.segments[2], + callback_id=unquote(route.segments[2]), ) diff --git a/tests/web/routes_test.py b/tests/web/routes_test.py index 05429c8..176a8a9 100644 --- a/tests/web/routes_test.py +++ b/tests/web/routes_test.py @@ -4,6 +4,7 @@ import threading import time +from urllib.parse import quote import pytest @@ -1112,3 +1113,60 @@ def worker(worker_id: int): # Check results assert len(errors) == 0, f"Thread safety test failed with errors: {errors}" assert len(results) == 5, f"Expected 5 successful workers, got {len(results)}" + + +def test_callback_routes_url_decoding(): + """Test that callback routes properly URL-decode callback IDs.""" + # Test callback ID with special characters that need URL encoding + callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ==" + encoded_callback_id = quote(callback_id, safe="") + + # Test CallbackSuccessRoute + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed" + ) + success_route = CallbackSuccessRoute.from_route(base_route) + assert success_route.callback_id == callback_id # Should be decoded + + # Test CallbackFailureRoute + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail" + ) + failure_route = CallbackFailureRoute.from_route(base_route) + assert failure_route.callback_id == callback_id # Should be decoded + + # Test CallbackHeartbeatRoute + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat" + ) + heartbeat_route = CallbackHeartbeatRoute.from_route(base_route) + assert heartbeat_route.callback_id == callback_id # Should be decoded + + +def test_router_callback_routes_url_decoding(): + """Test Router properly handles URL-encoded callback IDs.""" + router = Router() + callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ==" + encoded_callback_id = quote(callback_id, safe="") + + # Test success route + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + assert route.callback_id == callback_id # Should be decoded + + # Test failure route + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail", "POST" + ) + assert isinstance(route, CallbackFailureRoute) + assert route.callback_id == callback_id # Should be decoded + + # Test heartbeat route + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat", + "POST", + ) + assert isinstance(route, CallbackHeartbeatRoute) + assert route.callback_id == callback_id # Should be decoded