|
41 | 41 | MaintenanceEventConnectionHandler,
|
42 | 42 | MaintenanceEventPoolHandler,
|
43 | 43 | MaintenanceEventsConfig,
|
| 44 | + MaintenanceState, |
44 | 45 | )
|
45 | 46 | from .retry import Retry
|
46 | 47 | from .utils import (
|
@@ -284,6 +285,7 @@ def __init__(
|
284 | 285 | maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
|
285 | 286 | tmp_host_address: Optional[str] = None,
|
286 | 287 | tmp_relax_timeout: Optional[float] = -1,
|
| 288 | + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, |
287 | 289 | ):
|
288 | 290 | """
|
289 | 291 | Initialize a new Connection.
|
@@ -373,6 +375,7 @@ def __init__(
|
373 | 375 | self._should_reconnect = False
|
374 | 376 | self.tmp_host_address = tmp_host_address
|
375 | 377 | self.tmp_relax_timeout = tmp_relax_timeout
|
| 378 | + self.maintenance_state = maintenance_state |
376 | 379 |
|
377 | 380 | def __repr__(self):
|
378 | 381 | repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
|
@@ -829,6 +832,9 @@ def update_tmp_settings(
|
829 | 832 | if tmp_relax_timeout is not SENTINEL:
|
830 | 833 | self.tmp_relax_timeout = tmp_relax_timeout
|
831 | 834 |
|
| 835 | + def set_maintenance_state(self, state: "MaintenanceState"): |
| 836 | + self.maintenance_state = state |
| 837 | + |
832 | 838 |
|
833 | 839 | class Connection(AbstractConnection):
|
834 | 840 | "Manages TCP communication to and from a Redis server"
|
@@ -1717,11 +1723,18 @@ def make_connection(self) -> "ConnectionInterface":
|
1717 | 1723 | raise ConnectionError("Too many connections")
|
1718 | 1724 | self._created_connections += 1
|
1719 | 1725 |
|
| 1726 | + # Pass current maintenance_state to new connections |
| 1727 | + maintenance_state = self.connection_kwargs.get( |
| 1728 | + "maintenance_state", MaintenanceState.NONE |
| 1729 | + ) |
| 1730 | + kwargs = dict(self.connection_kwargs) |
| 1731 | + kwargs["maintenance_state"] = maintenance_state |
| 1732 | + |
1720 | 1733 | if self.cache is not None:
|
1721 | 1734 | return CacheProxyConnection(
|
1722 |
| - self.connection_class(**self.connection_kwargs), self.cache, self._lock |
| 1735 | + self.connection_class(**kwargs), self.cache, self._lock |
1723 | 1736 | )
|
1724 |
| - return self.connection_class(**self.connection_kwargs) |
| 1737 | + return self.connection_class(**kwargs) |
1725 | 1738 |
|
1726 | 1739 | def release(self, connection: "Connection") -> None:
|
1727 | 1740 | "Releases the connection back to the pool"
|
@@ -1946,6 +1959,16 @@ async def _mock(self, error: RedisError):
|
1946 | 1959 | """
|
1947 | 1960 | pass
|
1948 | 1961 |
|
| 1962 | + def set_maintenance_state_for_all(self, state: "MaintenanceState"): |
| 1963 | + with self._lock: |
| 1964 | + for conn in self._available_connections: |
| 1965 | + conn.set_maintenance_state(state) |
| 1966 | + for conn in self._in_use_connections: |
| 1967 | + conn.set_maintenance_state(state) |
| 1968 | + |
| 1969 | + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): |
| 1970 | + self.connection_kwargs["maintenance_state"] = state |
| 1971 | + |
1949 | 1972 |
|
1950 | 1973 | class BlockingConnectionPool(ConnectionPool):
|
1951 | 1974 | """
|
@@ -2040,15 +2063,20 @@ def make_connection(self):
|
2040 | 2063 | if self._in_maintenance:
|
2041 | 2064 | self._lock.acquire()
|
2042 | 2065 | self._locked = True
|
| 2066 | + # Pass current maintenance_state to new connections |
| 2067 | + maintenance_state = self.connection_kwargs.get( |
| 2068 | + "maintenance_state", MaintenanceState.NONE |
| 2069 | + ) |
| 2070 | + kwargs = dict(self.connection_kwargs) |
| 2071 | + kwargs["maintenance_state"] = maintenance_state |
2043 | 2072 | if self.cache is not None:
|
2044 | 2073 | connection = CacheProxyConnection(
|
2045 |
| - self.connection_class(**self.connection_kwargs), |
| 2074 | + self.connection_class(**kwargs), |
2046 | 2075 | self.cache,
|
2047 | 2076 | self._lock,
|
2048 | 2077 | )
|
2049 | 2078 | else:
|
2050 |
| - connection = self.connection_class(**self.connection_kwargs) |
2051 |
| - |
| 2079 | + connection = self.connection_class(**kwargs) |
2052 | 2080 | self._connections.append(connection)
|
2053 | 2081 | return connection
|
2054 | 2082 | finally:
|
@@ -2259,3 +2287,12 @@ def _update_maintenance_events_configs_for_connections(
|
2259 | 2287 | def set_in_maintenance(self, in_maintenance: bool):
|
2260 | 2288 | """Set the maintenance mode for the connection pool."""
|
2261 | 2289 | self._in_maintenance = in_maintenance
|
| 2290 | + |
| 2291 | + def set_maintenance_state_for_all(self, state: "MaintenanceState"): |
| 2292 | + with self._lock: |
| 2293 | + for conn in getattr(self, "_connections", []): |
| 2294 | + if conn: |
| 2295 | + conn.set_maintenance_state(state) |
| 2296 | + |
| 2297 | + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): |
| 2298 | + self.connection_kwargs["maintenance_state"] = state |
0 commit comments