|
1 | 1 | import asyncio
|
2 | 2 | import importlib.metadata
|
3 | 3 | import os
|
| 4 | +from email.utils import parsedate_to_datetime |
4 | 5 | from typing import (
|
5 | 6 | Any,
|
6 | 7 | AsyncIterator,
|
@@ -138,24 +139,62 @@ async def run(
|
138 | 139 | if not stream:
|
139 | 140 | res = None
|
140 | 141 | delay = retry_delay / 1000
|
141 |
| - for _ in range(max_retry_count): |
| 142 | + retry_count = 0 |
| 143 | + while retry_count <= max_retry_count: |
142 | 144 | try:
|
143 | 145 | res = await self.api.post(route, request, returns=TaskRunResponse)
|
144 | 146 | return res.to_domain(task)
|
145 | 147 | except HTTPStatusError as e:
|
146 | 148 | if e.response.status_code == 404:
|
147 | 149 | raise NotFoundError("Task not found")
|
148 |
| - if e.response.status_code == 429: |
| 150 | + retry_after = e.response.headers.get("Retry-After") |
| 151 | + if retry_after: |
| 152 | + try: |
| 153 | + #for 429 errors this is non-negative decimal |
| 154 | + delay = float(retry_after) |
| 155 | + except ValueError: |
| 156 | + try: |
| 157 | + retry_after_date = parsedate_to_datetime(retry_after) |
| 158 | + current_time = asyncio.get_event_loop().time() |
| 159 | + delay = (retry_after_date.timestamp()- current_time) |
| 160 | + except (TypeError, ValueError, OverflowError): |
| 161 | + delay = min(delay * 2, max_retry_delay / 1000) |
| 162 | + await asyncio.sleep(delay) |
| 163 | + elif e.response.status_code == 429: |
149 | 164 | if delay < max_retry_delay / 1000:
|
150 | 165 | delay = min(delay * 2, max_retry_delay / 1000)
|
151 | 166 | await asyncio.sleep(delay)
|
152 |
| - |
| 167 | + retry_count += 1 |
| 168 | + |
153 | 169 | async def _stream():
|
154 |
| - async for chunk in self.api.stream( |
155 |
| - method="POST", path=route, data=request, returns=RunTaskStreamChunk |
156 |
| - ): |
157 |
| - yield task.output_class.model_construct(None, **chunk.task_output) |
158 |
| - |
| 170 | + delay = retry_delay / 1000 |
| 171 | + retry_count = 0 |
| 172 | + while retry_count <= max_retry_count: |
| 173 | + try: |
| 174 | + async for chunk in self.api.stream( |
| 175 | + method="POST", path=route, data=request, returns=RunTaskStreamChunk |
| 176 | + ): |
| 177 | + yield task.output_class.model_construct(None, **chunk.task_output) |
| 178 | + except HTTPStatusError as e: |
| 179 | + if e.response.status_code == 404: |
| 180 | + raise NotFoundError("Task not found") |
| 181 | + retry_after = e.response.headers.get("Retry-After") |
| 182 | + |
| 183 | + if retry_after: |
| 184 | + try: |
| 185 | + delay = float(retry_after) |
| 186 | + except ValueError: |
| 187 | + try: |
| 188 | + retry_after_date = parsedate_to_datetime(retry_after) |
| 189 | + current_time = asyncio.get_event_loop().time() |
| 190 | + delay = (retry_after_date.timestamp() - current_time) |
| 191 | + except (TypeError, ValueError, OverflowError): |
| 192 | + delay = min(delay * 2, max_retry_delay / 1000) |
| 193 | + elif e.response.status_code == 429: |
| 194 | + if delay < max_retry_delay / 1000: |
| 195 | + delay = min(delay * 2, max_retry_delay / 1000) |
| 196 | + await asyncio.sleep(delay) |
| 197 | + retry_count += 1 |
159 | 198 | return _stream()
|
160 | 199 |
|
161 | 200 | async def import_run(
|
|
0 commit comments