Skip to content

Commit ced4570

Browse files
authored
Map AioRpcError to Cadence Error Types (#22)
1 parent 2fae543 commit ced4570

File tree

6 files changed

+339
-0
lines changed

6 files changed

+339
-0
lines changed

cadence/_internal/rpc/error.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import Callable, Any, Optional, Generator, TypeVar
2+
3+
import grpc
4+
from google.rpc.status_pb2 import Status # type: ignore
5+
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails, AioRpcError, UnaryUnaryCall, Metadata
6+
from grpc_status.rpc_status import from_call # type: ignore
7+
8+
from cadence.api.v1 import error_pb2
9+
from cadence import error
10+
11+
12+
RequestType = TypeVar("RequestType")
13+
ResponseType = TypeVar("ResponseType")
14+
DoneCallbackType = Callable[[Any], None]
15+
16+
17+
# A UnaryUnaryCall is an awaitable type returned by GRPC's aio support.
18+
# We need to take the UnaryUnaryCall we receive and return one that remaps the exception.
19+
# It doesn't have any functions to compose operations together, so our only option is to wrap it.
20+
# If the interceptor directly throws an exception other than AioRpcError it breaks GRPC
21+
class CadenceErrorUnaryUnaryCall(UnaryUnaryCall[RequestType, ResponseType]):
22+
23+
def __init__(self, wrapped: UnaryUnaryCall[RequestType, ResponseType]):
24+
super().__init__()
25+
self._wrapped = wrapped
26+
27+
def __await__(self) -> Generator[Any, None, ResponseType]:
28+
try:
29+
response = yield from self._wrapped.__await__() # type: ResponseType
30+
return response
31+
except AioRpcError as e:
32+
raise map_error(e)
33+
34+
async def initial_metadata(self) -> Metadata:
35+
return await self._wrapped.initial_metadata()
36+
37+
async def trailing_metadata(self) -> Metadata:
38+
return await self._wrapped.trailing_metadata()
39+
40+
async def code(self) -> grpc.StatusCode:
41+
return await self._wrapped.code()
42+
43+
async def details(self) -> str:
44+
return await self._wrapped.details() # type: ignore
45+
46+
async def wait_for_connection(self) -> None:
47+
await self._wrapped.wait_for_connection()
48+
49+
def cancelled(self) -> bool:
50+
return self._wrapped.cancelled() # type: ignore
51+
52+
def done(self) -> bool:
53+
return self._wrapped.done() # type: ignore
54+
55+
def time_remaining(self) -> Optional[float]:
56+
return self._wrapped.time_remaining() # type: ignore
57+
58+
def cancel(self) -> bool:
59+
return self._wrapped.cancel() # type: ignore
60+
61+
def add_done_callback(self, callback: DoneCallbackType) -> None:
62+
self._wrapped.add_done_callback(callback)
63+
64+
65+
class CadenceErrorInterceptor(UnaryUnaryClientInterceptor):
66+
67+
async def intercept_unary_unary(
68+
self,
69+
continuation: Callable[[ClientCallDetails, Any], Any],
70+
client_call_details: ClientCallDetails,
71+
request: Any
72+
) -> Any:
73+
rpc_call = await continuation(client_call_details, request)
74+
return CadenceErrorUnaryUnaryCall(rpc_call)
75+
76+
77+
78+
79+
def map_error(e: AioRpcError) -> error.CadenceError:
80+
status: Status | None = from_call(e)
81+
if not status or not status.details:
82+
return error.CadenceError(e.details(), e.code())
83+
84+
details = status.details[0]
85+
if details.Is(error_pb2.WorkflowExecutionAlreadyStartedError.DESCRIPTOR):
86+
already_started = error_pb2.WorkflowExecutionAlreadyStartedError()
87+
details.Unpack(already_started)
88+
return error.WorkflowExecutionAlreadyStartedError(e.details(), e.code(), already_started.start_request_id, already_started.run_id)
89+
elif details.Is(error_pb2.EntityNotExistsError.DESCRIPTOR):
90+
not_exists = error_pb2.EntityNotExistsError()
91+
details.Unpack(not_exists)
92+
return error.EntityNotExistsError(e.details(), e.code(), not_exists.current_cluster, not_exists.active_cluster, list(not_exists.active_clusters))
93+
elif details.Is(error_pb2.WorkflowExecutionAlreadyCompletedError.DESCRIPTOR):
94+
return error.WorkflowExecutionAlreadyCompletedError(e.details(), e.code())
95+
elif details.Is(error_pb2.DomainNotActiveError.DESCRIPTOR):
96+
not_active = error_pb2.DomainNotActiveError()
97+
details.Unpack(not_active)
98+
return error.DomainNotActiveError(e.details(), e.code(), not_active.domain, not_active.current_cluster, not_active.active_cluster, list(not_active.active_clusters))
99+
elif details.Is(error_pb2.ClientVersionNotSupportedError.DESCRIPTOR):
100+
not_supported = error_pb2.ClientVersionNotSupportedError()
101+
details.Unpack(not_supported)
102+
return error.ClientVersionNotSupportedError(e.details(), e.code(), not_supported.feature_version, not_supported.client_impl, not_supported.supported_versions)
103+
elif details.Is(error_pb2.FeatureNotEnabledError.DESCRIPTOR):
104+
not_enabled = error_pb2.FeatureNotEnabledError()
105+
details.Unpack(not_enabled)
106+
return error.FeatureNotEnabledError(e.details(), e.code(), not_enabled.feature_flag)
107+
elif details.Is(error_pb2.CancellationAlreadyRequestedError.DESCRIPTOR):
108+
return error.CancellationAlreadyRequestedError(e.details(), e.code())
109+
elif details.Is(error_pb2.DomainAlreadyExistsError.DESCRIPTOR):
110+
return error.DomainAlreadyExistsError(e.details(), e.code())
111+
elif details.Is(error_pb2.LimitExceededError.DESCRIPTOR):
112+
return error.LimitExceededError(e.details(), e.code())
113+
elif details.Is(error_pb2.QueryFailedError.DESCRIPTOR):
114+
return error.QueryFailedError(e.details(), e.code())
115+
elif details.Is(error_pb2.ServiceBusyError.DESCRIPTOR):
116+
service_busy = error_pb2.ServiceBusyError()
117+
details.Unpack(service_busy)
118+
return error.ServiceBusyError(e.details(), e.code(), service_busy.reason)
119+
else:
120+
return error.CadenceError(e.details(), e.code())
121+

cadence/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from grpc import ChannelCredentials, Compression
66

7+
from cadence._internal.rpc.error import CadenceErrorInterceptor
78
from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor
89
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
910
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
@@ -75,6 +76,7 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
7576

7677
def _create_channel(options: ClientOptions) -> Channel:
7778
interceptors = list(options["interceptors"])
79+
interceptors.append(CadenceErrorInterceptor())
7880
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))
7981

