Skip to content

Commit 49f1fbb

Browse files
authored
fix fractional value handling in retry delay policy (#18200)
1 parent 18563e5 commit 49f1fbb

File tree

4 files changed

+38
-10
lines changed

4 files changed

+38
-10
lines changed

src/prefect/_internal/schemas/validators.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,18 @@ def validate_compressionlib(value: str) -> str:
402402

403403

404404
# TODO: if we use this elsewhere we can change the error message to be more generic
405+
@overload
406+
def list_length_50_or_less(v: int) -> int: ...
407+
408+
409+
@overload
410+
def list_length_50_or_less(v: float) -> float: ...
411+
412+
413+
@overload
414+
def list_length_50_or_less(v: list[int]) -> list[int]: ...
415+
416+
405417
@overload
406418
def list_length_50_or_less(v: list[float]) -> list[float]: ...
407419

@@ -410,7 +422,9 @@ def list_length_50_or_less(v: list[float]) -> list[float]: ...
410422
def list_length_50_or_less(v: None) -> None: ...
411423

412424

413-
def list_length_50_or_less(v: Optional[list[float]]) -> Optional[list[float]]:
425+
def list_length_50_or_less(
426+
v: Optional[int | float | list[int] | list[float]],
427+
) -> Optional[int | float | list[int] | list[float]]:
414428
if isinstance(v, list) and (len(v) > 50):
415429
raise ValueError("Can not configure more than 50 retry delays per task.")
416430
return v

src/prefect/client/schemas/objects.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ class TaskRunPolicy(PrefectBaseModel):
698698
deprecated=True,
699699
)
700700
retries: Optional[int] = Field(default=None, description="The number of retries.")
701-
retry_delay: Union[None, int, list[int]] = Field(
701+
retry_delay: Union[None, int, float, list[int], list[float]] = Field(
702702
default=None,
703703
description="A delay time or list of delay times between retries, in seconds.",
704704
)
@@ -728,8 +728,8 @@ def populate_deprecated_fields(self):
728728
@field_validator("retry_delay")
729729
@classmethod
730730
def validate_configured_retry_delays(
731-
cls, v: Optional[list[float]]
732-
) -> Optional[list[float]]:
731+
cls, v: Optional[int | float | list[int] | list[float]]
732+
) -> Optional[int | float | list[int] | list[float]]:
733733
return list_length_50_or_less(v)
734734

735735
@field_validator("retry_jitter_factor")

src/prefect/server/schemas/core.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from prefect._internal.schemas.validators import (
3636
get_or_create_run_name,
37+
list_length_50_or_less,
3738
set_run_policy_deprecated_fields,
3839
validate_cache_key_length,
3940
validate_default_queue_id_not_none,
@@ -356,7 +357,7 @@ class TaskRunPolicy(PrefectBaseModel):
356357
deprecated=True,
357358
)
358359
retries: Optional[int] = Field(default=None, description="The number of retries.")
359-
retry_delay: Union[None, int, List[int]] = Field(
360+
retry_delay: Union[None, int, float, List[int], List[float]] = Field(
360361
default=None,
361362
description="A delay time or list of delay times between retries, in seconds.",
362363
)
@@ -371,11 +372,9 @@ def populate_deprecated_fields(cls, values: dict[str, Any]) -> dict[str, Any]:
371372
@field_validator("retry_delay")
372373
@classmethod
373374
def validate_configured_retry_delays(
374-
cls, v: int | list[int] | None
375-
) -> int | list[int] | None:
376-
if isinstance(v, list) and (len(v) > 50):
377-
raise ValueError("Can not configure more than 50 retry delays per task.")
378-
return v
375+
cls, v: int | float | list[int] | list[float] | None
376+
) -> int | float | list[int] | list[float] | None:
377+
return list_length_50_or_less(v)
379378

380379
@field_validator("retry_jitter_factor")
381380
@classmethod

tests/test_tasks.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4329,6 +4329,21 @@ async def test_task_cannot_configure_negative_relative_jitter(self):
43294329
async def insanity():
43304330
raise RuntimeError("try again!")
43314331

4332+
def test_task_accepts_fractional_retry_delay_seconds(self):
4333+
@task(retries=2, retry_delay_seconds=1.5)
4334+
def task_with_float_delay():
4335+
return "success"
4336+
4337+
@task(retries=3, retry_delay_seconds=[0.5, 1.1, 2.7])
4338+
def task_with_float_list_delay():
4339+
return "success"
4340+
4341+
assert task_with_float_delay.retries == 2
4342+
assert task_with_float_delay.retry_delay_seconds == 1.5
4343+
4344+
assert task_with_float_list_delay.retries == 3
4345+
assert task_with_float_list_delay.retry_delay_seconds == [0.5, 1.1, 2.7]
4346+
43324347

43334348
async def test_task_run_name_is_set(prefect_client, events_pipeline):
43344349
@task(task_run_name="fixed-name")

0 commit comments

Comments
 (0)