1
- import asyncio
2
1
import importlib .metadata
3
2
import os
4
- from email . utils import parsedate_to_datetime
3
+ from collections . abc import Awaitable , Callable
5
4
from typing import (
6
5
Any ,
7
6
AsyncIterator ,
25
24
RunTaskStreamChunk ,
26
25
TaskRunResponse ,
27
26
)
27
+ from workflowai .core .client .utils import build_retryable_wait
28
28
from workflowai .core .domain .cache_usage import CacheUsage
29
29
from workflowai .core .domain .errors import BaseError , WorkflowAIError
30
30
from workflowai .core .domain .task import Task , TaskInput , TaskOutput
@@ -77,9 +77,8 @@ async def run(
77
77
use_cache : CacheUsage = "when_available" ,
78
78
labels : Optional [set [str ]] = None ,
79
79
metadata : Optional [dict [str , Any ]] = None ,
80
- retry_delay : int = 5000 ,
81
- max_retry_delay : int = 60000 ,
82
- max_retry_count : int = 1 ,
80
+ max_retry_delay : float = 60 ,
81
+ max_retry_count : float = 1 ,
83
82
) -> TaskRun [TaskInput , TaskOutput ]: ...
84
83
85
84
@overload
@@ -94,12 +93,11 @@ async def run(
94
93
use_cache : CacheUsage = "when_available" ,
95
94
labels : Optional [set [str ]] = None ,
96
95
metadata : Optional [dict [str , Any ]] = None ,
97
- retry_delay : int = 5000 ,
98
- max_retry_delay : int = 60000 ,
99
- max_retry_count : int = 1 ,
96
+ max_retry_delay : float = 60 ,
97
+ max_retry_count : float = 1 ,
100
98
) -> AsyncIterator [TaskOutput ]: ...
101
99
102
- async def run ( # noqa: C901
100
+ async def run (
103
101
self ,
104
102
task : Task [TaskInput , TaskOutput ],
105
103
task_input : TaskInput ,
@@ -110,9 +108,8 @@ async def run( # noqa: C901
110
108
use_cache : CacheUsage = "when_available" ,
111
109
labels : Optional [set [str ]] = None ,
112
110
metadata : Optional [dict [str , Any ]] = None ,
113
- retry_delay : int = 5000 ,
114
- max_retry_delay : int = 60000 ,
115
- max_retry_count : int = 1 ,
111
+ max_retry_delay : float = 60 ,
112
+ max_retry_count : float = 1 ,
116
113
) -> Union [TaskRun [TaskInput , TaskOutput ], AsyncIterator [TaskOutput ]]:
117
114
await self ._auto_register (task )
118
115
@@ -135,76 +132,62 @@ async def run( # noqa: C901
135
132
)
136
133
137
134
route = f"/tasks/{ task .id } /schemas/{ task .schema_id } /run"
135
+ should_retry , wait_for_exception = build_retryable_wait (max_retry_delay , max_retry_count )
138
136
139
137
if not stream :
140
- res = None
141
- delay = retry_delay / 1000
142
- retry_count = 0
143
- while retry_count < max_retry_count :
144
- try :
145
- res = await self .api .post (route , request , returns = TaskRunResponse )
146
- return res .to_domain (task )
147
- except HTTPStatusError as e :
148
- if e .response .status_code == 404 :
149
- raise WorkflowAIError (
150
- error = BaseError (
151
- status_code = 404 ,
152
- code = "not_found" ,
153
- message = "Task not found" ,
154
- ),
155
- ) from e
156
- retry_after = e .response .headers .get ("Retry-After" )
157
- if retry_after :
158
- try :
159
- # for 429 errors this is non-negative decimal
160
- delay = float (retry_after )
161
- except ValueError :
162
- try :
163
- retry_after_date = parsedate_to_datetime (retry_after )
164
- current_time = asyncio .get_event_loop ().time ()
165
- delay = retry_after_date .timestamp () - current_time
166
- except (TypeError , ValueError , OverflowError ):
167
- delay = min (delay * 2 , max_retry_delay / 1000 )
168
- await asyncio .sleep (delay )
169
- elif e .response .status_code == 429 :
170
- if delay < max_retry_delay / 1000 :
171
- delay = min (delay * 2 , max_retry_delay / 1000 )
172
- await asyncio .sleep (delay )
173
- retry_count += 1
174
-
175
- async def _stream ():
176
- delay = retry_delay / 1000
177
- retry_count = 0
178
- while retry_count < max_retry_count :
179
- try :
180
- async for chunk in self .api .stream (
181
- method = "POST" ,
182
- path = route ,
183
- data = request ,
184
- returns = RunTaskStreamChunk ,
185
- ):
186
- yield task .output_class .model_construct (None , ** chunk .task_output )
187
- except HTTPStatusError as e :
188
- if e .response .status_code == 404 :
189
- raise WorkflowAIError (error = BaseError (message = "Task not found" )) from e
190
- retry_after = e .response .headers .get ("Retry-After" )
191
-
192
- if retry_after :
193
- try :
194
- delay = float (retry_after )
195
- except ValueError :
196
- try :
197
- retry_after_date = parsedate_to_datetime (retry_after )
198
- current_time = asyncio .get_event_loop ().time ()
199
- delay = retry_after_date .timestamp () - current_time
200
- except (TypeError , ValueError , OverflowError ):
201
- delay = min (delay * 2 , max_retry_delay / 1000 )
202
- elif e .response .status_code == 429 and delay < max_retry_delay / 1000 :
203
- delay = min (delay * 2 , max_retry_delay / 1000 )
204
- await asyncio .sleep (delay )
205
- retry_count += 1
206
-
207
- return _stream ()
138
+ return await self ._retriable_run (
139
+ route ,
140
+ request ,
141
+ task ,
142
+ should_retry = should_retry ,
143
+ wait_for_exception = wait_for_exception ,
144
+ )
145
+
146
+ return self ._retriable_stream (
147
+ route ,
148
+ request ,
149
+ task ,
150
+ should_retry = should_retry ,
151
+ wait_for_exception = wait_for_exception ,
152
+ )
153
+
154
+ async def _retriable_run (
155
+ self ,
156
+ route : str ,
157
+ request : RunRequest ,
158
+ task : Task [TaskInput , TaskOutput ],
159
+ should_retry : Callable [[], bool ],
160
+ wait_for_exception : Callable [[HTTPStatusError ], Awaitable [None ]],
161
+ ):
162
+ while should_retry ():
163
+ try :
164
+ res = await self .api .post (route , request , returns = TaskRunResponse )
165
+ return res .to_domain (task )
166
+ except HTTPStatusError as e : # noqa: PERF203
167
+ await wait_for_exception (e )
168
+
169
+ raise WorkflowAIError (error = BaseError (message = "max retries reached" ))
170
+
171
+ async def _retriable_stream (
172
+ self ,
173
+ route : str ,
174
+ request : RunRequest ,
175
+ task : Task [TaskInput , TaskOutput ],
176
+ should_retry : Callable [[], bool ],
177
+ wait_for_exception : Callable [[HTTPStatusError ], Awaitable [None ]],
178
+ ):
179
+ while should_retry ():
180
+ try :
181
+ async for chunk in self .api .stream (
182
+ method = "POST" ,
183
+ path = route ,
184
+ data = request ,
185
+ returns = RunTaskStreamChunk ,
186
+ ):
187
+ yield task .output_class .model_construct (None , ** chunk .task_output )
188
+ return
189
+ except HTTPStatusError as e : # noqa: PERF203
190
+ await wait_for_exception (e )
208
191
209
192
async def import_run (
210
193
self ,
0 commit comments