Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,32 @@ async def test_commit_offset_with_session_id_works(self, driver, topic_with_mess
msg2 = await reader.receive_message()
assert msg2.seqno == 2

async def test_commit_offset_retry_on_ydb_errors(self, driver, topic_with_messages, topic_consumer, monkeypatch):
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
message = await reader.receive_message()

call_count = 0
original_driver_call = driver.topic_client._driver

async def mock_driver_call(*args, **kwargs):
nonlocal call_count
call_count += 1

if call_count == 1:
raise ydb.Unavailable("Service temporarily unavailable")
elif call_count == 2:
raise ydb.Cancelled("Operation was cancelled")
else:
return await original_driver_call(*args, **kwargs)

monkeypatch.setattr(driver.topic_client, "_driver", mock_driver_call)

await driver.topic_client.commit_offset(
topic_with_messages, topic_consumer, message.partition_id, message.offset + 1
)

assert call_count == 3

async def test_reader_reconnect_after_commit_offset(self, driver, topic_with_messages, topic_consumer):
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
for out in ["123", "456", "789", "0"]:
Expand Down Expand Up @@ -257,6 +283,33 @@ def test_commit_offset_with_session_id_works(self, driver_sync, topic_with_messa
msg2 = reader.receive_message()
assert msg2.seqno == 2

def test_commit_offset_retry_on_ydb_errors(self, driver_sync, topic_with_messages, topic_consumer, monkeypatch):
with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader:
message = reader.receive_message()

# Counter to track retry attempts
call_count = 0
original_driver_call = driver_sync.topic_client._driver

def mock_driver_call(*args, **kwargs):
nonlocal call_count
call_count += 1

if call_count == 1:
raise ydb.Unavailable("Service temporarily unavailable")
elif call_count == 2:
raise ydb.Cancelled("Operation was cancelled")
else:
return original_driver_call(*args, **kwargs)

monkeypatch.setattr(driver_sync.topic_client, "_driver", mock_driver_call)

driver_sync.topic_client.commit_offset(
topic_with_messages, topic_consumer, message.partition_id, message.offset + 1
)

assert call_count == 3

def test_reader_reconnect_after_commit_offset(self, driver_sync, topic_with_messages, topic_consumer):
with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader:
for out in ["123", "456", "789", "0"]:
Expand Down
63 changes: 63 additions & 0 deletions ydb/retries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import functools
import inspect
import random
import time

Expand Down Expand Up @@ -161,3 +163,64 @@ async def retry_operation_async(callee, retry_settings=None, *args, **kwargs):
return await next_opt.result
except BaseException as e: # pylint: disable=W0703
next_opt.set_exception(e)


def ydb_retry(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about import to ydb namespace for common use as retry decorator for all operations, independent from topic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already import it in init file
from .retries import * # noqa

max_retries=10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about use timeout as retries by default?

max_session_acquire_timeout=None,
on_ydb_error_callback=None,
backoff_ceiling=6,
backoff_slot_duration=1,
get_session_client_timeout=5,
fast_backoff_settings=None,
slow_backoff_settings=None,
idempotent=False,
retry_cancelled=False,
):
"""
Decorator for automatic function retry in case of YDB errors.

Supports both synchronous and asynchronous functions.

:param max_retries: Maximum number of retries (default: 10)
:param max_session_acquire_timeout: Maximum session acquisition timeout (default: None)
:param on_ydb_error_callback: Callback for handling YDB errors (default: None)
:param backoff_ceiling: Ceiling for backoff algorithm (default: 6)
:param backoff_slot_duration: Slot duration for backoff (default: 1)
:param get_session_client_timeout: Session client timeout (default: 5)
:param fast_backoff_settings: Fast backoff settings (default: None)
:param slow_backoff_settings: Slow backoff settings (default: None)
:param idempotent: Whether the operation is idempotent (default: False)
:param retry_cancelled: Whether to retry cancelled operations (default: False)
"""

def decorator(func):
retry_settings = RetrySettings(
max_retries=max_retries,
max_session_acquire_timeout=max_session_acquire_timeout,
on_ydb_error_callback=on_ydb_error_callback,
backoff_ceiling=backoff_ceiling,
backoff_slot_duration=backoff_slot_duration,
get_session_client_timeout=get_session_client_timeout,
fast_backoff_settings=fast_backoff_settings,
slow_backoff_settings=slow_backoff_settings,
idempotent=idempotent,
retry_cancelled=retry_cancelled,
)

if inspect.iscoroutinefunction(func):

@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
return await retry_operation_async(func, retry_settings, *args, **kwargs)

return async_wrapper
else:

@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
return retry_operation_sync(func, retry_settings, *args, **kwargs)

return sync_wrapper

return decorator
4 changes: 4 additions & 0 deletions ydb/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
PublicAlterAutoPartitioningSettings as TopicAlterAutoPartitioningSettings,
)

from .retries import ydb_retry

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -348,6 +350,7 @@ def tx_writer(

return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self)

@ydb_retry(retry_cancelled=True, idempotent=True)
async def commit_offset(
self, path: str, consumer: str, partition_id: int, offset: int, read_session_id: Optional[str] = None
) -> None:
Expand Down Expand Up @@ -645,6 +648,7 @@ def tx_writer(

return TopicTxWriter(tx, self._driver, settings, _parent=self)

@ydb_retry(retry_cancelled=True, idempotent=True)
def commit_offset(
self, path: str, consumer: str, partition_id: int, offset: int, read_session_id: Optional[str] = None
) -> None:
Expand Down
Loading