Skip to content

Commit 52f0874

Browse files
committed
get red of req_id logic and loop on every call
1 parent 633f2d4 commit 52f0874

File tree

1 file changed

+88
-108
lines changed

1 file changed

+88
-108
lines changed

ydb/aio/coordination/lock.py

Lines changed: 88 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
2+
import contextlib
23
from typing import Optional
34

4-
from ydb import issues, StatusCode
5+
from ydb import issues
56
from ydb._grpc.grpcwrapper.ydb_coordination import (
67
AcquireSemaphore,
78
ReleaseSemaphore,
@@ -17,43 +18,33 @@
1718

1819

1920
class CoordinationLock:
20-
def __init__(
21-
self,
22-
client,
23-
name: str,
24-
node_path: Optional[str] = None,
25-
):
21+
def __init__(self, client, name: str, node_path: str = None):
2622
self._client = client
2723
self._driver = client._driver
2824
self._name = name
2925
self._node_path = node_path
3026

31-
self._req_id: Optional[int] = None
3227
self._count: int = 1
3328
self._timeout_millis: int = 30000
34-
self._next_req_id: int = 1
35-
36-
self._request_queue: asyncio.Queue = asyncio.Queue()
37-
self._stream: Optional[CoordinationStream] = None
3829

30+
self._stream = None
3931
self._reconnector = CoordinationReconnector(
4032
driver=self._driver,
41-
request_queue=self._request_queue,
33+
request_queue=asyncio.Queue(),
4234
node_path=self._node_path,
4335
timeout_millis=self._timeout_millis,
4436
)
4537

4638
self._wait_timeout: float = self._timeout_millis / 1000.0
4739

48-
def next_req_id(self) -> int:
49-
r = self._next_req_id
50-
self._next_req_id += 1
51-
return r
52-
53-
async def send(self, req):
54-
if self._stream is None:
55-
raise issues.Error("Stream is not started yet")
56-
await self._stream.send(req)
40+
self._pending_futures = {
41+
"acquire": [],
42+
"release": [],
43+
"create": [],
44+
"delete": [],
45+
"describe": [],
46+
"update": [],
47+
}
5748

5849
async def _ensure_session(self):
5950
if self._stream is not None and self._stream.session_id is not None:
@@ -64,75 +55,89 @@ async def _ensure_session(self):
6455

6556
self._reconnector.start()
6657
await self._reconnector.wait_ready()
67-
6858
self._stream = self._reconnector.get_stream()
6959

70-
async def _wait_for_response(self, req_id: int, *, kind: str):
60+
if not hasattr(self._stream, "_dispatch_task") or self._stream._dispatch_task is None:
61+
self._stream._dispatch_task = asyncio.create_task(self._stream_dispatch_loop())
62+
63+
async def _stream_dispatch_loop(self):
7164
try:
7265
while True:
7366
resp = await self._stream.receive(self._wait_timeout)
74-
67+
print("[RECV RAW]", resp)
7568
fs = FromServer.from_proto(resp)
76-
77-
if kind == "acquire":
78-
r = fs.acquire_semaphore_result
79-
elif kind == "describe":
80-
r = fs.describe_semaphore_result
81-
elif kind == "create":
82-
r = fs.create_semaphore_result
83-
elif kind == "update":
84-
r = fs.update_semaphore_result
85-
elif kind == "delete":
86-
r = fs.delete_semaphore_result
69+
print("[RECV PARSED]", fs)
70+
71+
raw = fs.raw
72+
73+
if raw.HasField("acquire_semaphore_result"):
74+
op_type = "acquire"
75+
payload = fs.acquire_semaphore_result
76+
elif raw.HasField("describe_semaphore_result"):
77+
op_type = "describe"
78+
payload = fs.describe_semaphore_result
79+
elif raw.HasField("create_semaphore_result"):
80+
op_type = "create"
81+
payload = fs.create_semaphore_result
82+
elif raw.HasField("update_semaphore_result"):
83+
op_type = "update"
84+
payload = fs.update_semaphore_result
85+
elif raw.HasField("delete_semaphore_result"):
86+
op_type = "delete"
87+
payload = fs.delete_semaphore_result
8788
else:
88-
r = None
89-
90-
if r and r.req_id == req_id:
91-
return r
92-
93-
except asyncio.TimeoutError:
94-
action = {
95-
"acquire": "acquisition",
96-
"describe": "describe",
97-
"update": "update",
98-
"delete": "delete",
99-
"create": "create",
100-
}.get(kind, "operation")
101-
102-
raise issues.Error(f"Timeout waiting for lock {self._name} {action}")
89+
continue
90+
91+
futures = self._pending_futures.get(op_type, [])
92+
for fut in futures:
93+
if not fut.done():
94+
print("[RESOLVE FUTURE]", fut)
95+
fut.set_result(payload)
96+
self._pending_futures[op_type] = []
97+
98+
except asyncio.CancelledError:
99+
for futs in self._pending_futures.values():
100+
for fut in futs:
101+
if not fut.done():
102+
fut.set_exception(asyncio.CancelledError())
103+
futs.clear()
104+
raise
105+
except Exception as exc:
106+
for futs in self._pending_futures.values():
107+
for fut in futs:
108+
if not fut.done():
109+
fut.set_exception(exc)
110+
futs.clear()
111+
with contextlib.suppress(Exception):
112+
await self._stream.close()
113+
return
103114

104-
async def __aenter__(self):
115+
async def _send_and_wait(self, req, op_type: str):
105116
await self._ensure_session()
117+
loop = asyncio.get_running_loop()
118+
fut = loop.create_future()
119+
self._pending_futures[op_type].append(fut)
120+
await self._stream.send(req)
121+
return await asyncio.wait_for(fut, timeout=self._wait_timeout)
106122

107-
req_id = self.next_req_id()
108-
self._req_id = req_id
109-
123+
async def __aenter__(self):
110124
req = AcquireSemaphore(
111-
req_id=req_id,
125+
req_id=0,
112126
name=self._name,
113127
count=self._count,
114128
ephemeral=False,
115129
timeout_millis=self._timeout_millis,
116130
)
117-
118-
await self.send(req)
119-
await self._wait_for_response(req_id, kind="acquire")
120-
131+
await self._send_and_wait(req, "acquire")
121132
return self
122133

123-
124134
async def __aexit__(self, exc_type, exc, tb):
125-
if self._req_id is not None:
126-
try:
127-
req = ReleaseSemaphore(
128-
req_id=self._req_id,
129-
name=self._name,
130-
)
131-
await self.send(req)
132-
except issues.Error:
133-
pass
134-
135-
self._req_id = None
135+
try:
136+
req = ReleaseSemaphore(req_id=0, name=self._name)
137+
if self._stream is not None:
138+
await self._stream.send(req)
139+
except issues.Error:
140+
pass
136141

137142
async def acquire(self):
138143
return await self.__aenter__()
@@ -141,67 +146,42 @@ async def release(self):
141146
await self.__aexit__(None, None, None)
142147

143148
async def create(self, init_limit, init_data):
144-
await self._ensure_session()
145-
146-
req_id = self.next_req_id()
147-
148-
req = CreateSemaphore(req_id=req_id, name=self._name, limit=init_limit, data=init_data)
149-
150-
await self.send(req)
151-
152-
resp = await self._wait_for_response(req_id, kind="create")
149+
req = CreateSemaphore(req_id=0, name=self._name, limit=init_limit, data=init_data)
150+
resp = await self._send_and_wait(req, "create")
153151
return CreateSemaphoreResult.from_proto(resp)
154152

155153
async def delete(self):
156-
await self._ensure_session()
157-
req_id = self.next_req_id()
158-
req = DeleteSemaphore(req_id=req_id, name=self._name)
159-
await self.send(req)
160-
resp = await self._wait_for_response(req_id, kind="delete")
154+
req = DeleteSemaphore(req_id=0, name=self._name)
155+
resp = await self._send_and_wait(req, "delete")
161156
return resp
162157

163158
async def describe(self):
164-
await self._ensure_session()
165-
166-
req_id = self.next_req_id()
167-
168159
req = DescribeSemaphore(
169-
req_id=req_id,
160+
req_id=0,
170161
name=self._name,
171162
include_owners=True,
172163
include_waiters=True,
173164
watch_data=False,
174165
watch_owners=False,
175166
)
176-
177-
await self.send(req)
178-
179-
resp = await self._wait_for_response(req_id, kind="describe")
167+
resp = await self._send_and_wait(req, "describe")
180168
return DescribeLockResult.from_proto(resp)
181169

182170
async def update(self, new_data):
183-
await self._ensure_session()
184-
185-
req_id = self.next_req_id()
186-
req = UpdateSemaphore(req_id=req_id, name=self._name, data=new_data)
187-
188-
await self.send(req)
189-
190-
resp = await self._wait_for_response(req_id, kind="update")
171+
req = UpdateSemaphore(req_id=0, name=self._name, data=new_data)
172+
resp = await self._send_and_wait(req, "update")
191173
return resp
192174

193175
async def close(self, flush: bool = True):
194176
try:
195-
if self._req_id is not None:
196-
req = ReleaseSemaphore(req_id=self._req_id, name=self._name)
197-
if self._stream is not None:
198-
await self.send(req)
177+
req = ReleaseSemaphore(req_id=0, name=self._name)
178+
if self._stream is not None:
179+
await self._stream.send(req)
199180
except issues.Error:
200181
pass
201182

202183
if self._reconnector is not None:
203184
await self._reconnector.stop(flush)
204185

205186
self._stream = None
206-
self._req_id = None
207187
self._node_path = None

0 commit comments

Comments
 (0)