|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import asyncio
|
2 | 4 | import json
|
3 | 5 | import socket
|
|
8 | 10 | from functools import partial
|
9 | 11 | from types import TracebackType
|
10 | 12 | from typing import (
|
| 13 | + Annotated, |
11 | 14 | Any,
|
12 | 15 | AsyncGenerator,
|
13 | 16 | Awaitable,
|
|
20 | 23 | )
|
21 | 24 |
|
22 | 25 | import orjson
|
| 26 | +from pydantic import BeforeValidator, Field |
23 | 27 | from redis.asyncio import Redis
|
24 | 28 | from redis.exceptions import ResponseError
|
25 | 29 | from typing_extensions import Self
|
|
33 | 37 | StopConsumer,
|
34 | 38 | )
|
35 | 39 | from prefect.server.utilities.messaging import Publisher as _Publisher
|
| 40 | +from prefect.settings.base import PrefectBaseSettings, build_settings_config |
36 | 41 | from prefect_redis.client import get_async_redis_client
|
37 | 42 |
|
38 | 43 | logger = get_logger(__name__)
|
39 | 44 |
|
40 | 45 | M = TypeVar("M", bound=Message)
|
41 | 46 |
|
| 47 | + |
| 48 | +def _interpret_string_as_timedelta_seconds(value: timedelta | str) -> timedelta: |
| 49 | + """Interpret a string as a timedelta in seconds.""" |
| 50 | + if isinstance(value, str): |
| 51 | + return timedelta(seconds=int(value)) |
| 52 | + return value |
| 53 | + |
| 54 | + |
| 55 | +TimeDelta = Annotated[ |
| 56 | + Union[str, timedelta], |
| 57 | + BeforeValidator(_interpret_string_as_timedelta_seconds), |
| 58 | +] |
| 59 | + |
| 60 | + |
| 61 | +class RedisMessagingPublisherSettings(PrefectBaseSettings): |
| 62 | + """Settings for the Redis messaging publisher. |
| 63 | +
|
| 64 | + No settings are required to be set by the user but any of the settings can be |
| 65 | + overridden by the user using environment variables. |
| 66 | +
|
| 67 | + Example: |
| 68 | + ``` |
| 69 | + PREFECT_REDIS_MESSAGING_PUBLISHER_BATCH_SIZE=10 |
| 70 | + PREFECT_REDIS_MESSAGING_PUBLISHER_PUBLISH_EVERY=10 |
| 71 | + PREFECT_REDIS_MESSAGING_PUBLISHER_DEDUPLICATE_BY=message_id |
| 72 | + ``` |
| 73 | + """ |
| 74 | + |
| 75 | + model_config = build_settings_config( |
| 76 | + ( |
| 77 | + "redis", |
| 78 | + "messaging", |
| 79 | + "publisher", |
| 80 | + ), |
| 81 | + ) |
| 82 | + batch_size: int = Field(default=5) |
| 83 | + publish_every: TimeDelta = Field(default=timedelta(seconds=10)) |
| 84 | + deduplicate_by: Optional[str] = Field(default=None) |
| 85 | + |
| 86 | + |
| 87 | +class RedisMessagingConsumerSettings(PrefectBaseSettings): |
| 88 | + """Settings for the Redis messaging consumer. |
| 89 | +
|
| 90 | + No settings are required to be set by the user but any of the settings can be |
| 91 | + overridden by the user using environment variables. |
| 92 | +
|
| 93 | + Example: |
| 94 | + ``` |
| 95 | + PREFECT_REDIS_MESSAGING_CONSUMER_BLOCK=10 |
| 96 | + PREFECT_REDIS_MESSAGING_CONSUMER_MIN_IDLE_TIME=10 |
| 97 | + PREFECT_REDIS_MESSAGING_CONSUMER_MAX_RETRIES=3 |
| 98 | + PREFECT_REDIS_MESSAGING_CONSUMER_TRIM_EVERY=60 |
| 99 | + ``` |
| 100 | + """ |
| 101 | + |
| 102 | + model_config = build_settings_config( |
| 103 | + ( |
| 104 | + "redis", |
| 105 | + "messaging", |
| 106 | + "consumer", |
| 107 | + ), |
| 108 | + ) |
| 109 | + block: TimeDelta = Field(default=timedelta(seconds=1)) |
| 110 | + min_idle_time: TimeDelta = Field(default=timedelta(seconds=0)) |
| 111 | + max_retries: int = Field(default=3) |
| 112 | + trim_every: TimeDelta = Field(default=timedelta(seconds=60)) |
| 113 | + should_process_pending_messages: bool = Field(default=True) |
| 114 | + starting_message_id: str = Field(default="0") |
| 115 | + automatically_acknowledge: bool = Field(default=True) |
| 116 | + |
| 117 | + |
42 | 118 | MESSAGE_DEDUPLICATION_LOOKBACK = timedelta(minutes=5)
|
43 | 119 |
|
44 | 120 |
|
@@ -130,14 +206,20 @@ def __init__(
|
130 | 206 | topic: str,
|
131 | 207 | cache: _Cache,
|
132 | 208 | deduplicate_by: Optional[str] = None,
|
133 |
| - batch_size: int = 5, |
| 209 | + batch_size: Optional[int] = None, |
134 | 210 | publish_every: Optional[timedelta] = None,
|
135 | 211 | ):
|
| 212 | + settings = RedisMessagingPublisherSettings() |
| 213 | + |
136 | 214 | self.stream = topic # Use topic as stream name
|
137 | 215 | self.cache = cache
|
138 |
| - self.deduplicate_by = deduplicate_by |
139 |
| - self.batch_size = batch_size |
140 |
| - self.publish_every = publish_every |
| 216 | + self.deduplicate_by = ( |
| 217 | + deduplicate_by if deduplicate_by is not None else settings.deduplicate_by |
| 218 | + ) |
| 219 | + self.batch_size = batch_size if batch_size is not None else settings.batch_size |
| 220 | + self.publish_every = ( |
| 221 | + publish_every if publish_every is not None else settings.publish_every |
| 222 | + ) |
141 | 223 | self._periodic_task: Optional[asyncio.Task[None]] = None
|
142 | 224 |
|
143 | 225 | async def __aenter__(self) -> Self:
|
@@ -220,27 +302,45 @@ def __init__(
|
220 | 302 | topic: str,
|
221 | 303 | name: Optional[str] = None,
|
222 | 304 | group: Optional[str] = None,
|
223 |
| - block: timedelta = timedelta(seconds=1), |
224 |
| - min_idle_time: timedelta = timedelta(seconds=0), |
225 |
| - should_process_pending_messages: bool = True, |
226 |
| - starting_message_id: str = "0", |
227 |
| - automatically_acknowledge: bool = True, |
228 |
| - max_retries: int = 3, |
229 |
| - trim_every: timedelta = timedelta(seconds=60), |
| 305 | + block: Optional[timedelta] = None, |
| 306 | + min_idle_time: Optional[timedelta] = None, |
| 307 | + should_process_pending_messages: Optional[bool] = None, |
| 308 | + starting_message_id: Optional[str] = None, |
| 309 | + automatically_acknowledge: Optional[bool] = None, |
| 310 | + max_retries: Optional[int] = None, |
| 311 | + trim_every: Optional[timedelta] = None, |
230 | 312 | ):
|
| 313 | + settings = RedisMessagingConsumerSettings() |
| 314 | + |
231 | 315 | self.name = name or topic
|
232 | 316 | self.stream = topic # Use topic as stream name
|
233 | 317 | self.group = group or topic # Use topic as default group name
|
234 |
| - self.block = block |
235 |
| - self.min_idle_time = min_idle_time |
236 |
| - self.should_process_pending_messages = should_process_pending_messages |
237 |
| - self.starting_message_id = starting_message_id |
238 |
| - self.automatically_acknowledge = automatically_acknowledge |
| 318 | + self.block = block if block is not None else settings.block |
| 319 | + self.min_idle_time = ( |
| 320 | + min_idle_time if min_idle_time is not None else settings.min_idle_time |
| 321 | + ) |
| 322 | + self.should_process_pending_messages = ( |
| 323 | + should_process_pending_messages |
| 324 | + if should_process_pending_messages is not None |
| 325 | + else settings.should_process_pending_messages |
| 326 | + ) |
| 327 | + self.starting_message_id = ( |
| 328 | + starting_message_id |
| 329 | + if starting_message_id is not None |
| 330 | + else settings.starting_message_id |
| 331 | + ) |
| 332 | + self.automatically_acknowledge = ( |
| 333 | + automatically_acknowledge |
| 334 | + if automatically_acknowledge is not None |
| 335 | + else settings.automatically_acknowledge |
| 336 | + ) |
| 337 | + self.trim_every = trim_every if trim_every is not None else settings.trim_every |
239 | 338 |
|
240 |
| - self.subscription = Subscription(max_retries=max_retries) |
| 339 | + self.subscription = Subscription( |
| 340 | + max_retries=max_retries if max_retries is not None else settings.max_retries |
| 341 | + ) |
241 | 342 | self._retry_counts: dict[str, int] = {}
|
242 | 343 |
|
243 |
| - self.trim_every = trim_every |
244 | 344 | self._last_trimmed: Optional[float] = None
|
245 | 345 |
|
246 | 346 | async def _ensure_stream_and_group(self, redis_client: Redis) -> None:
|
|
0 commit comments