diff --git a/redis/connection.py b/redis/connection.py index 006fa50b33..d49a8dc45c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -35,7 +35,8 @@ from .auth.token import TokenInterface from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider -from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent +from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \ + AfterConnectionCreatedEvent from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -53,6 +54,8 @@ MaintNotificationsConnectionHandler, MaintNotificationsPoolHandler, MaintenanceNotification, ) +from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \ + DB_CLIENT_CONNECTION_POOL_NAME from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -2060,6 +2063,13 @@ def set_retry(self, retry: Retry): def re_auth_callback(self, token: TokenInterface): pass + @abstractmethod + def get_connection_count(self) -> list[tuple[int, dict]]: + """ + Returns a connection count (both idle and in use). + """ + pass + class MaintNotificationsAbstractConnectionPool: """ @@ -2635,11 +2645,17 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": "Get a connection from the pool" self._checkpid() + is_created = False + with self._lock: try: connection = self._available_connections.pop() except IndexError: + # Start timing for observability + start_time = time.monotonic() + connection = self.make_connection() + is_created = True self._in_use_connections.add(connection) try: @@ -2666,6 +2682,14 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # leak it self.release(connection) raise + + if is_created: + self._event_dispatcher.dispatch( + AfterConnectionCreatedEvent( + connection_pool=self, + duration_seconds=time.monotonic() - start_time, + ) + ) return connection def get_encoder(self) -> Encoder: @@ -2785,6 +2809,20 @@ async def _mock(self, error: RedisError): """ pass + def get_connection_count(self) -> list[tuple[int, dict]]: + attributes = AttributeBuilder.build_base_attributes() + attributes[DB_CLIENT_CONNECTION_POOL_NAME] = repr(self) + free_connections_attributes = attributes.copy() + in_use_connections_attributes = attributes.copy() + + free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ConnectionState.IDLE.value + in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ConnectionState.USED.value + + return [ + (len(self._available_connections), free_connections_attributes), + (len(self._in_use_connections), in_use_connections_attributes), + ] + class BlockingConnectionPool(ConnectionPool): """ @@ -2917,6 +2955,7 @@ def get_connection(self, command_name=None, *keys, **options): """ # Make sure we haven't changed process. self._checkpid() + is_created = False # Try and get a connection from the pool. If one isn't available within # self.timeout then raise a ``ConnectionError``. @@ -2935,7 +2974,10 @@ def get_connection(self, command_name=None, *keys, **options): # If the ``connection`` is actually ``None`` then that's a cue to make # a new connection to add to the pool. if connection is None: + # Start timing for observability + start_time = time.monotonic() connection = self.make_connection() + is_created = True finally: if self._locked: try: @@ -2964,6 +3006,14 @@ def get_connection(self, command_name=None, *keys, **options): self.release(connection) raise + if is_created: + self._event_dispatcher.dispatch( + AfterConnectionCreatedEvent( + connection_pool=self, + duration_seconds=time.monotonic() - start_time, + ) + ) + return connection def release(self, connection): diff --git a/redis/event.py b/redis/event.py index 8286e0066e..86b529aaf0 100644 --- a/redis/event.py +++ b/redis/event.py @@ -3,11 +3,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Type, Union +from typing import Dict, List, Optional, Type, Union, Callable from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider -from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count +from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \ + record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff class EventListenerInterface(ABC): @@ -85,7 +86,8 @@ def __init__( ReAuthConnectionListener(), ], AfterPooledConnectionsInstantiationEvent: [ - RegisterReAuthForPooledConnections() + RegisterReAuthForPooledConnections(), + InitializeConnectionCountObservability() ], AfterSingleConnectionInstantiationEvent: [ RegisterReAuthForSingleConnection() @@ -97,6 +99,16 @@ def __init__( AsyncReAuthConnectionListener(), ], OnErrorEvent: [ExportErrorCountMetric()], + OnMaintenanceNotificationEvent: [ + ExportMaintenanceNotificationCountMetric(), + ], + AfterConnectionCreatedEvent: [ExportConnectionCreateTimeMetric()], + AfterConnectionTimeoutRelaxedEvent: [ + ExportConnectionRelaxedTimeoutMetric(), + ], + AfterConnectionHandoffEvent: [ + ExportConnectionHandoffMetric(), + ], } self._lock = threading.Lock() @@ -333,6 +345,30 @@ class OnMaintenanceNotificationEvent: notification: "MaintenanceNotification" connection: "MaintNotificationsAbstractConnection" +@dataclass +class AfterConnectionCreatedEvent: + """ + Event fired after connection is created in pool. + """ + connection_pool: "ConnectionPoolInterface" + duration_seconds: float + +@dataclass +class AfterConnectionTimeoutRelaxedEvent: + """ + Event fired after connection timeout is relaxed. + """ + connection: "MaintNotificationsAbstractConnection" + notification: "MaintenanceNotification" + relaxed: bool + +@dataclass +class AfterConnectionHandoffEvent: + """ + Event fired after connection is handed off. + """ + connection_pool: "ConnectionPoolInterface" + class AsyncOnCommandsFailEvent(OnCommandsFailEvent): pass @@ -547,4 +583,41 @@ def listen(self, event: OnMaintenanceNotificationEvent): network_peer_address=event.connection.host, network_peer_port=event.connection.port, maint_notification=repr(event.notification), - ) \ No newline at end of file + ) + +class ExportConnectionCreateTimeMetric(EventListenerInterface): + """ + Listener that exports connection create time metric. + """ + def listen(self, event: AfterConnectionCreatedEvent): + record_connection_create_time( + connection_pool=event.connection_pool, + duration_seconds=event.duration_seconds, + ) + +class InitializeConnectionCountObservability(EventListenerInterface): + """ + Listener that initializes connection count observability. + """ + def listen(self, event: AfterPooledConnectionsInstantiationEvent): + init_connection_count(event.connection_pools) + +class ExportConnectionRelaxedTimeoutMetric(EventListenerInterface): + """ + Listener that exports connection relaxed timeout metric. + """ + def listen(self, event: AfterConnectionTimeoutRelaxedEvent): + record_connection_relaxed_timeout( + connection_name=repr(event.connection), + maint_notification=repr(event.notification), + relaxed=event.relaxed, + ) + +class ExportConnectionHandoffMetric(EventListenerInterface): + """ + Listener that exports connection handoff metric. + """ + def listen(self, event: AfterConnectionHandoffEvent): + record_connection_handoff( + pool_name=repr(event.connection_pool), + ) diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index be6144352c..0c64118855 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -7,7 +7,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Literal, Optional, Union -from redis.event import OnMaintenanceNotificationEvent +from redis.event import OnMaintenanceNotificationEvent, EventDispatcherInterface, EventDispatcher, \ + AfterConnectionTimeoutRelaxedEvent, AfterConnectionHandoffEvent from redis.typing import Number @@ -560,6 +561,7 @@ def __init__( self, pool: "MaintNotificationsAbstractConnectionPool", config: MaintNotificationsConfig, + event_dispatcher: Optional[EventDispatcherInterface] = None, ) -> None: self.pool = pool self.config = config @@ -567,6 +569,11 @@ def __init__( self._lock = threading.RLock() self.connection = None + if event_dispatcher is not None: + self.event_dispatcher = event_dispatcher + else: + self.event_dispatcher = EventDispatcher() + def set_connection(self, connection: "MaintNotificationsAbstractConnection"): self.connection = connection @@ -683,6 +690,12 @@ def handle_node_moving_notification(self, notification: NodeMovingNotification): args=(notification,), ).start() + self.event_dispatcher.dispatch( + AfterConnectionHandoffEvent( + connection_pool=self.pool, + ) + ) + self._processed_notifications.add(notification) def run_proactive_reconnect(self, moving_address_src: Optional[str] = None): @@ -784,12 +797,12 @@ def handle_notification(self, notification: MaintenanceNotification): return if notification_type: - self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE) + self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE, notification=notification) else: - self.handle_maintenance_completed_notification() + self.handle_maintenance_completed_notification(notification=notification) def handle_maintenance_start_notification( - self, maintenance_state: MaintenanceState + self, maintenance_state: MaintenanceState, **kwargs ): if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -804,7 +817,16 @@ def handle_maintenance_start_notification( # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relaxed_timeout) - def handle_maintenance_completed_notification(self): + if kwargs.get('notification', None) is not None: + self.connection.event_dispatcher.dispatch( + AfterConnectionTimeoutRelaxedEvent( + connection=self.connection, + notification=kwargs.get('notification'), + relaxed=True, + ) + ) + + def handle_maintenance_completed_notification(self, **kwargs): # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -816,3 +838,12 @@ def handle_maintenance_completed_notification(self): # timeouts by providing -1 as the relaxed timeout self.connection.update_current_socket_timeout(-1) self.connection.maintenance_state = MaintenanceState.NONE + + if kwargs.get('notification', None) is not None: + self.connection.event_dispatcher.dispatch( + AfterConnectionTimeoutRelaxedEvent( + connection=self.connection, + notification=kwargs.get('notification'), + relaxed=False, + ) + ) diff --git a/redis/observability/attributes.py b/redis/observability/attributes.py index fc966a80bc..9493ff73da 100644 --- a/redis/observability/attributes.py +++ b/redis/observability/attributes.py @@ -31,6 +31,7 @@ # Connection pool attributes DB_CLIENT_CONNECTION_POOL_NAME = "db.client.connection.pool.name" DB_CLIENT_CONNECTION_STATE = "db.client.connection.state" +DB_CLIENT_CONNECTION_NAME = "db.client.connection.name" # Redis-specific attributes REDIS_CLIENT_LIBRARY = "redis.client.library" @@ -43,6 +44,7 @@ REDIS_CLIENT_PUBSUB_CHANNEL = "redis.client.pubsub.channel" REDIS_CLIENT_PUBSUB_SHARDED = "redis.client.pubsub.sharded" REDIS_CLIENT_ERROR_INTERNAL = "redis.client.errors.internal" +REDIS_CLIENT_ERROR_CATEGORY = "redis.client.errors.category" REDIS_CLIENT_STREAM_NAME = "redis.client.stream.name" REDIS_CLIENT_CONSUMER_GROUP = "redis.client.consumer_group" REDIS_CLIENT_CONSUMER_NAME = "redis.client.consumer_name" @@ -146,8 +148,9 @@ def build_operation_attributes( @staticmethod def build_connection_attributes( - pool_name: str, + pool_name: Optional[str] = None, connection_state: Optional[ConnectionState] = None, + connection_name: Optional[str] = None, is_pubsub: Optional[bool] = None, ) -> Dict[str, Any]: """ @@ -157,12 +160,15 @@ def build_connection_attributes( pool_name: Unique connection pool name connection_state: Connection state ('idle' or 'used') is_pubsub: Whether this is a PubSub connection + connection_name: Unique connection name Returns: Dictionary of connection pool attributes """ attrs: Dict[str, Any] = AttributeBuilder.build_base_attributes() - attrs[DB_CLIENT_CONNECTION_POOL_NAME] = pool_name + + if pool_name is not None: + attrs[DB_CLIENT_CONNECTION_POOL_NAME] = pool_name if connection_state is not None: attrs[DB_CLIENT_CONNECTION_STATE] = connection_state.value @@ -170,6 +176,9 @@ def build_connection_attributes( if is_pubsub is not None: attrs[REDIS_CLIENT_CONNECTION_PUBSUB] = is_pubsub + if connection_name is not None: + attrs[DB_CLIENT_CONNECTION_NAME] = connection_name + return attrs @staticmethod @@ -190,13 +199,18 @@ def build_error_attributes( attrs: Dict[str, Any] = {} if error_type is not None: - attrs[ERROR_TYPE] = AttributeBuilder.extract_error_type(error_type) + attrs[ERROR_TYPE] = error_type.__class__.__name__ if hasattr(error_type, "status_code") and error_type.status_code is not None: attrs[DB_RESPONSE_STATUS_CODE] = error_type.status_code else: attrs[DB_RESPONSE_STATUS_CODE] = "error" + if hasattr(error_type, "error_type") and error_type.error_type is not None: + attrs[REDIS_CLIENT_ERROR_CATEGORY] = error_type.error_type.value + else: + attrs[REDIS_CLIENT_ERROR_CATEGORY] = 'other' + if is_internal is not None: attrs[REDIS_CLIENT_ERROR_INTERNAL] = is_internal @@ -260,24 +274,6 @@ def build_streaming_attributes( return attrs - - @staticmethod - def extract_error_type(exception: Exception) -> str: - """ - Extract error type from an exception. - - Args: - exception: The exception that occurred - - Returns: - Error type string (exception class name) - """ - - if hasattr(exception, "error_type"): - return repr(exception) - else: - return f"other:{type(exception).__name__}" - @staticmethod def build_pool_name( server_address: str, diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py index cc54bca2ef..236947b99f 100644 --- a/redis/observability/metrics.py +++ b/redis/observability/metrics.py @@ -7,7 +7,7 @@ import logging import time -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Callable, List from redis.observability.attributes import AttributeBuilder, ConnectionState, REDIS_CLIENT_CONNECTION_NOTIFICATION, \ REDIS_CLIENT_CONNECTION_CLOSE_REASON, ERROR_TYPE, PubSubDirection @@ -54,6 +54,7 @@ def __init__(self, meter: Meter, config: OTelConfig): self.meter = meter self.config = config self.attr_builder = AttributeBuilder() + self.connection_count = None # Initialize enabled metric instruments @@ -93,12 +94,6 @@ def _init_resiliency_metrics(self) -> None: def _init_connection_basic_metrics(self) -> None: """Initialize basic connection metrics.""" - self.connection_count = self.meter.create_up_down_counter( - name="db.client.connection.count", - unit="{connections}", - description="Current connections by state (idle/used)", - ) - self.connection_create_time = self.meter.create_histogram( name="db.client.connection.create_time", unit="{seconds}", @@ -252,31 +247,25 @@ def record_maint_notification_count( attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] = maint_notification self.maintenance_notifications.add(1, attributes=attrs) - def record_connection_count( + def init_connection_count( self, - count: int, - pool_name: str, - state: ConnectionState, - is_pubsub: bool, + callback: Callable, ) -> None: """ - Record current connection count by state. + Initialize observable gauge for connection count metric. Args: - count: Increment/Decrement - pool_name: Connection pool name - state: Connection state ('idle' or 'used') - is_pubsub: Whether or not the connection is pubsub + callback: Callback function to retrieve connection count """ - if not hasattr(self, "connection_count"): + if not MetricGroup.CONNECTION_BASIC in self.config.metric_groups: return - attrs = self.attr_builder.build_connection_attributes( - pool_name=pool_name, - connection_state=state, - is_pubsub=is_pubsub, + self.connection_count = self.meter.create_observable_gauge( + name="db.client.connection.count", + unit="connections", + description="Number of connections in the pool", + callbacks=[callback], ) - self.connection_count.add(count, attributes=attrs) def record_connection_timeout(self, pool_name: str) -> None: """ @@ -293,20 +282,20 @@ def record_connection_timeout(self, pool_name: str) -> None: def record_connection_create_time( self, - pool_name: str, + connection_pool: "ConnectionPoolInterface", duration_seconds: float, ) -> None: """ Record time taken to create a new connection. Args: - pool_name: Connection pool name + connection_pool: Connection pool implementation duration_seconds: Creation time in seconds """ if not hasattr(self, "connection_create_time"): return - attrs = self.attr_builder.build_connection_attributes(pool_name=pool_name) + attrs = self.attr_builder.build_connection_attributes(pool_name=repr(connection_pool)) self.connection_create_time.record(duration_seconds, attributes=attrs) def record_connection_wait_time( @@ -430,13 +419,18 @@ def record_connection_closed( attrs = self.attr_builder.build_connection_attributes(pool_name=pool_name) if close_reason: attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] = close_reason - if error_type: - attrs[ERROR_TYPE] = AttributeBuilder.extract_error_type(error_type) + + attrs.update( + self.attr_builder.build_error_attributes( + error_type=error_type, + ) + ) + self.connection_closed.add(1, attributes=attrs) def record_connection_relaxed_timeout( self, - pool_name: str, + connection_name: str, maint_notification: str, relaxed: bool, ) -> None: @@ -444,14 +438,14 @@ def record_connection_relaxed_timeout( Record a connection timeout relaxation event. Args: - pool_name: Connection pool name + connection_name: Connection pool name maint_notification: Maintenance notification type relaxed: True to count up (relaxed), False to count down (unrelaxed) """ if not hasattr(self, "connection_relaxed_timeout"): return - attrs = self.attr_builder.build_connection_attributes(pool_name=pool_name) + attrs = self.attr_builder.build_connection_attributes(connection_name=connection_name) attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] = maint_notification self.connection_relaxed_timeout.add(1 if relaxed else -1, attributes=attrs) diff --git a/redis/observability/providers.py b/redis/observability/providers.py index 04732ea806..8cc36c34d8 100644 --- a/redis/observability/providers.py +++ b/redis/observability/providers.py @@ -23,12 +23,19 @@ import logging from typing import Optional -from opentelemetry.sdk.metrics import MeterProvider - from redis.observability.config import OTelConfig logger = logging.getLogger(__name__) +# Optional imports - OTel SDK may not be installed +try: + from opentelemetry.sdk.metrics import MeterProvider + + OTEL_AVAILABLE = True +except ImportError: + OTEL_AVAILABLE = False + MeterProvider = None + # Global singleton instance _global_provider_manager: Optional["OTelProviderManager"] = None diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py index bf60154611..f8ef690e51 100644 --- a/redis/observability/recorder.py +++ b/redis/observability/recorder.py @@ -20,7 +20,7 @@ """ import time -from typing import Optional +from typing import Optional, Callable from redis.observability.attributes import PubSubDirection, ConnectionState from redis.observability.metrics import RedisMetricsCollector @@ -93,14 +93,14 @@ def record_operation_duration( def record_connection_create_time( - pool_name: str, + connection_pool: "ConnectionPoolInterface", duration_seconds: float, ) -> None: """ Record connection creation time. Args: - pool_name: Connection pool identifier + connection_pool: Connection pool implementation duration_seconds: Time taken to create connection in seconds Example: @@ -118,30 +118,21 @@ def record_connection_create_time( # try: _metrics_collector.record_connection_create_time( - pool_name=pool_name, + connection_pool=connection_pool, duration_seconds=duration_seconds, ) # except Exception: # pass -def record_connection_count( - count: int, - pool_name: str, - state: ConnectionState, - is_pubsub: bool = False, +def init_connection_count( + connection_pools: list, ) -> None: """ - Record current connection count by state. + Initialize observable gauge for connection count metric. Args: - count: Increment/Decrement - pool_name: Connection pool identifier - state: Connection state ('idle' or 'used') - is_pubsub: Whether or not the connection is pubsub - - Example: - >>> record_connection_count(1, 'ConnectionPool', 'idle', False) + connection_pools: Connection pools to collect metrics from. """ global _metrics_collector @@ -150,14 +141,20 @@ def record_connection_count( if _metrics_collector is None: return + # Lazy import + from opentelemetry.metrics import Observation + + def connection_count_callback(__): + observations = [] + for pool in connection_pools: + for count, attributes in pool.get_connection_count(): + observations.append(Observation(count, attributes)) + return observations + # try: - from redis.observability.attributes import ConnectionState - _metrics_collector.record_connection_count( - count=count, - pool_name=pool_name, - state=state, - is_pubsub=is_pubsub, - ) + _metrics_collector.init_connection_count( + callback=connection_count_callback, + ) # except Exception: # pass @@ -287,7 +284,7 @@ def record_connection_closed( def record_connection_relaxed_timeout( - pool_name: str, + connection_name: str, maint_notification: str, relaxed: bool, ) -> None: @@ -295,12 +292,12 @@ def record_connection_relaxed_timeout( Record a connection timeout relaxation event. Args: - pool_name: Connection pool identifier + connection_name: Connection identifier maint_notification: Maintenance notification type relaxed: True to count up (relaxed), False to count down (unrelaxed) Example: - >>> record_connection_relaxed_timeout('ConnectionPool', 'MOVING', True) + >>> record_connection_relaxed_timeout('Connection', 'MOVING', True) """ global _metrics_collector @@ -310,11 +307,11 @@ def record_connection_relaxed_timeout( return # try: - _metrics_collector.record_connection_relaxed_timeout( - pool_name=pool_name, - maint_notification=maint_notification, - relaxed=relaxed, - ) + _metrics_collector.record_connection_relaxed_timeout( + connection_name=connection_name, + maint_notification=maint_notification, + relaxed=relaxed, + ) # except Exception: # pass @@ -339,9 +336,9 @@ def record_connection_handoff( return # try: - _metrics_collector.record_connection_handoff( - pool_name=pool_name, - ) + _metrics_collector.record_connection_handoff( + pool_name=pool_name, + ) # except Exception: # pass diff --git a/tests/test_client.py b/tests/test_client.py index 2ba17a452d..9e2c2ee789 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,17 @@ from unittest import mock +from unittest.mock import MagicMock import pytest import redis -from redis.event import EventDispatcher, OnErrorEvent, EventListenerInterface +from redis.event import ( + AfterPooledConnectionsInstantiationEvent, + ClientType, + EventDispatcher, + EventListenerInterface, + InitializeConnectionCountObservability, + OnErrorEvent, +) from redis.observability import recorder from redis.observability.config import OTelConfig, MetricGroup from redis.observability.metrics import RedisMetricsCollector @@ -632,4 +640,37 @@ def listen(self, event: object): assert error_events[0].retry_attempts == 1 # Second event is from final failure (is_internal=False) - assert error_events[1].is_internal is False \ No newline at end of file + assert error_events[1].is_internal is False + +class TestInitializeConnectionCountObservabilityListener: + """ + Unit tests that verify InitializeConnectionCountObservability listener + is correctly called when Redis client is instantiated, and that + the connection pools are passed to the OTel recorder. + """ + + def test_redis_client_init_calls_init_connection_count_with_pools(self): + """Test that Redis.__init__ triggers init_connection_count with connection pools.""" + mock_pool = MagicMock() + mock_pool.get_protocol.return_value = 2 + + with mock.patch( + "redis.client.ConnectionPool", return_value=mock_pool + ), mock.patch( + "redis.event.init_connection_count" + ) as mock_init_connection_count: + redis.Redis(host="localhost", port=6379) + + mock_init_connection_count.assert_called_once_with([mock_pool]) + + def test_redis_client_with_external_pool_calls_init_connection_count(self): + """Test that Redis with external pool triggers init_connection_count.""" + mock_pool = MagicMock() + mock_pool.get_protocol.return_value = 2 + + with mock.patch( + "redis.event.init_connection_count" + ) as mock_init_connection_count: + redis.Redis(connection_pool=mock_pool) + + mock_init_connection_count.assert_called_once_with([mock_pool]) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 7365c6ff13..e78fc976d3 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -4,11 +4,13 @@ from contextlib import closing from threading import Thread from unittest import mock +from unittest.mock import MagicMock import pytest import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool +from redis.event import AfterConnectionCreatedEvent, EventDispatcher, EventListenerInterface from redis.utils import SSL_AVAILABLE from .conftest import ( @@ -43,6 +45,9 @@ def can_read(self): def should_reconnect(self): return False + def re_auth(self): + pass + class TestConnectionPool: def get_pool( @@ -955,3 +960,163 @@ def test_health_check_in_pubsub_poll(self, r): assert wait_for_message(p) is None m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) self.assert_interval_advanced(p.connection) + +class TestConnectionPoolEventEmission: + """Tests for event emission from ConnectionPool.""" + + def test_connection_created_event_emitted_on_new_connection(self): + """Test that AfterConnectionCreatedEvent is emitted when a new connection is created.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionCreatedEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + pool.get_connection() + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionCreatedEvent) + assert event.connection_pool is pool + assert event.duration_seconds >= 0 + + def test_connection_created_event_not_emitted_on_reused_connection(self): + """Test that AfterConnectionCreatedEvent is NOT emitted when reusing a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionCreatedEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + conn = pool.get_connection() + pool.release(conn) + + # Reset the mock to clear the first call + listener.reset_mock() + + # Get the same connection again (reused) + pool.get_connection() + + # Event should NOT be emitted for reused connection + listener.listen.assert_not_called() + + def test_connection_created_event_emitted_multiple_times_for_new_connections(self): + """Test that AfterConnectionCreatedEvent is emitted for each new connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionCreatedEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + pool.get_connection() + pool.get_connection() + + assert listener.listen.call_count == 2 + + def test_connection_created_event_contains_duration(self): + """Test that AfterConnectionCreatedEvent contains a positive duration.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionCreatedEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + pool.get_connection() + + event = listener.listen.call_args[0][0] + assert isinstance(event.duration_seconds, float) + assert event.duration_seconds >= 0 + + +class TestBlockingConnectionPoolEventEmission: + """Tests for event emission from BlockingConnectionPool.""" + + def test_connection_created_event_emitted_on_new_connection(self): + """Test that AfterConnectionCreatedEvent is emitted when a new connection is created.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionCreatedEvent: [listener], + }) + + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + max_connections=10, + timeout=5, + ) + + pool.get_connection() + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionCreatedEvent) + assert event.connection_pool is pool + assert event.duration_seconds >= 0 + + def test_connection_created_event_not_emitted_on_reused_connection(self): + """Test that AfterConnectionCreatedEvent is NOT emitted when reusing a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionCreatedEvent: [listener], + }) + + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + max_connections=10, + timeout=5, + ) + + conn = pool.get_connection() + pool.release(conn) + + # Reset the mock to clear the first call + listener.reset_mock() + + # Get the same connection again (reused) + pool.get_connection() + + # Event should NOT be emitted for reused connection + listener.listen.assert_not_called() + + def test_connection_created_event_emitted_multiple_times_for_new_connections(self): + """Test that AfterConnectionCreatedEvent is emitted for each new connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionCreatedEvent: [listener], + }) + + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + max_connections=10, + timeout=5, + ) + + pool.get_connection() + pool.get_connection() + + assert listener.listen.call_count == 2 diff --git a/tests/test_maint_notifications.py b/tests/test_maint_notifications.py index 85aa671390..9b9e514d13 100644 --- a/tests/test_maint_notifications.py +++ b/tests/test_maint_notifications.py @@ -3,6 +3,7 @@ import pytest from redis.connection import ConnectionInterface, MaintNotificationsAbstractConnection +from redis.event import EventDispatcher from redis.maint_notifications import ( MaintenanceNotification, @@ -894,3 +895,190 @@ def test_endpoint_type_override(self): # Test with endpoint_type set to EXTERNAL_IP config = MaintNotificationsConfig(endpoint_type=EndpointType.EXTERNAL_IP) assert config.get_endpoint_type("localhost", conn) == EndpointType.EXTERNAL_IP + +class TestMaintNotificationsEventEmission: + """ + Tests for event emission from maintenance notification handlers. + These tests verify that events are properly dispatched through the event system + and that the actual OTel recorder functions are called with correct arguments. + """ + + @patch("redis.event.record_maint_notification_count") + def test_connection_handler_calls_record_maint_notification_count( + self, mock_record_maint_notification_count + ): + """Test that handle_notification calls record_maint_notification_count via listener.""" + event_dispatcher = EventDispatcher() + + mock_connection = Mock() + mock_connection.event_dispatcher = event_dispatcher + mock_connection.maintenance_state = MaintenanceState.NONE + mock_connection.host = "localhost" + mock_connection.port = 6379 + + config = MaintNotificationsConfig(enabled=True, relaxed_timeout=20) + handler = MaintNotificationsConnectionHandler(mock_connection, config) + + notification = NodeMigratingNotification(id=1, ttl=5) + handler.handle_notification(notification) + + mock_record_maint_notification_count.assert_called_once_with( + server_address="localhost", + server_port=6379, + network_peer_address="localhost", + network_peer_port=6379, + maint_notification=repr(notification), + ) + + @patch("redis.event.record_connection_relaxed_timeout") + def test_connection_handler_calls_record_connection_relaxed_timeout_on_start( + self, mock_record_connection_relaxed_timeout + ): + """Test that handle_notification calls record_connection_relaxed_timeout with relaxed=True.""" + event_dispatcher = EventDispatcher() + + mock_connection = Mock() + mock_connection.event_dispatcher = event_dispatcher + mock_connection.maintenance_state = MaintenanceState.NONE + + config = MaintNotificationsConfig(enabled=True, relaxed_timeout=20) + handler = MaintNotificationsConnectionHandler(mock_connection, config) + + notification = NodeMigratingNotification(id=1, ttl=5) + handler.handle_notification(notification) + + mock_record_connection_relaxed_timeout.assert_called_once_with( + connection_name=repr(mock_connection), + maint_notification=repr(notification), + relaxed=True, + ) + + @patch("redis.event.record_connection_relaxed_timeout") + def test_connection_handler_calls_record_connection_relaxed_timeout_on_complete( + self, mock_record_connection_relaxed_timeout + ): + """Test that handle_notification calls record_connection_relaxed_timeout with relaxed=False.""" + event_dispatcher = EventDispatcher() + + mock_connection = Mock() + mock_connection.event_dispatcher = event_dispatcher + mock_connection.maintenance_state = MaintenanceState.MAINTENANCE + + config = MaintNotificationsConfig(enabled=True, relaxed_timeout=20) + handler = MaintNotificationsConnectionHandler(mock_connection, config) + + notification = NodeMigratedNotification(id=1) + handler.handle_notification(notification) + + mock_record_connection_relaxed_timeout.assert_called_once_with( + connection_name=repr(mock_connection), + maint_notification=repr(notification), + relaxed=False, + ) + + @patch("redis.event.record_connection_relaxed_timeout") + def test_connection_handler_no_relaxed_timeout_call_when_disabled( + self, mock_record_connection_relaxed_timeout + ): + """Test that record_connection_relaxed_timeout is not called when relaxed_timeout is disabled.""" + event_dispatcher = EventDispatcher() + + mock_connection = Mock() + mock_connection.event_dispatcher = event_dispatcher + mock_connection.maintenance_state = MaintenanceState.NONE + mock_connection.host = "localhost" + mock_connection.port = 6379 + + config = MaintNotificationsConfig(enabled=True, relaxed_timeout=-1) + handler = MaintNotificationsConnectionHandler(mock_connection, config) + + notification = NodeMigratingNotification(id=1, ttl=5) + handler.handle_notification(notification) + + mock_record_connection_relaxed_timeout.assert_not_called() + + @patch("redis.event.record_connection_handoff") + def test_pool_handler_calls_record_connection_handoff( + self, mock_record_connection_handoff + ): + """Test that handle_node_moving_notification calls record_connection_handoff via listener.""" + event_dispatcher = EventDispatcher() + + mock_pool = Mock() + mock_pool._lock = MagicMock() + mock_pool._lock.__enter__ = Mock(return_value=None) + mock_pool._lock.__exit__ = Mock(return_value=None) + + config = MaintNotificationsConfig( + enabled=True, proactive_reconnect=True, relaxed_timeout=20 + ) + handler = MaintNotificationsPoolHandler( + mock_pool, config, event_dispatcher=event_dispatcher + ) + + notification = NodeMovingNotification( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("threading.Timer"): + handler.handle_node_moving_notification(notification) + + mock_record_connection_handoff.assert_called_once_with( + pool_name=repr(mock_pool), + ) + + @patch("redis.event.record_connection_handoff") + def test_pool_handler_no_handoff_call_when_already_processed( + self, mock_record_connection_handoff + ): + """Test that record_connection_handoff is not called for already processed notification.""" + event_dispatcher = EventDispatcher() + + mock_pool = Mock() + mock_pool._lock = MagicMock() + mock_pool._lock.__enter__ = Mock(return_value=None) + mock_pool._lock.__exit__ = Mock(return_value=None) + + config = MaintNotificationsConfig( + enabled=True, proactive_reconnect=True, relaxed_timeout=20 + ) + handler = MaintNotificationsPoolHandler( + mock_pool, config, event_dispatcher=event_dispatcher + ) + + notification = NodeMovingNotification( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + # Add notification to processed set + handler._processed_notifications.add(notification) + + handler.handle_node_moving_notification(notification) + + mock_record_connection_handoff.assert_not_called() + + @patch("redis.event.record_connection_handoff") + def test_pool_handler_no_handoff_call_when_disabled( + self, mock_record_connection_handoff + ): + """Test that record_connection_handoff is not called when both features are disabled.""" + event_dispatcher = EventDispatcher() + + mock_pool = Mock() + mock_pool._lock = MagicMock() + mock_pool._lock.__enter__ = Mock(return_value=None) + mock_pool._lock.__exit__ = Mock(return_value=None) + + config = MaintNotificationsConfig( + enabled=True, proactive_reconnect=False, relaxed_timeout=-1 + ) + handler = MaintNotificationsPoolHandler( + mock_pool, config, event_dispatcher=event_dispatcher + ) + + notification = NodeMovingNotification( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + handler.handle_node_moving_notification(notification) + + mock_record_connection_handoff.assert_not_called() diff --git a/tests/test_observability/test_config.py b/tests/test_observability/test_config.py index 6d49c803dd..222ad9c830 100644 --- a/tests/test_observability/test_config.py +++ b/tests/test_observability/test_config.py @@ -57,27 +57,6 @@ def test_single_telemetry_option(self): config = OTelConfig(enabled_telemetry=[TelemetryOption.METRICS]) assert config.enabled_telemetry == TelemetryOption.METRICS - def test_multiple_telemetry_options(self): - """Test setting multiple telemetry options.""" - config = OTelConfig( - enabled_telemetry=[TelemetryOption.METRICS, TelemetryOption.TRACES] - ) - assert TelemetryOption.METRICS in config.enabled_telemetry - assert TelemetryOption.TRACES in config.enabled_telemetry - - def test_all_telemetry_options(self): - """Test setting all telemetry options.""" - config = OTelConfig( - enabled_telemetry=[ - TelemetryOption.METRICS, - TelemetryOption.TRACES, - TelemetryOption.LOGS, - ] - ) - assert TelemetryOption.METRICS in config.enabled_telemetry - assert TelemetryOption.TRACES in config.enabled_telemetry - assert TelemetryOption.LOGS in config.enabled_telemetry - def test_empty_telemetry_list_disables_all(self): """Test that empty telemetry list disables all telemetry.""" config = OTelConfig(enabled_telemetry=[]) @@ -325,24 +304,3 @@ def test_metric_group_membership_check(self): assert bool(combined & MetricGroup.RESILIENCY) assert bool(combined & MetricGroup.CONNECTION_BASIC) assert not bool(combined & MetricGroup.COMMAND) - - -class TestTelemetryOptionEnum: - """Tests for TelemetryOption IntFlag enum.""" - - def test_telemetry_option_values_are_unique(self): - """Test that all TelemetryOption values are unique powers of 2.""" - values = [ - TelemetryOption.METRICS, - TelemetryOption.TRACES, - TelemetryOption.LOGS, - ] - for value in values: - assert value & (value - 1) == 0 # Power of 2 check - - def test_telemetry_option_can_be_combined(self): - """Test that TelemetryOption values can be combined with bitwise OR.""" - combined = TelemetryOption.METRICS | TelemetryOption.TRACES - assert TelemetryOption.METRICS in combined - assert TelemetryOption.TRACES in combined - assert TelemetryOption.LOGS not in combined diff --git a/tests/test_observability/test_recorder.py b/tests/test_observability/test_recorder.py index f39bedfc39..60926be12c 100644 --- a/tests/test_observability/test_recorder.py +++ b/tests/test_observability/test_recorder.py @@ -41,12 +41,12 @@ # Streaming attributes REDIS_CLIENT_STREAM_NAME, REDIS_CLIENT_CONSUMER_GROUP, - REDIS_CLIENT_CONSUMER_NAME, + REDIS_CLIENT_CONSUMER_NAME, DB_CLIENT_CONNECTION_NAME, ) from redis.observability.config import OTelConfig, MetricGroup from redis.observability.metrics import RedisMetricsCollector from redis.observability.recorder import record_operation_duration, record_connection_create_time, \ - record_connection_count, record_connection_timeout, record_connection_wait_time, record_connection_use_time, \ + record_connection_timeout, record_connection_wait_time, record_connection_use_time, \ record_connection_closed, record_connection_relaxed_timeout, record_connection_handoff, record_error_count, \ record_pubsub_message, reset_collector, record_streaming_lag @@ -63,8 +63,10 @@ def __init__(self): self.connection_handoff = MagicMock() self.pubsub_messages = MagicMock() - # UpDownCounters + # Gauges self.connection_count = MagicMock() + + # UpDownCounters self.connection_relaxed_timeout = MagicMock() # Histograms @@ -97,9 +99,14 @@ def create_counter_side_effect(name, **kwargs): } return instrument_map.get(name, MagicMock()) - def create_up_down_counter_side_effect(name, **kwargs): + def create_gauge_side_effect(name, **kwargs): instrument_map = { 'db.client.connection.count': mock_instruments.connection_count, + } + return instrument_map.get(name, MagicMock()) + + def create_up_down_counter_side_effect(name, **kwargs): + instrument_map = { 'redis.client.connection.relaxed_timeout': mock_instruments.connection_relaxed_timeout, } return instrument_map.get(name, MagicMock()) @@ -115,6 +122,7 @@ def create_histogram_side_effect(name, **kwargs): return instrument_map.get(name, MagicMock()) meter.create_counter.side_effect = create_counter_side_effect + meter.create_gauge.side_effect = create_gauge_side_effect meter.create_up_down_counter.side_effect = create_up_down_counter_side_effect meter.create_histogram.side_effect = create_histogram_side_effect @@ -198,7 +206,6 @@ def test_record_operation_duration_success(self, setup_recorder): assert attrs[SERVER_PORT] == 6379 assert attrs[DB_NAMESPACE] == '0' assert attrs[DB_OPERATION_NAME] == 'SET' - assert attrs[DB_RESPONSE_STATUS_CODE] == 'ok' def test_record_operation_duration_with_error(self, setup_recorder): """Test that error information is included in attributes.""" @@ -219,7 +226,7 @@ def test_record_operation_duration_with_error(self, setup_recorder): attrs = call_args[1]['attributes'] assert attrs[DB_OPERATION_NAME] == 'GET' - assert attrs[DB_RESPONSE_STATUS_CODE] is None + assert attrs[DB_RESPONSE_STATUS_CODE] == 'error' assert attrs[ERROR_TYPE] == 'ConnectionError' @@ -232,7 +239,7 @@ def test_record_connection_create_time(self, setup_recorder): instruments = setup_recorder record_connection_create_time( - pool_name='ConnectionPool', + connection_pool='ConnectionPool', duration_seconds=0.025, ) @@ -244,55 +251,7 @@ def test_record_connection_create_time(self, setup_recorder): # Verify attributes attrs = call_args[1]['attributes'] - assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' - - -class TestRecordConnectionCount: - """Tests for record_connection_count - verifies UpDownCounter.add() calls.""" - - def test_record_connection_count_idle_increment(self, setup_recorder): - """Test incrementing idle connection count.""" - - instruments = setup_recorder - - record_connection_count( - count=1, - pool_name='ConnectionPool', - state=ConnectionState.IDLE, - is_pubsub=False, - ) - - instruments.connection_count.add.assert_called_once() - call_args = instruments.connection_count.add.call_args - - # Verify increment value - assert call_args[0][0] == 1 - - # Verify attributes - attrs = call_args[1]['attributes'] - assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' - assert attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value - assert attrs[REDIS_CLIENT_CONNECTION_PUBSUB] is False - - def test_record_connection_count_used_decrement(self, setup_recorder): - """Test decrementing used connection count for pubsub.""" - - instruments = setup_recorder - - record_connection_count( - count=-1, - pool_name='ConnectionPool', - state=ConnectionState.USED, - is_pubsub=True, - ) - - instruments.connection_count.add.assert_called_once() - call_args = instruments.connection_count.add.call_args - - assert call_args[0][0] == -1 - attrs = call_args[1]['attributes'] - assert attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value - assert attrs[REDIS_CLIENT_CONNECTION_PUBSUB] is True + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == "'ConnectionPool'" class TestRecordConnectionTimeout: @@ -407,7 +366,7 @@ def test_record_connection_relaxed_timeout_relaxed(self, setup_recorder): instruments = setup_recorder record_connection_relaxed_timeout( - pool_name='ConnectionPool', + connection_name='Connection', maint_notification='MOVING', relaxed=True, ) @@ -418,7 +377,7 @@ def test_record_connection_relaxed_timeout_relaxed(self, setup_recorder): # relaxed=True means count up (+1) assert call_args[0][0] == 1 attrs = call_args[1]['attributes'] - assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + assert attrs[DB_CLIENT_CONNECTION_NAME] == 'Connection' assert attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] == 'MOVING' def test_record_connection_relaxed_timeout_unrelaxed(self, setup_recorder): @@ -427,7 +386,7 @@ def test_record_connection_relaxed_timeout_unrelaxed(self, setup_recorder): instruments = setup_recorder record_connection_relaxed_timeout( - pool_name='ConnectionPool', + connection_name='ConnectionPool', maint_notification='MIGRATING', relaxed=False, ) @@ -477,6 +436,7 @@ def test_record_error_count(self, setup_recorder): network_peer_port=6379, error_type=error, retry_attempts=3, + is_internal=True, ) instruments.client_errors.add.assert_called_once() @@ -491,6 +451,80 @@ def test_record_error_count(self, setup_recorder): assert attrs[ERROR_TYPE] == 'ConnectionError' assert attrs[REDIS_CLIENT_OPERATION_RETRY_ATTEMPTS] == 3 + def test_record_error_count_with_is_internal_false(self, setup_recorder): + """Test recording error count with is_internal=False.""" + + instruments = setup_recorder + + error = TimeoutError("Connection timed out") + record_error_count( + server_address='localhost', + server_port=6379, + network_peer_address='127.0.0.1', + network_peer_port=6379, + error_type=error, + retry_attempts=2, + is_internal=False, + ) + + instruments.client_errors.add.assert_called_once() + call_args = instruments.client_errors.add.call_args + + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[ERROR_TYPE] == 'TimeoutError' + assert attrs[REDIS_CLIENT_OPERATION_RETRY_ATTEMPTS] == 2 + + +class TestRecordMaintNotificationCount: + """Tests for record_maint_notification_count - verifies Counter.add() calls.""" + + def test_record_maint_notification_count(self, setup_recorder): + """Test recording maintenance notification count with all attributes.""" + + instruments = setup_recorder + + recorder.record_maint_notification_count( + server_address='localhost', + server_port=6379, + network_peer_address='127.0.0.1', + network_peer_port=6379, + maint_notification='MOVING', + ) + + instruments.maintenance_notifications.add.assert_called_once() + call_args = instruments.maintenance_notifications.add.call_args + + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[SERVER_ADDRESS] == 'localhost' + assert attrs[SERVER_PORT] == 6379 + assert attrs[NETWORK_PEER_ADDRESS] == '127.0.0.1' + assert attrs[NETWORK_PEER_PORT] == 6379 + assert attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] == 'MOVING' + + def test_record_maint_notification_count_migrating(self, setup_recorder): + """Test recording maintenance notification count with MIGRATING type.""" + + instruments = setup_recorder + + recorder.record_maint_notification_count( + server_address='redis-primary', + server_port=6380, + network_peer_address='10.0.0.1', + network_peer_port=6380, + maint_notification='MIGRATING', + ) + + instruments.maintenance_notifications.add.assert_called_once() + call_args = instruments.maintenance_notifications.add.call_args + + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[SERVER_ADDRESS] == 'redis-primary' + assert attrs[SERVER_PORT] == 6380 + assert attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] == 'MIGRATING' + class TestRecordPubsubMessage: """Tests for record_pubsub_message - verifies Counter.add() calls.""" @@ -626,7 +660,6 @@ def test_all_record_functions_safe_when_disabled(self): with patch.object(recorder, '_get_or_create_collector', return_value=None): # None of these should raise recorder.record_connection_create_time('pool', 0.1) - recorder.record_connection_count(1, 'pool', ConnectionState.IDLE, False) recorder.record_connection_timeout('pool') recorder.record_connection_wait_time('pool', 0.1) recorder.record_connection_use_time('pool', 0.1) @@ -634,6 +667,7 @@ def test_all_record_functions_safe_when_disabled(self): recorder.record_connection_relaxed_timeout('pool', 'MOVING', True) recorder.record_connection_handoff('pool') recorder.record_error_count('host', 6379, '127.0.0.1', 6379, Exception(), 0) + recorder.record_maint_notification_count('host', 6379, '127.0.0.1', 6379, 'MOVING') recorder.record_pubsub_message(PubSubDirection.PUBLISH) recorder.record_streaming_lag(0.1, 'stream', 'group', 'consumer') @@ -672,9 +706,14 @@ def create_counter_side_effect(name, **kwargs): } return instrument_map.get(name, MagicMock()) - def create_up_down_counter_side_effect(name, **kwargs): + def create_gauge_side_effect(name, **kwargs): instrument_map = { 'db.client.connection.count': mock_instruments.connection_count, + } + return instrument_map.get(name, MagicMock()) + + def create_up_down_counter_side_effect(name, **kwargs): + instrument_map = { 'redis.client.connection.relaxed_timeout': mock_instruments.connection_relaxed_timeout, } return instrument_map.get(name, MagicMock()) @@ -690,6 +729,7 @@ def create_histogram_side_effect(name, **kwargs): return instrument_map.get(name, MagicMock()) mock_meter.create_counter.side_effect = create_counter_side_effect + mock_meter.create_gauge.side_effect = create_gauge_side_effect mock_meter.create_up_down_counter.side_effect = create_up_down_counter_side_effect mock_meter.create_histogram.side_effect = create_histogram_side_effect @@ -718,26 +758,6 @@ def test_record_operation_duration_no_meter_call_when_command_disabled(self): # Verify no call to the histogram's record method instruments.operation_duration.record.assert_not_called() - def test_record_connection_count_no_meter_call_when_connection_basic_disabled(self): - """Test that record_connection_count makes no Meter calls when CONNECTION_BASIC is disabled.""" - instruments = MockInstruments() - collector = self._create_collector_with_disabled_groups( - instruments, - [MetricGroup.COMMAND] # No CONNECTION_BASIC - ) - - recorder.reset_collector() - with patch.object(recorder, '_get_or_create_collector', return_value=collector): - record_connection_count( - count=1, - pool_name='test-pool', - state=ConnectionState.IDLE, - is_pubsub=False, - ) - - # Verify no call to the up_down_counter's add method - instruments.connection_count.add.assert_not_called() - def test_record_connection_create_time_no_meter_call_when_connection_basic_disabled(self): """Test that record_connection_create_time makes no Meter calls when CONNECTION_BASIC is disabled.""" instruments = MockInstruments() @@ -749,7 +769,7 @@ def test_record_connection_create_time_no_meter_call_when_connection_basic_disab recorder.reset_collector() with patch.object(recorder, '_get_or_create_collector', return_value=collector): record_connection_create_time( - pool_name='test-pool', + connection_pool='test-pool', duration_seconds=0.050, ) @@ -821,7 +841,7 @@ def test_record_connection_relaxed_timeout_no_meter_call_when_connection_basic_d recorder.reset_collector() with patch.object(recorder, '_get_or_create_collector', return_value=collector): record_connection_relaxed_timeout( - pool_name='test-pool', + connection_name='test-pool', maint_notification='MOVING', relaxed=True, ) @@ -889,6 +909,27 @@ def test_record_error_count_no_meter_call_when_resiliency_disabled(self): # Verify no call to the counter's add method instruments.client_errors.add.assert_not_called() + def test_record_maint_notification_count_no_meter_call_when_resiliency_disabled(self): + """Test that record_maint_notification_count makes no Meter calls when RESILIENCY group is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No RESILIENCY + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + recorder.record_maint_notification_count( + server_address='localhost', + server_port=6379, + network_peer_address='127.0.0.1', + network_peer_port=6379, + maint_notification='MOVING', + ) + + # Verify no call to the counter's add method + instruments.maintenance_notifications.add.assert_not_called() + def test_all_record_functions_no_meter_calls_when_all_groups_disabled(self): """Test that all record_* functions make no Meter calls when all groups are disabled.""" instruments = MockInstruments() @@ -902,7 +943,6 @@ def test_all_record_functions_no_meter_calls_when_all_groups_disabled(self): # Call all record functions record_operation_duration('GET', 0.001, 'localhost', 6379) record_connection_create_time('pool', 0.050) - record_connection_count(1, 'pool', ConnectionState.IDLE, False) record_connection_timeout('pool') record_connection_wait_time('pool', 0.010) record_connection_use_time('pool', 0.100) @@ -910,13 +950,14 @@ def test_all_record_functions_no_meter_calls_when_all_groups_disabled(self): record_connection_relaxed_timeout('pool', 'MOVING', True) record_connection_handoff('pool') record_error_count('localhost', 6379, '127.0.0.1', 6379, Exception('err'), 0) + recorder.record_maint_notification_count('localhost', 6379, '127.0.0.1', 6379, 'MOVING') record_pubsub_message(PubSubDirection.PUBLISH, 'channel') record_streaming_lag(0.150, 'stream', 'group', 'consumer') # Verify no Meter instrument methods were called instruments.operation_duration.record.assert_not_called() instruments.connection_create_time.record.assert_not_called() - instruments.connection_count.add.assert_not_called() + instruments.connection_count.set.assert_not_called() instruments.connection_timeouts.add.assert_not_called() instruments.connection_wait_time.record.assert_not_called() instruments.connection_use_time.record.assert_not_called() @@ -924,6 +965,7 @@ def test_all_record_functions_no_meter_calls_when_all_groups_disabled(self): instruments.connection_relaxed_timeout.add.assert_not_called() instruments.connection_handoff.add.assert_not_called() instruments.client_errors.add.assert_not_called() + instruments.maintenance_notifications.add.assert_not_called() instruments.pubsub_messages.add.assert_not_called() instruments.stream_lag.record.assert_not_called() @@ -942,8 +984,8 @@ def test_enabled_group_receives_meter_calls_disabled_group_does_not(self): record_pubsub_message(PubSubDirection.PUBLISH, 'channel') # Call functions from disabled groups - record_connection_count(1, 'pool', ConnectionState.IDLE, False) record_error_count('localhost', 6379, '127.0.0.1', 6379, Exception('err'), 0) + recorder.record_maint_notification_count('localhost', 6379, '127.0.0.1', 6379, 'MOVING') record_streaming_lag(0.150, 'stream', 'group', 'consumer') # Enabled groups should have received Meter calls @@ -951,6 +993,7 @@ def test_enabled_group_receives_meter_calls_disabled_group_does_not(self): instruments.pubsub_messages.add.assert_called_once() # Disabled groups should NOT have received Meter calls - instruments.connection_count.add.assert_not_called() + instruments.connection_count.set.assert_not_called() instruments.client_errors.add.assert_not_called() + instruments.maintenance_notifications.add.assert_not_called() instruments.stream_lag.record.assert_not_called()