diff --git a/cadence/_internal/rpc/metadata.py b/cadence/_internal/rpc/yarpc.py similarity index 75% rename from cadence/_internal/rpc/metadata.py rename to cadence/_internal/rpc/yarpc.py index c46b909..266df13 100644 --- a/cadence/_internal/rpc/metadata.py +++ b/cadence/_internal/rpc/yarpc.py @@ -13,9 +13,18 @@ class _ClientCallDetails( ): pass -class MetadataInterceptor(UnaryUnaryClientInterceptor): - def __init__(self, metadata: Metadata): - self._metadata = metadata +SERVICE_KEY = "rpc-service" +CALLER_KEY = "rpc-caller" +ENCODING_KEY = "rpc-encoding" +ENCODING_PROTO = "proto" + +class YarpcMetadataInterceptor(UnaryUnaryClientInterceptor): + def __init__(self, service: str, caller: str): + self._metadata = Metadata( + (SERVICE_KEY, service), + (CALLER_KEY, caller), + (ENCODING_KEY, ENCODING_PROTO), + ) async def intercept_unary_unary( self, diff --git a/cadence/client.py b/cadence/client.py index ef5f542..7feb242 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -1,24 +1,43 @@ import os import socket -from typing import TypedDict +from typing import TypedDict, Unpack, Any, cast -from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub -from grpc.aio import Channel +from grpc import ChannelCredentials, Compression -from cadence.data_converter import DataConverter +from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor +from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub +from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel +from cadence.data_converter import DataConverter, DefaultDataConverter class ClientOptions(TypedDict, total=False): domain: str - identity: str + target: str data_converter: DataConverter + identity: str + service_name: str + caller_name: str + channel_arguments: dict[str, Any] + credentials: ChannelCredentials | None + compression: Compression + interceptors: list[ClientInterceptor] + +_DEFAULT_OPTIONS: ClientOptions = { + "data_converter": DefaultDataConverter(), + "identity": f"{os.getpid()}@{socket.gethostname()}", + "service_name": "cadence-frontend", + "caller_name": "cadence-client", + "channel_arguments": {}, + "credentials": None, + "compression": Compression.NoCompression, + "interceptors": [], +} class Client: - def __init__(self, channel: Channel, options: ClientOptions) -> None: - self._channel = channel - self._worker_stub = WorkerAPIStub(channel) - self._options = options - self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}" + def __init__(self, **kwargs: Unpack[ClientOptions]) -> None: + self._options = _validate_and_copy_defaults(ClientOptions(**kwargs)) + self._channel = _create_channel(self._options) + self._worker_stub = WorkerAPIStub(self._channel) @property def data_converter(self) -> DataConverter: @@ -30,14 +49,35 @@ def domain(self) -> str: @property def identity(self) -> str: - return self._identity + return self._options["identity"] @property def worker_stub(self) -> WorkerAPIStub: return self._worker_stub - async def close(self) -> None: await self._channel.close() +def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: + if "target" not in options: + raise ValueError("target must be specified") + + if "domain" not in options: + raise ValueError("domain must be specified") + + # Set default values for missing options + for key, value in _DEFAULT_OPTIONS.items(): + if key not in options: + cast(dict, options)[key] = value + + return options + + +def _create_channel(options: ClientOptions) -> Channel: + interceptors = list(options["interceptors"]) + interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"])) + if options["credentials"]: + return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors) + else: + return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors) \ No newline at end of file diff --git a/cadence/sample/client_example.py b/cadence/sample/client_example.py index 556691c..ece4346 100644 --- a/cadence/sample/client_example.py +++ b/cadence/sample/client_example.py @@ -1,22 +1,14 @@ import asyncio -from grpc.aio import insecure_channel, Metadata -from cadence.client import Client, ClientOptions -from cadence._internal.rpc.metadata import MetadataInterceptor +from cadence.client import Client from cadence.worker import Worker, Registry async def main(): - # TODO - Hide all this - metadata = Metadata() - metadata["rpc-service"] = "cadence-frontend" - metadata["rpc-encoding"] = "proto" - metadata["rpc-caller"] = "nate" - async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel: - client = Client(channel, ClientOptions(domain="foo")) - worker = Worker(client, "task_list", Registry()) - await worker.run() + client = Client(target="localhost:7833", domain="foo") + worker = Worker(client, "task_list", Registry()) + await worker.run() if __name__ == '__main__': asyncio.run(main())