Skip to content

Commit aba5247

Browse files
authored
add redis publisher/consumer settings (#18160)
1 parent 857102f commit aba5247

File tree

2 files changed

+156
-18
lines changed

2 files changed

+156
-18
lines changed

src/integrations/prefect-redis/prefect_redis/messaging.py

Lines changed: 118 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import json
35
import socket
@@ -8,6 +10,7 @@
810
from functools import partial
911
from types import TracebackType
1012
from typing import (
13+
Annotated,
1114
Any,
1215
AsyncGenerator,
1316
Awaitable,
@@ -20,6 +23,7 @@
2023
)
2124

2225
import orjson
26+
from pydantic import BeforeValidator, Field
2327
from redis.asyncio import Redis
2428
from redis.exceptions import ResponseError
2529
from typing_extensions import Self
@@ -33,12 +37,84 @@
3337
StopConsumer,
3438
)
3539
from prefect.server.utilities.messaging import Publisher as _Publisher
40+
from prefect.settings.base import PrefectBaseSettings, build_settings_config
3641
from prefect_redis.client import get_async_redis_client
3742

3843
logger = get_logger(__name__)
3944

4045
M = TypeVar("M", bound=Message)
4146

47+
48+
def _interpret_string_as_timedelta_seconds(value: timedelta | str) -> timedelta:
49+
"""Interpret a string as a timedelta in seconds."""
50+
if isinstance(value, str):
51+
return timedelta(seconds=int(value))
52+
return value
53+
54+
55+
TimeDelta = Annotated[
56+
Union[str, timedelta],
57+
BeforeValidator(_interpret_string_as_timedelta_seconds),
58+
]
59+
60+
61+
class RedisMessagingPublisherSettings(PrefectBaseSettings):
62+
"""Settings for the Redis messaging publisher.
63+
64+
No settings are required to be set by the user but any of the settings can be
65+
overridden by the user using environment variables.
66+
67+
Example:
68+
```
69+
PREFECT_REDIS_MESSAGING_PUBLISHER_BATCH_SIZE=10
70+
PREFECT_REDIS_MESSAGING_PUBLISHER_PUBLISH_EVERY=10
71+
PREFECT_REDIS_MESSAGING_PUBLISHER_DEDUPLICATE_BY=message_id
72+
```
73+
"""
74+
75+
model_config = build_settings_config(
76+
(
77+
"redis",
78+
"messaging",
79+
"publisher",
80+
),
81+
)
82+
batch_size: int = Field(default=5)
83+
publish_every: TimeDelta = Field(default=timedelta(seconds=10))
84+
deduplicate_by: Optional[str] = Field(default=None)
85+
86+
87+
class RedisMessagingConsumerSettings(PrefectBaseSettings):
88+
"""Settings for the Redis messaging consumer.
89+
90+
No settings are required to be set by the user but any of the settings can be
91+
overridden by the user using environment variables.
92+
93+
Example:
94+
```
95+
PREFECT_REDIS_MESSAGING_CONSUMER_BLOCK=10
96+
PREFECT_REDIS_MESSAGING_CONSUMER_MIN_IDLE_TIME=10
97+
PREFECT_REDIS_MESSAGING_CONSUMER_MAX_RETRIES=3
98+
PREFECT_REDIS_MESSAGING_CONSUMER_TRIM_EVERY=60
99+
```
100+
"""
101+
102+
model_config = build_settings_config(
103+
(
104+
"redis",
105+
"messaging",
106+
"consumer",
107+
),
108+
)
109+
block: TimeDelta = Field(default=timedelta(seconds=1))
110+
min_idle_time: TimeDelta = Field(default=timedelta(seconds=0))
111+
max_retries: int = Field(default=3)
112+
trim_every: TimeDelta = Field(default=timedelta(seconds=60))
113+
should_process_pending_messages: bool = Field(default=True)
114+
starting_message_id: str = Field(default="0")
115+
automatically_acknowledge: bool = Field(default=True)
116+
117+
42118
MESSAGE_DEDUPLICATION_LOOKBACK = timedelta(minutes=5)
43119

44120

@@ -130,14 +206,20 @@ def __init__(
130206
topic: str,
131207
cache: _Cache,
132208
deduplicate_by: Optional[str] = None,
133-
batch_size: int = 5,
209+
batch_size: Optional[int] = None,
134210
publish_every: Optional[timedelta] = None,
135211
):
212+
settings = RedisMessagingPublisherSettings()
213+
136214
self.stream = topic # Use topic as stream name
137215
self.cache = cache
138-
self.deduplicate_by = deduplicate_by
139-
self.batch_size = batch_size
140-
self.publish_every = publish_every
216+
self.deduplicate_by = (
217+
deduplicate_by if deduplicate_by is not None else settings.deduplicate_by
218+
)
219+
self.batch_size = batch_size if batch_size is not None else settings.batch_size
220+
self.publish_every = (
221+
publish_every if publish_every is not None else settings.publish_every
222+
)
141223
self._periodic_task: Optional[asyncio.Task[None]] = None
142224

