11import asyncio
2+ import contextlib
23from typing import Optional
34
4- from ydb import issues , StatusCode
5+ from ydb import issues
56from ydb ._grpc .grpcwrapper .ydb_coordination import (
67 AcquireSemaphore ,
78 ReleaseSemaphore ,
1718
1819
1920class CoordinationLock :
20- def __init__ (
21- self ,
22- client ,
23- name : str ,
24- node_path : Optional [str ] = None ,
25- ):
21+ def __init__ (self , client , name : str , node_path : str = None ):
2622 self ._client = client
2723 self ._driver = client ._driver
2824 self ._name = name
2925 self ._node_path = node_path
3026
31- self ._req_id : Optional [int ] = None
3227 self ._count : int = 1
3328 self ._timeout_millis : int = 30000
34- self ._next_req_id : int = 1
35-
36- self ._request_queue : asyncio .Queue = asyncio .Queue ()
37- self ._stream : Optional [CoordinationStream ] = None
3829
30+ self ._stream = None
3931 self ._reconnector = CoordinationReconnector (
4032 driver = self ._driver ,
41- request_queue = self . _request_queue ,
33+ request_queue = asyncio . Queue () ,
4234 node_path = self ._node_path ,
4335 timeout_millis = self ._timeout_millis ,
4436 )
4537
4638 self ._wait_timeout : float = self ._timeout_millis / 1000.0
4739
48- def next_req_id (self ) -> int :
49- r = self ._next_req_id
50- self ._next_req_id += 1
51- return r
52-
53- async def send (self , req ):
54- if self ._stream is None :
55- raise issues .Error ("Stream is not started yet" )
56- await self ._stream .send (req )
40+ self ._pending_futures = {
41+ "acquire" : [],
42+ "release" : [],
43+ "create" : [],
44+ "delete" : [],
45+ "describe" : [],
46+ "update" : [],
47+ }
5748
5849 async def _ensure_session (self ):
5950 if self ._stream is not None and self ._stream .session_id is not None :
@@ -64,75 +55,89 @@ async def _ensure_session(self):
6455
6556 self ._reconnector .start ()
6657 await self ._reconnector .wait_ready ()
67-
6858 self ._stream = self ._reconnector .get_stream ()
6959
70- async def _wait_for_response (self , req_id : int , * , kind : str ):
60+ if not hasattr (self ._stream , "_dispatch_task" ) or self ._stream ._dispatch_task is None :
61+ self ._stream ._dispatch_task = asyncio .create_task (self ._stream_dispatch_loop ())
62+
63+ async def _stream_dispatch_loop (self ):
7164 try :
7265 while True :
7366 resp = await self ._stream .receive (self ._wait_timeout )
74-
67+ print ( "[RECV RAW]" , resp )
7568 fs = FromServer .from_proto (resp )
76-
77- if kind == "acquire" :
78- r = fs .acquire_semaphore_result
79- elif kind == "describe" :
80- r = fs .describe_semaphore_result
81- elif kind == "create" :
82- r = fs .create_semaphore_result
83- elif kind == "update" :
84- r = fs .update_semaphore_result
85- elif kind == "delete" :
86- r = fs .delete_semaphore_result
69+ print ("[RECV PARSED]" , fs )
70+
71+ raw = fs .raw
72+
73+ if raw .HasField ("acquire_semaphore_result" ):
74+ op_type = "acquire"
75+ payload = fs .acquire_semaphore_result
76+ elif raw .HasField ("describe_semaphore_result" ):
77+ op_type = "describe"
78+ payload = fs .describe_semaphore_result
79+ elif raw .HasField ("create_semaphore_result" ):
80+ op_type = "create"
81+ payload = fs .create_semaphore_result
82+ elif raw .HasField ("update_semaphore_result" ):
83+ op_type = "update"
84+ payload = fs .update_semaphore_result
85+ elif raw .HasField ("delete_semaphore_result" ):
86+ op_type = "delete"
87+ payload = fs .delete_semaphore_result
8788 else :
88- r = None
89-
90- if r and r .req_id == req_id :
91- return r
92-
93- except asyncio .TimeoutError :
94- action = {
95- "acquire" : "acquisition" ,
96- "describe" : "describe" ,
97- "update" : "update" ,
98- "delete" : "delete" ,
99- "create" : "create" ,
100- }.get (kind , "operation" )
101-
102- raise issues .Error (f"Timeout waiting for lock { self ._name } { action } " )
89+ continue
90+
91+ futures = self ._pending_futures .get (op_type , [])
92+ for fut in futures :
93+ if not fut .done ():
94+ print ("[RESOLVE FUTURE]" , fut )
95+ fut .set_result (payload )
96+ self ._pending_futures [op_type ] = []
97+
98+ except asyncio .CancelledError :
99+ for futs in self ._pending_futures .values ():
100+ for fut in futs :
101+ if not fut .done ():
102+ fut .set_exception (asyncio .CancelledError ())
103+ futs .clear ()
104+ raise
105+ except Exception as exc :
106+ for futs in self ._pending_futures .values ():
107+ for fut in futs :
108+ if not fut .done ():
109+ fut .set_exception (exc )
110+ futs .clear ()
111+ with contextlib .suppress (Exception ):
112+ await self ._stream .close ()
113+ return
103114
104- async def __aenter__ (self ):
115+ async def _send_and_wait (self , req , op_type : str ):
105116 await self ._ensure_session ()
117+ loop = asyncio .get_running_loop ()
118+ fut = loop .create_future ()
119+ self ._pending_futures [op_type ].append (fut )
120+ await self ._stream .send (req )
121+ return await asyncio .wait_for (fut , timeout = self ._wait_timeout )
106122
107- req_id = self .next_req_id ()
108- self ._req_id = req_id
109-
123+ async def __aenter__ (self ):
110124 req = AcquireSemaphore (
111- req_id = req_id ,
125+ req_id = 0 ,
112126 name = self ._name ,
113127 count = self ._count ,
114128 ephemeral = False ,
115129 timeout_millis = self ._timeout_millis ,
116130 )
117-
118- await self .send (req )
119- await self ._wait_for_response (req_id , kind = "acquire" )
120-
131+ await self ._send_and_wait (req , "acquire" )
121132 return self
122133
123-
124134 async def __aexit__ (self , exc_type , exc , tb ):
125- if self ._req_id is not None :
126- try :
127- req = ReleaseSemaphore (
128- req_id = self ._req_id ,
129- name = self ._name ,
130- )
131- await self .send (req )
132- except issues .Error :
133- pass
134-
135- self ._req_id = None
135+ try :
136+ req = ReleaseSemaphore (req_id = 0 , name = self ._name )
137+ if self ._stream is not None :
138+ await self ._stream .send (req )
139+ except issues .Error :
140+ pass
136141
137142 async def acquire (self ):
138143 return await self .__aenter__ ()
@@ -141,67 +146,42 @@ async def release(self):
141146 await self .__aexit__ (None , None , None )
142147
143148 async def create (self , init_limit , init_data ):
144- await self ._ensure_session ()
145-
146- req_id = self .next_req_id ()
147-
148- req = CreateSemaphore (req_id = req_id , name = self ._name , limit = init_limit , data = init_data )
149-
150- await self .send (req )
151-
152- resp = await self ._wait_for_response (req_id , kind = "create" )
149+ req = CreateSemaphore (req_id = 0 , name = self ._name , limit = init_limit , data = init_data )
150+ resp = await self ._send_and_wait (req , "create" )
153151 return CreateSemaphoreResult .from_proto (resp )
154152
155153 async def delete (self ):
156- await self ._ensure_session ()
157- req_id = self .next_req_id ()
158- req = DeleteSemaphore (req_id = req_id , name = self ._name )
159- await self .send (req )
160- resp = await self ._wait_for_response (req_id , kind = "delete" )
154+ req = DeleteSemaphore (req_id = 0 , name = self ._name )
155+ resp = await self ._send_and_wait (req , "delete" )
161156 return resp
162157
163158 async def describe (self ):
164- await self ._ensure_session ()
165-
166- req_id = self .next_req_id ()
167-
168159 req = DescribeSemaphore (
169- req_id = req_id ,
160+ req_id = 0 ,
170161 name = self ._name ,
171162 include_owners = True ,
172163 include_waiters = True ,
173164 watch_data = False ,
174165 watch_owners = False ,
175166 )
176-
177- await self .send (req )
178-
179- resp = await self ._wait_for_response (req_id , kind = "describe" )
167+ resp = await self ._send_and_wait (req , "describe" )
180168 return DescribeLockResult .from_proto (resp )
181169
182170 async def update (self , new_data ):
183- await self ._ensure_session ()
184-
185- req_id = self .next_req_id ()
186- req = UpdateSemaphore (req_id = req_id , name = self ._name , data = new_data )
187-
188- await self .send (req )
189-
190- resp = await self ._wait_for_response (req_id , kind = "update" )
171+ req = UpdateSemaphore (req_id = 0 , name = self ._name , data = new_data )
172+ resp = await self ._send_and_wait (req , "update" )
191173 return resp
192174
193175 async def close (self , flush : bool = True ):
194176 try :
195- if self ._req_id is not None :
196- req = ReleaseSemaphore (req_id = self ._req_id , name = self ._name )
197- if self ._stream is not None :
198- await self .send (req )
177+ req = ReleaseSemaphore (req_id = 0 , name = self ._name )
178+ if self ._stream is not None :
179+ await self ._stream .send (req )
199180 except issues .Error :
200181 pass
201182
202183 if self ._reconnector is not None :
203184 await self ._reconnector .stop (flush )
204185
205186 self ._stream = None
206- self ._req_id = None
207187 self ._node_path = None
0 commit comments