Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from .auth.token import TokenInterface
from .backoff import NoBackoff
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'OnMaintenanceNotificationEvent' is not used.

Suggested change
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, \

Copilot uses AI. Check for mistakes.
AfterConnectionCreatedEvent
from .exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand All @@ -53,6 +54,8 @@
MaintNotificationsConnectionHandler,
MaintNotificationsPoolHandler, MaintenanceNotification,
)
from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \
DB_CLIENT_CONNECTION_POOL_NAME
from .retry import Retry
from .utils import (
CRYPTOGRAPHY_AVAILABLE,
Expand Down Expand Up @@ -2060,6 +2063,13 @@ def set_retry(self, retry: Retry):
def re_auth_callback(self, token: TokenInterface):
pass

@abstractmethod
def get_connection_count(self) -> list[tuple[int, dict]]:
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation uses the newer pipe union syntax list[tuple[int, dict]] which is only available in Python 3.9+. For better compatibility with older Python versions, consider using List[Tuple[int, Dict]] from the typing module instead, or verify that the project's minimum supported Python version is 3.9 or higher.

Copilot uses AI. Check for mistakes.
"""
Returns a connection count (both idle and in use).
"""
pass


class MaintNotificationsAbstractConnectionPool:
"""
Expand Down Expand Up @@ -2635,11 +2645,17 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
"Get a connection from the pool"

self._checkpid()
is_created = False

with self._lock:
try:
connection = self._available_connections.pop()
except IndexError:
# Start timing for observability
start_time = time.monotonic()

connection = self.make_connection()
is_created = True
self._in_use_connections.add(connection)

try:
Expand All @@ -2666,6 +2682,14 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
# leak it
self.release(connection)
raise

if is_created:
self._event_dispatcher.dispatch(
AfterConnectionCreatedEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time,
)
)
return connection

def get_encoder(self) -> Encoder:
Expand Down Expand Up @@ -2785,6 +2809,20 @@ async def _mock(self, error: RedisError):
"""
pass

def get_connection_count(self) -> list[tuple[int, dict]]:
attributes = AttributeBuilder.build_base_attributes()
attributes[DB_CLIENT_CONNECTION_POOL_NAME] = repr(self)
free_connections_attributes = attributes.copy()
in_use_connections_attributes = attributes.copy()

free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ConnectionState.IDLE.value
in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ConnectionState.USED.value

return [
(len(self._available_connections), free_connections_attributes),
(len(self._in_use_connections), in_use_connections_attributes),
Comment on lines +2821 to +2823
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_connection_count method accesses self._available_connections and self._in_use_connections without holding the self._lock. This creates a potential race condition where the connection lists could be modified by another thread during the length calculation, leading to inconsistent or incorrect counts. Consider acquiring the lock before accessing these collections, similar to how other methods like _get_free_connections and _get_in_use_connections do.

Suggested change
return [
(len(self._available_connections), free_connections_attributes),
(len(self._in_use_connections), in_use_connections_attributes),
with self._lock:
free_count = len(self._available_connections)
in_use_count = len(self._in_use_connections)
return [
(free_count, free_connections_attributes),
(in_use_count, in_use_connections_attributes),

Copilot uses AI. Check for mistakes.
]


class BlockingConnectionPool(ConnectionPool):
"""
Expand Down Expand Up @@ -2917,6 +2955,7 @@ def get_connection(self, command_name=None, *keys, **options):
"""
# Make sure we haven't changed process.
self._checkpid()
is_created = False

# Try and get a connection from the pool. If one isn't available within
# self.timeout then raise a ``ConnectionError``.
Expand All @@ -2935,7 +2974,10 @@ def get_connection(self, command_name=None, *keys, **options):
# If the ``connection`` is actually ``None`` then that's a cue to make
# a new connection to add to the pool.
if connection is None:
# Start timing for observability
start_time = time.monotonic()
connection = self.make_connection()
is_created = True
finally:
if self._locked:
try:
Expand Down Expand Up @@ -2964,6 +3006,14 @@ def get_connection(self, command_name=None, *keys, **options):
self.release(connection)
raise

if is_created:
self._event_dispatcher.dispatch(
AfterConnectionCreatedEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time,
)
)

return connection

def release(self, connection):
Expand Down
81 changes: 77 additions & 4 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type, Union
from typing import Dict, List, Optional, Type, Union, Callable
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Callable' is not used.

Suggested change
from typing import Dict, List, Optional, Type, Union, Callable
from typing import Dict, List, Optional, Type, Union

Copilot uses AI. Check for mistakes.

from redis.auth.token import TokenInterface
from redis.credentials import CredentialProvider, StreamingCredentialProvider
from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count
from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \
record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff


class EventListenerInterface(ABC):
Expand Down Expand Up @@ -85,7 +86,8 @@ def __init__(
ReAuthConnectionListener(),
],
AfterPooledConnectionsInstantiationEvent: [
RegisterReAuthForPooledConnections()
RegisterReAuthForPooledConnections(),
InitializeConnectionCountObservability()
],
AfterSingleConnectionInstantiationEvent: [
RegisterReAuthForSingleConnection()
Expand All @@ -97,6 +99,16 @@ def __init__(
AsyncReAuthConnectionListener(),
],
OnErrorEvent: [ExportErrorCountMetric()],
OnMaintenanceNotificationEvent: [
ExportMaintenanceNotificationCountMetric(),
],
AfterConnectionCreatedEvent: [ExportConnectionCreateTimeMetric()],
AfterConnectionTimeoutRelaxedEvent: [
ExportConnectionRelaxedTimeoutMetric(),
],
AfterConnectionHandoffEvent: [
ExportConnectionHandoffMetric(),
],
}

self._lock = threading.Lock()
Expand Down Expand Up @@ -333,6 +345,30 @@ class OnMaintenanceNotificationEvent:
notification: "MaintenanceNotification"
connection: "MaintNotificationsAbstractConnection"

@dataclass
class AfterConnectionCreatedEvent:
"""
Event fired after connection is created in pool.
"""
connection_pool: "ConnectionPoolInterface"
duration_seconds: float

@dataclass
class AfterConnectionTimeoutRelaxedEvent:
"""
Event fired after connection timeout is relaxed.
"""
connection: "MaintNotificationsAbstractConnection"
notification: "MaintenanceNotification"
relaxed: bool

@dataclass
class AfterConnectionHandoffEvent:
"""
Event fired after connection is handed off.
"""
connection_pool: "ConnectionPoolInterface"

class AsyncOnCommandsFailEvent(OnCommandsFailEvent):
pass

Expand Down Expand Up @@ -547,4 +583,41 @@ def listen(self, event: OnMaintenanceNotificationEvent):
network_peer_address=event.connection.host,
network_peer_port=event.connection.port,
maint_notification=repr(event.notification),
)
)

