|
11 | 11 | from queue import Empty, Full, LifoQueue
|
12 | 12 | from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
13 | 13 | from urllib.parse import parse_qs, unquote, urlparse
|
| 14 | +import enum |
14 | 15 |
|
15 | 16 | from redis.cache import (
|
16 | 17 | CacheEntry,
|
|
41 | 42 | MaintenanceEventConnectionHandler,
|
42 | 43 | MaintenanceEventPoolHandler,
|
43 | 44 | MaintenanceEventsConfig,
|
| 45 | + MaintenanceState, |
44 | 46 | )
|
45 | 47 | from .retry import Retry
|
46 | 48 | from .utils import (
|
@@ -284,6 +286,7 @@ def __init__(
|
284 | 286 | maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
|
285 | 287 | tmp_host_address: Optional[str] = None,
|
286 | 288 | tmp_relax_timeout: Optional[float] = -1,
|
| 289 | + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, |
287 | 290 | ):
|
288 | 291 | """
|
289 | 292 | Initialize a new Connection.
|
@@ -373,6 +376,7 @@ def __init__(
|
373 | 376 | self._should_reconnect = False
|
374 | 377 | self.tmp_host_address = tmp_host_address
|
375 | 378 | self.tmp_relax_timeout = tmp_relax_timeout
|
| 379 | + self.maintenance_state = maintenance_state |
376 | 380 |
|
377 | 381 | def __repr__(self):
|
378 | 382 | repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
|
@@ -829,6 +833,9 @@ def update_tmp_settings(
|
829 | 833 | if tmp_relax_timeout is not SENTINEL:
|
830 | 834 | self.tmp_relax_timeout = tmp_relax_timeout
|
831 | 835 |
|
| 836 | + def set_maintenance_state(self, state: "MaintenanceState"): |
| 837 | + self.maintenance_state = state |
| 838 | + |
832 | 839 |
|
833 | 840 | class Connection(AbstractConnection):
|
834 | 841 | "Manages TCP communication to and from a Redis server"
|
@@ -1717,11 +1724,18 @@ def make_connection(self) -> "ConnectionInterface":
|
1717 | 1724 | raise ConnectionError("Too many connections")
|
1718 | 1725 | self._created_connections += 1
|
1719 | 1726 |
|
| 1727 | + # Pass current maintenance_state to new connections |
| 1728 | + maintenance_state = self.connection_kwargs.get( |
| 1729 | + "maintenance_state", MaintenanceState.NONE |
| 1730 | + ) |
| 1731 | + kwargs = dict(self.connection_kwargs) |
| 1732 | + kwargs["maintenance_state"] = maintenance_state |
| 1733 | + |
1720 | 1734 | if self.cache is not None:
|
1721 | 1735 | return CacheProxyConnection(
|
1722 |
| - self.connection_class(**self.connection_kwargs), self.cache, self._lock |
| 1736 | + self.connection_class(**kwargs), self.cache, self._lock |
1723 | 1737 | )
|
1724 |
| - return self.connection_class(**self.connection_kwargs) |
| 1738 | + return self.connection_class(**kwargs) |
1725 | 1739 |
|
1726 | 1740 | def release(self, connection: "Connection") -> None:
|
1727 | 1741 | "Releases the connection back to the pool"
|
@@ -1946,6 +1960,16 @@ async def _mock(self, error: RedisError):
|
1946 | 1960 | """
|
1947 | 1961 | pass
|
1948 | 1962 |
|
| 1963 | + def set_maintenance_state_for_all(self, state: "MaintenanceState"): |
| 1964 | + with self._lock: |
| 1965 | + for conn in self._available_connections: |
| 1966 | + conn.set_maintenance_state(state) |
| 1967 | + for conn in self._in_use_connections: |
| 1968 | + conn.set_maintenance_state(state) |
| 1969 | + |
| 1970 | + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): |
| 1971 | + self.connection_kwargs["maintenance_state"] = state |
| 1972 | + |
1949 | 1973 |
|
1950 | 1974 | class BlockingConnectionPool(ConnectionPool):
|
1951 | 1975 | """
|
@@ -2040,15 +2064,20 @@ def make_connection(self):
|
2040 | 2064 | if self._in_maintenance:
|
2041 | 2065 | self._lock.acquire()
|
2042 | 2066 | self._locked = True
|
| 2067 | + # Pass current maintenance_state to new connections |
| 2068 | + maintenance_state = self.connection_kwargs.get( |
| 2069 | + "maintenance_state", MaintenanceState.NONE |
| 2070 | + ) |
| 2071 | + kwargs = dict(self.connection_kwargs) |
| 2072 | + kwargs["maintenance_state"] = maintenance_state |
2043 | 2073 | if self.cache is not None:
|
2044 | 2074 | connection = CacheProxyConnection(
|
2045 |
| - self.connection_class(**self.connection_kwargs), |
| 2075 | + self.connection_class(**kwargs), |
2046 | 2076 | self.cache,
|
2047 | 2077 | self._lock,
|
2048 | 2078 | )
|
2049 | 2079 | else:
|
2050 |
| - connection = self.connection_class(**self.connection_kwargs) |
2051 |
| - |
| 2080 | + connection = self.connection_class(**kwargs) |
2052 | 2081 | self._connections.append(connection)
|
2053 | 2082 | return connection
|
2054 | 2083 | finally:
|
@@ -2259,3 +2288,12 @@ def _update_maintenance_events_configs_for_connections(
|
2259 | 2288 | def set_in_maintenance(self, in_maintenance: bool):
|
2260 | 2289 | """Set the maintenance mode for the connection pool."""
|
2261 | 2290 | self._in_maintenance = in_maintenance
|
| 2291 | + |
| 2292 | + def set_maintenance_state_for_all(self, state: "MaintenanceState"): |
| 2293 | + with self._lock: |
| 2294 | + for conn in getattr(self, "_connections", []): |
| 2295 | + if conn: |
| 2296 | + conn.set_maintenance_state(state) |
| 2297 | + |
| 2298 | + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): |
| 2299 | + self.connection_kwargs["maintenance_state"] = state |
0 commit comments