From b29b2d0b24d98775a0725fb17c7f3a9067d42315 Mon Sep 17 00:00:00 2001 From: Nate Mortensen Date: Fri, 29 Aug 2025 14:45:35 -0700 Subject: [PATCH] Abstract Channel Initialization behind Client Our server being yarpc means that there are mandatory headers needed on every request. To include these we need to add mandatory interceptors. For additional things like retries and error mapping we'll also want additional interceptors. GRPC's sync implementation allows for adding interceptors to an existing channel, while the async implementation does not. As a result, our client needs to be responsible for Channel creation. Add GRPC channel options to ClientOptions and create the Channel within Client. This approach largely matches how the Java client approaches it, although it does allow for overriding the Channel still. --- .../_internal/rpc/{metadata.py => yarpc.py} | 15 ++++- cadence/client.py | 64 +++++++++++++++---- cadence/sample/client_example.py | 16 ++--- 3 files changed, 68 insertions(+), 27 deletions(-) rename cadence/_internal/rpc/{metadata.py => yarpc.py} (75%) diff --git a/cadence/_internal/rpc/metadata.py b/cadence/_internal/rpc/yarpc.py similarity index 75% rename from cadence/_internal/rpc/metadata.py rename to cadence/_internal/rpc/yarpc.py index c46b909..266df13 100644 --- a/cadence/_internal/rpc/metadata.py +++ b/cadence/_internal/rpc/yarpc.py @@ -13,9 +13,18 @@ class _ClientCallDetails( ): pass -class MetadataInterceptor(UnaryUnaryClientInterceptor): - def __init__(self, metadata: Metadata): - self._metadata = metadata +SERVICE_KEY = "rpc-service" +CALLER_KEY = "rpc-caller" +ENCODING_KEY = "rpc-encoding" +ENCODING_PROTO = "proto" + +class YarpcMetadataInterceptor(UnaryUnaryClientInterceptor): + def __init__(self, service: str, caller: str): + self._metadata = Metadata( + (SERVICE_KEY, service), + (CALLER_KEY, caller), + (ENCODING_KEY, ENCODING_PROTO), + ) async def intercept_unary_unary( self, diff --git a/cadence/client.py b/cadence/client.py index ef5f542..7feb242 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -1,24 +1,43 @@ import os import socket -from typing import TypedDict +from typing import TypedDict, Unpack, Any, cast -from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub -from grpc.aio import Channel +from grpc import ChannelCredentials, Compression -from cadence.data_converter import DataConverter +from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor +from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub +from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel +from cadence.data_converter import DataConverter, DefaultDataConverter class ClientOptions(TypedDict, total=False): domain: str - identity: str + target: str data_converter: DataConverter + identity: str + service_name: str + caller_name: str + channel_arguments: dict[str, Any] + credentials: ChannelCredentials | None + compression: Compression + interceptors: list[ClientInterceptor] + +_DEFAULT_OPTIONS: ClientOptions = { + "data_converter": DefaultDataConverter(), + "identity": f"{os.getpid()}@{socket.gethostname()}", + "service_name": "cadence-frontend", + "caller_name": "cadence-client", + "channel_arguments": {}, + "credentials": None, + "compression": Compression.NoCompression, + "interceptors": [], +} class Client: - def __init__(self, channel: Channel, options: ClientOptions) -> None: - self._channel = channel - self._worker_stub = WorkerAPIStub(channel) - self._options = options - self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}" + def __init__(self, **kwargs: Unpack[ClientOptions]) -> None: + self._options = _validate_and_copy_defaults(ClientOptions(**kwargs)) + self._channel = _create_channel(self._options) + self._worker_stub = WorkerAPIStub(self._channel) @property def data_converter(self) -> DataConverter: @@ -30,14 +49,35 @@ def domain(self) -> str: @property def identity(self) -> str: - return self._identity + return self._options["identity"] @property def worker_stub(self) -> WorkerAPIStub: return self._worker_stub - async def close(self) -> None: await self._channel.close() +def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: + if "target" not in options: + raise ValueError("target must be specified") + + if "domain" not in options: + raise ValueError("domain must be specified") + + # Set default values for missing options + for key, value in _DEFAULT_OPTIONS.items(): + if key not in options: + cast(dict, options)[key] = value + + return options + + +def _create_channel(options: ClientOptions) -> Channel: + interceptors = list(options["interceptors"]) + interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"])) + if options["credentials"]: + return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors) + else: + return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors) \ No newline at end of file diff --git a/cadence/sample/client_example.py b/cadence/sample/client_example.py index 556691c..ece4346 100644 --- a/cadence/sample/client_example.py +++ b/cadence/sample/client_example.py @@ -1,22 +1,14 @@ import asyncio -from grpc.aio import insecure_channel, Metadata -from cadence.client import Client, ClientOptions -from cadence._internal.rpc.metadata import MetadataInterceptor +from cadence.client import Client from cadence.worker import Worker, Registry async def main(): - # TODO - Hide all this - metadata = Metadata() - metadata["rpc-service"] = "cadence-frontend" - metadata["rpc-encoding"] = "proto" - metadata["rpc-caller"] = "nate" - async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel: - client = Client(channel, ClientOptions(domain="foo")) - worker = Worker(client, "task_list", Registry()) - await worker.run() + client = Client(target="localhost:7833", domain="foo") + worker = Worker(client, "task_list", Registry()) + await worker.run() if __name__ == '__main__': asyncio.run(main())