1
1
import os
2
2
import socket
3
- from typing import TypedDict
3
+ from typing import TypedDict , Unpack , Any , cast
4
4
5
- from cadence .api .v1 .service_worker_pb2_grpc import WorkerAPIStub
6
- from grpc .aio import Channel
5
+ from grpc import ChannelCredentials , Compression
7
6
8
- from cadence .data_converter import DataConverter
7
+ from cadence ._internal .rpc .yarpc import YarpcMetadataInterceptor
8
+ from cadence .api .v1 .service_worker_pb2_grpc import WorkerAPIStub
9
+ from grpc .aio import Channel , ClientInterceptor , secure_channel , insecure_channel
10
+ from cadence .data_converter import DataConverter , DefaultDataConverter
9
11
10
12
11
13
class ClientOptions (TypedDict , total = False ):
12
14
domain : str
13
- identity : str
15
+ target : str
14
16
data_converter : DataConverter
17
+ identity : str
18
+ service_name : str
19
+ caller_name : str
20
+ channel_arguments : dict [str , Any ]
21
+ credentials : ChannelCredentials | None
22
+ compression : Compression
23
+ interceptors : list [ClientInterceptor ]
24
+
25
+ _DEFAULT_OPTIONS : ClientOptions = {
26
+ "data_converter" : DefaultDataConverter (),
27
+ "identity" : f"{ os .getpid ()} @{ socket .gethostname ()} " ,
28
+ "service_name" : "cadence-frontend" ,
29
+ "caller_name" : "cadence-client" ,
30
+ "channel_arguments" : {},
31
+ "credentials" : None ,
32
+ "compression" : Compression .NoCompression ,
33
+ "interceptors" : [],
34
+ }
15
35
16
36
class Client :
17
- def __init__ (self , channel : Channel , options : ClientOptions ) -> None :
18
- self ._channel = channel
19
- self ._worker_stub = WorkerAPIStub (channel )
20
- self ._options = options
21
- self ._identity = options ["identity" ] if "identity" in options else f"{ os .getpid ()} @{ socket .gethostname ()} "
37
+ def __init__ (self , ** kwargs : Unpack [ClientOptions ]) -> None :
38
+ self ._options = _validate_and_copy_defaults (ClientOptions (** kwargs ))
39
+ self ._channel = _create_channel (self ._options )
40
+ self ._worker_stub = WorkerAPIStub (self ._channel )
22
41
23
42
@property
24
43
def data_converter (self ) -> DataConverter :
@@ -30,14 +49,35 @@ def domain(self) -> str:
30
49
31
50
@property
32
51
def identity (self ) -> str :
33
- return self ._identity
52
+ return self ._options [ "identity" ]
34
53
35
54
@property
36
55
def worker_stub (self ) -> WorkerAPIStub :
37
56
return self ._worker_stub
38
57
39
-
40
58
async def close (self ) -> None :
41
59
await self ._channel .close ()
42
60
61
+ def _validate_and_copy_defaults (options : ClientOptions ) -> ClientOptions :
62
+ if "target" not in options :
63
+ raise ValueError ("target must be specified" )
64
+
65
+ if "domain" not in options :
66
+ raise ValueError ("domain must be specified" )
67
+
68
+ # Set default values for missing options
69
+ for key , value in _DEFAULT_OPTIONS .items ():
70
+ if key not in options :
71
+ cast (dict , options )[key ] = value
72
+
73
+ return options
74
+
75
+
76
+ def _create_channel (options : ClientOptions ) -> Channel :
77
+ interceptors = list (options ["interceptors" ])
78
+ interceptors .append (YarpcMetadataInterceptor (options ["service_name" ], options ["caller_name" ]))
43
79
80
+ if options ["credentials" ]:
81
+ return secure_channel (options ["target" ], options ["credentials" ], options ["channel_arguments" ], options ["compression" ], interceptors )
82
+ else :
83
+ return insecure_channel (options ["target" ], options ["channel_arguments" ], options ["compression" ], interceptors )
0 commit comments