diff --git a/redis/event.py b/redis/event.py index 03480364db..bb40c32f34 100644 --- a/redis/event.py +++ b/redis/event.py @@ -251,21 +251,21 @@ def nodes(self) -> dict: def credential_provider(self) -> Union[CredentialProvider, None]: return self._credential_provider -class OnCommandFailEvent: +class OnCommandsFailEvent: """ Event fired whenever a command fails during the execution. """ def __init__( self, - command: tuple, + commands: tuple, exception: Exception, ): - self._command = command + self._commands = commands self._exception = exception @property def command(self) -> tuple: - return self._command + return self._commands @property def exception(self) -> Exception: diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 3002651486..85c719fc1a 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,6 +1,6 @@ import threading import socket -from typing import Callable +from typing import List, Any, Callable from redis.background import BackgroundScheduler from redis.exceptions import ConnectionError, TimeoutError @@ -30,23 +30,22 @@ def __init__(self, config: MultiDbConfig): self._failover_strategy.set_databases(self._databases) self._auto_fallback_interval = config.auto_fallback_interval self._event_dispatcher = config.event_dispatcher - self._command_executor = DefaultCommandExecutor( + self._command_retry = config.command_retry + self._command_retry.update_supported_errors((ConnectionRefusedError,)) + self.command_executor = DefaultCommandExecutor( failure_detectors=self._failure_detectors, databases=self._databases, - command_retry=config.command_retry, + command_retry=self._command_retry, failover_strategy=self._failover_strategy, event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, ) - - for fd in self._failure_detectors: - fd.set_command_executor(command_executor=self._command_executor) - - self._initialized = False + self.initialized = False self._hc_lock = threading.RLock() self._bg_scheduler = BackgroundScheduler() + self._config = config - def _initialize(self): + def initialize(self): """ Perform initialization of databases to define their initial state. """ @@ -72,7 +71,7 @@ def raise_exception_on_failed_hc(error): # Set states according to a weights and circuit state if database.circuit.state == CBState.CLOSED and not is_active_db_found: database.state = DBState.ACTIVE - self._command_executor.active_database = database + self.command_executor.active_database = database is_active_db_found = True elif database.circuit.state == CBState.CLOSED and is_active_db_found: database.state = DBState.PASSIVE @@ -82,7 +81,7 @@ def raise_exception_on_failed_hc(error): if not is_active_db_found: raise NoValidDatabaseException('Initial connection failed - no active database found') - self._initialized = True + self.initialized = True def get_databases(self) -> Databases: """ @@ -110,7 +109,7 @@ def set_active_database(self, database: AbstractDatabase) -> None: highest_weighted_db, _ = self._databases.get_top_n(1)[0] highest_weighted_db.state = DBState.PASSIVE database.state = DBState.ACTIVE - self._command_executor.active_database = database + self.command_executor.active_database = database return raise NoValidDatabaseException('Cannot set active database, database is unhealthy') @@ -132,7 +131,7 @@ def add_database(self, database: AbstractDatabase): def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: new_database.state = DBState.ACTIVE - self._command_executor.active_database = new_database + self.command_executor.active_database = new_database highest_weight_database.state = DBState.PASSIVE def remove_database(self, database: Database): @@ -144,7 +143,7 @@ def remove_database(self, database: Database): if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: highest_weighted_db.state = DBState.ACTIVE - self._command_executor.active_database = highest_weighted_db + self.command_executor.active_database = highest_weighted_db def update_database_weight(self, database: AbstractDatabase, weight: float): """ @@ -182,10 +181,25 @@ def execute_command(self, *args, **options): """ Executes a single command and return its result. """ - if not self._initialized: - self._initialize() + if not self.initialized: + self.initialize() + + return self.command_executor.execute_command(*args, **options) + + def pipeline(self): + """ + Enters into pipeline mode of the client. + """ + return Pipeline(self) - return self._command_executor.execute_command(*args, **options) + def transaction(self, func: Callable[["Pipeline"], None], *watches, **options): + """ + Executes callable as transaction. + """ + if not self.initialized: + self.initialize() + + return self.command_executor.execute_transaction(func, *watches, *options) def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: """ @@ -207,7 +221,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout) as e: + except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError) as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False @@ -219,7 +233,9 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep def _check_databases_health(self, on_error: Callable[[Exception], None] = None): """ Runs health checks as a recurring task. + Runs health checks against all databases. """ + for database, _ in self._databases: self._check_db_health(database, on_error) @@ -232,4 +248,66 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) def _half_open_circuit(circuit: CircuitBreaker): - circuit.state = CBState.HALF_OPEN \ No newline at end of file + circuit.state = CBState.HALF_OPEN + + +class Pipeline(RedisModuleCommands, CoreCommands, SentinelCommands): + """ + Pipeline implementation for multiple logical Redis databases. + """ + def __init__(self, client: MultiDBClient): + self._command_stack = [] + self._client = client + + def __enter__(self) -> "Pipeline": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self) -> int: + return len(self._command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + def reset(self) -> None: + self._command_stack = [] + + def close(self) -> None: + """Close the pipeline""" + self.reset() + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._command_stack.append((args, options)) + return self + + def execute_command(self, *args, **kwargs): + return self.pipeline_execute_command(*args, **kwargs) + + def execute(self) -> List[Any]: + if not self._client.initialized: + self._client.initialize() + + try: + return self._client.command_executor.execute_pipeline(tuple(self._command_stack)) + finally: + self.reset() diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 0783f6da82..690ea49a5c 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import List, Union, Optional, Callable -from redis.event import EventDispatcherInterface, OnCommandFailEvent +from redis.client import Pipeline +from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.circuit import State as CBState @@ -92,6 +93,9 @@ def __init__( :param auto_fallback_interval: Interval between fallback attempts. Fallback to a new database according to failover_strategy. """ + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + self._failure_detectors = failure_detectors self._databases = databases self._command_retry = command_retry @@ -139,19 +143,49 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None: self._auto_fallback_interval = auto_fallback_interval def execute_command(self, *args, **options): - self._check_active_database() + def callback(): + return self._active_database.client.execute_command(*args, **options) + + return self._execute_with_failure_detection(callback, args) + + def execute_pipeline(self, command_stack: tuple): + """ + Executes a stack of commands in pipeline. + """ + def callback(): + with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + pipe.execute_command(*command, **options) + + return pipe.execute() + + return self._execute_with_failure_detection(callback, command_stack) + + def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """ + Executes a transaction block wrapped in callback. + """ + def callback(): + return self._active_database.client.transaction(transaction, *watches, **options) + + return self._execute_with_failure_detection(callback) + + def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + """ + Execute a commands execution callback with failure detection. + """ + def wrapper(): + # On each retry we need to check active database as it might change. + self._check_active_database() + return callback() return self._command_retry.call_with_retry( - lambda: self._execute_command(*args, **options), - lambda error: self._on_command_fail(error, *args), + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), ) - def _execute_command(self, *args, **options): - self._check_active_database() - return self._active_database.client.execute_command(*args, **options) - def _on_command_fail(self, error, *args): - self._event_dispatcher.dispatch(OnCommandFailEvent(args, error)) + self._event_dispatcher.dispatch(OnCommandsFailEvent(args, error)) def _check_active_database(self): """ @@ -180,5 +214,5 @@ def _setup_event_dispatcher(self): """ event_listener = RegisterCommandFailure(self._failure_detectors) self._event_dispatcher.register_listeners({ - OnCommandFailEvent: [event_listener], + OnCommandsFailEvent: [event_listener], }) \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 3a5ed3ec24..e86ee15358 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,6 +1,6 @@ from typing import List -from redis.event import EventListenerInterface, OnCommandFailEvent +from redis.event import EventListenerInterface, OnCommandsFailEvent from redis.multidb.failure_detector import FailureDetector @@ -11,6 +11,6 @@ class RegisterCommandFailure(EventListenerInterface): def __init__(self, failure_detectors: List[FailureDetector]): self._failure_detectors = failure_detectors - def listen(self, event: OnCommandFailEvent) -> None: + def listen(self, event: OnCommandsFailEvent) -> None: for failure_detector in self._failure_detectors: failure_detector.register_failure(event.exception, event.command) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index b94c4ce61e..cf3877957f 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -4,7 +4,7 @@ import pybreaker import pytest -from redis.event import EventDispatcher, OnCommandFailEvent +from redis.event import EventDispatcher, OnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF @@ -455,8 +455,8 @@ def test_add_new_failure_detector( mock_fd = mock_multi_db_config.failure_detectors[0] # Event fired if command against mock_db1 would fail - command_fail_event = OnCommandFailEvent( - command=('SET', 'key', 'value'), + command_fail_event = OnCommandsFailEvent( + commands=('SET', 'key', 'value'), exception=Exception(), ) diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 675f9d442f..3661294966 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -152,7 +152,6 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( auto_fallback_interval=0.1, command_retry=Retry(NoBackoff(), threshold), ) - fd.set_command_executor(command_executor=executor) assert executor.execute_command('SET', 'key', 'value') == 'OK1' assert executor.execute_command('SET', 'key', 'value') == 'OK2' diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py new file mode 100644 index 0000000000..9caad235df --- /dev/null +++ b/tests/test_multidb/test_pipeline.py @@ -0,0 +1,352 @@ +from time import sleep +from unittest.mock import patch, Mock + +import pybreaker +import pytest + +from redis.event import EventDispatcher +from redis.exceptions import ConnectionError +from redis.client import Pipeline +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.client import MultiDBClient +from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF, DEFAULT_FAILURES_THRESHOLD +from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.healthcheck import EchoHealthCheck +from redis.retry import Retry +from tests.test_multidb.conftest import create_weighted_list + +def mock_pipe() -> Pipeline: + mock_pipe = Mock(spec=Pipeline) + mock_pipe.__enter__ = Mock(return_value=mock_pipe) + mock_pipe.__exit__ = Mock(return_value=None) + return mock_pipe + +class TestPipeline: + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_executes_pipeline_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + pipe = client.pipeline() + pipe.set('key1', 'value1') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_pipeline_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + for hc in mock_multi_db_config.health_checks: + hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with client.pipeline() as pipe: + pipe.set('key1', 'value1') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + pipe = mock_pipe() + pipe.execute.return_value = ['OK', 'value'] + mock_db.client.pipeline.return_value = pipe + + pipe1 = mock_pipe() + pipe1.execute.return_value = ['OK1', 'value'] + mock_db1.client.pipeline.return_value = pipe1 + + pipe2 = mock_pipe() + pipe2.execute.return_value = ['OK2', 'value'] + mock_db2.client.pipeline.return_value = pipe2 + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value'] + + sleep(0.15) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK2', 'value'] + + sleep(0.1) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK', 'value'] + + sleep(0.1) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value'] + +class TestTransaction: + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_executes_transaction_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert client.transaction(callback) == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_transaction_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert client.transaction(callback) == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + mock_db.client.transaction.return_value = ['OK', 'value'] + mock_db1.client.transaction.return_value = ['OK1', 'value'] + mock_db2.client.transaction.return_value = ['OK2', 'value'] + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert client.transaction(callback) == ['OK1', 'value'] + sleep(0.15) + assert client.transaction(callback) == ['OK2', 'value'] + sleep(0.1) + assert client.transaction(callback) == ['OK', 'value'] + sleep(0.1) + assert client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 93e251ed4b..2b9bfc7e74 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -5,6 +5,7 @@ import pytest from redis.backoff import NoBackoff +from redis.client import Pipeline from redis.exceptions import ConnectionError from redis.retry import Retry from tests.test_scenario.conftest import get_endpoint_config @@ -16,7 +17,7 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event endpoint_config = get_endpoint_config('re-active-active') action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 3, "cluster_index": 0} + parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 1, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) @@ -33,6 +34,11 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event logger.info(f"Action completed. Status: {status_result['status']}") class TestActiveActiveStandalone: + + def teardown_method(self, method): + # Timeout so the cluster could recover from network failure. + sleep(3) + @pytest.mark.parametrize( "r_multi_db", [ @@ -47,19 +53,16 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector daemon=True, args=(fault_injector_client,event) ) - thread.start() + # Client initialized on the first command. r_multi_db.set('key', 'value') - current_active_db = r_multi_db._command_executor.active_database + thread.start() # Execute commands before network failure while not event.is_set(): assert r_multi_db.get('key') == 'value' sleep(0.1) - # Active db has been changed. - assert current_active_db != r_multi_db._command_executor.active_database - # Execute commands after network failure for _ in range(3): assert r_multi_db.get('key') == 'value' @@ -68,25 +71,136 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector @pytest.mark.parametrize( "r_multi_db", [ - { - "failure_threshold": 15, - "command_retry": Retry(NoBackoff(), retries=5), - "health_check_interval": 100, - } + {"failure_threshold": 2} + ], + indirect=True + ) + def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + # Client initialized on first pipe execution. + with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + # Client initialized on first pipe execution. + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} ], indirect=True ) - def test_multi_db_client_throws_error_on_retry_exceed(self, r_multi_db, fault_injector_client): + def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, args=(fault_injector_client,event) ) + + def callback(pipe: Pipeline): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + + # Client initialized on first transaction execution. + r_multi_db.transaction(callback) thread.start() - with pytest.raises(ConnectionError): - # Retries count > failure threshold, so a client gives up earlier. - while not event.is_set(): - assert r_multi_db.get('key') == 'value' - sleep(0.1) \ No newline at end of file + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.transaction(callback) + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + r_multi_db.transaction(callback) + sleep(0.1) \ No newline at end of file