class ExportConnectionCreateTimeMetric(EventListenerInterface):
"""
Listener that exports connection create time metric.
"""
def listen(self, event: AfterConnectionCreatedEvent):
record_connection_create_time(
connection_pool=event.connection_pool,
duration_seconds=event.duration_seconds,
)

class InitializeConnectionCountObservability(EventListenerInterface):
"""
Listener that initializes connection count observability.
"""
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
init_connection_count(event.connection_pools)

class ExportConnectionRelaxedTimeoutMetric(EventListenerInterface):
"""
Listener that exports connection relaxed timeout metric.
"""
def listen(self, event: AfterConnectionTimeoutRelaxedEvent):
record_connection_relaxed_timeout(
connection_name=repr(event.connection),
maint_notification=repr(event.notification),
relaxed=event.relaxed,
)

class ExportConnectionHandoffMetric(EventListenerInterface):
"""
Listener that exports connection handoff metric.
"""
def listen(self, event: AfterConnectionHandoffEvent):
record_connection_handoff(
pool_name=repr(event.connection_pool),
)
41 changes: 36 additions & 5 deletions redis/maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal, Optional, Union

from redis.event import OnMaintenanceNotificationEvent
from redis.event import OnMaintenanceNotificationEvent, EventDispatcherInterface, EventDispatcher, \
AfterConnectionTimeoutRelaxedEvent, AfterConnectionHandoffEvent
from redis.typing import Number


Expand Down Expand Up @@ -560,13 +561,19 @@ def __init__(
self,
pool: "MaintNotificationsAbstractConnectionPool",
config: MaintNotificationsConfig,
event_dispatcher: Optional[EventDispatcherInterface] = None,
) -> None:
self.pool = pool
self.config = config
self._processed_notifications = set()
self._lock = threading.RLock()
self.connection = None

if event_dispatcher is not None:
self.event_dispatcher = event_dispatcher
else:
self.event_dispatcher = EventDispatcher()

def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
self.connection = connection

Expand Down Expand Up @@ -683,6 +690,12 @@ def handle_node_moving_notification(self, notification: NodeMovingNotification):
args=(notification,),
).start()

self.event_dispatcher.dispatch(
AfterConnectionHandoffEvent(
connection_pool=self.pool,
)
)

self._processed_notifications.add(notification)

def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
Expand Down Expand Up @@ -784,12 +797,12 @@ def handle_notification(self, notification: MaintenanceNotification):
return

if notification_type:
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE, notification=notification)
else:
self.handle_maintenance_completed_notification()
self.handle_maintenance_completed_notification(notification=notification)

def handle_maintenance_start_notification(
self, maintenance_state: MaintenanceState
self, maintenance_state: MaintenanceState, **kwargs
):
if (
self.connection.maintenance_state == MaintenanceState.MOVING
Expand All @@ -804,7 +817,16 @@ def handle_maintenance_start_notification(
# extend the timeout for all created connections
self.connection.update_current_socket_timeout(self.config.relaxed_timeout)

def handle_maintenance_completed_notification(self):
if kwargs.get('notification', None) is not None:
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check kwargs.get('notification', None) is not None is redundant. The get method with a default of None already returns None if the key is missing, so checking is not None achieves the same result as just checking the truthiness. Consider simplifying to if kwargs.get('notification'): or better yet, use a more explicit parameter instead of **kwargs.

Copilot uses AI. Check for mistakes.
self.connection.event_dispatcher.dispatch(
AfterConnectionTimeoutRelaxedEvent(
connection=self.connection,
notification=kwargs.get('notification'),
relaxed=True,
)
)

def handle_maintenance_completed_notification(self, **kwargs):
# Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
if (
self.connection.maintenance_state == MaintenanceState.MOVING
Expand All @@ -816,3 +838,12 @@ def handle_maintenance_completed_notification(self):
# timeouts by providing -1 as the relaxed timeout
self.connection.update_current_socket_timeout(-1)
self.connection.maintenance_state = MaintenanceState.NONE

if kwargs.get('notification', None) is not None:
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check kwargs.get('notification', None) is not None is redundant. The get method with a default of None already returns None if the key is missing, so checking is not None achieves the same result as just checking the truthiness. Consider simplifying to if kwargs.get('notification'): or better yet, use a more explicit parameter instead of **kwargs.

Copilot uses AI. Check for mistakes.
self.connection.event_dispatcher.dispatch(
AfterConnectionTimeoutRelaxedEvent(
connection=self.connection,
notification=kwargs.get('notification'),
relaxed=False,
)
)
Loading