Skip to content

Commit b8c6349

Browse files
committed
Adding maintenance state to connections. Migrating and Migrated are not processed in in Moving state. Tests are updated
1 parent b0db983 commit b8c6349

File tree

4 files changed

+237
-50
lines changed

4 files changed

+237
-50
lines changed

redis/connection.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from queue import Empty, Full, LifoQueue
1212
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
1313
from urllib.parse import parse_qs, unquote, urlparse
14+
import enum
1415

1516
from redis.cache import (
1617
CacheEntry,
@@ -41,6 +42,7 @@
4142
MaintenanceEventConnectionHandler,
4243
MaintenanceEventPoolHandler,
4344
MaintenanceEventsConfig,
45+
MaintenanceState,
4446
)
4547
from .retry import Retry
4648
from .utils import (
@@ -284,6 +286,7 @@ def __init__(
284286
maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
285287
tmp_host_address: Optional[str] = None,
286288
tmp_relax_timeout: Optional[float] = -1,
289+
maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
287290
):
288291
"""
289292
Initialize a new Connection.
@@ -373,6 +376,7 @@ def __init__(
373376
self._should_reconnect = False
374377
self.tmp_host_address = tmp_host_address
375378
self.tmp_relax_timeout = tmp_relax_timeout
379+
self.maintenance_state = maintenance_state
376380

377381
def __repr__(self):
378382
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
@@ -829,6 +833,9 @@ def update_tmp_settings(
829833
if tmp_relax_timeout is not SENTINEL:
830834
self.tmp_relax_timeout = tmp_relax_timeout
831835

836+
def set_maintenance_state(self, state: "MaintenanceState"):
837+
self.maintenance_state = state
838+
832839

833840
class Connection(AbstractConnection):
834841
"Manages TCP communication to and from a Redis server"
@@ -1717,11 +1724,18 @@ def make_connection(self) -> "ConnectionInterface":
17171724
raise ConnectionError("Too many connections")
17181725
self._created_connections += 1
17191726

1727+
# Pass current maintenance_state to new connections
1728+
maintenance_state = self.connection_kwargs.get(
1729+
"maintenance_state", MaintenanceState.NONE
1730+
)
1731+
kwargs = dict(self.connection_kwargs)
1732+
kwargs["maintenance_state"] = maintenance_state
1733+
17201734
if self.cache is not None:
17211735
return CacheProxyConnection(
1722-
self.connection_class(**self.connection_kwargs), self.cache, self._lock
1736+
self.connection_class(**kwargs), self.cache, self._lock
17231737
)
1724-
return self.connection_class(**self.connection_kwargs)
1738+
return self.connection_class(**kwargs)
17251739

17261740
def release(self, connection: "Connection") -> None:
17271741
"Releases the connection back to the pool"
@@ -1946,6 +1960,16 @@ async def _mock(self, error: RedisError):
19461960
"""
19471961
pass
19481962

1963+
def set_maintenance_state_for_all(self, state: "MaintenanceState"):
1964+
with self._lock:
1965+
for conn in self._available_connections:
1966+
conn.set_maintenance_state(state)
1967+
for conn in self._in_use_connections:
1968+
conn.set_maintenance_state(state)
1969+
1970+
def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"):
1971+
self.connection_kwargs["maintenance_state"] = state
1972+
19491973

19501974
class BlockingConnectionPool(ConnectionPool):
19511975
"""
@@ -2040,15 +2064,20 @@ def make_connection(self):
20402064
if self._in_maintenance:
20412065
self._lock.acquire()
20422066
self._locked = True
2067+
# Pass current maintenance_state to new connections
2068+
maintenance_state = self.connection_kwargs.get(
2069+
"maintenance_state", MaintenanceState.NONE
2070+
)
2071+
kwargs = dict(self.connection_kwargs)
2072+
kwargs["maintenance_state"] = maintenance_state
20432073
if self.cache is not None:
20442074
connection = CacheProxyConnection(
2045-
self.connection_class(**self.connection_kwargs),
2075+
self.connection_class(**kwargs),
20462076
self.cache,
20472077
self._lock,
20482078
)
20492079
else:
2050-
connection = self.connection_class(**self.connection_kwargs)
2051-
2080+
connection = self.connection_class(**kwargs)
20522081
self._connections.append(connection)
20532082
return connection
20542083
finally:
@@ -2259,3 +2288,12 @@ def _update_maintenance_events_configs_for_connections(
22592288
def set_in_maintenance(self, in_maintenance: bool):
22602289
"""Set the maintenance mode for the connection pool."""
22612290
self._in_maintenance = in_maintenance
2291+
2292+
def set_maintenance_state_for_all(self, state: "MaintenanceState"):
2293+
with self._lock:
2294+
for conn in getattr(self, "_connections", []):
2295+
if conn:
2296+
conn.set_maintenance_state(state)
2297+
2298+
def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"):
2299+
self.connection_kwargs["maintenance_state"] = state

redis/maintenance_events.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import time
44
from abc import ABC, abstractmethod
55
from typing import TYPE_CHECKING, Optional, Union
6+
import enum
7+
8+
9+
class MaintenanceState(enum.Enum):
10+
NONE = "none"
11+
MOVING = "moving"
12+
MIGRATING = "migrating"
13+
614

715
from redis.typing import Number
816

@@ -351,6 +359,9 @@ def handle_node_moving_event(self, event: NodeMovingEvent):
351359
):
352360
if getattr(self.pool, "set_in_maintenance", False):
353361
self.pool.set_in_maintenance(True)
362+
# Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance)
363+
self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING)
364+
self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING)
354365
# edit the config for new connections until the notification expires
355366
self.pool.update_connection_kwargs_with_tmp_settings(
356367
tmp_host_address=event.new_node_host,
@@ -368,7 +379,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent):
368379
tmp_host_address=event.new_node_host,
369380
tmp_relax_timeout=self.config.relax_timeout,
370381
)
371-
372382
# take care for the inactive connections in the pool
373383
# delete them and create new ones
374384
self.pool.disconnect_and_reconfigure_free_connections(
@@ -388,16 +398,19 @@ def handle_node_moved_event(self):
388398
tmp_host_address=None,
389399
tmp_relax_timeout=-1,
390400
)
401+
# Clear state to NONE in kwargs immediately after updating tmp kwargs
402+
self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE)
391403
with self.pool._lock:
392404
if self.config.is_relax_timeouts_enabled():
393405
# reset the timeout for existing connections
394406
self.pool.update_connections_current_timeout(
395407
relax_timeout=-1, include_free_connections=True
396408
)
397-
398409
self.pool.update_connections_tmp_settings(
399410
tmp_host_address=None, tmp_relax_timeout=-1
400411
)
412+
# Clear state to NONE for all connections
413+
self.pool.set_maintenance_state_for_all(MaintenanceState.NONE)
401414

402415

403416
class MaintenanceEventConnectionHandler:
@@ -416,17 +429,24 @@ def handle_event(self, event: MaintenanceEvent):
416429
logging.error(f"Unhandled event type: {event}")
417430

418431
def handle_migrating_event(self, notification: NodeMigratingEvent):
419-
if not self.config.is_relax_timeouts_enabled():
432+
if (
433+
self.connection.maintenance_state == MaintenanceState.MOVING
434+
or not self.config.is_relax_timeouts_enabled()
435+
):
420436
return
421-
437+
self.connection.set_maintenance_state(MaintenanceState.MIGRATING)
422438
# extend the timeout for all created connections
423439
self.connection.update_current_socket_timeout(self.config.relax_timeout)
424440
self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout)
425441

426442
def handle_migration_completed_event(self, notification: "NodeMigratedEvent"):
427-
if not self.config.is_relax_timeouts_enabled():
443+
# Only reset timeouts if state is not MOVING and relax timeouts are enabled
444+
if (
445+
self.connection.maintenance_state == MaintenanceState.MOVING
446+
or not self.config.is_relax_timeouts_enabled()
447+
):
428448
return
429-
449+
self.connection.set_maintenance_state(MaintenanceState.NONE)
430450
# Node migration completed - reset the connection
431451
# timeouts by providing -1 as the relax timeout
432452
self.connection.update_current_socket_timeout(-1)

tests/test_connection_pool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import redis
1010
from redis.cache import CacheConfig
1111
from redis.connection import CacheProxyConnection, Connection, to_bool
12+
from redis.maintenance_events import MaintenanceState
1213
from redis.utils import SSL_AVAILABLE
1314

1415
from .conftest import (
@@ -53,10 +54,15 @@ def get_pool(
5354
return pool
5455

5556
def test_connection_creation(self):
56-
connection_kwargs = {"foo": "bar", "biz": "baz"}
57+
connection_kwargs = {
58+
"foo": "bar",
59+
"biz": "baz",
60+
"maintenance_state": MaintenanceState.NONE,
61+
}
5762
pool = self.get_pool(
5863
connection_kwargs=connection_kwargs, connection_class=DummyConnection
5964
)
65+
6066
connection = pool.get_connection()
6167
assert isinstance(connection, DummyConnection)
6268
assert connection.kwargs == connection_kwargs
@@ -150,7 +156,9 @@ def test_connection_creation(self, master_host):
150156
"host": master_host[0],
151157
"port": master_host[1],
152158
}
159+
153160
pool = self.get_pool(connection_kwargs=connection_kwargs)
161+
connection_kwargs["maintenance_state"] = MaintenanceState.NONE
154162
connection = pool.get_connection()
155163
assert isinstance(connection, DummyConnection)
156164
assert connection.kwargs == connection_kwargs

0 commit comments

Comments
 (0)