Skip to content

Commit c2c7ce6

Browse files
authored
Merge pull request #10 from workflowai/guillaume/fix-retry-logic
fix: retry logic returns
2 parents d3d937b + a7eb2d5 commit c2c7ce6

File tree

7 files changed

+166
-92
lines changed

7 files changed

+166
-92
lines changed

conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import pytest
2+
from freezegun import freeze_time
3+
4+
5+
@pytest.fixture()
6+
def frozen_time():
7+
with freeze_time("2024-01-01T00:00:00Z") as frozen_time:
8+
yield frozen_time

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.2.2"
3+
version = "0.2.3"
44
description = ""
55
authors = ["Guillaume Aquilina <guillaume@workflowai.com>"]
66
readme = "README.md"

workflowai/core/client/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ async def run(
3333
use_cache: "cache_usage.CacheUsage" = "when_available",
3434
labels: Optional[set[str]] = None,
3535
metadata: Optional[dict[str, Any]] = None,
36-
retry_delay: int = 5000,
37-
max_retry_delay: int = 60000,
38-
max_retry_count: int = 1,
36+
max_retry_delay: float = 60,
37+
max_retry_count: float = 1,
3938
) -> "task_run.TaskRun[task.TaskInput, task.TaskOutput]": ...
4039

4140
@overload
@@ -50,9 +49,8 @@ async def run(
5049
use_cache: "cache_usage.CacheUsage" = "when_available",
5150
labels: Optional[set[str]] = None,
5251
metadata: Optional[dict[str, Any]] = None,
53-
retry_delay: int = 5000,
54-
max_retry_delay: int = 60000,
55-
max_retry_count: int = 1,
52+
max_retry_delay: float = 60,
53+
max_retry_count: float = 1,
5654
) -> AsyncIterator["task.TaskOutput"]: ...
5755

