From ac862803d77b2c2ec6305e6032807e8f0ca32291 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 13 Jun 2025 14:32:15 +0300 Subject: [PATCH 01/22] Added Database, Healthcheck, CircuitBreaker, FailureDetector --- redis/client.py | 7 +- redis/event.py | 50 +++++++-- redis/multidb/__init__.py | 0 redis/multidb/circuit.py | 79 ++++++++++++++ redis/multidb/database.py | 112 ++++++++++++++++++++ redis/multidb/event.py | 0 redis/multidb/failure_detector.py | 77 ++++++++++++++ redis/multidb/healthcheck.py | 92 ++++++++++++++++ tests/test_event.py | 55 ++++++++++ tests/test_multidb/__init__.py | 0 tests/test_multidb/conftest.py | 22 ++++ tests/test_multidb/test_circuit.py | 41 +++++++ tests/test_multidb/test_database.py | 57 ++++++++++ tests/test_multidb/test_failure_detector.py | 94 ++++++++++++++++ 14 files changed, 677 insertions(+), 9 deletions(-) create mode 100644 redis/multidb/__init__.py create mode 100644 redis/multidb/circuit.py create mode 100644 redis/multidb/database.py create mode 100644 redis/multidb/event.py create mode 100644 redis/multidb/failure_detector.py create mode 100644 redis/multidb/healthcheck.py create mode 100644 tests/test_event.py create mode 100644 tests/test_multidb/__init__.py create mode 100644 tests/test_multidb/conftest.py create mode 100644 tests/test_multidb/test_circuit.py create mode 100644 tests/test_multidb/test_database.py create mode 100644 tests/test_multidb/test_failure_detector.py diff --git a/redis/client.py b/redis/client.py index 28e9a82f76..5fa441b4e8 100755 --- a/redis/client.py +++ b/redis/client.py @@ -45,7 +45,7 @@ AfterPubSubConnectionInstantiationEvent, AfterSingleConnectionInstantiationEvent, ClientType, - EventDispatcher, + EventDispatcher, OnCommandFailEvent, ) from redis.exceptions import ( ConnectionError, @@ -605,7 +605,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) - def _close_connection(self, conn) -> None: + def _close_connection(self, conn, error, *args) -> None: """ Close the connection before retrying. @@ -616,6 +616,7 @@ def _close_connection(self, conn) -> None: do a health check as part of the send_command logic(on connection level). """ + self._event_dispatcher.dispatch(OnCommandFailEvent(args, error)) conn.disconnect() # COMMAND EXECUTION AND PROTOCOL PARSING @@ -635,7 +636,7 @@ def _execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda _: self._close_connection(conn), + lambda error: self._close_connection(conn, error, *args), ) finally: if self._single_connection_client: diff --git a/redis/event.py b/redis/event.py index 5cc6c0017c..4cf6022f7e 100644 --- a/redis/event.py +++ b/redis/event.py @@ -2,7 +2,7 @@ import threading from abc import ABC, abstractmethod from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Type from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider @@ -42,6 +42,11 @@ def dispatch(self, event: object): async def dispatch_async(self, event: object): pass + @abstractmethod + def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]): + """Register additional listeners.""" + pass + class EventException(Exception): """ @@ -56,11 +61,14 @@ def __init__(self, exception: Exception, event: object): class EventDispatcher(EventDispatcherInterface): # TODO: Make dispatcher to accept external mappings. - def __init__(self): + def __init__( + self, + event_listeners: Optional[Dict[Type[object], List[EventListenerInterface]]] = None, + ): """ - Mapping should be extended for any new events or listeners to be added. + Dispatcher that dispatches events to listeners associated with given event. """ - self._event_listeners_mapping = { + self._event_listeners_mapping: Dict[Type[object], List[EventListenerInterface]]= { AfterConnectionReleasedEvent: [ ReAuthConnectionListener(), ], @@ -77,18 +85,28 @@ def __init__(self): ], } + if event_listeners: + self.register_listeners(event_listeners) + def dispatch(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + listeners = self._event_listeners_mapping.get(type(event), []) for listener in listeners: listener.listen(event) async def dispatch_async(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + listeners = self._event_listeners_mapping.get(type(event), []) for listener in listeners: await listener.listen(event) + def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): + for event in event_listeners: + if event in self._event_listeners_mapping: + self._event_listeners_mapping[event] = list(set(self._event_listeners_mapping[event] + event_listeners[event])) + else: + self._event_listeners_mapping[event] = event_listeners[event] + class AfterConnectionReleasedEvent: """ @@ -225,6 +243,26 @@ def nodes(self) -> dict: def credential_provider(self) -> Union[CredentialProvider, None]: return self._credential_provider +class OnCommandFailEvent: + """ + Event fired whenever a command fails during the execution. + """ + def __init__( + self, + command: tuple, + exception: Exception, + ): + self._command = command + self._exception = exception + + @property + def command(self) -> tuple: + return self._command + + @property + def exception(self) -> Exception: + return self._exception + class ReAuthConnectionListener(EventListenerInterface): """ diff --git a/redis/multidb/__init__.py b/redis/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py new file mode 100644 index 0000000000..0672485e59 --- /dev/null +++ b/redis/multidb/circuit.py @@ -0,0 +1,79 @@ +from abc import abstractmethod, ABC +from enum import Enum +from typing import Callable + +import pybreaker + +class State(Enum): + CLOSED = 'closed' + OPEN = 'open' + HALF_OPEN = 'half-open' + +class CircuitBreaker(ABC): + @property + @abstractmethod + def grace_period(self) -> float: + """The grace period in seconds when the circle should be kept open.""" + pass + + @property + @abstractmethod + def state(self) -> State: + """The current state of the circuit.""" + pass + + @state.setter + @abstractmethod + def state(self, state: State): + """Set current state of the circuit.""" + pass + + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + +class PBListener(pybreaker.CircuitBreakerListener): + def __init__( + self, + cb: Callable[[CircuitBreaker, State, State], None] + ): + """Wrapper for callback to be compatible with pybreaker implementation.""" + self._cb = cb + + def state_change(self, cb, old_state, new_state): + cb = PBCircuitBreakerAdapter(cb) + old_state = State(value=old_state.name) + new_state = State(value=new_state.name) + self._cb(cb, old_state, new_state) + + +class PBCircuitBreakerAdapter(CircuitBreaker): + def __init__(self, cb: pybreaker.CircuitBreaker): + """Adapter for pybreaker CircuitBreaker.""" + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + listener = PBListener(cb) + self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/database.py b/redis/multidb/database.py new file mode 100644 index 0000000000..794000673d --- /dev/null +++ b/redis/multidb/database.py @@ -0,0 +1,112 @@ +import redis +from abc import ABC, abstractmethod +from enum import Enum +from typing import Union, List + +from typing_extensions import Optional + +from redis import RedisCluster, Sentinel +from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.healthcheck import HealthCheck, AbstractHealthCheck + + +class State(Enum): + ACTIVE = 0 + PASSIVE = 1 + DISCONNECTED = 2 + +class AbstractDatabase(ABC): + @property + @abstractmethod + def client(self) -> Union[redis.Redis, RedisCluster, Sentinel]: + """The underlying redis client.""" + pass + + @property + @abstractmethod + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" + pass + + @property + @abstractmethod + def state(self) -> State: + """The state of the current database.""" + pass + + @property + @abstractmethod + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @abstractmethod + def add_health_check(self, health_check: HealthCheck) -> None: + """Adds a new healthcheck to the current database.""" + pass + + @abstractmethod + def is_healthy(self) -> bool: + """Checks if the current database is healthy.""" + pass + +class Database(AbstractDatabase): + def __init__( + self, + client: Union[redis.Redis, RedisCluster, Sentinel], + cb: CircuitBreaker, + weight: float, + state: State, + health_checks: Optional[List[HealthCheck]] = None, + ): + """ + param: client: Client instance for communication with the database. + param: cb: Circuit breaker for the current database. + param: weight: Weight of current database. Database with the highest weight becomes Active. + param: state: State of the current database. + param: health_checks: List of health cheks to determine if the current database is healthy. + """ + self._client = client + self._cb = cb + self._weight = weight + self._state = state + self._health_checks = health_checks or [] + + @property + def client(self) -> Union[redis.Redis, RedisCluster, Sentinel]: + return self._client + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, state: State): + self._state = state + + @property + def circuit(self) -> CircuitBreaker: + return self._cb + + def add_health_check(self, health_check: HealthCheck) -> None: + self._health_checks.append(health_check) + + def is_healthy(self) -> bool: + is_healthy = True + + for health_check in self._health_checks: + is_healthy = health_check.check_health(self) + + if not is_healthy: + self._cb.state = CBState.OPEN + break + + return is_healthy diff --git a/redis/multidb/event.py b/redis/multidb/event.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py new file mode 100644 index 0000000000..a73ace764a --- /dev/null +++ b/redis/multidb/failure_detector.py @@ -0,0 +1,77 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List, Dict, Type + +from typing_extensions import Optional + +from redis.multidb.circuit import State as CBState +from redis.multidb.database import Database + + +class FailureDetector(ABC): + """ + Detects failure based on organic traffic between client and database. + """ + @property + @abstractmethod + def database(self) -> Database: + pass + + @abstractmethod + def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + +class CommandFailureDetector(FailureDetector): + """ + Detects a failure based on a threshold of failed commands during a specific period of time. + """ + + def __init__( + self, + threshold: int, + duration: float, + database: Database, + error_types: Optional[List[Type[Exception]]] = None, + ) -> None: + """ + param: threshold: Threshold of failed commands over the duration after which database will be marked as failed. + param: duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. + param: database: Database instance associated with failure detection. + param: error_types: List of exception that has to be registered. By default, all exceptions are registered. + """ + self._threshold = threshold + self._duration = duration + self._database = database + self._error_types = error_types + self._start_time: datetime = datetime.now() + self._end_time: datetime = self._start_time + timedelta(seconds=self._duration) + self._failures_within_duration: Dict[Exception, Dict[datetime, tuple]] = {} + + @property + def database(self) -> Database: + return self._database + + def register_failure(self, exception: Exception, cmd: tuple) -> None: + failure_time = datetime.now() + + if not self._start_time < failure_time < self._end_time: + self._reset() + + if self._error_types: + if type(exception) in self._error_types: + self._failures_within_duration[exception] = {datetime.now(): cmd} + else: + self._failures_within_duration[exception] = {datetime.now(): cmd} + + self._check_threshold() + + def _check_threshold(self): + if len(self._failures_within_duration.keys()) >= self._threshold: + self._database.circuit.state = CBState.OPEN + self._reset() + + def _reset(self) -> None: + self._start_time = datetime.now() + self._end_time = self._start_time + timedelta(seconds=self._duration) + self._failures_within_duration = {} \ No newline at end of file diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py new file mode 100644 index 0000000000..4999a7d251 --- /dev/null +++ b/redis/multidb/healthcheck.py @@ -0,0 +1,92 @@ +from abc import abstractmethod, ABC + +from redis.backoff import AbstractBackoff +from redis.retry import Retry +from redis.multidb.circuit import State as CBState + + +class HealthCheck(ABC): + @property + @abstractmethod + def check_interval(self) -> float: + """The health check interval in seconds.""" + pass + + @property + @abstractmethod + def num_retries(self) -> int: + """The number of times to retry the health check.""" + pass + + @property + @abstractmethod + def backoff(self) -> AbstractBackoff: + """The backoff strategy for the health check.""" + pass + + @abstractmethod + def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class AbstractHealthCheck(HealthCheck): + def __init__( + self, + check_interval: float, + num_retries: int, + backoff: AbstractBackoff + ) -> None: + self._check_interval = check_interval + self._num_retries = num_retries + self._backoff = backoff + self._retry = Retry(self._backoff, self._num_retries) + + @property + def check_interval(self) -> float: + return self._check_interval + + @property + def num_retries(self) -> int: + return self._num_retries + + @property + def backoff(self) -> AbstractBackoff: + return self._backoff + + @abstractmethod + def check_health(self, database) -> bool: + pass + + +class EchoHealthCheck(AbstractHealthCheck): + def __init__( + self, + check_interval: float, + num_retries: int, + backoff: AbstractBackoff, + ) -> None: + """ + Check database healthiness by sending an echo request. + """ + super().__init__( + check_interval=check_interval, + num_retries=num_retries, + backoff=backoff, + ) + def check_health(self, database) -> bool: + try: + return self._retry.call_with_retry( + lambda : self._returns_echoed_message(database), + lambda _: self.dummy_fail() + ) + except Exception: + database.circuit.state = CBState.OPEN + return False + + def _returns_echoed_message(self, database) -> bool: + expected_message = "healthcheck" + actual_message = database.client.execute_command('ECHO', expected_message) + return actual_message == expected_message + + def dummy_fail(self): + pass \ No newline at end of file diff --git a/tests/test_event.py b/tests/test_event.py new file mode 100644 index 0000000000..27526abeaf --- /dev/null +++ b/tests/test_event.py @@ -0,0 +1,55 @@ +from unittest.mock import Mock, AsyncMock + +from redis.event import EventListenerInterface, EventDispatcher, AsyncEventListenerInterface + + +class TestEventDispatcher: + def test_register_listeners(self): + mock_event = Mock(spec=object) + mock_event_listener = Mock(spec=EventListenerInterface) + listener_called = 0 + + def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + dispatcher.dispatch(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=EventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + dispatcher.dispatch(mock_event) + + assert listener_called == 3 + + async def test_register_listeners_async(self): + mock_event = Mock(spec=object) + mock_event_listener = AsyncMock(spec=AsyncEventListenerInterface) + listener_called = 0 + + async def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=AsyncEventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 3 \ No newline at end of file diff --git a/tests/test_multidb/__init__.py b/tests/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py new file mode 100644 index 0000000000..0a4d4f099a --- /dev/null +++ b/tests/test_multidb/conftest.py @@ -0,0 +1,22 @@ +from unittest.mock import Mock + +import pytest + +from redis import Redis +from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.database import Database, State + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + +@pytest.fixture() +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) + +@pytest.fixture() +def mock_db() -> Database: + db = Mock(spec=Database) + db.circuit.state = CBState.CLOSED + return db \ No newline at end of file diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py new file mode 100644 index 0000000000..5ddeacfea7 --- /dev/null +++ b/tests/test_multidb/test_circuit.py @@ -0,0 +1,41 @@ +import pybreaker + +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker + + +class TestPBCircuitBreaker: + def test_cb_correctly_configured(self): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + assert adapter.state == CbState.CLOSED + + adapter.state = CbState.OPEN + assert adapter.state == CbState.OPEN + + adapter.state = CbState.HALF_OPEN + assert adapter.state == CbState.HALF_OPEN + + adapter.state = CbState.CLOSED + assert adapter.state == CbState.CLOSED + + assert adapter.grace_period == 5 + adapter.grace_period = 10 + + assert adapter.grace_period == 10 + + def test_cb_executes_callback_on_state_changed(self): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + called_count = 0 + + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + nonlocal called_count + assert old_state == CbState.CLOSED + assert new_state == CbState.HALF_OPEN + assert isinstance(cb, PBCircuitBreakerAdapter) + called_count += 1 + + adapter.on_state_changed(callback) + adapter.state = CbState.HALF_OPEN + + assert called_count == 1 \ No newline at end of file diff --git a/tests/test_multidb/test_database.py b/tests/test_multidb/test_database.py new file mode 100644 index 0000000000..39380236f9 --- /dev/null +++ b/tests/test_multidb/test_database.py @@ -0,0 +1,57 @@ +from redis.backoff import ExponentialBackoff +from redis.multidb.database import Database, State +from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestDatabase: + def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + mock_cb.state = CBState.CLOSED + hc = EchoHealthCheck( + check_interval=1.0, + num_retries=3, + backoff=ExponentialBackoff(), + ) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE, health_checks=[hc]) + + assert db.is_healthy() == True + assert mock_client.execute_command.call_count == 3 + assert db.circuit.state == CBState.CLOSED + + def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] + mock_cb.state = CBState.CLOSED + hc = EchoHealthCheck( + check_interval=1.0, + num_retries=3, + backoff=ExponentialBackoff(), + ) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE, health_checks=[hc]) + + assert db.is_healthy() == False + assert mock_client.execute_command.call_count == 3 + assert db.circuit.state == CBState.OPEN + + def test_database_is_unhealthy_on_exceeded_healthcheck_retries(self, mock_client, mock_cb): + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, ConnectionError, ConnectionError] + mock_cb.state = CBState.CLOSED + hc = EchoHealthCheck( + check_interval=1.0, + num_retries=3, + backoff=ExponentialBackoff(), + ) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE, health_checks=[hc]) + + assert db.is_healthy() == False + assert mock_client.execute_command.call_count == 4 + assert db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..c1d0fb793e --- /dev/null +++ b/tests/test_multidb/test_failure_detector.py @@ -0,0 +1,94 @@ +from time import sleep + +from redis.multidb.failure_detector import CommandFailureDetector +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestCommandFailureDetector: + def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1, mock_db) + assert fd.database.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.OPEN + + def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1, mock_db) + assert fd.database.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.CLOSED + + def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): + fd = CommandFailureDetector(5, 0.3, mock_db) + assert fd.database.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.CLOSED + + # 4 more failure as last one already refreshed timer + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.OPEN + + def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): + fd = CommandFailureDetector(5, 0.3, mock_db) + assert fd.database.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.4) + + assert fd.database.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.CLOSED + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.OPEN + + def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1, mock_db, error_types=[ConnectionError]) + assert fd.database.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.CLOSED + + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + + assert fd.database.circuit.state == CBState.OPEN \ No newline at end of file From 4f4a53c5d0d46b06dc935e34ad070d8a9e80c60f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 17 Jun 2025 11:41:23 +0300 Subject: [PATCH 02/22] Added DatabaseSelector, exceptions, refactored existing entities --- redis/client.py | 2 +- redis/data_structure.py | 57 +++++++ redis/event.py | 5 + redis/multidb/circuit.py | 5 + redis/multidb/database.py | 59 +++---- redis/multidb/exception.py | 2 + redis/multidb/failure_detector.py | 26 +-- redis/multidb/healthcheck.py | 56 ++---- redis/multidb/selector.py | 61 +++++++ tests/test_multidb/conftest.py | 45 ++++- tests/test_multidb/test_data_structure.py | 47 +++++ tests/test_multidb/test_failure_detector.py | 161 +++++++++++------- .../{test_database.py => test_healthcheck.py} | 33 ++-- tests/test_multidb/test_selector.py | 101 +++++++++++ 14 files changed, 482 insertions(+), 178 deletions(-) create mode 100644 redis/data_structure.py create mode 100644 redis/multidb/exception.py create mode 100644 redis/multidb/selector.py create mode 100644 tests/test_multidb/test_data_structure.py rename tests/test_multidb/{test_database.py => test_healthcheck.py} (66%) create mode 100644 tests/test_multidb/test_selector.py diff --git a/redis/client.py b/redis/client.py index 5fa441b4e8..8ebe2e38b3 100755 --- a/redis/client.py +++ b/redis/client.py @@ -616,7 +616,7 @@ def _close_connection(self, conn, error, *args) -> None: do a health check as part of the send_command logic(on connection level). """ - self._event_dispatcher.dispatch(OnCommandFailEvent(args, error)) + self._event_dispatcher.dispatch(OnCommandFailEvent(args, error, self)) conn.disconnect() # COMMAND EXECUTION AND PROTOCOL PARSING diff --git a/redis/data_structure.py b/redis/data_structure.py new file mode 100644 index 0000000000..06ed1814e9 --- /dev/null +++ b/redis/data_structure.py @@ -0,0 +1,57 @@ +from typing import List + + +class WeightedList: + def __init__(self): + self._items = [] + + def add(self, item, weight: float) -> None: + """Add item with weight, maintaining sorted order""" + # Find insertion point using binary search + left, right = 0, len(self._items) + while left < right: + mid = (left + right) // 2 + if self._items[mid][0] < weight: + right = mid + else: + left = mid + 1 + + self._items.insert(left, (weight, item)) + + def remove(self, item): + """Remove first occurrence of item""" + for i, (weight, stored_item) in enumerate(self._items): + if stored_item == item: + self._items.pop(i) + return weight + raise ValueError("Item not found") + + def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple]: + """Get all items within weight range""" + result = [] + for weight, item in self._items: + if min_weight <= weight <= max_weight: + result.append((item, weight)) + return result + + def get_top_n(self, n: int) -> List[tuple]: + """Get top N the highest weighted items""" + return [(item, weight) for weight, item in self._items[:n]] + + def update_weight(self, item, new_weight: float): + """Update weight of an item""" + old_weight = self.remove(item) + self.add(item, new_weight) + return old_weight + + def __iter__(self): + """Iterate in descending weight order""" + for weight, item in self._items: + yield item, weight + + def __len__(self): + return len(self._items) + + def __getitem__(self, index): + weight, item = self._items[index] + return item, weight \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index 4cf6022f7e..39a924ed16 100644 --- a/redis/event.py +++ b/redis/event.py @@ -251,9 +251,11 @@ def __init__( self, command: tuple, exception: Exception, + client, ): self._command = command self._exception = exception + self._client = client @property def command(self) -> tuple: @@ -263,6 +265,9 @@ def command(self) -> tuple: def exception(self) -> Exception: return self._exception + @property + def client(self): + return self._client class ReAuthConnectionListener(EventListenerInterface): """ diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 0672485e59..1f3d00e81c 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -16,6 +16,11 @@ def grace_period(self) -> float: """The grace period in seconds when the circle should be kept open.""" pass + @grace_period.setter + @abstractmethod + def grace_period(self, grace_period: float): + """Set the grace period in seconds.""" + @property @abstractmethod def state(self) -> State: diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 794000673d..a992818774 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -1,14 +1,10 @@ import redis from abc import ABC, abstractmethod from enum import Enum -from typing import Union, List - -from typing_extensions import Optional +from typing import Union from redis import RedisCluster, Sentinel -from redis.multidb.circuit import CircuitBreaker, State as CBState -from redis.multidb.healthcheck import HealthCheck, AbstractHealthCheck - +from redis.multidb.circuit import CircuitBreaker class State(Enum): ACTIVE = 0 @@ -22,32 +18,46 @@ def client(self) -> Union[redis.Redis, RedisCluster, Sentinel]: """The underlying redis client.""" pass + @client.setter + @abstractmethod + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + @property @abstractmethod def weight(self) -> float: """The weight of this database in compare to others. Used to determine the database failover to.""" pass + @weight.setter + @abstractmethod + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" + pass + @property @abstractmethod def state(self) -> State: """The state of the current database.""" pass - @property + @state.setter @abstractmethod - def circuit(self) -> CircuitBreaker: - """Circuit breaker for the current database.""" + def state(self, state: State): + """Set the state of the current database.""" pass + @property @abstractmethod - def add_health_check(self, health_check: HealthCheck) -> None: - """Adds a new healthcheck to the current database.""" + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" pass + @circuit.setter @abstractmethod - def is_healthy(self) -> bool: - """Checks if the current database is healthy.""" + def circuit(self, circuit: CircuitBreaker): + """Set the circuit breaker for the current database.""" pass class Database(AbstractDatabase): @@ -57,7 +67,6 @@ def __init__( cb: CircuitBreaker, weight: float, state: State, - health_checks: Optional[List[HealthCheck]] = None, ): """ param: client: Client instance for communication with the database. @@ -70,12 +79,15 @@ def __init__( self._cb = cb self._weight = weight self._state = state - self._health_checks = health_checks or [] @property def client(self) -> Union[redis.Redis, RedisCluster, Sentinel]: return self._client + @client.setter + def client(self, client: Union[redis.Redis, RedisCluster, Sentinel]): + self._client = client + @property def weight(self) -> float: return self._weight @@ -96,17 +108,6 @@ def state(self, state: State): def circuit(self) -> CircuitBreaker: return self._cb - def add_health_check(self, health_check: HealthCheck) -> None: - self._health_checks.append(health_check) - - def is_healthy(self) -> bool: - is_healthy = True - - for health_check in self._health_checks: - is_healthy = health_check.check_health(self) - - if not is_healthy: - self._cb.state = CBState.OPEN - break - - return is_healthy + @circuit.setter + def circuit(self, circuit: CircuitBreaker): + self._cb = circuit diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py new file mode 100644 index 0000000000..80fdb9409a --- /dev/null +++ b/redis/multidb/exception.py @@ -0,0 +1,2 @@ +class NoValidDatabaseException(Exception): + pass \ No newline at end of file diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index a73ace764a..262e79be1a 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -5,20 +5,11 @@ from typing_extensions import Optional from redis.multidb.circuit import State as CBState -from redis.multidb.database import Database - class FailureDetector(ABC): - """ - Detects failure based on organic traffic between client and database. - """ - @property - @abstractmethod - def database(self) -> Database: - pass @abstractmethod - def register_failure(self, exception: Exception, cmd: tuple) -> None: + def register_failure(self, database, exception: Exception, cmd: tuple) -> None: """Register a failure that occurred during command execution.""" pass @@ -31,28 +22,21 @@ def __init__( self, threshold: int, duration: float, - database: Database, error_types: Optional[List[Type[Exception]]] = None, ) -> None: """ param: threshold: Threshold of failed commands over the duration after which database will be marked as failed. param: duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. - param: database: Database instance associated with failure detection. param: error_types: List of exception that has to be registered. By default, all exceptions are registered. """ self._threshold = threshold self._duration = duration - self._database = database self._error_types = error_types self._start_time: datetime = datetime.now() self._end_time: datetime = self._start_time + timedelta(seconds=self._duration) self._failures_within_duration: Dict[Exception, Dict[datetime, tuple]] = {} - @property - def database(self) -> Database: - return self._database - - def register_failure(self, exception: Exception, cmd: tuple) -> None: + def register_failure(self, database, exception: Exception, cmd: tuple) -> None: failure_time = datetime.now() if not self._start_time < failure_time < self._end_time: @@ -64,11 +48,11 @@ def register_failure(self, exception: Exception, cmd: tuple) -> None: else: self._failures_within_duration[exception] = {datetime.now(): cmd} - self._check_threshold() + self._check_threshold(database) - def _check_threshold(self): + def _check_threshold(self, database): if len(self._failures_within_duration.keys()) >= self._threshold: - self._database.circuit.state = CBState.OPEN + database.circuit.state = CBState.OPEN self._reset() def _reset(self) -> None: diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 4999a7d251..152b095b09 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,27 +1,15 @@ from abc import abstractmethod, ABC -from redis.backoff import AbstractBackoff from redis.retry import Retry from redis.multidb.circuit import State as CBState class HealthCheck(ABC): - @property - @abstractmethod - def check_interval(self) -> float: - """The health check interval in seconds.""" - pass - - @property - @abstractmethod - def num_retries(self) -> int: - """The number of times to retry the health check.""" - pass @property @abstractmethod - def backoff(self) -> AbstractBackoff: - """The backoff strategy for the health check.""" + def retry(self) -> Retry: + """The retry object to use for health checks.""" pass @abstractmethod @@ -32,26 +20,13 @@ def check_health(self, database) -> bool: class AbstractHealthCheck(HealthCheck): def __init__( self, - check_interval: float, - num_retries: int, - backoff: AbstractBackoff + retry: Retry, ) -> None: - self._check_interval = check_interval - self._num_retries = num_retries - self._backoff = backoff - self._retry = Retry(self._backoff, self._num_retries) + self._retry = retry @property - def check_interval(self) -> float: - return self._check_interval - - @property - def num_retries(self) -> int: - return self._num_retries - - @property - def backoff(self) -> AbstractBackoff: - return self._backoff + def retry(self) -> Retry: + return self._retry @abstractmethod def check_health(self, database) -> bool: @@ -61,24 +36,25 @@ def check_health(self, database) -> bool: class EchoHealthCheck(AbstractHealthCheck): def __init__( self, - check_interval: float, - num_retries: int, - backoff: AbstractBackoff, + retry: Retry, ) -> None: """ Check database healthiness by sending an echo request. """ super().__init__( - check_interval=check_interval, - num_retries=num_retries, - backoff=backoff, + retry=retry, ) def check_health(self, database) -> bool: try: - return self._retry.call_with_retry( + is_healthy = self._retry.call_with_retry( lambda : self._returns_echoed_message(database), - lambda _: self.dummy_fail() + lambda _: self._dummy_fail() ) + + if not is_healthy: + database.circuit.state = CBState.OPEN + + return is_healthy except Exception: database.circuit.state = CBState.OPEN return False @@ -88,5 +64,5 @@ def _returns_echoed_message(self, database) -> bool: actual_message = database.client.execute_command('ECHO', expected_message) return actual_message == expected_message - def dummy_fail(self): + def _dummy_fail(self): pass \ No newline at end of file diff --git a/redis/multidb/selector.py b/redis/multidb/selector.py new file mode 100644 index 0000000000..d2c8468835 --- /dev/null +++ b/redis/multidb/selector.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from typing import List + +from redis.data_structure import WeightedList +from redis.multidb.database import AbstractDatabase +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.retry import Retry + + +class DatabaseSelector(ABC): + + @property + @abstractmethod + def database(self) -> AbstractDatabase: + """Select the database.""" + pass + + @abstractmethod + def add_database(self, database: AbstractDatabase) -> None: + """Add the database.""" + pass + + +class WeightBasedDatabaseSelector(DatabaseSelector): + """ + Choose the active database with the highest weight. + """ + def __init__( + self, + databases: List[AbstractDatabase], + retry: Retry, + ): + self._retry = retry + self._retry.update_supported_errors([NoValidDatabaseException]) + self._databases = WeightedList() + + for database in databases: + self._databases.add(database, database.weight) + + @property + def database(self) -> AbstractDatabase: + return self._retry.call_with_retry( + lambda: self._get_active_database(), + lambda _: self._dummy_fail() + ) + + def add_database(self, database: AbstractDatabase) -> None: + self._databases.add(database, database.weight) + + def _get_active_database(self) -> AbstractDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + else: + continue + + raise NoValidDatabaseException('No valid database available for communication') + + def _dummy_fail(self): + pass diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 0a4d4f099a..197d74fa12 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -11,12 +11,51 @@ def mock_client() -> Redis: return Mock(spec=Redis) -@pytest.fixture() +@pytest.fixture(scope='function') def mock_cb() -> CircuitBreaker: return Mock(spec=CircuitBreaker) @pytest.fixture() -def mock_db() -> Database: +def mock_db(request) -> Database: db = Mock(spec=Database) - db.circuit.state = CBState.CLOSED + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb return db \ No newline at end of file diff --git a/tests/test_multidb/test_data_structure.py b/tests/test_multidb/test_data_structure.py new file mode 100644 index 0000000000..832c661ac0 --- /dev/null +++ b/tests/test_multidb/test_data_structure.py @@ -0,0 +1,47 @@ +from redis.data_structure import WeightedList + + +class TestWeightedList: + def test_add_items(self): + wlist = WeightedList() + + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] + + def test_remove_items(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + wlist.remove('item2') + wlist.remove('item4') + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item1', 3.0)] + + def test_get_by_weight_range(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_by_weight_range(2.0, 3.0) == [('item1', 3.0), ('item2', 2.0)] + + def test_update_weights(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] + + wlist.update_weight('item2', 5.0) + + assert wlist.get_top_n(4) == [('item2', 5.0), ('item3', 4.0), ('item4', 4.0), ('item1', 3.0)] \ No newline at end of file diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index c1d0fb793e..8e0c1bcbad 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -1,94 +1,131 @@ from time import sleep +import pytest + from redis.multidb.failure_detector import CommandFailureDetector from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError class TestCommandFailureDetector: + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): - fd = CommandFailureDetector(5, 1, mock_db) - assert fd.database.circuit.state == CBState.CLOSED - - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert fd.database.circuit.state == CBState.OPEN - + fd = CommandFailureDetector(5, 1) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): - fd = CommandFailureDetector(5, 1, mock_db) - assert fd.database.circuit.state == CBState.CLOSED - - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert fd.database.circuit.state == CBState.CLOSED - + fd = CommandFailureDetector(5, 1) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): - fd = CommandFailureDetector(5, 0.3, mock_db) - assert fd.database.circuit.state == CBState.CLOSED + fd = CommandFailureDetector(5, 0.3) + assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - assert fd.database.circuit.state == CBState.CLOSED + assert mock_db.circuit.state == CBState.CLOSED # 4 more failure as last one already refreshed timer - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert fd.database.circuit.state == CBState.OPEN - + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): - fd = CommandFailureDetector(5, 0.3, mock_db) - assert fd.database.circuit.state == CBState.CLOSED + fd = CommandFailureDetector(5, 0.3) + assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) sleep(0.4) - assert fd.database.circuit.state == CBState.CLOSED + assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - assert fd.database.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + assert mock_db.circuit.state == CBState.CLOSED + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - assert fd.database.circuit.state == CBState.OPEN + assert mock_db.circuit.state == CBState.OPEN + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): - fd = CommandFailureDetector(5, 1, mock_db, error_types=[ConnectionError]) - assert fd.database.circuit.state == CBState.CLOSED + fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) + assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - assert fd.database.circuit.state == CBState.CLOSED + assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) - assert fd.database.circuit.state == CBState.OPEN \ No newline at end of file + assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_multidb/test_database.py b/tests/test_multidb/test_healthcheck.py similarity index 66% rename from tests/test_multidb/test_database.py rename to tests/test_multidb/test_healthcheck.py index 39380236f9..50a16be73a 100644 --- a/tests/test_multidb/test_database.py +++ b/tests/test_multidb/test_healthcheck.py @@ -3,9 +3,10 @@ from redis.multidb.healthcheck import EchoHealthCheck from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError +from redis.retry import Retry -class TestDatabase: +class TestEchoHealthCheck: def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ Mocking responses to mix error and actual responses to ensure that health check retry @@ -13,14 +14,10 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] mock_cb.state = CBState.CLOSED - hc = EchoHealthCheck( - check_interval=1.0, - num_retries=3, - backoff=ExponentialBackoff(), - ) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE, health_checks=[hc]) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) - assert db.is_healthy() == True + assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 assert db.circuit.state == CBState.CLOSED @@ -31,27 +28,19 @@ def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, moc """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] mock_cb.state = CBState.CLOSED - hc = EchoHealthCheck( - check_interval=1.0, - num_retries=3, - backoff=ExponentialBackoff(), - ) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE, health_checks=[hc]) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) - assert db.is_healthy() == False + assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 3 assert db.circuit.state == CBState.OPEN def test_database_is_unhealthy_on_exceeded_healthcheck_retries(self, mock_client, mock_cb): mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, ConnectionError, ConnectionError] mock_cb.state = CBState.CLOSED - hc = EchoHealthCheck( - check_interval=1.0, - num_retries=3, - backoff=ExponentialBackoff(), - ) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE, health_checks=[hc]) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) - assert db.is_healthy() == False + assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 4 assert db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_multidb/test_selector.py b/tests/test_multidb/test_selector.py new file mode 100644 index 0000000000..a40b706f40 --- /dev/null +++ b/tests/test_multidb/test_selector.py @@ -0,0 +1,101 @@ +from unittest.mock import PropertyMock + +import pytest + +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.selector import WeightBasedDatabaseSelector +from redis.retry import Retry + + +class TestWeightBasedDatabaseSelector: + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + selector = WeightBasedDatabaseSelector([mock_db, mock_db1, mock_db2], retry=retry) + + assert selector.database == mock_db1 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + selector = WeightBasedDatabaseSelector([mock_db, mock_db1, mock_db2], retry=retry) + + assert selector.database == mock_db + assert state_mock.call_count == 4 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + selector = WeightBasedDatabaseSelector([mock_db, mock_db1, mock_db2], retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert selector.database + + assert state_mock.call_count == 4 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_return_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(ExponentialBackoff(cap=1), 3) + selector = WeightBasedDatabaseSelector([mock_db, mock_db2], retry=retry) + assert selector.database == mock_db2 + + selector.add_database(mock_db1) + assert selector.database == mock_db1 \ No newline at end of file From acc68efb3031917061ea2210e6a4247cfe460672 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 17 Jun 2025 12:56:35 +0300 Subject: [PATCH 03/22] Added MultiDbConfig --- redis/multidb/config.py | 54 ++++++++++++++++++++++++++ redis/multidb/failure_detector.py | 6 +-- redis/multidb/selector.py | 5 --- tests/test_multidb/test_config.py | 59 +++++++++++++++++++++++++++++ tests/test_multidb/test_selector.py | 37 ++++++++++++++++-- 5 files changed, 149 insertions(+), 12 deletions(-) create mode 100644 redis/multidb/config.py create mode 100644 tests/test_multidb/test_config.py diff --git a/redis/multidb/config.py b/redis/multidb/config.py new file mode 100644 index 0000000000..0be1cff587 --- /dev/null +++ b/redis/multidb/config.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass, field +from typing import List, Type, Union + +from redis import Redis, Sentinel +from redis.asyncio import RedisCluster +from redis.backoff import ExponentialWithJitterBackoff +from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.selector import DatabaseSelector, WeightBasedDatabaseSelector +from redis.retry import Retry + +DEFAULT_GRACE_PERIOD = 1 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +DEFAULT_FAILURES_THRESHOLD = 100 +DEFAULT_FAILURES_DURATION = 2 +DEFAULT_DATABASE_SELECTOR_RETRIES = 3 +DEFAULT_DATABASE_SELECTOR_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_AUTO_FALLBACK_INTERVAL = -1 + +def default_health_checks() -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF)), + ] + +def default_failure_detectors() -> List[FailureDetector]: + return [ + CommandFailureDetector(threshold=DEFAULT_FAILURES_THRESHOLD, duration=DEFAULT_FAILURES_DURATION), + ] + +def default_database_selector() -> DatabaseSelector: + return WeightBasedDatabaseSelector( + retry=Retry(retries=DEFAULT_DATABASE_SELECTOR_RETRIES, backoff=DEFAULT_DATABASE_SELECTOR_BACKOFF) + ) + +@dataclass +class MultiDbConfig: + client_class: Type[Union[Redis, RedisCluster, Sentinel]] = Redis + client_kwargs: dict = field(default_factory=dict) + grace_period: int = DEFAULT_GRACE_PERIOD + failure_detectors: List[FailureDetector] = field(default_factory=default_failure_detectors) + health_checks: List[HealthCheck] = field(default_factory=default_health_checks) + health_check_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL + database_selector: DatabaseSelector = field(default_factory=default_database_selector) + auto_fallback_interval: int = DEFAULT_AUTO_FALLBACK_INTERVAL + + def client(self) -> Union[Redis, RedisCluster, Sentinel]: + if len(self.client_kwargs) > 0: + return self.client_class(**self.client_kwargs) + + return self.client_class() + + diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 262e79be1a..1d3075935b 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -25,9 +25,9 @@ def __init__( error_types: Optional[List[Type[Exception]]] = None, ) -> None: """ - param: threshold: Threshold of failed commands over the duration after which database will be marked as failed. - param: duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. - param: error_types: List of exception that has to be registered. By default, all exceptions are registered. + :param threshold: Threshold of failed commands over the duration after which database will be marked as failed. + :param duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. + :param error_types: List of exception that has to be registered. By default, all exceptions are registered. """ self._threshold = threshold self._duration = duration diff --git a/redis/multidb/selector.py b/redis/multidb/selector.py index d2c8468835..619cadd2d4 100644 --- a/redis/multidb/selector.py +++ b/redis/multidb/selector.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import List from redis.data_structure import WeightedList from redis.multidb.database import AbstractDatabase @@ -28,16 +27,12 @@ class WeightBasedDatabaseSelector(DatabaseSelector): """ def __init__( self, - databases: List[AbstractDatabase], retry: Retry, ): self._retry = retry self._retry.update_supported_errors([NoValidDatabaseException]) self._databases = WeightedList() - for database in databases: - self._databases.add(database, database.weight) - @property def database(self) -> AbstractDatabase: return self._retry.call_with_retry( diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py new file mode 100644 index 0000000000..2102abe255 --- /dev/null +++ b/tests/test_multidb/test_config.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock + +from redis.connection import ConnectionPool + +from redis import Redis +from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector +from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.multidb.selector import WeightBasedDatabaseSelector, DatabaseSelector + + +class TestMultiDbConfig: + def test_default_config(self): + config = MultiDbConfig() + + assert isinstance(config.client(), Redis) + assert config.grace_period == DEFAULT_GRACE_PERIOD + assert len(config.failure_detectors) == 1 + assert isinstance(config.failure_detectors[0], CommandFailureDetector) + assert len(config.health_checks) == 1 + assert isinstance(config.health_checks[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance(config.database_selector, WeightBasedDatabaseSelector) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_connection_pool.connection_kwargs = {} + grace_period = 2 + mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_database_selector = Mock(spec=DatabaseSelector) + auto_fallback_interval = 10 + + config = MultiDbConfig( + client_kwargs={"connection_pool": mock_connection_pool}, + grace_period=grace_period, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + database_selector=mock_database_selector, + auto_fallback_interval=auto_fallback_interval, + ) + + client = config.client() + assert isinstance(client, Redis) + assert client.connection_pool == mock_connection_pool + assert config.grace_period == grace_period + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.database_selector == mock_database_selector + assert config.auto_fallback_interval == auto_fallback_interval \ No newline at end of file diff --git a/tests/test_multidb/test_selector.py b/tests/test_multidb/test_selector.py index a40b706f40..9e8906db02 100644 --- a/tests/test_multidb/test_selector.py +++ b/tests/test_multidb/test_selector.py @@ -29,7 +29,10 @@ class TestWeightBasedDatabaseSelector: ) def test_get_valid_database(self, mock_db, mock_db1, mock_db2): retry = Retry(NoBackoff(), 0) - selector = WeightBasedDatabaseSelector([mock_db, mock_db1, mock_db2], retry=retry) + selector = WeightBasedDatabaseSelector(retry=retry) + selector.add_database(mock_db) + selector.add_database(mock_db1) + selector.add_database(mock_db2) assert selector.database == mock_db1 @@ -51,7 +54,10 @@ def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): type(mock_db.circuit).state = state_mock retry = Retry(ExponentialBackoff(cap=1), 3) - selector = WeightBasedDatabaseSelector([mock_db, mock_db1, mock_db2], retry=retry) + selector = WeightBasedDatabaseSelector(retry=retry) + selector.add_database(mock_db) + selector.add_database(mock_db1) + selector.add_database(mock_db2) assert selector.database == mock_db assert state_mock.call_count == 4 @@ -74,13 +80,34 @@ def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db type(mock_db.circuit).state = state_mock retry = Retry(ExponentialBackoff(cap=1), 3) - selector = WeightBasedDatabaseSelector([mock_db, mock_db1, mock_db2], retry=retry) + selector = WeightBasedDatabaseSelector(retry=retry) + selector.add_database(mock_db) + selector.add_database(mock_db1) + selector.add_database(mock_db2) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): assert selector.database assert state_mock.call_count == 4 + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + selector = WeightBasedDatabaseSelector(retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert selector.database + @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', [ @@ -94,7 +121,9 @@ def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db ) def test_add_database_return_valid_database(self, mock_db, mock_db1, mock_db2): retry = Retry(ExponentialBackoff(cap=1), 3) - selector = WeightBasedDatabaseSelector([mock_db, mock_db2], retry=retry) + selector = WeightBasedDatabaseSelector(retry=retry) + selector.add_database(mock_db) + selector.add_database(mock_db2) assert selector.database == mock_db2 selector.add_database(mock_db1) From 255bb0e00dc35c8e1f183a91b3b6c854d756ffe8 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 17 Jun 2025 13:50:17 +0300 Subject: [PATCH 04/22] Added DatabaseConfig --- redis/multidb/config.py | 29 ++++++++++--- redis/multidb/database.py | 6 +-- tests/test_multidb/test_config.py | 71 +++++++++++++++++++++++++------ 3 files changed, 84 insertions(+), 22 deletions(-) diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 0be1cff587..98fb7d3d89 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -1,9 +1,13 @@ from dataclasses import dataclass, field from typing import List, Type, Union +import pybreaker + from redis import Redis, Sentinel from redis.asyncio import RedisCluster from redis.backoff import ExponentialWithJitterBackoff +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.database import Database, AbstractDatabase from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck from redis.multidb.selector import DatabaseSelector, WeightBasedDatabaseSelector @@ -34,21 +38,34 @@ def default_database_selector() -> DatabaseSelector: retry=Retry(retries=DEFAULT_DATABASE_SELECTOR_RETRIES, backoff=DEFAULT_DATABASE_SELECTOR_BACKOFF) ) +def default_circuit_breaker() -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=DEFAULT_GRACE_PERIOD) + return PBCircuitBreakerAdapter(circuit_breaker) + +@dataclass +class DatabaseConfig: + client_kwargs: dict + weight: float + circuit: CircuitBreaker = field(default_factory=default_circuit_breaker) + @dataclass class MultiDbConfig: + databases_config: List[DatabaseConfig] client_class: Type[Union[Redis, RedisCluster, Sentinel]] = Redis - client_kwargs: dict = field(default_factory=dict) - grace_period: int = DEFAULT_GRACE_PERIOD failure_detectors: List[FailureDetector] = field(default_factory=default_failure_detectors) health_checks: List[HealthCheck] = field(default_factory=default_health_checks) health_check_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL database_selector: DatabaseSelector = field(default_factory=default_database_selector) auto_fallback_interval: int = DEFAULT_AUTO_FALLBACK_INTERVAL - def client(self) -> Union[Redis, RedisCluster, Sentinel]: - if len(self.client_kwargs) > 0: - return self.client_class(**self.client_kwargs) + def databases(self) -> List[AbstractDatabase]: + databases = [] - return self.client_class() + for database_config in self.databases_config: + client = self.client_class(**database_config.client_kwargs) + databases.append( + Database(client=client, circuit=database_config.circuit, weight=database_config.weight) + ) + return databases diff --git a/redis/multidb/database.py b/redis/multidb/database.py index a992818774..6e9cf8eb84 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -64,9 +64,9 @@ class Database(AbstractDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster, Sentinel], - cb: CircuitBreaker, + circuit: CircuitBreaker, weight: float, - state: State, + state: State = State.DISCONNECTED, ): """ param: client: Client instance for communication with the database. @@ -76,7 +76,7 @@ def __init__( param: health_checks: List of health cheks to determine if the current database is healthy. """ self._client = client - self._cb = cb + self._cb = circuit self._weight = weight self._state = state diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 2102abe255..e96befd0c3 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -3,8 +3,10 @@ from redis.connection import ConnectionPool from redis import Redis -from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD +from redis.multidb.database import Database from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck from redis.multidb.selector import WeightBasedDatabaseSelector, DatabaseSelector @@ -12,10 +14,27 @@ class TestMultiDbConfig: def test_default_config(self): - config = MultiDbConfig() + db_configs = [ + DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), + DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), + DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), + ] + + config = MultiDbConfig( + databases_config=db_configs + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db in databases: + assert isinstance(db, Database) + assert db.weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + i+=1 - assert isinstance(config.client(), Redis) - assert config.grace_period == DEFAULT_GRACE_PERIOD assert len(config.failure_detectors) == 1 assert isinstance(config.failure_detectors[0], CommandFailureDetector) assert len(config.health_checks) == 1 @@ -25,18 +44,36 @@ def test_default_config(self): assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL def test_overridden_config(self): - mock_connection_pool = Mock(spec=ConnectionPool) - mock_connection_pool.connection_kwargs = {} grace_period = 2 + mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] health_check_interval = 10 mock_database_selector = Mock(spec=DatabaseSelector) auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 + ), + ] config = MultiDbConfig( - client_kwargs={"connection_pool": mock_connection_pool}, - grace_period=grace_period, + databases_config=db_configs, failure_detectors=mock_failure_detectors, health_checks=mock_health_checks, health_check_interval=health_check_interval, @@ -44,10 +81,18 @@ def test_overridden_config(self): auto_fallback_interval=auto_fallback_interval, ) - client = config.client() - assert isinstance(client, Redis) - assert client.connection_pool == mock_connection_pool - assert config.grace_period == grace_period + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db in databases: + assert isinstance(db, Database) + assert db.weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i+=1 + assert len(config.failure_detectors) == 2 assert config.failure_detectors[0] == mock_failure_detectors[0] assert config.failure_detectors[1] == mock_failure_detectors[1] From 79db257d99cf85a5358344758f02a0e78e203e66 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 17 Jun 2025 13:55:37 +0300 Subject: [PATCH 05/22] Added DatabaseConfig test coverage --- tests/test_multidb/test_config.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index e96befd0c3..50fe830322 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,9 +1,6 @@ from unittest.mock import Mock - from redis.connection import ConnectionPool - -from redis import Redis -from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -101,4 +98,24 @@ def test_overridden_config(self): assert config.health_checks[1] == mock_health_checks[1] assert config.health_check_interval == health_check_interval assert config.database_selector == mock_database_selector - assert config.auto_fallback_interval == auto_fallback_interval \ No newline at end of file + assert config.auto_fallback_interval == auto_fallback_interval + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + + assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.weight == 1.0 + assert isinstance(config.circuit, PBCircuitBreakerAdapter) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + ) + + assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit \ No newline at end of file From 8790db1fabed1b73483b097b998e7b8e2f9528dd Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 18 Jun 2025 11:05:06 +0300 Subject: [PATCH 06/22] Renamed DatabaseSelector into FailoverStrategy --- redis/multidb/config.py | 14 +++--- redis/multidb/database.py | 3 +- redis/multidb/{selector.py => failover.py} | 4 +- tests/test_multidb/test_config.py | 10 ++-- .../{test_selector.py => test_failover.py} | 50 +++++++++---------- 5 files changed, 40 insertions(+), 41 deletions(-) rename redis/multidb/{selector.py => failover.py} (94%) rename tests/test_multidb/{test_selector.py => test_failover.py} (73%) diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 98fb7d3d89..2606dad291 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -10,7 +10,7 @@ from redis.multidb.database import Database, AbstractDatabase from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck -from redis.multidb.selector import DatabaseSelector, WeightBasedDatabaseSelector +from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy from redis.retry import Retry DEFAULT_GRACE_PERIOD = 1 @@ -19,8 +19,8 @@ DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) DEFAULT_FAILURES_THRESHOLD = 100 DEFAULT_FAILURES_DURATION = 2 -DEFAULT_DATABASE_SELECTOR_RETRIES = 3 -DEFAULT_DATABASE_SELECTOR_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_FAILOVER_RETRIES = 3 +DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) DEFAULT_AUTO_FALLBACK_INTERVAL = -1 def default_health_checks() -> List[HealthCheck]: @@ -33,9 +33,9 @@ def default_failure_detectors() -> List[FailureDetector]: CommandFailureDetector(threshold=DEFAULT_FAILURES_THRESHOLD, duration=DEFAULT_FAILURES_DURATION), ] -def default_database_selector() -> DatabaseSelector: - return WeightBasedDatabaseSelector( - retry=Retry(retries=DEFAULT_DATABASE_SELECTOR_RETRIES, backoff=DEFAULT_DATABASE_SELECTOR_BACKOFF) +def default_failover_strategy() -> FailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) def default_circuit_breaker() -> CircuitBreaker: @@ -55,7 +55,7 @@ class MultiDbConfig: failure_detectors: List[FailureDetector] = field(default_factory=default_failure_detectors) health_checks: List[HealthCheck] = field(default_factory=default_health_checks) health_check_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL - database_selector: DatabaseSelector = field(default_factory=default_database_selector) + failover_strategy: FailoverStrategy = field(default_factory=default_failover_strategy) auto_fallback_interval: int = DEFAULT_AUTO_FALLBACK_INTERVAL def databases(self) -> List[AbstractDatabase]: diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 6e9cf8eb84..d4b3be7c78 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -70,10 +70,9 @@ def __init__( ): """ param: client: Client instance for communication with the database. - param: cb: Circuit breaker for the current database. + param: circuit: Circuit breaker for the current database. param: weight: Weight of current database. Database with the highest weight becomes Active. param: state: State of the current database. - param: health_checks: List of health cheks to determine if the current database is healthy. """ self._client = client self._cb = circuit diff --git a/redis/multidb/selector.py b/redis/multidb/failover.py similarity index 94% rename from redis/multidb/selector.py rename to redis/multidb/failover.py index 619cadd2d4..bef810e1ae 100644 --- a/redis/multidb/selector.py +++ b/redis/multidb/failover.py @@ -7,7 +7,7 @@ from redis.retry import Retry -class DatabaseSelector(ABC): +class FailoverStrategy(ABC): @property @abstractmethod @@ -21,7 +21,7 @@ def add_database(self, database: AbstractDatabase) -> None: pass -class WeightBasedDatabaseSelector(DatabaseSelector): +class WeightBasedFailoverStrategy(FailoverStrategy): """ Choose the active database with the highest weight. """ diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 50fe830322..52076fb055 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -6,7 +6,7 @@ from redis.multidb.database import Database from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck -from redis.multidb.selector import WeightBasedDatabaseSelector, DatabaseSelector +from redis.multidb.failover import WeightBasedFailoverStrategy, FailoverStrategy class TestMultiDbConfig: @@ -37,7 +37,7 @@ def test_default_config(self): assert len(config.health_checks) == 1 assert isinstance(config.health_checks[0], EchoHealthCheck) assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL - assert isinstance(config.database_selector, WeightBasedDatabaseSelector) + assert isinstance(config.failover_strategy, WeightBasedFailoverStrategy) assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL def test_overridden_config(self): @@ -55,7 +55,7 @@ def test_overridden_config(self): mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] health_check_interval = 10 - mock_database_selector = Mock(spec=DatabaseSelector) + mock_failover_strategy = Mock(spec=FailoverStrategy) auto_fallback_interval = 10 db_configs = [ DatabaseConfig( @@ -74,7 +74,7 @@ def test_overridden_config(self): failure_detectors=mock_failure_detectors, health_checks=mock_health_checks, health_check_interval=health_check_interval, - database_selector=mock_database_selector, + failover_strategy=mock_failover_strategy, auto_fallback_interval=auto_fallback_interval, ) @@ -97,7 +97,7 @@ def test_overridden_config(self): assert config.health_checks[0] == mock_health_checks[0] assert config.health_checks[1] == mock_health_checks[1] assert config.health_check_interval == health_check_interval - assert config.database_selector == mock_database_selector + assert config.failover_strategy == mock_failover_strategy assert config.auto_fallback_interval == auto_fallback_interval class TestDatabaseConfig: diff --git a/tests/test_multidb/test_selector.py b/tests/test_multidb/test_failover.py similarity index 73% rename from tests/test_multidb/test_selector.py rename to tests/test_multidb/test_failover.py index 9e8906db02..a4239d943a 100644 --- a/tests/test_multidb/test_selector.py +++ b/tests/test_multidb/test_failover.py @@ -5,11 +5,11 @@ from redis.backoff import NoBackoff, ExponentialBackoff from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException -from redis.multidb.selector import WeightBasedDatabaseSelector +from redis.multidb.failover import WeightBasedFailoverStrategy from redis.retry import Retry -class TestWeightBasedDatabaseSelector: +class TestWeightBasedFailoverStrategy: @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', [ @@ -29,12 +29,12 @@ class TestWeightBasedDatabaseSelector: ) def test_get_valid_database(self, mock_db, mock_db1, mock_db2): retry = Retry(NoBackoff(), 0) - selector = WeightBasedDatabaseSelector(retry=retry) - selector.add_database(mock_db) - selector.add_database(mock_db1) - selector.add_database(mock_db2) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.add_database(mock_db) + failover_strategy.add_database(mock_db1) + failover_strategy.add_database(mock_db2) - assert selector.database == mock_db1 + assert failover_strategy.database == mock_db1 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -54,12 +54,12 @@ def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): type(mock_db.circuit).state = state_mock retry = Retry(ExponentialBackoff(cap=1), 3) - selector = WeightBasedDatabaseSelector(retry=retry) - selector.add_database(mock_db) - selector.add_database(mock_db1) - selector.add_database(mock_db2) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.add_database(mock_db) + failover_strategy.add_database(mock_db1) + failover_strategy.add_database(mock_db2) - assert selector.database == mock_db + assert failover_strategy.database == mock_db assert state_mock.call_count == 4 @pytest.mark.parametrize( @@ -80,13 +80,13 @@ def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db type(mock_db.circuit).state = state_mock retry = Retry(ExponentialBackoff(cap=1), 3) - selector = WeightBasedDatabaseSelector(retry=retry) - selector.add_database(mock_db) - selector.add_database(mock_db1) - selector.add_database(mock_db2) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.add_database(mock_db) + failover_strategy.add_database(mock_db1) + failover_strategy.add_database(mock_db2) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert selector.database + assert failover_strategy.database assert state_mock.call_count == 4 @@ -103,10 +103,10 @@ def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db ) def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): retry = Retry(NoBackoff(), 0) - selector = WeightBasedDatabaseSelector(retry=retry) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert selector.database + assert failover_strategy.database @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -121,10 +121,10 @@ def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): ) def test_add_database_return_valid_database(self, mock_db, mock_db1, mock_db2): retry = Retry(ExponentialBackoff(cap=1), 3) - selector = WeightBasedDatabaseSelector(retry=retry) - selector.add_database(mock_db) - selector.add_database(mock_db2) - assert selector.database == mock_db2 + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.add_database(mock_db) + failover_strategy.add_database(mock_db2) + assert failover_strategy.database == mock_db2 - selector.add_database(mock_db1) - assert selector.database == mock_db1 \ No newline at end of file + failover_strategy.add_database(mock_db1) + assert failover_strategy.database == mock_db1 \ No newline at end of file From b3ad8da6335955fe5c0f1a59b9b39aecb0c119a3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 18 Jun 2025 16:34:11 +0300 Subject: [PATCH 07/22] Added CommandExecutor --- redis/multidb/command_executor.py | 163 +++++++++++++++++++ redis/multidb/config.py | 2 +- redis/multidb/event.py | 28 ++++ redis/multidb/failure_detector.py | 10 +- tests/conftest.py | 6 + tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_command_executor.py | 169 ++++++++++++++++++++ 7 files changed, 383 insertions(+), 7 deletions(-) create mode 100644 redis/multidb/command_executor.py create mode 100644 tests/test_multidb/test_command_executor.py diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py new file mode 100644 index 0000000000..cfebf10492 --- /dev/null +++ b/redis/multidb/command_executor.py @@ -0,0 +1,163 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List, Union + +from redis.event import EventDispatcherInterface, OnCommandFailEvent +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.database import Database +from redis.multidb.circuit import State as CBState +from redis.multidb.event import RegisterCommandFailure +from redis.multidb.failover import FailoverStrategy +from redis.multidb.failure_detector import FailureDetector + + +class CommandExecutor(ABC): + + @property + @abstractmethod + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def databases(self) -> List[Database]: + """Returns a list of databases.""" + pass + + @abstractmethod + def add_database(self, database: Database) -> None: + """Adds new database to the list of databases.""" + pass + + @property + @abstractmethod + def active_database(self) -> Union[Database, None]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: Database) -> None: + """Sets currently active database.""" + pass + + @property + @abstractmethod + def failover_strategy(self) -> FailoverStrategy: + """Returns failover strategy.""" + pass + + @property + @abstractmethod + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" + pass + + @auto_fallback_interval.setter + @abstractmethod + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" + pass + + @abstractmethod + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + +class DefaultCommandExecutor(CommandExecutor): + + def __init__( + self, + failure_detectors: List[FailureDetector], + databases: List[Database], + failover_strategy: FailoverStrategy, + event_dispatcher: EventDispatcherInterface, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + :param failure_detectors: List of failure detectors. + :param databases: List of databases. + :param failover_strategy: Strategy that defines the failover logic. + :param event_dispatcher: Event dispatcher. + :param auto_fallback_interval: Interval between fallback attempts. Fallback to a new database according to + failover_strategy. + """ + self._failure_detectors = failure_detectors + self._databases = databases + self._failover_strategy = failover_strategy + self._event_dispatcher = event_dispatcher + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + self._active_database: Union[Database, None] = None + self._setup_event_dispatcher() + + @property + def failure_detectors(self) -> List[FailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def databases(self) -> List[Database]: + return self._databases + + def add_database(self, database: Database) -> None: + self._failover_strategy.add_database(database) + self._databases.append(database) + + @property + def active_database(self) -> Union[Database, None]: + return self._active_database + + @active_database.setter + def active_database(self, database: Database) -> None: + self._active_database = database + + @property + def failover_strategy(self) -> FailoverStrategy: + return self._failover_strategy + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def execute_command(self, *args, **options): + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + self._active_database = self._failover_strategy.database + self._schedule_next_fallback() + + return self._active_database.client.execute_command(*args, **options) + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + + def _setup_event_dispatcher(self): + """ + Registers command failure event listener. + """ + event_listener = RegisterCommandFailure(self._failure_detectors, self._databases) + self._event_dispatcher.register_listeners({ + OnCommandFailEvent: [event_listener], + }) \ No newline at end of file diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 2606dad291..97dd1d8c95 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -56,7 +56,7 @@ class MultiDbConfig: health_checks: List[HealthCheck] = field(default_factory=default_health_checks) health_check_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL failover_strategy: FailoverStrategy = field(default_factory=default_failover_strategy) - auto_fallback_interval: int = DEFAULT_AUTO_FALLBACK_INTERVAL + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL def databases(self) -> List[AbstractDatabase]: databases = [] diff --git a/redis/multidb/event.py b/redis/multidb/event.py index e69de29bb2..ad6a275bb5 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -0,0 +1,28 @@ +from typing import List + +from redis.event import EventListenerInterface, OnCommandFailEvent +from redis.multidb.database import Database +from redis.multidb.failure_detector import FailureDetector + + +class RegisterCommandFailure(EventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + def __init__(self, failure_detectors: List[FailureDetector], databases: List[Database]): + self._failure_detectors = failure_detectors + self._databases = databases + + def listen(self, event: OnCommandFailEvent) -> None: + matching_database = None + + for database in self._databases: + if event.client == database.client: + matching_database = database + break + + if matching_database is None: + return + + for failure_detector in self._failure_detectors: + failure_detector.register_failure(matching_database, event.exception, event.command) diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 1d3075935b..649bd3a915 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -34,7 +34,7 @@ def __init__( self._error_types = error_types self._start_time: datetime = datetime.now() self._end_time: datetime = self._start_time + timedelta(seconds=self._duration) - self._failures_within_duration: Dict[Exception, Dict[datetime, tuple]] = {} + self._failures_within_duration: List[tuple[datetime, tuple]] = [] def register_failure(self, database, exception: Exception, cmd: tuple) -> None: failure_time = datetime.now() @@ -44,18 +44,18 @@ def register_failure(self, database, exception: Exception, cmd: tuple) -> None: if self._error_types: if type(exception) in self._error_types: - self._failures_within_duration[exception] = {datetime.now(): cmd} + self._failures_within_duration.append((datetime.now(), cmd)) else: - self._failures_within_duration[exception] = {datetime.now(): cmd} + self._failures_within_duration.append((datetime.now(), cmd)) self._check_threshold(database) def _check_threshold(self, database): - if len(self._failures_within_duration.keys()) >= self._threshold: + if len(self._failures_within_duration) >= self._threshold: database.circuit.state = CBState.OPEN self._reset() def _reset(self) -> None: self._start_time = datetime.now() self._end_time = self._start_time + timedelta(seconds=self._duration) - self._failures_within_duration = {} \ No newline at end of file + self._failures_within_duration = [] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 7eaccb1acb..fc316ea720 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,7 @@ ) from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url from redis.credentials import CredentialProvider +from redis.event import EventDispatcherInterface from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_tls_certificates @@ -581,6 +582,11 @@ def mock_connection() -> ConnectionInterface: mock_connection = Mock(spec=ConnectionInterface) return mock_connection +@pytest.fixture() +def mock_ed() -> EventDispatcherInterface: + mock_ed = Mock(spec=EventDispatcherInterface) + return mock_ed + @pytest.fixture() def cache_key(request) -> CacheKey: diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 197d74fa12..2cfc22d37f 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -5,16 +5,26 @@ from redis import Redis from redis.multidb.circuit import CircuitBreaker, State as CBState from redis.multidb.database import Database, State +from redis.multidb.failover import FailoverStrategy +from redis.multidb.failure_detector import FailureDetector @pytest.fixture() def mock_client() -> Redis: return Mock(spec=Redis) -@pytest.fixture(scope='function') +@pytest.fixture() def mock_cb() -> CircuitBreaker: return Mock(spec=CircuitBreaker) +@pytest.fixture() +def mock_fd() -> FailureDetector: + return Mock(spec=FailureDetector) + +@pytest.fixture() +def mock_fs() -> FailoverStrategy: + return Mock(spec=FailoverStrategy) + @pytest.fixture() def mock_db(request) -> Database: db = Mock(spec=Database) diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..2907132235 --- /dev/null +++ b/tests/test_multidb/test_command_executor.py @@ -0,0 +1,169 @@ +from time import sleep +from unittest.mock import PropertyMock + +import pytest + +from redis.event import EventDispatcher, OnCommandFailEvent +from redis.multidb.circuit import State as CBState +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.failure_detector import CommandFailureDetector + + +class TestDefaultCommandExecutor: + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=[mock_db, mock_db1, mock_db2], + failover_strategy=mock_fs, + event_dispatcher=mock_ed + ) + + executor.active_database = mock_db1 + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + + executor.active_database = mock_db2 + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=[mock_db, mock_db1, mock_db2], + failover_strategy=mock_fs, + event_dispatcher=mock_ed + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.circuit.state = CBState.OPEN + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=[mock_db, mock_db1, mock_db2], + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.weight = 0.1 + sleep(0.15) + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + mock_db1.weight = 0.7 + sleep(0.15) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 5 + fd = CommandFailureDetector(threshold, 1) + ed = EventDispatcher() + + # Event fired if command against mock_db1 would fail + command_fail_event = OnCommandFailEvent( + command=('SET', 'key', 'value'), + exception=Exception(), + client=mock_db1.client + ) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=[mock_db, mock_db1, mock_db2], + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + + # Simulate failing command events that lead to a failure detection + for i in range(threshold): + ed.dispatch(command_fail_event) + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + + command_fail_event = OnCommandFailEvent( + command=('SET', 'key', 'value'), + exception=Exception(), + client=mock_db2.client + ) + + for i in range(threshold): + ed.dispatch(command_fail_event) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_selector.call_count == 3 \ No newline at end of file From 3a1dc9cbc09e956f7705e19c01df86fa3a5c203a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 18 Jun 2025 16:49:10 +0300 Subject: [PATCH 08/22] Updated healthcheck to close circuit on success --- redis/multidb/healthcheck.py | 2 ++ tests/test_multidb/test_healthcheck.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 152b095b09..13bdfd6399 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -53,6 +53,8 @@ def check_health(self, database) -> bool: if not is_healthy: database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED return is_healthy except Exception: diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 50a16be73a..22b033ae35 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -43,4 +43,14 @@ def test_database_is_unhealthy_on_exceeded_healthcheck_retries(self, mock_client assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 4 - assert db.circuit.state == CBState.OPEN \ No newline at end of file + assert db.circuit.state == CBState.OPEN + + def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + + assert hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 + assert db.circuit.state == CBState.CLOSED \ No newline at end of file From 9bb92350a406543e3fbf83e795db7b812e9d099c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Jun 2025 11:30:07 +0300 Subject: [PATCH 09/22] Added thread-safeness --- redis/data_structure.py | 71 +++++++++++-------- redis/multidb/command_executor.py | 6 +- redis/multidb/failover.py | 1 + redis/multidb/failure_detector.py | 13 ++-- .../{test_multidb => }/test_data_structure.py | 34 ++++++++- 5 files changed, 90 insertions(+), 35 deletions(-) rename tests/{test_multidb => }/test_data_structure.py (55%) diff --git a/redis/data_structure.py b/redis/data_structure.py index 06ed1814e9..cc7202f6c9 100644 --- a/redis/data_structure.py +++ b/redis/data_structure.py @@ -1,57 +1,72 @@ +import threading from typing import List class WeightedList: + """ + Thread-safe weighted list. + """ def __init__(self): self._items = [] + self._lock = threading.RLock() def add(self, item, weight: float) -> None: """Add item with weight, maintaining sorted order""" - # Find insertion point using binary search - left, right = 0, len(self._items) - while left < right: - mid = (left + right) // 2 - if self._items[mid][0] < weight: - right = mid - else: - left = mid + 1 + with self._lock: + # Find insertion point using binary search + left, right = 0, len(self._items) + while left < right: + mid = (left + right) // 2 + if self._items[mid][0] < weight: + right = mid + else: + left = mid + 1 - self._items.insert(left, (weight, item)) + self._items.insert(left, (weight, item)) def remove(self, item): """Remove first occurrence of item""" - for i, (weight, stored_item) in enumerate(self._items): - if stored_item == item: - self._items.pop(i) - return weight - raise ValueError("Item not found") + with self._lock: + for i, (weight, stored_item) in enumerate(self._items): + if stored_item == item: + self._items.pop(i) + return weight + raise ValueError("Item not found") def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple]: """Get all items within weight range""" - result = [] - for weight, item in self._items: - if min_weight <= weight <= max_weight: - result.append((item, weight)) - return result + with self._lock: + result = [] + for weight, item in self._items: + if min_weight <= weight <= max_weight: + result.append((item, weight)) + return result def get_top_n(self, n: int) -> List[tuple]: """Get top N the highest weighted items""" - return [(item, weight) for weight, item in self._items[:n]] + with self._lock: + return [(item, weight) for weight, item in self._items[:n]] def update_weight(self, item, new_weight: float): - """Update weight of an item""" - old_weight = self.remove(item) - self.add(item, new_weight) - return old_weight + with self._lock: + """Update weight of an item""" + old_weight = self.remove(item) + self.add(item, new_weight) + return old_weight def __iter__(self): """Iterate in descending weight order""" - for weight, item in self._items: + with self._lock: + items_copy = self._items.copy() # Create snapshot as lock released after each 'yield' + + for weight, item in items_copy: yield item, weight def __len__(self): - return len(self._items) + with self._lock: + return len(self._items) def __getitem__(self, index): - weight, item = self._items[index] - return item, weight \ No newline at end of file + with self._lock: + weight, item = self._items[index] + return item, weight \ No newline at end of file diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index cfebf10492..dff40d625e 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -145,7 +145,11 @@ def execute_command(self, *args, **options): self._active_database = self._failover_strategy.database self._schedule_next_fallback() - return self._active_database.client.execute_command(*args, **options) + try: + return self._active_database.client.execute_command(*args, **options) + except Exception: + # Retry until failure detector will trigger opening of circuit + return self.execute_command(*args, **options) def _schedule_next_fallback(self) -> None: if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index bef810e1ae..07bbe618af 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,3 +1,4 @@ +import threading from abc import ABC, abstractmethod from redis.data_structure import WeightedList diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 649bd3a915..df949152c0 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -1,6 +1,7 @@ +import threading from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Dict, Type +from typing import List, Type from typing_extensions import Optional @@ -35,6 +36,7 @@ def __init__( self._start_time: datetime = datetime.now() self._end_time: datetime = self._start_time + timedelta(seconds=self._duration) self._failures_within_duration: List[tuple[datetime, tuple]] = [] + self._lock = threading.RLock() def register_failure(self, database, exception: Exception, cmd: tuple) -> None: failure_time = datetime.now() @@ -42,11 +44,12 @@ def register_failure(self, database, exception: Exception, cmd: tuple) -> None: if not self._start_time < failure_time < self._end_time: self._reset() - if self._error_types: - if type(exception) in self._error_types: + with self._lock: + if self._error_types: + if type(exception) in self._error_types: + self._failures_within_duration.append((datetime.now(), cmd)) + else: self._failures_within_duration.append((datetime.now(), cmd)) - else: - self._failures_within_duration.append((datetime.now(), cmd)) self._check_threshold(database) diff --git a/tests/test_multidb/test_data_structure.py b/tests/test_data_structure.py similarity index 55% rename from tests/test_multidb/test_data_structure.py rename to tests/test_data_structure.py index 832c661ac0..9e8058c85e 100644 --- a/tests/test_multidb/test_data_structure.py +++ b/tests/test_data_structure.py @@ -1,3 +1,8 @@ +import concurrent +import random +from concurrent.futures import ThreadPoolExecutor +from time import sleep + from redis.data_structure import WeightedList @@ -44,4 +49,31 @@ def test_update_weights(self): wlist.update_weight('item2', 5.0) - assert wlist.get_top_n(4) == [('item2', 5.0), ('item3', 4.0), ('item4', 4.0), ('item1', 3.0)] \ No newline at end of file + assert wlist.get_top_n(4) == [('item2', 5.0), ('item3', 4.0), ('item4', 4.0), ('item1', 3.0)] + + def test_thread_safety(self) -> None: + """Test thread safety with concurrent operations""" + wl = WeightedList() + + def worker(worker_id): + for i in range(100): + # Add items + wl.add(f"item_{worker_id}_{i}", random.uniform(0, 100)) + + # Read operations + try: + length = len(wl) + if length > 0: + top_items = wl.get_top_n(min(5, length)) + items_in_range = wl.get_by_weight_range(20, 80) + except Exception as e: + print(f"Error in worker {worker_id}: {e}") + + sleep(0.001) # Small delay + + # Run multiple workers concurrently + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + concurrent.futures.wait(futures) + + assert len(wl) == 500 \ No newline at end of file From 3218e36498486270910a7c0bb1c2db535969e563 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Jun 2025 11:41:37 +0300 Subject: [PATCH 10/22] Added missing thread-safeness --- redis/multidb/failure_detector.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index df949152c0..7cb5d5db07 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -54,11 +54,13 @@ def register_failure(self, database, exception: Exception, cmd: tuple) -> None: self._check_threshold(database) def _check_threshold(self, database): - if len(self._failures_within_duration) >= self._threshold: - database.circuit.state = CBState.OPEN - self._reset() + with self._lock: + if len(self._failures_within_duration) >= self._threshold: + database.circuit.state = CBState.OPEN + self._reset() def _reset(self) -> None: - self._start_time = datetime.now() - self._end_time = self._start_time + timedelta(seconds=self._duration) - self._failures_within_duration = [] \ No newline at end of file + with self._lock: + self._start_time = datetime.now() + self._end_time = self._start_time + timedelta(seconds=self._duration) + self._failures_within_duration = [] \ No newline at end of file From 4cdb6f4c6d79cf3fcc9df944db7ccf5b1d3022b6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Jun 2025 11:46:59 +0300 Subject: [PATCH 11/22] Added missing thread-safenes for dispatcher --- redis/event.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/redis/event.py b/redis/event.py index 39a924ed16..ebbd6b5553 100644 --- a/redis/event.py +++ b/redis/event.py @@ -85,27 +85,33 @@ def __init__( ], } + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + if event_listeners: self.register_listeners(event_listeners) def dispatch(self, event: object): - listeners = self._event_listeners_mapping.get(type(event), []) + with self._lock: + listeners = self._event_listeners_mapping.get(type(event), []) - for listener in listeners: - listener.listen(event) + for listener in listeners: + listener.listen(event) async def dispatch_async(self, event: object): - listeners = self._event_listeners_mapping.get(type(event), []) + with self._async_lock: + listeners = self._event_listeners_mapping.get(type(event), []) - for listener in listeners: - await listener.listen(event) + for listener in listeners: + await listener.listen(event) def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): - for event in event_listeners: - if event in self._event_listeners_mapping: - self._event_listeners_mapping[event] = list(set(self._event_listeners_mapping[event] + event_listeners[event])) - else: - self._event_listeners_mapping[event] = event_listeners[event] + with self._lock: + for event in event_listeners: + if event in self._event_listeners_mapping: + self._event_listeners_mapping[event] = list(set(self._event_listeners_mapping[event] + event_listeners[event])) + else: + self._event_listeners_mapping[event] = event_listeners[event] class AfterConnectionReleasedEvent: From 6914467d71e91916c3523d9616cbe8f5816a0ebd Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Jun 2025 17:31:31 +0300 Subject: [PATCH 12/22] Refactored client to keep databases in WeightedList --- redis/data_structure.py | 7 +- redis/multidb/client.py | 92 +++++++++++++++++++++++++ redis/multidb/command_executor.py | 25 +++---- redis/multidb/config.py | 19 ++++-- redis/multidb/database.py | 3 + redis/multidb/event.py | 6 +- redis/multidb/failover.py | 13 ++-- tests/test_multidb/conftest.py | 37 +++++++++- tests/test_multidb/test_client.py | 100 ++++++++++++++++++++++++++++ tests/test_multidb/test_config.py | 8 +-- tests/test_multidb/test_failover.py | 49 +++++--------- 11 files changed, 286 insertions(+), 73 deletions(-) create mode 100644 redis/multidb/client.py create mode 100644 tests/test_multidb/test_client.py diff --git a/redis/data_structure.py b/redis/data_structure.py index cc7202f6c9..d40d5bb71d 100644 --- a/redis/data_structure.py +++ b/redis/data_structure.py @@ -1,8 +1,9 @@ import threading -from typing import List +from typing import List, Any, TypeVar, Generic +T = TypeVar('T') -class WeightedList: +class WeightedList(Generic[T]): """ Thread-safe weighted list. """ @@ -66,7 +67,7 @@ def __len__(self): with self._lock: return len(self._items) - def __getitem__(self, index): + def __getitem__(self, index) -> tuple[Any, int]: with self._lock: weight, item = self._items[index] return item, weight \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py new file mode 100644 index 0000000000..bd2d56fe28 --- /dev/null +++ b/redis/multidb/client.py @@ -0,0 +1,92 @@ +import threading + +from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.config import MultiDbConfig +from redis.multidb.circuit import State as CBState +from redis.multidb.database import State as DBState, Database, AbstractDatabase +from redis.multidb.exception import NoValidDatabaseException + + +class MultiDBClient(RedisModuleCommands, CoreCommands, SentinelCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.health_checks + self._health_check_interval = config.health_check_interval + self._failure_detectors = config.failure_detectors + self._failover_strategy = config.failover_strategy + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + failover_strategy=self._failover_strategy, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self._initialized = False + self._lock = threading.RLock() + + def _initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + + is_active_db = False + + for database, weight in self._databases: + self._check_db_health(database) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db: + database.state = DBState.ACTIVE + self._command_executor.active_database = database + is_active_db = True + elif database.circuit.state == CBState.CLOSED and is_active_db: + database.state = DBState.PASSIVE + else: + database.state = DBState.DISCONNECTED + + if not is_active_db: + raise NoValidDatabaseException('Initial connection failed - no active database found') + + self._initialized = True + + def add_database(self, database: Database): + """ + Adds a new database to the database list. + """ + with self._lock: + if database in self._databases: + raise ValueError('Given database already exists') + + + def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self._initialized: + self._initialize() + + with self._lock: + return self._command_executor.execute_command(*args, **options) + + def _check_db_health(self, database: AbstractDatabase) -> None: + """ + Runs health checks on the given database until first failure. + """ + is_healthy = True + + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break + + is_healthy = health_check.check_health(database) + diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index dff40d625e..5c28e03f23 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Union +from typing import List, Union, Optional from redis.event import EventDispatcherInterface, OnCommandFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database +from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure from redis.multidb.failover import FailoverStrategy @@ -26,15 +26,10 @@ def add_failure_detector(self, failure_detector: FailureDetector) -> None: @property @abstractmethod - def databases(self) -> List[Database]: + def databases(self) -> Databases: """Returns a list of databases.""" pass - @abstractmethod - def add_database(self, database: Database) -> None: - """Adds new database to the list of databases.""" - pass - @property @abstractmethod def active_database(self) -> Union[Database, None]: @@ -43,7 +38,7 @@ def active_database(self) -> Union[Database, None]: @active_database.setter @abstractmethod - def active_database(self, database: Database) -> None: + def active_database(self, database: AbstractDatabase) -> None: """Sets currently active database.""" pass @@ -76,7 +71,7 @@ class DefaultCommandExecutor(CommandExecutor): def __init__( self, failure_detectors: List[FailureDetector], - databases: List[Database], + databases: Databases, failover_strategy: FailoverStrategy, event_dispatcher: EventDispatcherInterface, auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, @@ -106,19 +101,15 @@ def add_failure_detector(self, failure_detector: FailureDetector) -> None: self._failure_detectors.append(failure_detector) @property - def databases(self) -> List[Database]: + def databases(self) -> Databases: return self._databases - def add_database(self, database: Database) -> None: - self._failover_strategy.add_database(database) - self._databases.append(database) - @property - def active_database(self) -> Union[Database, None]: + def active_database(self) -> Optional[AbstractDatabase]: return self._active_database @active_database.setter - def active_database(self, database: Database) -> None: + def active_database(self, database: AbstractDatabase) -> None: self._active_database = database @property diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 97dd1d8c95..755c6d68e1 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -1,13 +1,15 @@ from dataclasses import dataclass, field -from typing import List, Type, Union +from typing import List, Type, Union, Set import pybreaker from redis import Redis, Sentinel from redis.asyncio import RedisCluster from redis.backoff import ExponentialWithJitterBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcher, EventDispatcherInterface from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter -from redis.multidb.database import Database, AbstractDatabase +from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy @@ -42,6 +44,9 @@ def default_circuit_breaker() -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=DEFAULT_GRACE_PERIOD) return PBCircuitBreakerAdapter(circuit_breaker) +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + @dataclass class DatabaseConfig: client_kwargs: dict @@ -57,14 +62,16 @@ class MultiDbConfig: health_check_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL failover_strategy: FailoverStrategy = field(default_factory=default_failover_strategy) auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) - def databases(self) -> List[AbstractDatabase]: - databases = [] + def databases(self) -> Databases: + databases = WeightedList() for database_config in self.databases_config: client = self.client_class(**database_config.client_kwargs) - databases.append( - Database(client=client, circuit=database_config.circuit, weight=database_config.weight) + databases.add( + Database(client=client, circuit=database_config.circuit, weight=database_config.weight), + database_config.weight ) return databases diff --git a/redis/multidb/database.py b/redis/multidb/database.py index d4b3be7c78..469c53c309 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -4,6 +4,7 @@ from typing import Union from redis import RedisCluster, Sentinel +from redis.data_structure import WeightedList from redis.multidb.circuit import CircuitBreaker class State(Enum): @@ -60,6 +61,8 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass +Databases = WeightedList[tuple[AbstractDatabase, int]] + class Database(AbstractDatabase): def __init__( self, diff --git a/redis/multidb/event.py b/redis/multidb/event.py index ad6a275bb5..6e8482acb6 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,7 +1,7 @@ -from typing import List +from typing import List, Set from redis.event import EventListenerInterface, OnCommandFailEvent -from redis.multidb.database import Database +from redis.multidb.config import Databases from redis.multidb.failure_detector import FailureDetector @@ -9,7 +9,7 @@ class RegisterCommandFailure(EventListenerInterface): """ Event listener that registers command failures and passing it to the failure detectors. """ - def __init__(self, failure_detectors: List[FailureDetector], databases: List[Database]): + def __init__(self, failure_detectors: List[FailureDetector], databases: Databases): self._failure_detectors = failure_detectors self._databases = databases diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index 07bbe618af..8beb953cbe 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,7 +1,7 @@ -import threading from abc import ABC, abstractmethod from redis.data_structure import WeightedList +from redis.multidb.database import Databases from redis.multidb.database import AbstractDatabase from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException @@ -17,18 +17,17 @@ def database(self) -> AbstractDatabase: pass @abstractmethod - def add_database(self, database: AbstractDatabase) -> None: - """Add the database.""" + def set_databases(self, databases: Databases) -> None: + """Set the databases.""" pass - class WeightBasedFailoverStrategy(FailoverStrategy): """ Choose the active database with the highest weight. """ def __init__( self, - retry: Retry, + retry: Retry ): self._retry = retry self._retry.update_supported_errors([NoValidDatabaseException]) @@ -41,8 +40,8 @@ def database(self) -> AbstractDatabase: lambda _: self._dummy_fail() ) - def add_database(self, database: AbstractDatabase) -> None: - self._databases.add(database, database.weight) + def set_databases(self, databases: Databases) -> None: + self._databases = databases def _get_active_database(self) -> AbstractDatabase: for database, _ in self._databases: diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 2cfc22d37f..3326a29486 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -3,10 +3,15 @@ import pytest from redis import Redis +from redis.data_structure import WeightedList from redis.multidb.circuit import CircuitBreaker, State as CBState -from redis.multidb.database import Database, State +from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.database import Database, State, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck +from tests.conftest import mock_ed @pytest.fixture() @@ -25,6 +30,10 @@ def mock_fd() -> FailureDetector: def mock_fs() -> FailoverStrategy: return Mock(spec=FailoverStrategy) +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + @pytest.fixture() def mock_db(request) -> Database: db = Mock(spec=Database) @@ -68,4 +77,28 @@ def mock_db2(request) -> Database: mock_cb.state = cb.get("state", CBState.CLOSED) db.circuit = mock_cb - return db \ No newline at end of file + return db + +@pytest.fixture() +def mock_multi_db_config( + request, mock_fd, mock_fs, mock_hc, mock_ed +) -> MultiDbConfig: + hc_interval = request.param.get('hc_interval', None) + if hc_interval is None: + hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL + + auto_fallback_interval = request.param.get('auto_fallback_interval', None) + if auto_fallback_interval is None: + auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_checks=[mock_hc], + health_check_interval=hc_interval, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed + ) + + return config \ No newline at end of file diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py new file mode 100644 index 0000000000..3c55f8f0d7 --- /dev/null +++ b/tests/test_multidb/test_client.py @@ -0,0 +1,100 @@ +from unittest.mock import patch + +import pytest + +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.database import State as DBState +from redis.multidb.client import MultiDBClient +from redis.multidb.exception import NoValidDatabaseException + + +class TestMultiDbClient: + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE or mock_db2.state == DBState.DISCONNECTED + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): + client.set('key', 'value') + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.DISCONNECTED + assert mock_db1.state == DBState.DISCONNECTED + assert mock_db2.state == DBState.DISCONNECTED \ No newline at end of file diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 52076fb055..a810eea676 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -26,9 +26,9 @@ def test_default_config(self): assert len(databases) == 3 i = 0 - for db in databases: + for db, weight in databases: assert isinstance(db, Database) - assert db.weight == db_configs[i].weight + assert weight == db_configs[i].weight assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD i+=1 @@ -83,9 +83,9 @@ def test_overridden_config(self): assert len(databases) == 3 i = 0 - for db in databases: + for db, weight in databases: assert isinstance(db, Database) - assert db.weight == db_configs[i].weight + assert weight == db_configs[i].weight assert db.client.connection_pool == mock_connection_pools[i] assert db.circuit.grace_period == grace_period i+=1 diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py index a4239d943a..06390c4e2e 100644 --- a/tests/test_multidb/test_failover.py +++ b/tests/test_multidb/test_failover.py @@ -3,6 +3,7 @@ import pytest from redis.backoff import NoBackoff, ExponentialBackoff +from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -29,10 +30,13 @@ class TestWeightBasedFailoverStrategy: ) def test_get_valid_database(self, mock_db, mock_db1, mock_db2): retry = Retry(NoBackoff(), 0) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.add_database(mock_db) - failover_strategy.add_database(mock_db1) - failover_strategy.add_database(mock_db2) + failover_strategy.set_databases(databases) assert failover_strategy.database == mock_db1 @@ -54,10 +58,12 @@ def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): type(mock_db.circuit).state = state_mock retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.add_database(mock_db) - failover_strategy.add_database(mock_db1) - failover_strategy.add_database(mock_db2) + failover_strategy.set_databases(databases) assert failover_strategy.database == mock_db assert state_mock.call_count == 4 @@ -80,10 +86,12 @@ def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db type(mock_db.circuit).state = state_mock retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.add_database(mock_db) - failover_strategy.add_database(mock_db1) - failover_strategy.add_database(mock_db2) + failover_strategy.set_databases(databases) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): assert failover_strategy.database @@ -106,25 +114,4 @@ def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): failover_strategy = WeightBasedFailoverStrategy(retry=retry) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert failover_strategy.database - - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - ), - ], - indirect=True, - ) - def test_add_database_return_valid_database(self, mock_db, mock_db1, mock_db2): - retry = Retry(ExponentialBackoff(cap=1), 3) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.add_database(mock_db) - failover_strategy.add_database(mock_db2) - assert failover_strategy.database == mock_db2 - - failover_strategy.add_database(mock_db1) - assert failover_strategy.database == mock_db1 \ No newline at end of file + assert failover_strategy.database \ No newline at end of file From 5b94757c4fe839368d15dbe3ceb74b50d4021065 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 26 Jun 2025 12:32:56 +0300 Subject: [PATCH 13/22] Added database CRUD operations --- redis/data_structure.py | 12 +- redis/multidb/client.py | 60 ++++++++-- redis/multidb/failover.py | 4 +- tests/test_data_structure.py | 4 +- tests/test_multidb/conftest.py | 10 +- tests/test_multidb/test_client.py | 193 ++++++++++++++++++++++++++++-- 6 files changed, 253 insertions(+), 30 deletions(-) diff --git a/redis/data_structure.py b/redis/data_structure.py index d40d5bb71d..0c0959499b 100644 --- a/redis/data_structure.py +++ b/redis/data_structure.py @@ -1,5 +1,5 @@ import threading -from typing import List, Any, TypeVar, Generic +from typing import List, Any, TypeVar, Generic, Union T = TypeVar('T') @@ -8,10 +8,10 @@ class WeightedList(Generic[T]): Thread-safe weighted list. """ def __init__(self): - self._items = [] + self._items: List[tuple[Any, Union[int, float]]] = [] self._lock = threading.RLock() - def add(self, item, weight: float) -> None: + def add(self, item: Any, weight: float) -> None: """Add item with weight, maintaining sorted order""" with self._lock: # Find insertion point using binary search @@ -34,7 +34,7 @@ def remove(self, item): return weight raise ValueError("Item not found") - def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple]: + def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Union[int, float]]]: """Get all items within weight range""" with self._lock: result = [] @@ -43,7 +43,7 @@ def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tupl result.append((item, weight)) return result - def get_top_n(self, n: int) -> List[tuple]: + def get_top_n(self, n: int) -> List[tuple[Any, Union[int, float]]]: """Get top N the highest weighted items""" with self._lock: return [(item, weight) for weight, item in self._items[:n]] @@ -67,7 +67,7 @@ def __len__(self): with self._lock: return len(self._items) - def __getitem__(self, index) -> tuple[Any, int]: + def __getitem__(self, index) -> tuple[Any, Union[int, float]]: with self._lock: weight, item = self._items[index] return item, weight \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index bd2d56fe28..4a121d1a21 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,10 +1,8 @@ -import threading - from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig from redis.multidb.circuit import State as CBState -from redis.multidb.database import State as DBState, Database, AbstractDatabase +from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases from redis.multidb.exception import NoValidDatabaseException @@ -30,13 +28,11 @@ def __init__(self, config: MultiDbConfig): auto_fallback_interval=self._auto_fallback_interval, ) self._initialized = False - self._lock = threading.RLock() def _initialize(self): """ Perform initialization of databases to define their initial state. """ - is_active_db = False for database, weight in self._databases: @@ -57,14 +53,61 @@ def _initialize(self): self._initialized = True + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + def add_database(self, database: Database): """ Adds a new database to the database list. """ - with self._lock: - if database in self._databases: + for existing_db, _ in self._databases: + if existing_db == database: raise ValueError('Given database already exists') + self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + + if database.weight > highest_weight and database.circuit.state == CBState.CLOSED: + database.state = DBState.ACTIVE + self._command_executor.active_database = database + highest_weighted_db.state = DBState.PASSIVE + + def remove_database(self, database: Database): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + highest_weighted_db.state = DBState.ACTIVE + self._command_executor.active_database = highest_weighted_db + + def update_database_weight(self, database: Database, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + + if not exists: + raise ValueError('Given database is not a member of database list') + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + + if weight > highest_weight and database.circuit.state == CBState.CLOSED: + database.state = DBState.ACTIVE + self._command_executor.active_database = database + highest_weighted_db.state = DBState.PASSIVE def execute_command(self, *args, **options): """ @@ -73,8 +116,7 @@ def execute_command(self, *args, **options): if not self._initialized: self._initialize() - with self._lock: - return self._command_executor.execute_command(*args, **options) + return self._command_executor.execute_command(*args, **options) def _check_db_health(self, database: AbstractDatabase) -> None: """ diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index 8beb953cbe..f370d25952 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -13,12 +13,12 @@ class FailoverStrategy(ABC): @property @abstractmethod def database(self) -> AbstractDatabase: - """Select the database.""" + """Select the database according to the strategy.""" pass @abstractmethod def set_databases(self, databases: Databases) -> None: - """Set the databases.""" + """Set the databases strategy operates on.""" pass class WeightBasedFailoverStrategy(FailoverStrategy): diff --git a/tests/test_data_structure.py b/tests/test_data_structure.py index 9e8058c85e..31ac5c4316 100644 --- a/tests/test_data_structure.py +++ b/tests/test_data_structure.py @@ -24,8 +24,8 @@ def test_remove_items(self): wlist.add('item3', 4.0) wlist.add('item4', 4.0) - wlist.remove('item2') - wlist.remove('item4') + assert wlist.remove('item2') == 2.0 + assert wlist.remove('item4') == 4.0 assert wlist.get_top_n(4) == [('item3', 4.0), ('item1', 3.0)] diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 3326a29486..ad2057a118 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -101,4 +101,12 @@ def mock_multi_db_config( event_dispatcher=mock_ed ) - return config \ No newline at end of file + return config + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs \ No newline at end of file diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 3c55f8f0d7..8f5ef228f9 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -2,11 +2,11 @@ import pytest -from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.multidb.database import State as DBState from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException +from tests.test_multidb.conftest import create_weighted_list class TestMultiDbClient: @@ -32,10 +32,7 @@ class TestMultiDbClient: def test_execute_command_against_correct_db_on_successful_initialization( self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) + databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object( mock_multi_db_config, @@ -73,10 +70,7 @@ def test_execute_command_against_correct_db_on_successful_initialization( def test_execute_command_throws_exception_on_failed_initialization( self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) + databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object( mock_multi_db_config, @@ -97,4 +91,183 @@ def test_execute_command_throws_exception_on_failed_initialization( assert mock_db.state == DBState.DISCONNECTED assert mock_db1.state == DBState.DISCONNECTED - assert mock_db2.state == DBState.DISCONNECTED \ No newline at end of file + assert mock_db2.state == DBState.DISCONNECTED + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match='Given database already exists'): + client.add_database(mock_db) + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK2' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 2 + + assert mock_db.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + client.add_database(mock_db1) + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.remove_database(mock_db1) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.update_database_weight(mock_db2, 0.8) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE \ No newline at end of file From daba501fb0db525a8a5f3301eb628715f094a760 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 26 Jun 2025 15:18:17 +0300 Subject: [PATCH 14/22] Added on-fly configuration --- redis/multidb/client.py | 59 ++++++- redis/multidb/event.py | 2 +- tests/test_multidb/test_client.py | 164 +++++++++++++++++++- tests/test_multidb/test_command_executor.py | 13 +- 4 files changed, 222 insertions(+), 16 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 4a121d1a21..9cd0adbbd4 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,9 +1,13 @@ +import threading + from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig from redis.multidb.circuit import State as CBState from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck class MultiDBClient(RedisModuleCommands, CoreCommands, SentinelCommands): @@ -28,6 +32,7 @@ def __init__(self, config: MultiDbConfig): auto_fallback_interval=self._auto_fallback_interval, ) self._initialized = False + self._hc_lock = threading.RLock() def _initialize(self): """ @@ -59,7 +64,31 @@ def get_databases(self) -> Databases: """ return self._databases - def add_database(self, database: Database): + def set_active_database(self, database: AbstractDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + + if not exists: + raise ValueError('Given database is not a member of database list') + + self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + highest_weighted_db.state = DBState.PASSIVE + database.state = DBState.ACTIVE + self._command_executor.active_database = database + return + + raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + + def add_database(self, database: AbstractDatabase): """ Adds a new database to the database list. """ @@ -88,7 +117,7 @@ def remove_database(self, database: Database): highest_weighted_db.state = DBState.ACTIVE self._command_executor.active_database = highest_weighted_db - def update_database_weight(self, database: Database, weight: float): + def update_database_weight(self, database: AbstractDatabase, weight: float): """ Updates a database from the database list. """ @@ -109,6 +138,19 @@ def update_database_weight(self, database: Database, weight: float): self._command_executor.active_database = database highest_weighted_db.state = DBState.PASSIVE + def add_failure_detector(self, failure_detector: FailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + with self._hc_lock: + self._health_checks.append(healthcheck) + def execute_command(self, *args, **options): """ Executes a single command and return its result. @@ -124,11 +166,12 @@ def _check_db_health(self, database: AbstractDatabase) -> None: """ is_healthy = True - # Health check will setup circuit state - for health_check in self._health_checks: - if not is_healthy: - # If one of the health checks failed, it's considered unhealthy - break + with self._hc_lock: + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break - is_healthy = health_check.check_health(database) + is_healthy = health_check.check_health(database) diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 6e8482acb6..fea6f3b7c5 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -16,7 +16,7 @@ def __init__(self, failure_detectors: List[FailureDetector], databases: Database def listen(self, event: OnCommandFailEvent) -> None: matching_database = None - for database in self._databases: + for database, _ in self._databases: if event.client == database.client: matching_database = database break diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 8f5ef228f9..5b8e20036a 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -1,11 +1,14 @@ -from unittest.mock import patch +from unittest.mock import patch, Mock import pytest +from redis.event import EventDispatcher, OnCommandFailEvent from redis.multidb.circuit import State as CBState -from redis.multidb.database import State as DBState +from redis.multidb.database import State as DBState, AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck from tests.test_multidb.conftest import create_weighted_list @@ -270,4 +273,159 @@ def test_update_database_weight_to_be_highest( assert mock_db.state == DBState.PASSIVE assert mock_db1.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE \ No newline at end of file + assert mock_db2.state == DBState.ACTIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = OnCommandFailEvent( + command=('SET', 'key', 'value'), + exception=Exception(), + client=mock_db1.client + ) + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=FailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + client.add_health_check(another_hc) + client._check_db_health(mock_db1) + + assert another_hc.check_health.call_count == 1 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db.client.execute_command.return_value = 'OK' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.set_active_database(mock_db) + assert client.set('key', 'value') == 'OK' + + assert mock_db.state == DBState.ACTIVE + assert mock_db1.state == DBState.PASSIVE + assert mock_db2.state == DBState.PASSIVE + + with pytest.raises(ValueError, match='Given database is not a member of database list'): + client.set_active_database(Mock(spec=AbstractDatabase)) + + mock_db1.circuit.state = CBState.OPEN + + with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): + client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 2907132235..54c6d38e1d 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -7,6 +7,7 @@ from redis.multidb.circuit import State as CBState from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.failure_detector import CommandFailureDetector +from tests.test_multidb.conftest import create_weighted_list class TestDefaultCommandExecutor: @@ -24,10 +25,11 @@ class TestDefaultCommandExecutor: def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' + databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( failure_detectors=[mock_fd], - databases=[mock_db, mock_db1, mock_db2], + databases=databases, failover_strategy=mock_fs, event_dispatcher=mock_ed ) @@ -57,10 +59,11 @@ def test_execute_command_automatically_select_active_database( mock_db2.client.execute_command.return_value = 'OK2' mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2]) type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( failure_detectors=[mock_fd], - databases=[mock_db, mock_db1, mock_db2], + databases=databases, failover_strategy=mock_fs, event_dispatcher=mock_ed ) @@ -90,10 +93,11 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( mock_db2.client.execute_command.return_value = 'OK2' mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( failure_detectors=[mock_fd], - databases=[mock_db, mock_db1, mock_db2], + databases=databases, failover_strategy=mock_fs, event_dispatcher=mock_ed, auto_fallback_interval=0.1, @@ -132,6 +136,7 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( threshold = 5 fd = CommandFailureDetector(threshold, 1) ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) # Event fired if command against mock_db1 would fail command_fail_event = OnCommandFailEvent( @@ -142,7 +147,7 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( executor = DefaultCommandExecutor( failure_detectors=[fd], - databases=[mock_db, mock_db1, mock_db2], + databases=databases, failover_strategy=mock_fs, event_dispatcher=ed, auto_fallback_interval=0.1, From 061e5184853d042e9e279748664311a183af7803 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 27 Jun 2025 12:56:25 +0300 Subject: [PATCH 15/22] Added background health checks --- redis/multidb/client.py | 61 ++++++++- redis/multidb/command_executor.py | 5 +- redis/multidb/config.py | 2 +- redis/multidb/healthcheck.py | 4 +- tests/test_multidb/test_client.py | 198 +++++++++++++++++++++++++++++- 5 files changed, 264 insertions(+), 6 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 9cd0adbbd4..845dbd529d 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,3 +1,4 @@ +import asyncio import threading from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands @@ -33,16 +34,43 @@ def __init__(self, config: MultiDbConfig): ) self._initialized = False self._hc_lock = threading.RLock() + self._init_timer = None + self._next_timer = None + + def __del__(self): + if self._init_timer is not None: + self._init_timer.cancel() + if self._next_timer is not None: + self._next_timer.cancel() def _initialize(self): """ Perform initialization of databases to define their initial state. """ + + # Starts recurring + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=_start_event_loop_in_thread, args=(loop,), daemon=True + ) + thread.start() + + # Event to block for initial execution. + init_event = asyncio.Event() + self._init_timer = loop.call_later( + 0, self._run_health_check_recurring, init_event + ) + + # Blocks in thread-safe manner. + asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result() + is_active_db = False for database, weight in self._databases: - self._check_db_health(database) - # Set states according to a weights and circuit state if database.circuit.state == CBState.CLOSED and not is_active_db: database.state = DBState.ACTIVE @@ -175,3 +203,32 @@ def _check_db_health(self, database: AbstractDatabase) -> None: is_healthy = health_check.check_health(database) + def _run_health_check_recurring(self, init_event: asyncio.Event = None): + """ + Runs health checks as recurring task. + """ + try: + for database, _ in self._databases: + self._check_db_health(database) + + loop = asyncio.get_running_loop() + self._next_timer = loop.call_later( + self._health_check_interval, + self._run_health_check_recurring, + None + ) + finally: + if init_event: + init_event.set() + +def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop): + """ + Starts event loop in a thread. + Used to be able to schedule tasks using loop.call_later. + + :param event_loop: + :return: + """ + asyncio.set_event_loop(event_loop) + event_loop.run_forever() + diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 5c28e03f23..b17e3fcc3d 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,7 +1,9 @@ +import socket from abc import ABC, abstractmethod from datetime import datetime, timedelta from typing import List, Union, Optional +from redis.exceptions import ConnectionError, TimeoutError from redis.event import EventDispatcherInterface, OnCommandFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases @@ -92,6 +94,7 @@ def __init__( self._next_fallback_attempt: datetime self._active_database: Union[Database, None] = None self._setup_event_dispatcher() + self._schedule_next_fallback() @property def failure_detectors(self) -> List[FailureDetector]: @@ -138,7 +141,7 @@ def execute_command(self, *args, **options): try: return self._active_database.client.execute_command(*args, **options) - except Exception: + except (ConnectionError, TimeoutError, socket.timeout): # Retry until failure detector will trigger opening of circuit return self.execute_command(*args, **options) diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 755c6d68e1..84ea37db13 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -59,7 +59,7 @@ class MultiDbConfig: client_class: Type[Union[Redis, RedisCluster, Sentinel]] = Redis failure_detectors: List[FailureDetector] = field(default_factory=default_failure_detectors) health_checks: List[HealthCheck] = field(default_factory=default_health_checks) - health_check_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL failover_strategy: FailoverStrategy = field(default_factory=default_failover_strategy) auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 13bdfd6399..ad2130f1e0 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,4 +1,6 @@ +import socket from abc import abstractmethod, ABC +from redis.exceptions import ConnectionError, TimeoutError from redis.retry import Retry from redis.multidb.circuit import State as CBState @@ -57,7 +59,7 @@ def check_health(self, database) -> bool: database.circuit.state = CBState.CLOSED return is_healthy - except Exception: + except (ConnectionError, TimeoutError, socket.timeout): database.circuit.state = CBState.OPEN return False diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 5b8e20036a..e90e072848 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -1,14 +1,19 @@ +from time import sleep from unittest.mock import patch, Mock import pytest from redis.event import EventDispatcher, OnCommandFailEvent from redis.multidb.circuit import State as CBState +from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF from redis.multidb.database import State as DBState, AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list @@ -58,6 +63,197 @@ def test_execute_command_against_correct_db_on_successful_initialization( assert mock_db1.state == DBState.ACTIVE assert mock_db2.state == DBState.PASSIVE or mock_db2.state == DBState.DISCONNECTED + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.15) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.1) + + assert client.set('key', 'value') == 'OK' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.1) + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.15) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.1) + + assert client.set('key', 'value') == 'OK' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.1) + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.15) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.22) + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ From a5627740ae14340ad2f8aba62a2407d27e20fb9d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 2 Jul 2025 11:10:53 +0300 Subject: [PATCH 16/22] Added background healthcheck + half-open event --- redis/client.py | 3 +- redis/multidb/circuit.py | 28 ++++++++- redis/multidb/client.py | 10 ++- redis/multidb/command_executor.py | 5 +- redis/multidb/config.py | 6 +- redis/multidb/database.py | 3 +- redis/multidb/event.py | 2 +- redis/multidb/healthcheck.py | 11 ++-- tests/test_multidb/test_circuit.py | 13 +++- tests/test_multidb/test_client.py | 97 ++++-------------------------- 10 files changed, 75 insertions(+), 103 deletions(-) diff --git a/redis/client.py b/redis/client.py index 8ebe2e38b3..b821987b28 100755 --- a/redis/client.py +++ b/redis/client.py @@ -45,7 +45,7 @@ AfterPubSubConnectionInstantiationEvent, AfterSingleConnectionInstantiationEvent, ClientType, - EventDispatcher, OnCommandFailEvent, + EventDispatcher, ) from redis.exceptions import ( ConnectionError, @@ -616,7 +616,6 @@ def _close_connection(self, conn, error, *args) -> None: do a health check as part of the send_command logic(on connection level). """ - self._event_dispatcher.dispatch(OnCommandFailEvent(args, error, self)) conn.disconnect() # COMMAND EXECUTION AND PROTOCOL PARSING diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 1f3d00e81c..9211173c83 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -33,6 +33,18 @@ def state(self, state: State): """Set current state of the circuit.""" pass + @property + @abstractmethod + def database(self): + """Database associated with this circuit.""" + pass + + @database.setter + @abstractmethod + def database(self, database): + """Set database associated with this circuit.""" + pass + @abstractmethod def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" @@ -41,13 +53,16 @@ def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]) class PBListener(pybreaker.CircuitBreakerListener): def __init__( self, - cb: Callable[[CircuitBreaker, State, State], None] + cb: Callable[[CircuitBreaker, State, State], None], + database, ): """Wrapper for callback to be compatible with pybreaker implementation.""" self._cb = cb + self._database = database def state_change(self, cb, old_state, new_state): cb = PBCircuitBreakerAdapter(cb) + cb.database = self._database old_state = State(value=old_state.name) new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) @@ -62,6 +77,7 @@ def __init__(self, cb: pybreaker.CircuitBreaker): State.OPEN: self._cb.open, State.HALF_OPEN: self._cb.half_open, } + self._database = None @property def grace_period(self) -> float: @@ -79,6 +95,14 @@ def state(self) -> State: def state(self, state: State): self._state_pb_mapper[state]() + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): - listener = PBListener(cb) + listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 845dbd529d..1c80426348 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -4,7 +4,7 @@ from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig -from redis.multidb.circuit import State as CBState +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector @@ -71,6 +71,9 @@ def _initialize(self): is_active_db = False for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + # Set states according to a weights and circuit state if database.circuit.state == CBState.CLOSED and not is_active_db: database.state = DBState.ACTIVE @@ -221,6 +224,11 @@ def _run_health_check_recurring(self, init_event: asyncio.Event = None): if init_event: init_event.set() + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + if new_state == CBState.HALF_OPEN: + self._check_db_health(circuit.database) + return + def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop): """ Starts event loop in a thread. diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index b17e3fcc3d..60dbeca36b 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -141,7 +141,10 @@ def execute_command(self, *args, **options): try: return self._active_database.client.execute_command(*args, **options) - except (ConnectionError, TimeoutError, socket.timeout): + except (ConnectionError, TimeoutError, socket.timeout) as e: + # Register command failure + self._event_dispatcher.dispatch(OnCommandFailEvent(args, e, self.active_database.client)) + # Retry until failure detector will trigger opening of circuit return self.execute_command(*args, **options) diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 84ea37db13..97dd3ab483 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Type, Union, Set +from typing import List, Type, Union, Set, Optional import pybreaker @@ -15,7 +15,7 @@ from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy from redis.retry import Retry -DEFAULT_GRACE_PERIOD = 1 +DEFAULT_GRACE_PERIOD = 5 DEFAULT_HEALTH_CHECK_INTERVAL = 5 DEFAULT_HEALTH_CHECK_RETRIES = 3 DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) @@ -49,8 +49,8 @@ def default_event_dispatcher() -> EventDispatcherInterface: @dataclass class DatabaseConfig: - client_kwargs: dict weight: float + client_kwargs: dict = field(default_factory=dict) circuit: CircuitBreaker = field(default_factory=default_circuit_breaker) @dataclass diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 469c53c309..a644956d37 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -61,7 +61,7 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass -Databases = WeightedList[tuple[AbstractDatabase, int]] +Databases = WeightedList[tuple[AbstractDatabase, Union[int, float]]] class Database(AbstractDatabase): def __init__( @@ -79,6 +79,7 @@ def __init__( """ self._client = client self._cb = circuit + self._cb.database = self self._weight = weight self._state = state diff --git a/redis/multidb/event.py b/redis/multidb/event.py index fea6f3b7c5..3d366dab77 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,4 +1,4 @@ -from typing import List, Set +from typing import List from redis.event import EventListenerInterface, OnCommandFailEvent from redis.multidb.config import Databases diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index ad2130f1e0..52d86b1c66 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -53,20 +53,21 @@ def check_health(self, database) -> bool: lambda _: self._dummy_fail() ) - if not is_healthy: + if not is_healthy and database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED return is_healthy except (ConnectionError, TimeoutError, socket.timeout): - database.circuit.state = CBState.OPEN + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN return False def _returns_echoed_message(self, database) -> bool: - expected_message = "healthcheck" - actual_message = database.client.execute_command('ECHO', expected_message) - return actual_message == expected_message + expected_message = ["healthcheck", b"healthcheck"] + actual_message = database.client.execute_command('ECHO', "healthcheck") + return actual_message in expected_message def _dummy_fail(self): pass \ No newline at end of file diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index 5ddeacfea7..7dc642373b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,10 +1,18 @@ import pybreaker +import pytest from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker class TestPBCircuitBreaker: - def test_cb_correctly_configured(self): + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, + ], + indirect=True, + ) + def test_cb_correctly_configured(self, mock_db): pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) adapter = PBCircuitBreakerAdapter(cb=pb_circuit) assert adapter.state == CbState.CLOSED @@ -23,6 +31,9 @@ def test_cb_correctly_configured(self): assert adapter.grace_period == 10 + adapter.database = mock_db + assert adapter.database == mock_db + def test_cb_executes_callback_on_state_changed(self): pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) adapter = PBCircuitBreakerAdapter(cb=pb_circuit) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index e90e072848..96ab0c0a64 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -1,10 +1,11 @@ from time import sleep from unittest.mock import patch, Mock +import pybreaker import pytest from redis.event import EventDispatcher, OnCommandFailEvent -from redis.multidb.circuit import State as CBState +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF from redis.multidb.database import State as DBState, AbstractDatabase @@ -78,72 +79,18 @@ def test_execute_command_against_correct_db_on_successful_initialization( def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): - databases = create_weighted_list(mock_db, mock_db1, mock_db2) + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] - mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 - client = MultiDBClient(mock_multi_db_config) - assert client.set('key', 'value') == 'OK1' + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - - sleep(0.15) - - assert client.set('key', 'value') == 'OK2' - - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - - sleep(0.1) - - assert client.set('key', 'value') == 'OK' - - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - - sleep(0.1) - - assert client.set('key', 'value') == 'OK1' - - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - - @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', - [ - ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - ), - ], - indirect=True, - ) - def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 - ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object( @@ -166,35 +113,13 @@ def test_execute_command_against_correct_db_on_background_health_check_determine client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.15) - assert client.set('key', 'value') == 'OK2' - - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.1) - assert client.set('key', 'value') == 'OK' - - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.1) - assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ From 3ab13674661c2b29e768c716bb3b94c27c06e97b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 3 Jul 2025 13:18:45 +0300 Subject: [PATCH 17/22] Refactored background scheduling --- redis/background.py | 89 ++++++++++++++++++++++++++++++++++++++++ redis/multidb/client.py | 68 ++++++++---------------------- redis/multidb/config.py | 4 +- tests/test_background.py | 60 +++++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 53 deletions(-) create mode 100644 redis/background.py create mode 100644 tests/test_background.py diff --git a/redis/background.py b/redis/background.py new file mode 100644 index 0000000000..6466649859 --- /dev/null +++ b/redis/background.py @@ -0,0 +1,89 @@ +import asyncio +import threading +from typing import Callable + +class BackgroundScheduler: + """ + Schedules background tasks execution either in separate thread or in the running event loop. + """ + def __init__(self): + self._next_timer = None + + def __del__(self): + if self._next_timer: + self._next_timer.cancel() + + def run_once(self, delay: float, callback: Callable, *args): + """ + Runs callable task once after certain delay in seconds. + """ + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later, delay, callback, *args), + daemon=True + ) + thread.start() + + def run_recurring( + self, + interval: float, + callback: Callable, + *args + ): + """ + Runs recurring callable task with given interval in seconds. + """ + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later_recurring, interval, callback, *args), + daemon=True + ) + thread.start() + + def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args): + self._next_timer = loop.call_later(delay, callback, *args) + + def _call_later_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args + ): + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + def _execute_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args + ): + """ + Executes recurring callable task with given interval in seconds. + """ + callback(*args) + + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + +def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args): + """ + Starts event loop in a thread and schedule callback as soon as event loop is ready. + Used to be able to schedule tasks using loop.call_later. + + :param event_loop: + :return: + """ + asyncio.set_event_loop(event_loop) + event_loop.call_soon(call_soon_cb, event_loop, *args) + event_loop.run_forever() \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 1c80426348..96433bead0 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,9 +1,9 @@ -import asyncio import threading +from redis.background import BackgroundScheduler from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands from redis.multidb.command_executor import DefaultCommandExecutor -from redis.multidb.config import MultiDbConfig +from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases from redis.multidb.exception import NoValidDatabaseException @@ -34,39 +34,21 @@ def __init__(self, config: MultiDbConfig): ) self._initialized = False self._hc_lock = threading.RLock() - self._init_timer = None - self._next_timer = None - - def __del__(self): - if self._init_timer is not None: - self._init_timer.cancel() - if self._next_timer is not None: - self._next_timer.cancel() + self._bg_scheduler = BackgroundScheduler() def _initialize(self): """ Perform initialization of databases to define their initial state. """ - # Starts recurring - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # Run loop in a separate thread to unblock main thread. - loop = asyncio.new_event_loop() - thread = threading.Thread( - target=_start_event_loop_in_thread, args=(loop,), daemon=True - ) - thread.start() - - # Event to block for initial execution. - init_event = asyncio.Event() - self._init_timer = loop.call_later( - 0, self._run_health_check_recurring, init_event - ) + # Initial databases check to define initial state + self._check_databases_health() - # Blocks in thread-safe manner. - asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result() + # Starts recurring health checks on the background. + self._bg_scheduler.run_recurring( + self._health_check_interval, + self._check_databases_health, + ) is_active_db = False @@ -206,37 +188,21 @@ def _check_db_health(self, database: AbstractDatabase) -> None: is_healthy = health_check.check_health(database) - def _run_health_check_recurring(self, init_event: asyncio.Event = None): + def _check_databases_health(self): """ Runs health checks as recurring task. """ - try: - for database, _ in self._databases: - self._check_db_health(database) - - loop = asyncio.get_running_loop() - self._next_timer = loop.call_later( - self._health_check_interval, - self._run_health_check_recurring, - None - ) - finally: - if init_event: - init_event.set() + for database, _ in self._databases: + self._check_db_health(database) def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return -def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop): - """ - Starts event loop in a thread. - Used to be able to schedule tasks using loop.call_later. + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) - :param event_loop: - :return: - """ - asyncio.set_event_loop(event_loop) - event_loop.run_forever() +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 97dd3ab483..a349409e9f 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Type, Union, Set, Optional +from typing import List, Type, Union import pybreaker @@ -15,7 +15,7 @@ from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy from redis.retry import Retry -DEFAULT_GRACE_PERIOD = 5 +DEFAULT_GRACE_PERIOD = 5.0 DEFAULT_HEALTH_CHECK_INTERVAL = 5 DEFAULT_HEALTH_CHECK_RETRIES = 3 DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) diff --git a/tests/test_background.py b/tests/test_background.py new file mode 100644 index 0000000000..4b3a5377c1 --- /dev/null +++ b/tests/test_background.py @@ -0,0 +1,60 @@ +from time import sleep + +import pytest + +from redis.background import BackgroundScheduler + +class TestBackgroundScheduler: + def test_run_once(self): + execute_counter = 0 + one = 'arg1' + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_once(0.1, callback, one, two) + assert execute_counter == 0 + + sleep(0.15) + + assert execute_counter == 1 + + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ] + ) + def test_run_recurring(self, interval, timeout, call_count): + execute_counter = 0 + one = 'arg1' + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_recurring(interval, callback, one, two) + assert execute_counter == 0 + + sleep(timeout) + + assert execute_counter == call_count \ No newline at end of file From badef0e8592d81146ba68b0368f6e2a1a7b5d062 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 7 Jul 2025 12:12:34 +0300 Subject: [PATCH 18/22] Refactored healthchecks --- redis/multidb/client.py | 18 +++++++-- redis/multidb/healthcheck.py | 24 ++---------- tests/test_multidb/test_client.py | 52 +++++++++++++++++++++----- tests/test_multidb/test_healthcheck.py | 17 +-------- 4 files changed, 62 insertions(+), 49 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 96433bead0..2996b9153c 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,6 +1,8 @@ import threading +import socket from redis.background import BackgroundScheduler +from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD @@ -186,7 +188,18 @@ def _check_db_health(self, database: AbstractDatabase) -> None: # If one of the health checks failed, it's considered unhealthy break - is_healthy = health_check.check_health(database) + try: + is_healthy = health_check.check_health(database) + + if not is_healthy and database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + except (ConnectionError, TimeoutError, socket.timeout): + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + is_healthy = False + def _check_databases_health(self): """ @@ -204,5 +217,4 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) def _half_open_circuit(circuit: CircuitBreaker): - circuit.state = CBState.HALF_OPEN - + circuit.state = CBState.HALF_OPEN \ No newline at end of file diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 52d86b1c66..a96b9cf815 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,9 +1,5 @@ -import socket from abc import abstractmethod, ABC -from redis.exceptions import ConnectionError, TimeoutError - from redis.retry import Retry -from redis.multidb.circuit import State as CBState class HealthCheck(ABC): @@ -47,22 +43,10 @@ def __init__( retry=retry, ) def check_health(self, database) -> bool: - try: - is_healthy = self._retry.call_with_retry( - lambda : self._returns_echoed_message(database), - lambda _: self._dummy_fail() - ) - - if not is_healthy and database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - elif is_healthy and database.circuit.state != CBState.CLOSED: - database.circuit.state = CBState.CLOSED - - return is_healthy - except (ConnectionError, TimeoutError, socket.timeout): - if database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - return False + return self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: self._dummy_fail() + ) def _returns_echoed_message(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 96ab0c0a64..8543d72b18 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -28,6 +28,38 @@ class TestMultiDbClient: {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ ( {}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, @@ -35,10 +67,9 @@ class TestMultiDbClient: {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, ), ], - ids=['all closed - highest weight', 'highest weight - open'], indirect=True, ) - def test_execute_command_against_correct_db_on_successful_initialization( + def test_execute_command_against_correct_db_and_closed_circuit( self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) @@ -51,7 +82,7 @@ def test_execute_command_against_correct_db_on_successful_initialization( mock_db1.client.execute_command.return_value = 'OK1' for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + hc.check_health.side_effect = [False, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -60,9 +91,9 @@ def test_execute_command_against_correct_db_on_successful_initialization( for hc in mock_multi_db_config.health_checks: assert hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE or mock_db2.state == DBState.DISCONNECTED + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -277,7 +308,7 @@ def test_add_database_makes_new_database_active( mock_db2.client.execute_command.return_value = 'OK2' for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -327,7 +358,7 @@ def test_remove_highest_weighted_database( mock_db2.client.execute_command.return_value = 'OK2' for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -374,7 +405,7 @@ def test_update_database_weight_to_be_highest( mock_db2.client.execute_command.return_value = 'OK2' for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -546,7 +577,8 @@ def test_set_active_database( with pytest.raises(ValueError, match='Given database is not a member of database list'): client.set_active_database(Mock(spec=AbstractDatabase)) - mock_db1.circuit.state = CBState.OPEN + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 22b033ae35..9601638913 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -13,13 +13,11 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): according to given configuration. """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] - mock_cb.state = CBState.CLOSED hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 - assert db.circuit.state == CBState.CLOSED def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): """ @@ -27,23 +25,11 @@ def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, moc according to given configuration. """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] - mock_cb.state = CBState.CLOSED hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 3 - assert db.circuit.state == CBState.OPEN - - def test_database_is_unhealthy_on_exceeded_healthcheck_retries(self, mock_client, mock_cb): - mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, ConnectionError, ConnectionError] - mock_cb.state = CBState.CLOSED - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) - - assert hc.check_health(db) == False - assert mock_client.execute_command.call_count == 4 - assert db.circuit.state == CBState.OPEN def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] @@ -52,5 +38,4 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 - assert db.circuit.state == CBState.CLOSED \ No newline at end of file + assert mock_client.execute_command.call_count == 3 \ No newline at end of file From fcc60358ebbf1488eb22a1f13f5b424a9434a197 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 15 Jul 2025 10:20:55 +0300 Subject: [PATCH 19/22] Removed code repetitions, fixed weight assignment, added loops enhancement, fixed data structure --- redis/data_structure.py | 24 +++++++++++++----------- redis/event.py | 10 ++++++---- redis/multidb/client.py | 29 +++++++++++++++-------------- redis/multidb/database.py | 4 +++- redis/multidb/failover.py | 2 -- tests/test_multidb/test_client.py | 1 + 6 files changed, 38 insertions(+), 32 deletions(-) diff --git a/redis/data_structure.py b/redis/data_structure.py index 0c0959499b..5b0df7f017 100644 --- a/redis/data_structure.py +++ b/redis/data_structure.py @@ -1,6 +1,8 @@ import threading from typing import List, Any, TypeVar, Generic, Union +from redis.typing import Number + T = TypeVar('T') class WeightedList(Generic[T]): @@ -8,7 +10,7 @@ class WeightedList(Generic[T]): Thread-safe weighted list. """ def __init__(self): - self._items: List[tuple[Any, Union[int, float]]] = [] + self._items: List[tuple[Any, Number]] = [] self._lock = threading.RLock() def add(self, item: Any, weight: float) -> None: @@ -18,35 +20,35 @@ def add(self, item: Any, weight: float) -> None: left, right = 0, len(self._items) while left < right: mid = (left + right) // 2 - if self._items[mid][0] < weight: + if self._items[mid][1] < weight: right = mid else: left = mid + 1 - self._items.insert(left, (weight, item)) + self._items.insert(left, (item, weight)) def remove(self, item): """Remove first occurrence of item""" with self._lock: - for i, (weight, stored_item) in enumerate(self._items): + for i, (stored_item, weight) in enumerate(self._items): if stored_item == item: self._items.pop(i) return weight raise ValueError("Item not found") - def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Union[int, float]]]: + def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Number]]: """Get all items within weight range""" with self._lock: result = [] - for weight, item in self._items: + for item, weight in self._items: if min_weight <= weight <= max_weight: result.append((item, weight)) return result - def get_top_n(self, n: int) -> List[tuple[Any, Union[int, float]]]: + def get_top_n(self, n: int) -> List[tuple[Any, Number]]: """Get top N the highest weighted items""" with self._lock: - return [(item, weight) for weight, item in self._items[:n]] + return [(item, weight) for item, weight in self._items[:n]] def update_weight(self, item, new_weight: float): with self._lock: @@ -60,14 +62,14 @@ def __iter__(self): with self._lock: items_copy = self._items.copy() # Create snapshot as lock released after each 'yield' - for weight, item in items_copy: + for item, weight in items_copy: yield item, weight def __len__(self): with self._lock: return len(self._items) - def __getitem__(self, index) -> tuple[Any, Union[int, float]]: + def __getitem__(self, index) -> tuple[Any, Number]: with self._lock: - weight, item = self._items[index] + item, weight = self._items[index] return item, weight \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index 8bc1bd4f41..fdb42a04d5 100644 --- a/redis/event.py +++ b/redis/event.py @@ -107,11 +107,13 @@ async def dispatch_async(self, event: object): def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): with self._lock: - for event in event_listeners: - if event in self._event_listeners_mapping: - self._event_listeners_mapping[event] = list(set(self._event_listeners_mapping[event] + event_listeners[event])) + for event_type in event_listeners: + if event_type in self._event_listeners_mapping: + self._event_listeners_mapping[event_type] = list( + set(self._event_listeners_mapping[event_type] + event_listeners[event_type]) + ) else: - self._event_listeners_mapping[event] = event_listeners[event] + self._event_listeners_mapping[event_type] = event_listeners[event_type] class AfterConnectionReleasedEvent: diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 2996b9153c..78ce039868 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -52,23 +52,23 @@ def _initialize(self): self._check_databases_health, ) - is_active_db = False + is_active_db_found = False for database, weight in self._databases: # Set on state changed callback for each circuit. database.circuit.on_state_changed(self._on_circuit_state_change_callback) # Set states according to a weights and circuit state - if database.circuit.state == CBState.CLOSED and not is_active_db: + if database.circuit.state == CBState.CLOSED and not is_active_db_found: database.state = DBState.ACTIVE self._command_executor.active_database = database - is_active_db = True - elif database.circuit.state == CBState.CLOSED and is_active_db: + is_active_db_found = True + elif database.circuit.state == CBState.CLOSED and is_active_db_found: database.state = DBState.PASSIVE else: database.state = DBState.DISCONNECTED - if not is_active_db: + if not is_active_db_found: raise NoValidDatabaseException('Initial connection failed - no active database found') self._initialized = True @@ -88,6 +88,7 @@ def set_active_database(self, database: AbstractDatabase) -> None: for existing_db, _ in self._databases: if existing_db == database: exists = True + break if not exists: raise ValueError('Given database is not a member of database list') @@ -115,11 +116,13 @@ def add_database(self, database: AbstractDatabase): highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] self._databases.add(database, database.weight) + self._change_active_database(database, highest_weighted_db) - if database.weight > highest_weight and database.circuit.state == CBState.CLOSED: - database.state = DBState.ACTIVE - self._command_executor.active_database = database - highest_weighted_db.state = DBState.PASSIVE + def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): + if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + new_database.state = DBState.ACTIVE + self._command_executor.active_database = new_database + highest_weight_database.state = DBState.PASSIVE def remove_database(self, database: Database): """ @@ -141,17 +144,15 @@ def update_database_weight(self, database: AbstractDatabase, weight: float): for existing_db, _ in self._databases: if existing_db == database: exists = True + break if not exists: raise ValueError('Given database is not a member of database list') highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] self._databases.update_weight(database, weight) - - if weight > highest_weight and database.circuit.state == CBState.CLOSED: - database.state = DBState.ACTIVE - self._command_executor.active_database = database - highest_weighted_db.state = DBState.PASSIVE + database.weight = weight + self._change_active_database(database, highest_weighted_db) def add_failure_detector(self, failure_detector: FailureDetector): """ diff --git a/redis/multidb/database.py b/redis/multidb/database.py index a644956d37..7a655b151f 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -6,6 +6,8 @@ from redis import RedisCluster, Sentinel from redis.data_structure import WeightedList from redis.multidb.circuit import CircuitBreaker +from redis.typing import Number + class State(Enum): ACTIVE = 0 @@ -61,7 +63,7 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass -Databases = WeightedList[tuple[AbstractDatabase, Union[int, float]]] +Databases = WeightedList[tuple[AbstractDatabase, Number]] class Database(AbstractDatabase): def __init__( diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index f370d25952..a4c825aac1 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -47,8 +47,6 @@ def _get_active_database(self) -> AbstractDatabase: for database, _ in self._databases: if database.circuit.state == CBState.CLOSED: return database - else: - continue raise NoValidDatabaseException('No valid database available for communication') diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 8543d72b18..c2ade264b5 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -420,6 +420,7 @@ def test_update_database_weight_to_be_highest( assert mock_db2.state == DBState.PASSIVE client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 assert client.set('key', 'value') == 'OK2' From d5dc65c6bae2e03088b2713ecc7fa8093d190d3d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 17 Jul 2025 12:31:06 +0300 Subject: [PATCH 20/22] Refactored configuration --- redis/multidb/client.py | 8 +++-- redis/multidb/config.py | 59 ++++++++++++++++++------------- tests/test_multidb/test_config.py | 12 +++---- 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 78ce039868..b565a65c96 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -20,10 +20,12 @@ class MultiDBClient(RedisModuleCommands, CoreCommands, SentinelCommands): """ def __init__(self, config: MultiDbConfig): self._databases = config.databases() - self._health_checks = config.health_checks + self._health_checks = config.default_health_checks() if config.health_checks is None else config.health_checks self._health_check_interval = config.health_check_interval - self._failure_detectors = config.failure_detectors - self._failover_strategy = config.failover_strategy + self._failure_detectors = config.default_failure_detectors() \ + if config.failure_detectors is None else config.failure_detectors + self._failover_strategy = config.default_failover_strategy() \ + if config.failover_strategy is None else config.failover_strategy self._failover_strategy.set_databases(self._databases) self._auto_fallback_interval = config.auto_fallback_interval self._event_dispatcher = config.event_dispatcher diff --git a/redis/multidb/config.py b/redis/multidb/config.py index a349409e9f..06bd49aa44 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -2,10 +2,11 @@ from typing import List, Type, Union import pybreaker +from typing_extensions import Optional from redis import Redis, Sentinel from redis.asyncio import RedisCluster -from redis.backoff import ExponentialWithJitterBackoff +from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter @@ -25,25 +26,6 @@ DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) DEFAULT_AUTO_FALLBACK_INTERVAL = -1 -def default_health_checks() -> List[HealthCheck]: - return [ - EchoHealthCheck(retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF)), - ] - -def default_failure_detectors() -> List[FailureDetector]: - return [ - CommandFailureDetector(threshold=DEFAULT_FAILURES_THRESHOLD, duration=DEFAULT_FAILURES_DURATION), - ] - -def default_failover_strategy() -> FailoverStrategy: - return WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) - -def default_circuit_breaker() -> CircuitBreaker: - circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=DEFAULT_GRACE_PERIOD) - return PBCircuitBreakerAdapter(circuit_breaker) - def default_event_dispatcher() -> EventDispatcherInterface: return EventDispatcher() @@ -51,16 +33,27 @@ def default_event_dispatcher() -> EventDispatcherInterface: class DatabaseConfig: weight: float client_kwargs: dict = field(default_factory=dict) - circuit: CircuitBreaker = field(default_factory=default_circuit_breaker) + circuit: Optional[CircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + + def default_circuit_breaker(self) -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return PBCircuitBreakerAdapter(circuit_breaker) @dataclass class MultiDbConfig: databases_config: List[DatabaseConfig] client_class: Type[Union[Redis, RedisCluster, Sentinel]] = Redis - failure_detectors: List[FailureDetector] = field(default_factory=default_failure_detectors) - health_checks: List[HealthCheck] = field(default_factory=default_health_checks) + failure_detectors: Optional[List[FailureDetector]] = None + failure_threshold: int = DEFAULT_FAILURES_THRESHOLD + failures_interval: float = DEFAULT_FAILURES_DURATION + health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL - failover_strategy: FailoverStrategy = field(default_factory=default_failover_strategy) + health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES + health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + failover_strategy: Optional[FailoverStrategy] = None + failover_retries: int = DEFAULT_FAILOVER_RETRIES + failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) @@ -69,10 +62,26 @@ def databases(self) -> Databases: for database_config in self.databases_config: client = self.client_class(**database_config.client_kwargs) + circuit = database_config.default_circuit_breaker() \ + if database_config.circuit is None else database_config.circuit databases.add( - Database(client=client, circuit=database_config.circuit, weight=database_config.weight), + Database(client=client, circuit=circuit, weight=database_config.weight), database_config.weight ) return databases + def default_failure_detectors(self) -> List[FailureDetector]: + return [ + CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + ] + + def default_failover_strategy(self) -> FailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + ) diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index a810eea676..f5e10591e7 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -32,12 +32,12 @@ def test_default_config(self): assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD i+=1 - assert len(config.failure_detectors) == 1 - assert isinstance(config.failure_detectors[0], CommandFailureDetector) - assert len(config.health_checks) == 1 - assert isinstance(config.health_checks[0], EchoHealthCheck) + assert len(config.default_failure_detectors()) == 1 + assert isinstance(config.default_failure_detectors()[0], CommandFailureDetector) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL - assert isinstance(config.failover_strategy, WeightBasedFailoverStrategy) + assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL def test_overridden_config(self): @@ -106,7 +106,7 @@ def test_default_config(self): assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} assert config.weight == 1.0 - assert isinstance(config.circuit, PBCircuitBreakerAdapter) + assert isinstance(config.default_circuit_breaker(), PBCircuitBreakerAdapter) def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) From 708682227b29ba5d977fb45900944c44aaafe84b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 18 Jul 2025 09:27:06 +0300 Subject: [PATCH 21/22] Refactored failure detector --- redis/event.py | 6 -- redis/multidb/client.py | 4 + redis/multidb/command_executor.py | 13 +-- redis/multidb/event.py | 16 +--- redis/multidb/failure_detector.py | 21 +++-- redis/multidb/healthcheck.py | 11 ++- tests/test_multidb/test_client.py | 1 - tests/test_multidb/test_command_executor.py | 3 +- tests/test_multidb/test_failure_detector.py | 87 ++++++++++++--------- 9 files changed, 84 insertions(+), 78 deletions(-) diff --git a/redis/event.py b/redis/event.py index fdb42a04d5..03480364db 100644 --- a/redis/event.py +++ b/redis/event.py @@ -259,11 +259,9 @@ def __init__( self, command: tuple, exception: Exception, - client, ): self._command = command self._exception = exception - self._client = client @property def command(self) -> tuple: @@ -273,10 +271,6 @@ def command(self) -> tuple: def exception(self) -> Exception: return self._exception - @property - def client(self): - return self._client - class ReAuthConnectionListener(EventListenerInterface): """ Listener that performs re-authentication of given connection. diff --git a/redis/multidb/client.py b/redis/multidb/client.py index b565a65c96..ebf4a52f4b 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -36,6 +36,10 @@ def __init__(self, config: MultiDbConfig): event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, ) + + for fd in self._failure_detectors: + fd.set_command_executor(command_executor=self._command_executor) + self._initialized = False self._hc_lock = threading.RLock() self._bg_scheduler = BackgroundScheduler() diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 60dbeca36b..d2802f8528 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,9 +1,7 @@ -import socket from abc import ABC, abstractmethod from datetime import datetime, timedelta from typing import List, Union, Optional -from redis.exceptions import ConnectionError, TimeoutError from redis.event import EventDispatcherInterface, OnCommandFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases @@ -139,14 +137,7 @@ def execute_command(self, *args, **options): self._active_database = self._failover_strategy.database self._schedule_next_fallback() - try: - return self._active_database.client.execute_command(*args, **options) - except (ConnectionError, TimeoutError, socket.timeout) as e: - # Register command failure - self._event_dispatcher.dispatch(OnCommandFailEvent(args, e, self.active_database.client)) - - # Retry until failure detector will trigger opening of circuit - return self.execute_command(*args, **options) + return self._active_database.client.execute_command(*args, **options) def _schedule_next_fallback(self) -> None: if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: @@ -158,7 +149,7 @@ def _setup_event_dispatcher(self): """ Registers command failure event listener. """ - event_listener = RegisterCommandFailure(self._failure_detectors, self._databases) + event_listener = RegisterCommandFailure(self._failure_detectors) self._event_dispatcher.register_listeners({ OnCommandFailEvent: [event_listener], }) \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 3d366dab77..3a5ed3ec24 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,7 +1,6 @@ from typing import List from redis.event import EventListenerInterface, OnCommandFailEvent -from redis.multidb.config import Databases from redis.multidb.failure_detector import FailureDetector @@ -9,20 +8,9 @@ class RegisterCommandFailure(EventListenerInterface): """ Event listener that registers command failures and passing it to the failure detectors. """ - def __init__(self, failure_detectors: List[FailureDetector], databases: Databases): + def __init__(self, failure_detectors: List[FailureDetector]): self._failure_detectors = failure_detectors - self._databases = databases def listen(self, event: OnCommandFailEvent) -> None: - matching_database = None - - for database, _ in self._databases: - if event.client == database.client: - matching_database = database - break - - if matching_database is None: - return - for failure_detector in self._failure_detectors: - failure_detector.register_failure(matching_database, event.exception, event.command) + failure_detector.register_failure(event.exception, event.command) diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 7cb5d5db07..49b2bdd518 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -7,13 +7,19 @@ from redis.multidb.circuit import State as CBState + class FailureDetector(ABC): @abstractmethod - def register_failure(self, database, exception: Exception, cmd: tuple) -> None: + def register_failure(self, exception: Exception, cmd: tuple) -> None: """Register a failure that occurred during command execution.""" pass + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + class CommandFailureDetector(FailureDetector): """ Detects a failure based on a threshold of failed commands during a specific period of time. @@ -30,6 +36,7 @@ def __init__( :param duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. :param error_types: List of exception that has to be registered. By default, all exceptions are registered. """ + self._command_executor = None self._threshold = threshold self._duration = duration self._error_types = error_types @@ -38,7 +45,7 @@ def __init__( self._failures_within_duration: List[tuple[datetime, tuple]] = [] self._lock = threading.RLock() - def register_failure(self, database, exception: Exception, cmd: tuple) -> None: + def register_failure(self, exception: Exception, cmd: tuple) -> None: failure_time = datetime.now() if not self._start_time < failure_time < self._end_time: @@ -51,12 +58,16 @@ def register_failure(self, database, exception: Exception, cmd: tuple) -> None: else: self._failures_within_duration.append((datetime.now(), cmd)) - self._check_threshold(database) + self._check_threshold() + + def set_command_executor(self, command_executor) -> None: + self._command_executor = command_executor - def _check_threshold(self, database): + def _check_threshold(self): with self._lock: if len(self._failures_within_duration) >= self._threshold: - database.circuit.state = CBState.OPEN + if self._command_executor and self._command_executor.active_database: + self._command_executor.active_database.circuit.state = CBState.OPEN self._reset() def _reset(self) -> None: diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index a96b9cf815..4d8ec5d395 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -43,10 +43,13 @@ def __init__( retry=retry, ) def check_health(self, database) -> bool: - return self._retry.call_with_retry( - lambda: self._returns_echoed_message(database), - lambda _: self._dummy_fail() - ) + try: + return self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: self._dummy_fail() + ) + except Exception: + return False def _returns_echoed_message(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index c2ade264b5..b94c4ce61e 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -458,7 +458,6 @@ def test_add_new_failure_detector( command_fail_event = OnCommandFailEvent( command=('SET', 'key', 'value'), exception=Exception(), - client=mock_db1.client ) for hc in mock_multi_db_config.health_checks: diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 54c6d38e1d..e0790c3635 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -142,7 +142,6 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( command_fail_event = OnCommandFailEvent( command=('SET', 'key', 'value'), exception=Exception(), - client=mock_db1.client ) executor = DefaultCommandExecutor( @@ -152,6 +151,7 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( event_dispatcher=ed, auto_fallback_interval=0.1, ) + fd.set_command_executor(command_executor=executor) assert executor.execute_command('SET', 'key', 'value') == 'OK1' @@ -164,7 +164,6 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( command_fail_event = OnCommandFailEvent( command=('SET', 'key', 'value'), exception=Exception(), - client=mock_db2.client ) for i in range(threshold): diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index 8e0c1bcbad..86d6e1cd82 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -1,7 +1,9 @@ from time import sleep +from unittest.mock import Mock import pytest +from redis.multidb.command_executor import CommandExecutor from redis.multidb.failure_detector import CommandFailureDetector from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -17,13 +19,16 @@ class TestCommandFailureDetector: ) def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.OPEN @@ -36,12 +41,15 @@ def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exce ) def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.CLOSED @@ -54,25 +62,28 @@ def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interv ) def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): fd = CommandFailureDetector(5, 0.3) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) sleep(0.1) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.CLOSED # 4 more failure as last one already refreshed timer - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.OPEN @@ -85,23 +96,26 @@ def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_e ) def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): fd = CommandFailureDetector(5, 0.3) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) sleep(0.4) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.OPEN @@ -114,18 +128,21 @@ def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): ) def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file From 2561d6f6ae655331a54b9056c9522e3bbad4eb0d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 18 Jul 2025 10:19:45 +0300 Subject: [PATCH 22/22] Refactored retry logic --- redis/multidb/client.py | 1 + redis/multidb/command_executor.py | 34 ++++++++++++++++-- redis/multidb/config.py | 10 +++++- tests/test_multidb/test_command_executor.py | 39 +++++++-------------- tests/test_multidb/test_config.py | 3 ++ 5 files changed, 58 insertions(+), 29 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index ebf4a52f4b..d1e9438624 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -32,6 +32,7 @@ def __init__(self, config: MultiDbConfig): self._command_executor = DefaultCommandExecutor( failure_detectors=self._failure_detectors, databases=self._databases, + command_retry=config.command_retry, failover_strategy=self._failover_strategy, event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index d2802f8528..478bed39d3 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta from typing import List, Union, Optional +from redis.backoff import NoBackoff from redis.event import EventDispatcherInterface, OnCommandFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases @@ -9,6 +10,7 @@ from redis.multidb.event import RegisterCommandFailure from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector +from redis.retry import Retry class CommandExecutor(ABC): @@ -60,6 +62,12 @@ def auto_fallback_interval(self, auto_fallback_interval: float) -> None: """Sets auto-fallback interval.""" pass + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + @abstractmethod def execute_command(self, *args, **options): """Executes a command and returns the result.""" @@ -72,6 +80,7 @@ def __init__( self, failure_detectors: List[FailureDetector], databases: Databases, + command_retry: Retry, failover_strategy: FailoverStrategy, event_dispatcher: EventDispatcherInterface, auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, @@ -86,6 +95,7 @@ def __init__( """ self._failure_detectors = failure_detectors self._databases = databases + self._command_retry = command_retry self._failover_strategy = failover_strategy self._event_dispatcher = event_dispatcher self._auto_fallback_interval = auto_fallback_interval @@ -105,6 +115,10 @@ def add_failure_detector(self, failure_detector: FailureDetector) -> None: def databases(self) -> Databases: return self._databases + @property + def command_retry(self) -> Retry: + return self._command_retry + @property def active_database(self) -> Optional[AbstractDatabase]: return self._active_database @@ -126,6 +140,24 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None: self._auto_fallback_interval = auto_fallback_interval def execute_command(self, *args, **options): + self._check_active_database() + + return self._command_retry.call_with_retry( + lambda: self._execute_command(*args, **options), + lambda error: self._on_command_fail(error, *args), + ) + + def _execute_command(self, *args, **options): + self._check_active_database() + return self._active_database.client.execute_command(*args, **options) + + def _on_command_fail(self, error, *args): + self._event_dispatcher.dispatch(OnCommandFailEvent(args, error)) + + def _check_active_database(self): + """ + Checks if active database need to be updated. + """ if ( self._active_database is None or self._active_database.circuit.state != CBState.CLOSED @@ -137,8 +169,6 @@ def execute_command(self, *args, **options): self._active_database = self._failover_strategy.database self._schedule_next_fallback() - return self._active_database.client.execute_command(*args, **options) - def _schedule_next_fallback(self) -> None: if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: return diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 06bd49aa44..d5ba0f864a 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -6,7 +6,7 @@ from redis import Redis, Sentinel from redis.asyncio import RedisCluster -from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff +from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter @@ -44,6 +44,9 @@ def default_circuit_breaker(self) -> CircuitBreaker: class MultiDbConfig: databases_config: List[DatabaseConfig] client_class: Type[Union[Redis, RedisCluster, Sentinel]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) failure_detectors: Optional[List[FailureDetector]] = None failure_threshold: int = DEFAULT_FAILURES_THRESHOLD failures_interval: float = DEFAULT_FAILURES_DURATION @@ -61,6 +64,11 @@ def databases(self) -> Databases: databases = WeightedList() for database_config in self.databases_config: + if database_config.client_kwargs.get("retry", None) is not None: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + client = self.client_class(**database_config.client_kwargs) circuit = database_config.default_circuit_breaker() \ if database_config.circuit is None else database_config.circuit diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index e0790c3635..675f9d442f 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -3,10 +3,13 @@ import pytest -from redis.event import EventDispatcher, OnCommandFailEvent +from redis.exceptions import ConnectionError +from redis.backoff import NoBackoff +from redis.event import EventDispatcher from redis.multidb.circuit import State as CBState from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.failure_detector import CommandFailureDetector +from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list @@ -31,7 +34,8 @@ def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, m failure_detectors=[mock_fd], databases=databases, failover_strategy=mock_fs, - event_dispatcher=mock_ed + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) ) executor.active_database = mock_db1 @@ -65,7 +69,8 @@ def test_execute_command_automatically_select_active_database( failure_detectors=[mock_fd], databases=databases, failover_strategy=mock_fs, - event_dispatcher=mock_ed + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) ) assert executor.execute_command('SET', 'key', 'value') == 'OK1' @@ -101,6 +106,7 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( failover_strategy=mock_fs, event_dispatcher=mock_ed, auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0) ) assert executor.execute_command('SET', 'key', 'value') == 'OK1' @@ -129,45 +135,26 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( def test_execute_command_fallback_to_another_db_after_failure_detection( self, mock_db, mock_db1, mock_db2, mock_fs ): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + mock_db1.client.execute_command.side_effect = ['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1'] + mock_db2.client.execute_command.side_effect = ['OK2', ConnectionError, ConnectionError, ConnectionError] mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) type(mock_fs).database = mock_selector - threshold = 5 + threshold = 3 fd = CommandFailureDetector(threshold, 1) ed = EventDispatcher() databases = create_weighted_list(mock_db, mock_db1, mock_db2) - # Event fired if command against mock_db1 would fail - command_fail_event = OnCommandFailEvent( - command=('SET', 'key', 'value'), - exception=Exception(), - ) - executor = DefaultCommandExecutor( failure_detectors=[fd], databases=databases, failover_strategy=mock_fs, event_dispatcher=ed, auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), ) fd.set_command_executor(command_executor=executor) assert executor.execute_command('SET', 'key', 'value') == 'OK1' - - # Simulate failing command events that lead to a failure detection - for i in range(threshold): - ed.dispatch(command_fail_event) - assert executor.execute_command('SET', 'key', 'value') == 'OK2' - - command_fail_event = OnCommandFailEvent( - command=('SET', 'key', 'value'), - exception=Exception(), - ) - - for i in range(threshold): - ed.dispatch(command_fail_event) - assert executor.execute_command('SET', 'key', 'value') == 'OK1' assert mock_selector.call_count == 3 \ No newline at end of file diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index f5e10591e7..87aae701a9 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -7,6 +7,7 @@ from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck from redis.multidb.failover import WeightBasedFailoverStrategy, FailoverStrategy +from redis.retry import Retry class TestMultiDbConfig: @@ -30,6 +31,7 @@ def test_default_config(self): assert isinstance(db, Database) assert weight == db_configs[i].weight assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry i+=1 assert len(config.default_failure_detectors()) == 1 @@ -39,6 +41,7 @@ def test_default_config(self): assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) def test_overridden_config(self): grace_period = 2