Skip to content

Commit 2fae543

Browse files
authored
Abstract Channel Initialization behind Client (#21)
1 parent 1213216 commit 2fae543

File tree

3 files changed

+68
-27
lines changed

3 files changed

+68
-27
lines changed

cadence/_internal/rpc/metadata.py renamed to cadence/_internal/rpc/yarpc.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,18 @@ class _ClientCallDetails(
1313
):
1414
pass
1515

16-
class MetadataInterceptor(UnaryUnaryClientInterceptor):
17-
def __init__(self, metadata: Metadata):
18-
self._metadata = metadata
16+
SERVICE_KEY = "rpc-service"
17+
CALLER_KEY = "rpc-caller"
18+
ENCODING_KEY = "rpc-encoding"
19+
ENCODING_PROTO = "proto"
20+
21+
class YarpcMetadataInterceptor(UnaryUnaryClientInterceptor):
22+
def __init__(self, service: str, caller: str):
23+
self._metadata = Metadata(
24+
(SERVICE_KEY, service),
25+
(CALLER_KEY, caller),
26+
(ENCODING_KEY, ENCODING_PROTO),
27+
)
1928

2029
async def intercept_unary_unary(
2130
self,

cadence/client.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,43 @@
11
import os
22
import socket
3-
from typing import TypedDict
3+
from typing import TypedDict, Unpack, Any, cast
44

5-
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
6-
from grpc.aio import Channel
5+
from grpc import ChannelCredentials, Compression
76

8-
from cadence.data_converter import DataConverter
7+
from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor
8+
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
9+
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
10+
from cadence.data_converter import DataConverter, DefaultDataConverter
911

1012

1113
class ClientOptions(TypedDict, total=False):
1214
domain: str
13-
identity: str
15+
target: str
1416
data_converter: DataConverter
17+
identity: str
18+
service_name: str
19+
caller_name: str
20+
channel_arguments: dict[str, Any]
21+
credentials: ChannelCredentials | None
22+
compression: Compression
23+
interceptors: list[ClientInterceptor]
24+
25+
_DEFAULT_OPTIONS: ClientOptions = {
26+
"data_converter": DefaultDataConverter(),
27+
"identity": f"{os.getpid()}@{socket.gethostname()}",
28+
"service_name": "cadence-frontend",
29+
"caller_name": "cadence-client",
30+
"channel_arguments": {},
31+
"credentials": None,
32+
"compression": Compression.NoCompression,
33+
"interceptors": [],
34+
}
1535

1636
class Client:
17-
def __init__(self, channel: Channel, options: ClientOptions) -> None:
18-
self._channel = channel
19-
self._worker_stub = WorkerAPIStub(channel)
20-
self._options = options
21-
self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}"
37+
def __init__(self, **kwargs: Unpack[ClientOptions]) -> None:
38+
self._options = _validate_and_copy_defaults(ClientOptions(**kwargs))
39+
self._channel = _create_channel(self._options)
40+
self._worker_stub = WorkerAPIStub(self._channel)
2241

2342
@property
2443
def data_converter(self) -> DataConverter:
@@ -30,14 +49,35 @@ def domain(self) -> str:
3049

3150
@property
3251
def identity(self) -> str:
33-
return self._identity
52+
return self._options["identity"]
3453

3554
@property
3655
def worker_stub(self) -> WorkerAPIStub:
3756
return self._worker_stub
3857

39-
4058
async def close(self) -> None:
4159
await self._channel.close()
4260

61+
def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
62+
if "target" not in options:
63+
raise ValueError("target must be specified")
64+
65+
if "domain" not in options:
66+
raise ValueError("domain must be specified")
67+
68+
# Set default values for missing options
69+
for key, value in _DEFAULT_OPTIONS.items():
70+
if key not in options:
71+
cast(dict, options)[key] = value
72+
73+
return options
74+
75+
76+
def _create_channel(options: ClientOptions) -> Channel:
77+
interceptors = list(options["interceptors"])
78+
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))
4379

80+
if options["credentials"]:
81+
return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors)
82+
else:
83+
return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors)

cadence/sample/client_example.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
11
import asyncio
22

3-
from grpc.aio import insecure_channel, Metadata
43

5-
from cadence.client import Client, ClientOptions
6-
from cadence._internal.rpc.metadata import MetadataInterceptor
4+
from cadence.client import Client
75
from cadence.worker import Worker, Registry
86

97

108
async def main():
11-
# TODO - Hide all this
12-
metadata = Metadata()
13-
metadata["rpc-service"] = "cadence-frontend"
14-
metadata["rpc-encoding"] = "proto"
15-
metadata["rpc-caller"] = "nate"
16-
async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel:
17-
client = Client(channel, ClientOptions(domain="foo"))
18-
worker = Worker(client, "task_list", Registry())
19-
await worker.run()
9+
client = Client(target="localhost:7833", domain="foo")
10+
worker = Worker(client, "task_list", Registry())
11+
await worker.run()
2012

2113
if __name__ == '__main__':
2214
asyncio.run(main())

0 commit comments

Comments
 (0)