5856
async def run(
@@ -66,9 +64,8 @@ async def run(
6664
use_cache: "cache_usage.CacheUsage" = "when_available",
6765
labels: Optional[set[str]] = None,
6866
metadata: Optional[dict[str, Any]] = None,
69-
retry_delay: int = 5000,
70-
max_retry_delay: int = 60000,
71-
max_retry_count: int = 1,
67+
max_retry_delay: float = 60,
68+
max_retry_count: float = 1,
7269
) -> Union[
7370
"task_run.TaskRun[task.TaskInput, task.TaskOutput]",
7471
AsyncIterator["task.TaskOutput"],

workflowai/core/client/client.py

Lines changed: 63 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import asyncio
21
import importlib.metadata
32
import os
4-
from email.utils import parsedate_to_datetime
3+
from collections.abc import Awaitable, Callable
54
from typing import (
65
Any,
76
AsyncIterator,
@@ -25,6 +24,7 @@
2524
RunTaskStreamChunk,
2625
TaskRunResponse,
2726
)
27+
from workflowai.core.client.utils import build_retryable_wait
2828
from workflowai.core.domain.cache_usage import CacheUsage
2929
from workflowai.core.domain.errors import BaseError, WorkflowAIError
3030
from workflowai.core.domain.task import Task, TaskInput, TaskOutput
@@ -77,9 +77,8 @@ async def run(
7777
use_cache: CacheUsage = "when_available",
7878
labels: Optional[set[str]] = None,
7979
metadata: Optional[dict[str, Any]] = None,
80-
retry_delay: int = 5000,
81-
max_retry_delay: int = 60000,
82-
max_retry_count: int = 1,
80+
max_retry_delay: float = 60,
81+
max_retry_count: float = 1,
8382
) -> TaskRun[TaskInput, TaskOutput]: ...
8483

8584
@overload
@@ -94,12 +93,11 @@ async def run(
9493
use_cache: CacheUsage = "when_available",
9594
labels: Optional[set[str]] = None,
9695
metadata: Optional[dict[str, Any]] = None,
97-
retry_delay: int = 5000,
98-
max_retry_delay: int = 60000,
99-
max_retry_count: int = 1,
96+
max_retry_delay: float = 60,
97+
max_retry_count: float = 1,
10098
) -> AsyncIterator[TaskOutput]: ...
10199

102-
async def run( # noqa: C901
100+
async def run(
103101
self,
104102
task: Task[TaskInput, TaskOutput],
105103
task_input: TaskInput,
@@ -110,9 +108,8 @@ async def run( # noqa: C901
110108
use_cache: CacheUsage = "when_available",
111109
labels: Optional[set[str]] = None,
112110
metadata: Optional[dict[str, Any]] = None,
113-
retry_delay: int = 5000,
114-
max_retry_delay: int = 60000,
115-
max_retry_count: int = 1,
111+
max_retry_delay: float = 60,
112+
max_retry_count: float = 1,
116113
) -> Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]:
117114
await self._auto_register(task)
118115

@@ -135,76 +132,62 @@ async def run( # noqa: C901
135132
)
136133

137134
route = f"/tasks/{task.id}/schemas/{task.schema_id}/run"
135+
should_retry, wait_for_exception = build_retryable_wait(max_retry_delay, max_retry_count)
138136

139137
if not stream:
140-
res = None
141-
delay = retry_delay / 1000
142-
retry_count = 0
143-
while retry_count < max_retry_count:
144-
try:
145-
res = await self.api.post(route, request, returns=TaskRunResponse)
146-
return res.to_domain(task)
147-
except HTTPStatusError as e:
148-
if e.response.status_code == 404:
149-
raise WorkflowAIError(
150-
error=BaseError(
151-
status_code=404,
152-
code="not_found",
153-
message="Task not found",
154-
),
155-
) from e
156-
retry_after = e.response.headers.get("Retry-After")
157-
if retry_after:
158-
try:
159-
# for 429 errors this is non-negative decimal
160-
delay = float(retry_after)
161-
except ValueError:
162-
try:
163-
retry_after_date = parsedate_to_datetime(retry_after)
164-
current_time = asyncio.get_event_loop().time()
165-
delay = retry_after_date.timestamp() - current_time
166-
except (TypeError, ValueError, OverflowError):
167-
delay = min(delay * 2, max_retry_delay / 1000)
168-
await asyncio.sleep(delay)
169-
elif e.response.status_code == 429:
170-
if delay < max_retry_delay / 1000:
171-
delay = min(delay * 2, max_retry_delay / 1000)
172-
await asyncio.sleep(delay)
173-
retry_count += 1
174-
175-
async def _stream():
176-
delay = retry_delay / 1000
177-
retry_count = 0
178-
while retry_count < max_retry_count:
179-
try:
180-
async for chunk in self.api.stream(
181-
method="POST",
182-
path=route,
183-
data=request,
184-
returns=RunTaskStreamChunk,
185-
):
186-
yield task.output_class.model_construct(None, **chunk.task_output)
187-
except HTTPStatusError as e:
188-
if e.response.status_code == 404:
189-
raise WorkflowAIError(error=BaseError(message="Task not found")) from e
190-
retry_after = e.response.headers.get("Retry-After")
191-
192-
if retry_after:
193-
try:
194-
delay = float(retry_after)
195-
except ValueError:
196-
try:
197-
retry_after_date = parsedate_to_datetime(retry_after)
198-
current_time = asyncio.get_event_loop().time()
199-
delay = retry_after_date.timestamp() - current_time
200-
except (TypeError, ValueError, OverflowError):
201-
delay = min(delay * 2, max_retry_delay / 1000)
202-
elif e.response.status_code == 429 and delay < max_retry_delay / 1000:
203-
delay = min(delay * 2, max_retry_delay / 1000)
204-
await asyncio.sleep(delay)
205-
retry_count += 1
206-
207-
return _stream()
138+
return await self._retriable_run(
139+
route,
140+
request,
141+
task,
142+
should_retry=should_retry,
143+
wait_for_exception=wait_for_exception,
144+
)
145+
146+
return self._retriable_stream(
147+
route,
148+
request,
149+
task,
150+
should_retry=should_retry,
151+
wait_for_exception=wait_for_exception,
152+
)
153+
154+
async def _retriable_run(
155+
self,
156+
route: str,
157+
request: RunRequest,
158+
task: Task[TaskInput, TaskOutput],
159+
should_retry: Callable[[], bool],
160+
wait_for_exception: Callable[[HTTPStatusError], Awaitable[None]],
161+
):
162+
while should_retry():
163+
try:
164+
res = await self.api.post(route, request, returns=TaskRunResponse)
165+
return res.to_domain(task)
166+
except HTTPStatusError as e: # noqa: PERF203
167+
await wait_for_exception(e)
168+
169+
raise WorkflowAIError(error=BaseError(message="max retries reached"))
170+
171+
async def _retriable_stream(
172+
self,
173+
route: str,
174+
request: RunRequest,
175+
task: Task[TaskInput, TaskOutput],
176+
should_retry: Callable[[], bool],
177+
wait_for_exception: Callable[[HTTPStatusError], Awaitable[None]],
178+
):
179+
while should_retry():
180+
try:
181+
async for chunk in self.api.stream(
182+
method="POST",
183+
path=route,
184+
data=request,
185+
returns=RunTaskStreamChunk,
186+
):
187+
yield task.output_class.model_construct(None, **chunk.task_output)
188+
return
189+
except HTTPStatusError as e: # noqa: PERF203
190+
await wait_for_exception(e)
208191

209192
async def import_run(
210193
self,

workflowai/core/client/client_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def test_success_with_headers(self, httpx_mock: HTTPXMock, client: Client)
140140
async def test_run_retries_on_too_many_requests(self, httpx_mock: HTTPXMock, client: Client):
141141
task = HelloTask(id="123", schema_id=1)
142142

143-
httpx_mock.add_response(status_code=429)
143+
httpx_mock.add_response(headers={"Retry-After": "0.01"}, status_code=429)
144144
httpx_mock.add_response(json=fixtures_json("task_run.json"))
145145

146146
task_run = await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5)

workflowai/core/client/utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
# Sometimes, 2 payloads are sent in a single message.
22
# By adding the " at the end we more or less guarantee that
33
# the delimiter is not withing a quoted string
4+
import asyncio
45
import re
6+
from email.utils import parsedate_to_datetime
7+
from time import time
8+
from typing import Any, Optional
9+
10+
from httpx import HTTPStatusError
511

612
delimiter = re.compile(r'\}\n\ndata: \{"')
713

@@ -13,3 +19,51 @@ def split_chunks(chunk: bytes):
1319
yield chunk_str[start : match.start() + 1]
1420
start = match.end() - 2
1521
yield chunk_str[start:]
22+
23+
24+
def retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]:
25+
if retry_after is None:
26+
return None
27+
28+
try:
29+
return float(retry_after)
30+
except ValueError:
31+
pass
32+
try:
33+
retry_after_date = parsedate_to_datetime(retry_after)
34+
current_time = time()
35+
return retry_after_date.timestamp() - current_time
36+
except (TypeError, ValueError, OverflowError):
37+
return None
38+
39+
40+
# Returns two functions:
41+
# - _should_retry: returns True if we should retry
42+
# - _wait_for_exception: waits after an exception only if we should retry, otherwise raises
43+
# This is a bit convoluted and would be better in a function wrapper, but since we are dealing
44+
# with both Awaitable and AsyncGenerator, a wrapper would just be too complex
45+
def build_retryable_wait(
46+
max_retry_delay: float = 60,
47+
max_retry_count: float = 1,
48+
):
49+
now = time()
50+
retry_count = 0
51+
52+
def _leftover_delay():
53+
# Time remaining before we hit the max retry delay
54+
return max_retry_delay - (time() - now)
55+
56+
def _should_retry():
57+
return retry_count < max_retry_count and _leftover_delay() >= 0
58+
59+
async def _wait_for_exception(e: HTTPStatusError):
60+
nonlocal retry_count
61+
retry_after = retry_after_to_delay_seconds(e.response.headers.get("Retry-After"))
62+
leftover_delay = _leftover_delay()
63+
if not retry_after or leftover_delay < 0 or retry_count >= max_retry_count:
64+
# TODO: convert error to WorkflowAIError
65+
raise e
66+
await asyncio.sleep(retry_after)
67+
retry_count += 1
68+
69+
return _should_retry, _wait_for_exception

workflowai/core/client/utils_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
from typing import Optional
2+
from unittest.mock import Mock
3+
14
import pytest
5+
from freezegun import freeze_time
6+
from httpx import HTTPStatusError
27

3-
from workflowai.core.client.utils import split_chunks
8+
from workflowai.core.client.utils import build_retryable_wait, retry_after_to_delay_seconds, split_chunks
49

510

611
@pytest.mark.parametrize(
@@ -15,3 +20,30 @@
1520
)
1621
def test_split_chunks(chunk: bytes, expected: list[bytes]):
1722
assert list(split_chunks(chunk)) == expected
23+
24+
25+
@freeze_time("2024-01-01T00:00:00Z")
26+
@pytest.mark.parametrize(
27+
("retry_after", "expected"),
28+
[
29+
(None, None),
30+
("10", 10),
31+
("Wed, 01 Jan 2024 00:00:10 UTC", 10),
32+
],
33+
)
34+
def test_retry_after_to_delay_seconds(retry_after: Optional[str], expected: Optional[float]):
35+
assert retry_after_to_delay_seconds(retry_after) == expected
36+
37+
38+
class TestBuildRetryableWait:
39+
@pytest.fixture()
40+
def request_error(self):
41+
response = Mock()
42+
response.headers = {"Retry-After": "0.01"}
43+
return HTTPStatusError(message="", request=Mock(), response=response)
44+
45+
async def test_should_retry_count(self, request_error: HTTPStatusError):
46+
should_retry, wait_for_exception = build_retryable_wait(60, 1)
47+
assert should_retry()
48+
await wait_for_exception(request_error)
49+
assert not should_retry()

0 commit comments

Comments
 (0)