Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ class _ClientCallDetails(
):
pass

class MetadataInterceptor(UnaryUnaryClientInterceptor):
def __init__(self, metadata: Metadata):
self._metadata = metadata
SERVICE_KEY = "rpc-service"
CALLER_KEY = "rpc-caller"
ENCODING_KEY = "rpc-encoding"
ENCODING_PROTO = "proto"

class YarpcMetadataInterceptor(UnaryUnaryClientInterceptor):
def __init__(self, service: str, caller: str):
self._metadata = Metadata(
(SERVICE_KEY, service),
(CALLER_KEY, caller),
(ENCODING_KEY, ENCODING_PROTO),
)

async def intercept_unary_unary(
self,
Expand Down
64 changes: 52 additions & 12 deletions cadence/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
import os
import socket
from typing import TypedDict
from typing import TypedDict, Unpack, Any, cast

from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
from grpc.aio import Channel
from grpc import ChannelCredentials, Compression

from cadence.data_converter import DataConverter
from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
from cadence.data_converter import DataConverter, DefaultDataConverter


class ClientOptions(TypedDict, total=False):
domain: str
identity: str
target: str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: service_endpoint? target is too general, I guess this is for socket address or a yarpc dispatcher for internal services. How are we going to integrate this internally?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, target is rather weird. I picked this name because it matches the naming from the GRPC api: https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.secure_channel .

In the Java client we specify host and port, I don't know the equivalent in Go.

Internally I think this will work similar to Java, where we use the GRPC transport.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding some comment on this or maybe change to grpc_target then?

data_converter: DataConverter
identity: str
service_name: str
caller_name: str
channel_arguments: dict[str, Any]
credentials: ChannelCredentials | None
compression: Compression
interceptors: list[ClientInterceptor]

_DEFAULT_OPTIONS: ClientOptions = {
"data_converter": DefaultDataConverter(),
"identity": f"{os.getpid()}@{socket.gethostname()}",
"service_name": "cadence-frontend",
"caller_name": "cadence-client",
"channel_arguments": {},
"credentials": None,
"compression": Compression.NoCompression,
"interceptors": [],
}

class Client:
def __init__(self, channel: Channel, options: ClientOptions) -> None:
self._channel = channel
self._worker_stub = WorkerAPIStub(channel)
self._options = options
self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}"
def __init__(self, **kwargs: Unpack[ClientOptions]) -> None:
self._options = _validate_and_copy_defaults(ClientOptions(**kwargs))
self._channel = _create_channel(self._options)
self._worker_stub = WorkerAPIStub(self._channel)

@property
def data_converter(self) -> DataConverter:
Expand All @@ -30,14 +49,35 @@ def domain(self) -> str:

@property
def identity(self) -> str:
return self._identity
return self._options["identity"]

@property
def worker_stub(self) -> WorkerAPIStub:
return self._worker_stub


async def close(self) -> None:
await self._channel.close()

def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move to ClientOptions class method so it's less likely someone implements something similar in other packages

if "target" not in options:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe options.target is not None so it's discoverable on how this field is used through IDE?

raise ValueError("target must be specified")

if "domain" not in options:
raise ValueError("domain must be specified")

# Set default values for missing options
for key, value in _DEFAULT_OPTIONS.items():
if key not in options:
cast(dict, options)[key] = value

return options


def _create_channel(options: ClientOptions) -> Channel:
interceptors = list(options["interceptors"])
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))

if options["credentials"]:
return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors)
else:
return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors)
16 changes: 4 additions & 12 deletions cadence/sample/client_example.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import asyncio

from grpc.aio import insecure_channel, Metadata

from cadence.client import Client, ClientOptions
from cadence._internal.rpc.metadata import MetadataInterceptor
from cadence.client import Client
from cadence.worker import Worker, Registry


async def main():
# TODO - Hide all this
metadata = Metadata()
metadata["rpc-service"] = "cadence-frontend"
metadata["rpc-encoding"] = "proto"
metadata["rpc-caller"] = "nate"
async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel:
client = Client(channel, ClientOptions(domain="foo"))
worker = Worker(client, "task_list", Registry())
await worker.run()
client = Client(target="localhost:7833", domain="foo")
worker = Worker(client, "task_list", Registry())
await worker.run()

if __name__ == '__main__':
asyncio.run(main())