33import enum
44import functools
55from typing import (
6+ Iterable ,
67 Optional ,
78)
89
1112 issues ,
1213)
1314from .._grpc .grpcwrapper import ydb_query as _ydb_query
15+ from ..connection import _RpcState as RpcState
1416
1517from . import base
1618
@@ -50,7 +52,7 @@ def terminal(cls, state: QueryTxStateEnum) -> bool:
5052
5153def reset_tx_id_handler (func ):
5254 @functools .wraps (func )
53- def decorator (rpc_state , response_pb , session_state , tx_state , * args , ** kwargs ):
55+ def decorator (rpc_state , response_pb , session_state : base . IQuerySessionState , tx_state : QueryTxState , * args , ** kwargs ):
5456 try :
5557 return func (rpc_state , response_pb , session_state , tx_state , * args , ** kwargs )
5658 except issues .Error :
@@ -87,35 +89,47 @@ def _already_in(self, target: QueryTxStateEnum) -> bool:
8789 return self ._state == target
8890
8991
90- def _construct_tx_settings (tx_state ) :
92+ def _construct_tx_settings (tx_state : QueryTxState ) -> _ydb_query . TransactionSettings :
9193 tx_settings = _ydb_query .TransactionSettings .from_public (tx_state .tx_mode )
9294 return tx_settings
9395
9496
95- def _create_begin_transaction_request (session_state , tx_state ):
97+ def _create_begin_transaction_request (
98+ session_state : base .IQuerySessionState , tx_state : QueryTxState
99+ ) -> _apis .ydb_query .BeginTransactionRequest :
96100 request = _ydb_query .BeginTransactionRequest (
97101 session_id = session_state .session_id ,
98102 tx_settings = _construct_tx_settings (tx_state ),
99103 ).to_proto ()
100104 return request
101105
102106
103- def _create_commit_transaction_request (session_state , tx_state ):
107+ def _create_commit_transaction_request (
108+ session_state : base .IQuerySessionState , tx_state : QueryTxState
109+ ) -> _apis .ydb_query .CommitTransactionRequest :
104110 request = _apis .ydb_query .CommitTransactionRequest ()
105111 request .tx_id = tx_state .tx_id
106112 request .session_id = session_state .session_id
107113 return request
108114
109115
110- def _create_rollback_transaction_request (session_state , tx_state ):
116+ def _create_rollback_transaction_request (
117+ session_state : base .IQuerySessionState , tx_state : QueryTxState
118+ ) -> _apis .ydb_query .RollbackTransactionRequest :
111119 request = _apis .ydb_query .RollbackTransactionRequest ()
112120 request .tx_id = tx_state .tx_id
113121 request .session_id = session_state .session_id
114122 return request
115123
116124
117125@base .bad_session_handler
118- def wrap_tx_begin_response (rpc_state , response_pb , session_state , tx_state , tx ):
126+ def wrap_tx_begin_response (
127+ rpc_state : RpcState ,
128+ response_pb : _apis .ydb_query .BeginTransactionResponse ,
129+ session_state : base .IQuerySessionState ,
130+ tx_state : QueryTxState ,
131+ tx : "BaseQueryTxContext" ,
132+ ) -> "BaseQueryTxContext" :
119133 message = _ydb_query .BeginTransactionResponse .from_proto (response_pb )
120134 issues ._process_response (message .status )
121135 tx_state ._change_state (QueryTxStateEnum .BEGINED )
@@ -125,7 +139,13 @@ def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx):
125139
126140@base .bad_session_handler
127141@reset_tx_id_handler
128- def wrap_tx_commit_response (rpc_state , response_pb , session_state , tx_state , tx ):
142+ def wrap_tx_commit_response (
143+ rpc_state : RpcState ,
144+ response_pb : _apis .ydb_query .CommitTransactionResponse ,
145+ session_state : base .IQuerySessionState ,
146+ tx_state : QueryTxState ,
147+ tx : "BaseQueryTxContext" ,
148+ ) -> "BaseQueryTxContext" :
129149 message = _ydb_query .CommitTransactionResponse .from_proto (response_pb )
130150 issues ._process_response (message .status )
131151 tx_state ._change_state (QueryTxStateEnum .COMMITTED )
@@ -134,7 +154,13 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx)
134154
135155@base .bad_session_handler
136156@reset_tx_id_handler
137- def wrap_tx_rollback_response (rpc_state , response_pb , session_state , tx_state , tx ):
157+ def wrap_tx_rollback_response (
158+ rpc_state : RpcState ,
159+ response_pb : _apis .ydb_query .RollbackTransactionResponse ,
160+ session_state : base .IQuerySessionState ,
161+ tx_state : QueryTxState ,
162+ tx : "BaseQueryTxContext" ,
163+ ) -> "BaseQueryTxContext" :
138164 message = _ydb_query .RollbackTransactionResponse .from_proto (response_pb )
139165 issues ._process_response (message .status )
140166 tx_state ._change_state (QueryTxStateEnum .ROLLBACKED )
@@ -211,7 +237,7 @@ def tx_id(self) -> Optional[str]:
211237 """
212238 return self ._tx_state .tx_id
213239
214- def _begin_call (self , settings : Optional [base .QueryClientSettings ]):
240+ def _begin_call (self , settings : Optional [base .QueryClientSettings ]) -> "BaseQueryTxContext" :
215241 return self ._driver (
216242 _create_begin_transaction_request (self ._session_state , self ._tx_state ),
217243 _apis .QueryService .Stub ,
@@ -221,7 +247,7 @@ def _begin_call(self, settings: Optional[base.QueryClientSettings]):
221247 (self ._session_state , self ._tx_state , self ),
222248 )
223249
224- def _commit_call (self , settings : Optional [base .QueryClientSettings ]):
250+ def _commit_call (self , settings : Optional [base .QueryClientSettings ]) -> "BaseQueryTxContext" :
225251 return self ._driver (
226252 _create_commit_transaction_request (self ._session_state , self ._tx_state ),
227253 _apis .QueryService .Stub ,
@@ -231,7 +257,7 @@ def _commit_call(self, settings: Optional[base.QueryClientSettings]):
231257 (self ._session_state , self ._tx_state , self ),
232258 )
233259
234- def _rollback_call (self , settings : Optional [base .QueryClientSettings ]):
260+ def _rollback_call (self , settings : Optional [base .QueryClientSettings ]) -> "BaseQueryTxContext" :
235261 return self ._driver (
236262 _create_rollback_transaction_request (self ._session_state , self ._tx_state ),
237263 _apis .QueryService .Stub ,
@@ -249,7 +275,7 @@ def _execute_call(
249275 exec_mode : base .QueryExecMode = None ,
250276 parameters : dict = None ,
251277 concurrent_result_sets : bool = False ,
252- ):
278+ ) -> Iterable [ _apis . ydb_query . ExecuteQueryResponsePart ] :
253279 request = base .create_execute_query_request (
254280 query = query ,
255281 session_id = self ._session_state .session_id ,
@@ -263,23 +289,24 @@ def _execute_call(
263289 )
264290
265291 return self ._driver (
266- request ,
292+ request . to_proto () ,
267293 _apis .QueryService .Stub ,
268294 _apis .QueryService .ExecuteQuery ,
269295 )
270296
271- def _ensure_prev_stream_finished (self ):
297+ def _ensure_prev_stream_finished (self ) -> None :
272298 if self ._prev_stream is not None :
273299 for _ in self ._prev_stream :
274300 pass
275301 self ._prev_stream = None
276302
277- def _handle_tx_meta (self , tx_meta = None ):
278- if not self .tx_id and tx_meta :
279- self ._tx_state ._change_state (QueryTxStateEnum .BEGINED )
280- self ._tx_state .tx_id = tx_meta .id
303+ def _move_to_beginned (self , tx_id : str ) -> None :
304+ if self ._tx_state ._already_in (QueryTxStateEnum .BEGINED ):
305+ return
306+ self ._tx_state ._change_state (QueryTxStateEnum .BEGINED )
307+ self ._tx_state .tx_id = tx_id
281308
282- def _move_to_commited (self ):
309+ def _move_to_commited (self ) -> None :
283310 if self ._tx_state ._already_in (QueryTxStateEnum .COMMITTED ):
284311 return
285312 self ._tx_state ._change_state (QueryTxStateEnum .COMMITTED )
0 commit comments