Skip to content

Commit 4653f32

Browse files
committed
Adding maintenance state to connections. Migrating and Migrated are not processed in in Moving state. Tests are updated
1 parent b0db983 commit 4653f32

File tree

4 files changed

+236
-50
lines changed

4 files changed

+236
-50
lines changed

redis/connection.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
MaintenanceEventConnectionHandler,
4242
MaintenanceEventPoolHandler,
4343
MaintenanceEventsConfig,
44+
MaintenanceState,
4445
)
4546
from .retry import Retry
4647
from .utils import (
@@ -284,6 +285,7 @@ def __init__(
284285
maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
285286
tmp_host_address: Optional[str] = None,
286287
tmp_relax_timeout: Optional[float] = -1,
288+
maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
287289
):
288290
"""
289291
Initialize a new Connection.
@@ -373,6 +375,7 @@ def __init__(
373375
self._should_reconnect = False
374376
self.tmp_host_address = tmp_host_address
375377
self.tmp_relax_timeout = tmp_relax_timeout
378+
self.maintenance_state = maintenance_state
376379

377380
def __repr__(self):
378381
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
@@ -829,6 +832,9 @@ def update_tmp_settings(
829832
if tmp_relax_timeout is not SENTINEL:
830833
self.tmp_relax_timeout = tmp_relax_timeout
831834

835+
def set_maintenance_state(self, state: "MaintenanceState"):
836+
self.maintenance_state = state
837+
832838

833839
class Connection(AbstractConnection):
834840
"Manages TCP communication to and from a Redis server"
@@ -1717,11 +1723,18 @@ def make_connection(self) -> "ConnectionInterface":
17171723
raise ConnectionError("Too many connections")
17181724
self._created_connections += 1
17191725

1726+
# Pass current maintenance_state to new connections
1727+
maintenance_state = self.connection_kwargs.get(
1728+
"maintenance_state", MaintenanceState.NONE
1729+
)
1730+
kwargs = dict(self.connection_kwargs)
1731+
kwargs["maintenance_state"] = maintenance_state
1732+
17201733
if self.cache is not None:
17211734
return CacheProxyConnection(
1722-
self.connection_class(**self.connection_kwargs), self.cache, self._lock
1735+
self.connection_class(**kwargs), self.cache, self._lock
17231736
)
1724-
return self.connection_class(**self.connection_kwargs)
1737+
return self.connection_class(**kwargs)
17251738

17261739
def release(self, connection: "Connection") -> None:
17271740
"Releases the connection back to the pool"
@@ -1946,6 +1959,16 @@ async def _mock(self, error: RedisError):
19461959
"""
19471960
pass
19481961

1962+
def set_maintenance_state_for_all(self, state: "MaintenanceState"):
1963+
with self._lock:
1964+
for conn in self._available_connections:
1965+
conn.set_maintenance_state(state)
1966+
for conn in self._in_use_connections:
1967+
conn.set_maintenance_state(state)
1968+
1969+
def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"):
1970+
self.connection_kwargs["maintenance_state"] = state
1971+
19491972

19501973
class BlockingConnectionPool(ConnectionPool):
19511974
"""
@@ -2040,15 +2063,20 @@ def make_connection(self):
20402063
if self._in_maintenance:
20412064
self._lock.acquire()
20422065
self._locked = True
2066+
# Pass current maintenance_state to new connections
2067+
maintenance_state = self.connection_kwargs.get(
2068+
"maintenance_state", MaintenanceState.NONE
2069+
)
2070+
kwargs = dict(self.connection_kwargs)
2071+
kwargs["maintenance_state"] = maintenance_state
20432072
if self.cache is not None:
20442073
connection = CacheProxyConnection(
2045-
self.connection_class(**self.connection_kwargs),
2074+
self.connection_class(**kwargs),
20462075
self.cache,
20472076
self._lock,
20482077
)
20492078
else:
2050-
connection = self.connection_class(**self.connection_kwargs)
2051-
2079+
connection = self.connection_class(**kwargs)
20522080
self._connections.append(connection)
20532081
return connection
20542082
finally:
@@ -2259,3 +2287,12 @@ def _update_maintenance_events_configs_for_connections(
22592287
def set_in_maintenance(self, in_maintenance: bool):
22602288
"""Set the maintenance mode for the connection pool."""
22612289
self._in_maintenance = in_maintenance
2290+
2291+
def set_maintenance_state_for_all(self, state: "MaintenanceState"):
2292+
with self._lock:
2293+
for conn in getattr(self, "_connections", []):
2294+
if conn:
2295+
conn.set_maintenance_state(state)
2296+
2297+
def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"):
2298+
self.connection_kwargs["maintenance_state"] = state

redis/maintenance_events.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import logging
23
import threading
34
import time
@@ -6,6 +7,13 @@
67

78
from redis.typing import Number
89

10+
11+
class MaintenanceState(enum.Enum):
12+
NONE = "none"
13+
MOVING = "moving"
14+
MIGRATING = "migrating"
15+
16+
917
if TYPE_CHECKING:
1018
from redis.connection import (
1119
BlockingConnectionPool,
@@ -351,6 +359,9 @@ def handle_node_moving_event(self, event: NodeMovingEvent):
351359
):
352360
if getattr(self.pool, "set_in_maintenance", False):
353361
self.pool.set_in_maintenance(True)
362+
# Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance)
363+
self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING)
364+
self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING)
354365
# edit the config for new connections until the notification expires
355366
self.pool.update_connection_kwargs_with_tmp_settings(
356367
tmp_host_address=event.new_node_host,
@@ -368,7 +379,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent):
368379
tmp_host_address=event.new_node_host,
369380
tmp_relax_timeout=self.config.relax_timeout,
370381
)
371-
372382
# take care for the inactive connections in the pool
373383
# delete them and create new ones
374384
self.pool.disconnect_and_reconfigure_free_connections(
@@ -388,16 +398,19 @@ def handle_node_moved_event(self):
388398
tmp_host_address=None,
389399
tmp_relax_timeout=-1,
390400
)
401+
# Clear state to NONE in kwargs immediately after updating tmp kwargs
402+
self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE)
391403
with self.pool._lock:
392404
if self.config.is_relax_timeouts_enabled():
393405
# reset the timeout for existing connections
394406
self.pool.update_connections_current_timeout(
395407
relax_timeout=-1, include_free_connections=True
396408
)
397-
398409
self.pool.update_connections_tmp_settings(
399410
tmp_host_address=None, tmp_relax_timeout=-1
400411
)
412+
# Clear state to NONE for all connections
413+
self.pool.set_maintenance_state_for_all(MaintenanceState.NONE)
401414

402415

403416
class MaintenanceEventConnectionHandler:
@@ -416,17 +429,24 @@ def handle_event(self, event: MaintenanceEvent):
416429
logging.error(f"Unhandled event type: {event}")
417430

418431
def handle_migrating_event(self, notification: NodeMigratingEvent):
419-
if not self.config.is_relax_timeouts_enabled():
432+
if (
433+
self.connection.maintenance_state == MaintenanceState.MOVING
434+
or not self.config.is_relax_timeouts_enabled()
435+
):
420436
return
421-
437+
self.connection.set_maintenance_state(MaintenanceState.MIGRATING)
422438
# extend the timeout for all created connections
423439
self.connection.update_current_socket_timeout(self.config.relax_timeout)
424440
self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout)
425441

426442
def handle_migration_completed_event(self, notification: "NodeMigratedEvent"):
427-
if not self.config.is_relax_timeouts_enabled():
443+
# Only reset timeouts if state is not MOVING and relax timeouts are enabled
444+
if (
445+
self.connection.maintenance_state == MaintenanceState.MOVING
446+
or not self.config.is_relax_timeouts_enabled()
447+
):
428448
return
429-
449+
self.connection.set_maintenance_state(MaintenanceState.NONE)
430450
# Node migration completed - reset the connection
431451
# timeouts by providing -1 as the relax timeout
432452
self.connection.update_current_socket_timeout(-1)

tests/test_connection_pool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import redis
1010
from redis.cache import CacheConfig
1111
from redis.connection import CacheProxyConnection, Connection, to_bool
12+
from redis.maintenance_events import MaintenanceState
1213
from redis.utils import SSL_AVAILABLE
1314

1415
from .conftest import (
@@ -53,10 +54,15 @@ def get_pool(
5354
return pool
5455

5556
def test_connection_creation(self):
56-
connection_kwargs = {"foo": "bar", "biz": "baz"}
57+
connection_kwargs = {
58+
"foo": "bar",
59+
"biz": "baz",
60+
"maintenance_state": MaintenanceState.NONE,
61+
}
5762
pool = self.get_pool(
5863
connection_kwargs=connection_kwargs, connection_class=DummyConnection
5964
)
65+
6066
connection = pool.get_connection()
6167
assert isinstance(connection, DummyConnection)
6268
assert connection.kwargs == connection_kwargs
@@ -150,7 +156,9 @@ def test_connection_creation(self, master_host):
150156
"host": master_host[0],
151157
"port": master_host[1],
152158
}
159+
153160
pool = self.get_pool(connection_kwargs=connection_kwargs)
161+
connection_kwargs["maintenance_state"] = MaintenanceState.NONE
154162
connection = pool.get_connection()
155163
assert isinstance(connection, DummyConnection)
156164
assert connection.kwargs == connection_kwargs

0 commit comments

Comments
 (0)