143225
async def __aenter__(self) -> Self:
@@ -220,27 +302,45 @@ def __init__(
220302
topic: str,
221303
name: Optional[str] = None,
222304
group: Optional[str] = None,
223-
block: timedelta = timedelta(seconds=1),
224-
min_idle_time: timedelta = timedelta(seconds=0),
225-
should_process_pending_messages: bool = True,
226-
starting_message_id: str = "0",
227-
automatically_acknowledge: bool = True,
228-
max_retries: int = 3,
229-
trim_every: timedelta = timedelta(seconds=60),
305+
block: Optional[timedelta] = None,
306+
min_idle_time: Optional[timedelta] = None,
307+
should_process_pending_messages: Optional[bool] = None,
308+
starting_message_id: Optional[str] = None,
309+
automatically_acknowledge: Optional[bool] = None,
310+
max_retries: Optional[int] = None,
311+
trim_every: Optional[timedelta] = None,
230312
):
313+
settings = RedisMessagingConsumerSettings()
314+
231315
self.name = name or topic
232316
self.stream = topic # Use topic as stream name
233317
self.group = group or topic # Use topic as default group name
234-
self.block = block
235-
self.min_idle_time = min_idle_time
236-
self.should_process_pending_messages = should_process_pending_messages
237-
self.starting_message_id = starting_message_id
238-
self.automatically_acknowledge = automatically_acknowledge
318+
self.block = block if block is not None else settings.block
319+
self.min_idle_time = (
320+
min_idle_time if min_idle_time is not None else settings.min_idle_time
321+
)
322+
self.should_process_pending_messages = (
323+
should_process_pending_messages
324+
if should_process_pending_messages is not None
325+
else settings.should_process_pending_messages
326+
)
327+
self.starting_message_id = (
328+
starting_message_id
329+
if starting_message_id is not None
330+
else settings.starting_message_id
331+
)
332+
self.automatically_acknowledge = (
333+
automatically_acknowledge
334+
if automatically_acknowledge is not None
335+
else settings.automatically_acknowledge
336+
)
337+
self.trim_every = trim_every if trim_every is not None else settings.trim_every
239338

240-
self.subscription = Subscription(max_retries=max_retries)
339+
self.subscription = Subscription(
340+
max_retries=max_retries if max_retries is not None else settings.max_retries
341+
)
241342
self._retry_counts: dict[str, int] = {}
242343

243-
self.trim_every = trim_every
244344
self._last_trimmed: Optional[float] = None
245345

246346
async def _ensure_stream_and_group(self, redis_client: Redis) -> None:

src/integrations/prefect-redis/tests/test_messaging.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
Consumer,
1313
Message,
1414
Publisher,
15+
RedisMessagingConsumerSettings,
16+
RedisMessagingPublisherSettings,
1517
StopConsumer,
1618
)
1719
from redis.asyncio import Redis
@@ -442,3 +444,39 @@ async def handler(message: Message):
442444
await consumer_task
443445

444446
assert captured_events == [emitted_event]
447+
448+
449+
class TestRedisMessagingSettings:
450+
"""Test the Redis messaging settings."""
451+
452+
def test_publisher_settings(self):
453+
"""Test Redis publisher settings."""
454+
settings = RedisMessagingPublisherSettings()
455+
assert settings.batch_size == 5
456+
assert settings.publish_every == timedelta(seconds=10)
457+
assert settings.deduplicate_by is None
458+
459+
def test_consumer_settings(self):
460+
"""Test Redis consumer settings."""
461+
settings = RedisMessagingConsumerSettings()
462+
assert settings.block == timedelta(seconds=1)
463+
assert settings.min_idle_time == timedelta(seconds=0)
464+
assert settings.max_retries == 3
465+
assert settings.trim_every == timedelta(seconds=60)
466+
assert settings.should_process_pending_messages is True
467+
assert settings.starting_message_id == "0"
468+
assert settings.automatically_acknowledge is True
469+
470+
def test_publisher_settings_can_be_overridden(
471+
self, monkeypatch: pytest.MonkeyPatch
472+
):
473+
"""Test that Redis publisher settings can be overridden."""
474+
monkeypatch.setenv("PREFECT_REDIS_MESSAGING_PUBLISHER_BATCH_SIZE", "10")
475+
settings = RedisMessagingPublisherSettings()
476+
assert settings.batch_size == 10
477+
478+
def test_consumer_settings_can_be_overridden(self, monkeypatch: pytest.MonkeyPatch):
479+
"""Test that Redis consumer settings can be overridden."""
480+
monkeypatch.setenv("PREFECT_REDIS_MESSAGING_CONSUMER_BLOCK", "10")
481+
settings = RedisMessagingConsumerSettings()
482+
assert settings.block == timedelta(seconds=10)

0 commit comments

Comments
 (0)