@@ -96,6 +96,32 @@ async def test_commit_offset_with_session_id_works(self, driver, topic_with_mess
9696 msg2 = await reader .receive_message ()
9797 assert msg2 .seqno == 2
9898
99+ async def test_commit_offset_retry_on_ydb_errors (self , driver , topic_with_messages , topic_consumer , monkeypatch ):
100+ async with driver .topic_client .reader (topic_with_messages , topic_consumer ) as reader :
101+ message = await reader .receive_message ()
102+
103+ call_count = 0
104+ original_driver_call = driver .topic_client ._driver
105+
106+ async def mock_driver_call (* args , ** kwargs ):
107+ nonlocal call_count
108+ call_count += 1
109+
110+ if call_count == 1 :
111+ raise ydb .Unavailable ("Service temporarily unavailable" )
112+ elif call_count == 2 :
113+ raise ydb .Cancelled ("Operation was cancelled" )
114+ else :
115+ return await original_driver_call (* args , ** kwargs )
116+
117+ monkeypatch .setattr (driver .topic_client , "_driver" , mock_driver_call )
118+
119+ await driver .topic_client .commit_offset (
120+ topic_with_messages , topic_consumer , message .partition_id , message .offset + 1
121+ )
122+
123+ assert call_count == 3
124+
99125 async def test_reader_reconnect_after_commit_offset (self , driver , topic_with_messages , topic_consumer ):
100126 async with driver .topic_client .reader (topic_with_messages , topic_consumer ) as reader :
101127 for out in ["123" , "456" , "789" , "0" ]:
@@ -257,6 +283,33 @@ def test_commit_offset_with_session_id_works(self, driver_sync, topic_with_messa
257283 msg2 = reader .receive_message ()
258284 assert msg2 .seqno == 2
259285
286+ def test_commit_offset_retry_on_ydb_errors (self , driver_sync , topic_with_messages , topic_consumer , monkeypatch ):
287+ with driver_sync .topic_client .reader (topic_with_messages , topic_consumer ) as reader :
288+ message = reader .receive_message ()
289+
290+ # Counter to track retry attempts
291+ call_count = 0
292+ original_driver_call = driver_sync .topic_client ._driver
293+
294+ def mock_driver_call (* args , ** kwargs ):
295+ nonlocal call_count
296+ call_count += 1
297+
298+ if call_count == 1 :
299+ raise ydb .Unavailable ("Service temporarily unavailable" )
300+ elif call_count == 2 :
301+ raise ydb .Cancelled ("Operation was cancelled" )
302+ else :
303+ return original_driver_call (* args , ** kwargs )
304+
305+ monkeypatch .setattr (driver_sync .topic_client , "_driver" , mock_driver_call )
306+
307+ driver_sync .topic_client .commit_offset (
308+ topic_with_messages , topic_consumer , message .partition_id , message .offset + 1
309+ )
310+
311+ assert call_count == 3
312+
260313 def test_reader_reconnect_after_commit_offset (self , driver_sync , topic_with_messages , topic_consumer ):
261314 with driver_sync .topic_client .reader (topic_with_messages , topic_consumer ) as reader :
262315 for out in ["123" , "456" , "789" , "0" ]:
0 commit comments