@@ -69,6 +69,9 @@ class UnknownGrpcMessageError(issues.Error):
69
69
pass
70
70
71
71
72
+ _stop_grpc_connection_marker = object ()
73
+
74
+
72
75
class QueueToIteratorAsyncIO :
73
76
__slots__ = ("_queue" ,)
74
77
@@ -79,10 +82,10 @@ def __aiter__(self):
79
82
return self
80
83
81
84
async def __anext__ (self ):
82
- try :
83
- return await self ._queue .get ()
84
- except asyncio .QueueEmpty :
85
+ item = await self ._queue .get ()
86
+ if item is _stop_grpc_connection_marker :
85
87
raise StopAsyncIteration ()
88
+ return item
86
89
87
90
88
91
class AsyncQueueToSyncIteratorAsyncIO :
@@ -100,13 +103,10 @@ def __iter__(self):
100
103
return self
101
104
102
105
def __next__ (self ):
103
- try :
104
- res = asyncio .run_coroutine_threadsafe (
105
- self ._queue .get (), self ._loop
106
- ).result ()
107
- return res
108
- except asyncio .QueueEmpty :
106
+ item = asyncio .run_coroutine_threadsafe (self ._queue .get (), self ._loop ).result ()
107
+ if item is _stop_grpc_connection_marker :
109
108
raise StopIteration ()
109
+ return item
110
110
111
111
112
112
class SyncIteratorToAsyncIterator :
@@ -133,6 +133,10 @@ async def receive(self) -> Any:
133
133
def write (self , wrap_message : IToProto ):
134
134
...
135
135
136
+ @abc .abstractmethod
137
+ def close (self ):
138
+ ...
139
+
136
140
137
141
SupportedDriverType = Union [ydb .Driver , ydb .aio .Driver ]
138
142
@@ -142,11 +146,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
142
146
from_server_grpc : AsyncIterator
143
147
convert_server_grpc_to_wrapper : Callable [[Any ], Any ]
144
148
_connection_state : str
149
+ _stream_call : Optional [
150
+ Union [grpc .aio .StreamStreamCall , "grpc._channel._MultiThreadedRendezvous" ]
151
+ ]
145
152
146
153
def __init__ (self , convert_server_grpc_to_wrapper ):
147
154
self .from_client_grpc = asyncio .Queue ()
148
155
self .convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
149
156
self ._connection_state = "new"
157
+ self ._stream_call = None
150
158
151
159
async def start (self , driver : SupportedDriverType , stub , method ):
152
160
if asyncio .iscoroutinefunction (driver .__call__ ):
@@ -155,13 +163,19 @@ async def start(self, driver: SupportedDriverType, stub, method):
155
163
await self ._start_sync_driver (driver , stub , method )
156
164
self ._connection_state = "started"
157
165
166
+ def close (self ):
167
+ self .from_client_grpc .put_nowait (_stop_grpc_connection_marker )
168
+ if self ._stream_call :
169
+ self ._stream_call .cancel ()
170
+
158
171
async def _start_asyncio_driver (self , driver : ydb .aio .Driver , stub , method ):
159
172
requests_iterator = QueueToIteratorAsyncIO (self .from_client_grpc )
160
173
stream_call = await driver (
161
174
requests_iterator ,
162
175
stub ,
163
176
method ,
164
177
)
178
+ self ._stream_call = stream_call
165
179
self .from_server_grpc = stream_call .__aiter__ ()
166
180
167
181
async def _start_sync_driver (self , driver : ydb .Driver , stub , method ):
@@ -172,6 +186,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
172
186
stub ,
173
187
method ,
174
188
)
189
+ self ._stream_call = stream_call
175
190
self .from_server_grpc = SyncIteratorToAsyncIterator (stream_call .__iter__ ())
176
191
177
192
async def receive (self ) -> Any :
0 commit comments