8082
if options["credentials"]:

cadence/error.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import grpc
2+
3+
4+
class CadenceError(Exception):
5+
6+
def __init__(self, message: str, code: grpc.StatusCode, *args):
7+
super().__init__(message, code, *args)
8+
self.code = code
9+
pass
10+
11+
12+
class WorkflowExecutionAlreadyStartedError(CadenceError):
13+
14+
def __init__(self, message: str, code: grpc.StatusCode, start_request_id: str, run_id: str) -> None:
15+
super().__init__(message, code, start_request_id, run_id)
16+
self.start_request_id = start_request_id
17+
self.run_id = run_id
18+
19+
class EntityNotExistsError(CadenceError):
20+
21+
def __init__(self, message: str, code: grpc.StatusCode, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None:
22+
super().__init__(message, code, current_cluster, active_cluster, active_clusters)
23+
self.current_cluster = current_cluster
24+
self.active_cluster = active_cluster
25+
self.active_clusters = active_clusters
26+
27+
class WorkflowExecutionAlreadyCompletedError(CadenceError):
28+
pass
29+
30+
class DomainNotActiveError(CadenceError):
31+
def __init__(self, message: str, code: grpc.StatusCode, domain: str, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None:
32+
super().__init__(message, code, domain, current_cluster, active_cluster, active_clusters)
33+
self.domain = domain
34+
self.current_cluster = current_cluster
35+
self.active_cluster = active_cluster
36+
self.active_clusters = active_clusters
37+
38+
class ClientVersionNotSupportedError(CadenceError):
39+
def __init__(self, message: str, code: grpc.StatusCode, feature_version: str, client_impl: str, supported_versions: str) -> None:
40+
super().__init__(message, code, feature_version, client_impl, supported_versions)
41+
self.feature_version = feature_version
42+
self.client_impl = client_impl
43+
self.supported_versions = supported_versions
44+
45+
class FeatureNotEnabledError(CadenceError):
46+
def __init__(self, message: str, code: grpc.StatusCode, feature_flag: str) -> None:
47+
super().__init__(message, code, feature_flag)
48+
self.feature_flag = feature_flag
49+
50+
class CancellationAlreadyRequestedError(CadenceError):
51+
pass
52+
53+
class DomainAlreadyExistsError(CadenceError):
54+
pass
55+
56+
class LimitExceededError(CadenceError):
57+
pass
58+
59+
class QueryFailedError(CadenceError):
60+
pass
61+
62+
class ServiceBusyError(CadenceError):
63+
def __init__(self, message: str, code: grpc.StatusCode, reason: str) -> None:
64+
super().__init__(message, code, reason)
65+
self.reason = reason

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727
requires-python = ">=3.11,<3.14"
2828
dependencies = [
2929
"grpcio==1.71.2",
30+
"grpcio-status>=1.71.2",
3031
"msgspec>=0.19.0",
3132
"protobuf==5.29.1",
3233
"typing-extensions>=4.0.0",
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from concurrent import futures
2+
3+
import pytest
4+
from google.protobuf import any_pb2
5+
from google.rpc import code_pb2, status_pb2
6+
from grpc import Status, StatusCode, server
7+
from grpc.aio import insecure_channel
8+
from grpc_status.rpc_status import to_status
9+
10+
from cadence._internal.rpc.error import CadenceErrorInterceptor
11+
from cadence.api.v1 import error_pb2, service_meta_pb2_grpc
12+
from cadence import error
13+
from google.protobuf.message import Message
14+
15+
from cadence.api.v1.service_meta_pb2 import HealthRequest, HealthResponse
16+
from cadence.error import CadenceError
17+
18+
19+
class FakeService(service_meta_pb2_grpc.MetaAPIServicer):
20+
def __init__(self) -> None:
21+
super().__init__()
22+
self.status: Status | None = None
23+
self.port: int | None = None
24+
25+
def Health(self, request, context):
26+
if temp := self.status:
27+
self.status = None
28+
context.abort_with_status(temp)
29+
return HealthResponse(ok=True)
30+
31+
32+
@pytest.fixture(scope="module")
33+
def fake_service():
34+
fake = FakeService()
35+
sync_server = server(futures.ThreadPoolExecutor(max_workers=1))
36+
service_meta_pb2_grpc.add_MetaAPIServicer_to_server(fake, sync_server)
37+
fake.port = sync_server.add_insecure_port("[::]:0")
38+
sync_server.start()
39+
yield fake
40+
sync_server.stop(grace=None)
41+
42+
@pytest.mark.usefixtures("fake_service")
43+
@pytest.mark.parametrize(
44+
"err,expected",
45+
[
46+
pytest.param(None, None,id="no error"),
47+
pytest.param(
48+
error_pb2.WorkflowExecutionAlreadyStartedError(start_request_id="start_request", run_id="run_id"),
49+
error.WorkflowExecutionAlreadyStartedError(message="message", code=StatusCode.INVALID_ARGUMENT, start_request_id="start_request", run_id="run_id"),
50+
id="WorkflowExecutionAlreadyStartedError"),
51+
pytest.param(
52+
error_pb2.EntityNotExistsError(current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
53+
error.EntityNotExistsError(message="message", code=StatusCode.INVALID_ARGUMENT, current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
54+
id="EntityNotExistsError"),
55+
pytest.param(
56+
error_pb2.WorkflowExecutionAlreadyCompletedError(),
57+
error.WorkflowExecutionAlreadyCompletedError(message="message", code=StatusCode.INVALID_ARGUMENT),
58+
id="WorkflowExecutionAlreadyCompletedError"),
59+
pytest.param(
60+
error_pb2.DomainNotActiveError(domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
61+
error.DomainNotActiveError(message="message", code=StatusCode.INVALID_ARGUMENT, domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
62+
id="DomainNotActiveError"),
63+
pytest.param(
64+
error_pb2.ClientVersionNotSupportedError(feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"),
65+
error.ClientVersionNotSupportedError(message="message", code=StatusCode.INVALID_ARGUMENT, feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"),
66+
id="ClientVersionNotSupportedError"),
67+
pytest.param(
68+
error_pb2.FeatureNotEnabledError(feature_flag="feature_flag"),
69+
error.FeatureNotEnabledError(message="message", code=StatusCode.INVALID_ARGUMENT,feature_flag="feature_flag"),
70+
id="FeatureNotEnabledError"),
71+
pytest.param(
72+
error_pb2.CancellationAlreadyRequestedError(),
73+
error.CancellationAlreadyRequestedError(message="message", code=StatusCode.INVALID_ARGUMENT),
74+
id="CancellationAlreadyRequestedError"),
75+
pytest.param(
76+
error_pb2.DomainAlreadyExistsError(),
77+
error.DomainAlreadyExistsError(message="message", code=StatusCode.INVALID_ARGUMENT),
78+
id="DomainAlreadyExistsError"),
79+
pytest.param(
80+
error_pb2.LimitExceededError(),
81+
error.LimitExceededError(message="message", code=StatusCode.INVALID_ARGUMENT),
82+
id="LimitExceededError"),
83+
pytest.param(
84+
error_pb2.QueryFailedError(),
85+
error.QueryFailedError(message="message", code=StatusCode.INVALID_ARGUMENT),
86+
id="QueryFailedError"),
87+
pytest.param(
88+
error_pb2.ServiceBusyError(reason="reason"),
89+
error.ServiceBusyError(message="message", code=StatusCode.INVALID_ARGUMENT, reason="reason"),
90+
id="ServiceBusyError"),
91+
pytest.param(
92+
to_status(status_pb2.Status(code=code_pb2.PERMISSION_DENIED, message="no permission")),
93+
error.CadenceError(message="no permission", code=StatusCode.PERMISSION_DENIED),
94+
id="unknown error type"),
95+
]
96+
)
97+
@pytest.mark.asyncio
98+
async def test_map_error(fake_service, err: Message | Status, expected: CadenceError):
99+
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[CadenceErrorInterceptor()]) as channel:
100+
stub = service_meta_pb2_grpc.MetaAPIStub(channel)
101+
if expected is None:
102+
response = await stub.Health(HealthRequest(), timeout=1)
103+
assert response == HealthResponse(ok=True)
104+
else:
105+
if isinstance(err, Message):
106+
fake_service.status = details_to_status(err)
107+
else:
108+
fake_service.status = err
109+
with pytest.raises(type(expected)) as exc_info:
110+
await stub.Health(HealthRequest(), timeout=1)
111+
assert exc_info.value.args == expected.args
112+
113+
def details_to_status(message: Message) -> Status:
114+
detail = any_pb2.Any()
115+
detail.Pack(message)
116+
status_proto = status_pb2.Status(
117+
code=code_pb2.INVALID_ARGUMENT,
118+
message="message",
119+
details=[detail],
120+
)
121+
return to_status(status_proto)
122+

uv.lock

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)