Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ydb/_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def __next__(self):


class SyncResponseIterator(object):
def __init__(self, it, wrapper):
def __init__(self, it, wrapper, error_converter=None):
self.it = it
self.wrapper = wrapper
self.error_converter = error_converter

def cancel(self):
self.it.cancel()
Expand All @@ -161,9 +162,16 @@ def __iter__(self):
return self

def _next(self):
res = self.wrapper(next(self.it))
try:
res = self.wrapper(next(self.it))
except BaseException as e:
if self.error_converter:
raise self.error_converter(e) from e
raise e

if res is not None:
return res

return self._next()

def next(self):
Expand Down
6 changes: 4 additions & 2 deletions ydb/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .transaction import QueryTxContext

from .._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT, DEFAULT_LONG_STREAM_TIMEOUT
from .._errors import stream_error_converter


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -362,12 +363,13 @@ def execute(
)

return base.SyncResponseContextIterator(
stream_it,
lambda resp: base.wrap_execute_query_response(
it=stream_it,
wrapper=lambda resp: base.wrap_execute_query_response(
rpc_state=None,
response_pb=resp,
session_state=self._state,
session=self,
settings=self._settings,
),
error_converter=stream_error_converter,
)
6 changes: 4 additions & 2 deletions ydb/query/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..connection import _RpcState as RpcState

from . import base
from .._errors import stream_error_converter
from ..settings import BaseRequestSettings

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -500,14 +501,15 @@ def execute(
)

self._prev_stream = base.SyncResponseContextIterator(
stream_it,
lambda resp: base.wrap_execute_query_response(
it=stream_it,
wrapper=lambda resp: base.wrap_execute_query_response(
rpc_state=None,
response_pb=resp,
session_state=self._session_state,
tx=self,
commit_tx=commit_tx,
settings=self.session._settings,
),
error_converter=stream_error_converter,
)
return self._prev_stream
Loading