Skip to content

Commit 4db696c

Browse files
authored
Merge branch 'v3' into topic-writer-flush-on-close
2 parents 6f0b65a + 896a677 commit 4db696c

File tree

7 files changed

+67
-9
lines changed

7 files changed

+67
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Close grpc streams while closing readers/writers
12
* Add control plane operations for topic api: create, drop
23

34
## 3.0.1b4 ##

tests/topics/test_topic_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ async def test_read_message(
99
reader = driver.topic_client.topic_reader(topic_consumer, topic_path)
1010

1111
assert await reader.receive_batch() is not None
12+
await reader.close()

ydb/_grpc/grpcwrapper/common_utils.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class UnknownGrpcMessageError(issues.Error):
6969
pass
7070

7171

72+
_stop_grpc_connection_marker = object()
73+
74+
7275
class QueueToIteratorAsyncIO:
7376
__slots__ = ("_queue",)
7477

@@ -79,10 +82,10 @@ def __aiter__(self):
7982
return self
8083

8184
async def __anext__(self):
82-
try:
83-
return await self._queue.get()
84-
except asyncio.QueueEmpty:
85+
item = await self._queue.get()
86+
if item is _stop_grpc_connection_marker:
8587
raise StopAsyncIteration()
88+
return item
8689

8790

8891
class AsyncQueueToSyncIteratorAsyncIO:
@@ -100,13 +103,10 @@ def __iter__(self):
100103
return self
101104

102105
def __next__(self):
103-
try:
104-
res = asyncio.run_coroutine_threadsafe(
105-
self._queue.get(), self._loop
106-
).result()
107-
return res
108-
except asyncio.QueueEmpty:
106+
item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result()
107+
if item is _stop_grpc_connection_marker:
109108
raise StopIteration()
109+
return item
110110

111111

112112
class SyncIteratorToAsyncIterator:
@@ -133,6 +133,10 @@ async def receive(self) -> Any:
133133
def write(self, wrap_message: IToProto):
134134
...
135135

136+
@abc.abstractmethod
137+
def close(self):
138+
...
139+
136140

137141
SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver]
138142

@@ -142,11 +146,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
142146
from_server_grpc: AsyncIterator
143147
convert_server_grpc_to_wrapper: Callable[[Any], Any]
144148
_connection_state: str
149+
_stream_call: Optional[
150+
Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"]
151+
]
145152

146153
def __init__(self, convert_server_grpc_to_wrapper):
147154
self.from_client_grpc = asyncio.Queue()
148155
self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
149156
self._connection_state = "new"
157+
self._stream_call = None
150158

151159
async def start(self, driver: SupportedDriverType, stub, method):
152160
if asyncio.iscoroutinefunction(driver.__call__):
@@ -155,13 +163,19 @@ async def start(self, driver: SupportedDriverType, stub, method):
155163
await self._start_sync_driver(driver, stub, method)
156164
self._connection_state = "started"
157165

166+
def close(self):
167+
self.from_client_grpc.put_nowait(_stop_grpc_connection_marker)
168+
if self._stream_call:
169+
self._stream_call.cancel()
170+
158171
async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
159172
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
160173
stream_call = await driver(
161174
requests_iterator,
162175
stub,
163176
method,
164177
)
178+
self._stream_call = stream_call
165179
self.from_server_grpc = stream_call.__aiter__()
166180

167181
async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
@@ -172,6 +186,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
172186
stub,
173187
method,
174188
)
189+
self._stream_call = stream_call
175190
self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__())
176191

177192
async def receive(self) -> Any:

ydb/_topic_common/test_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,36 @@
88
class StreamMock(IGrpcWrapperAsyncIO):
99
from_server: asyncio.Queue
1010
from_client: asyncio.Queue
11+
_closed: bool
1112

1213
def __init__(self):
1314
self.from_server = asyncio.Queue()
1415
self.from_client = asyncio.Queue()
16+
self._closed = False
1517

1618
async def receive(self) -> typing.Any:
19+
if self._closed:
20+
raise Exception("read from closed StreamMock")
21+
1722
item = await self.from_server.get()
23+
if item is None:
24+
raise StopAsyncIteration()
1825
if isinstance(item, Exception):
1926
raise item
2027
return item
2128

2229
def write(self, wrap_message: IToProto):
30+
if self._closed:
31+
raise Exception("write to closed StreamMock")
2332
self.from_client.put_nowait(wrap_message)
2433

34+
def close(self):
35+
if self._closed:
36+
return
37+
38+
self._closed = True
39+
self.from_server.put_nowait(None)
40+
2541

2642
async def wait_condition(f: typing.Callable[[], bool], timeout=1):
2743
start = time.monotonic()

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ async def close(self):
496496
self._closed = True
497497
self._set_first_error(TopicReaderStreamClosedError())
498498
self._state_changed.set()
499+
self._stream.close()
499500

500501
for task in self._background_tasks:
501502
task.cancel()

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ async def _connection_loop(self):
296296
pending = []
297297

298298
# noinspection PyBroadException
299+
stream_writer = None
299300
try:
300301
stream_writer = await WriterAsyncIOStream.create(
301302
self._driver, self._init_message, self._get_token
@@ -322,6 +323,7 @@ async def _connection_loop(self):
322323
done, pending = await asyncio.wait(
323324
[send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED
324325
)
326+
stream_writer.close()
325327
done.pop().result()
326328
except issues.Error as err:
327329
# todo log error
@@ -338,6 +340,8 @@ async def _connection_loop(self):
338340
self._stop(err)
339341
return
340342
finally:
343+
if stream_writer:
344+
stream_writer.close()
341345
if len(pending) > 0:
342346
for task in pending:
343347
task.cancel()
@@ -417,6 +421,9 @@ def __init__(
417421
):
418422
self._token_getter = token_getter
419423

424+
def close(self):
425+
self._stream.close()
426+
420427
@staticmethod
421428
async def create(
422429
driver: SupportedDriverType,

ydb/_topic_writer/topic_writer_asyncio_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,37 @@ class StreamWriterMock:
158158
from_client: asyncio.Queue
159159
from_server: asyncio.Queue
160160

161+
_closed: bool
162+
161163
def __init__(self):
162164
self.last_seqno = 0
163165
self.from_server = asyncio.Queue()
164166
self.from_client = asyncio.Queue()
167+
self._closed = False
165168

166169
def write(self, messages: typing.List[InternalMessage]):
170+
if self._closed:
171+
raise Exception("write to closed StreamWriterMock")
172+
167173
self.from_client.put_nowait(messages)
168174

169175
async def receive(self) -> StreamWriteMessage.WriteResponse:
176+
if self._closed:
177+
raise Exception("read from closed StreamWriterMock")
178+
170179
item = await self.from_server.get()
171180
if isinstance(item, Exception):
172181
raise item
173182
return item
174183

184+
def close(self):
185+
if self._closed:
186+
return
187+
188+
self.from_server.put_nowait(
189+
Exception("waited message while StreamWriterMock closed")
190+
)
191+
175192
@pytest.fixture(autouse=True)
176193
async def stream_writer_double_queue(self, monkeypatch):
177194
class DoubleQueueWriters:

0 commit comments

Comments
 (0)