Skip to content

Fix async client safety #3512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ using `invoke standalone-tests`; similarly, RedisCluster tests can be run by usi
Each run of tests starts and stops the various dockers required. Sometimes
things get stuck, an `invoke clean` can help.

## Linting and Formatting

Call `invoke linters` to run linters without also running tests. This command will
only report issues, not fix them automatically. Run `invoke formatters` to
automatically format your code.

## Documentation

If relevant, update the code documentation, via docstrings, or in `/docs`.
Expand Down
47 changes: 45 additions & 2 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ def __init__(
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()

# When used as an async context manager, we need to increment and decrement
# a usage counter so that we can close the connection pool when no one is
# using the client.
self._usage_counter = 0
self._usage_lock = asyncio.Lock()

def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
Expand Down Expand Up @@ -594,10 +600,47 @@ def client(self) -> "Redis":
)

async def __aenter__(self: _RedisT) -> _RedisT:
return await self.initialize()
"""
Async context manager entry. Increments a usage counter so that the
connection pool is only closed (via aclose()) when no context is using
the client.
"""
await self._increment_usage()
try:
# Initialize the client (i.e. establish connection, etc.)
return await self.initialize()
except Exception:
# If initialization fails, decrement the counter to keep it in sync
await self._decrement_usage()
raise

async def _increment_usage(self) -> int:
"""
Helper coroutine to increment the usage counter while holding the lock.
Returns the new value of the usage counter.
"""
async with self._usage_lock:
self._usage_counter += 1
return self._usage_counter

async def _decrement_usage(self) -> int:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A helper method is required so we can use it in the shield().

"""
Helper coroutine to decrement the usage counter while holding the lock.
Returns the new value of the usage counter.
"""
async with self._usage_lock:
self._usage_counter -= 1
return self._usage_counter

async def __aexit__(self, exc_type, exc_value, traceback):
await self.aclose()
"""
Async context manager exit. Decrements a usage counter. If this is the
last exit (counter becomes zero), the client closes its connection pool.
"""
current_usage = await asyncio.shield(self._decrement_usage())
if current_usage == 0:
# This was the last active context, so disconnect the pool.
await asyncio.shield(self.aclose())

_DEL_MESSAGE = "Unclosed Redis client"

Expand Down
49 changes: 46 additions & 3 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,12 @@ def __init__(
self._initialize = True
self._lock: Optional[asyncio.Lock] = None

# When used as an async context manager, we need to increment and decrement
# a usage counter so that we can close the connection pool when no one is
# using the client.
self._usage_counter = 0
self._usage_lock = asyncio.Lock()

async def initialize(self) -> "RedisCluster":
"""Get all nodes from startup nodes & creates connections if not initialized."""
if self._initialize:
Expand Down Expand Up @@ -467,10 +473,47 @@ async def close(self) -> None:
await self.aclose()

async def __aenter__(self) -> "RedisCluster":
return await self.initialize()
"""
Async context manager entry. Increments a usage counter so that the
connection pool is only closed (via aclose()) when no context is using
the client.
"""
await self._increment_usage()
try:
# Initialize the client (i.e. establish connection, etc.)
return await self.initialize()
except Exception:
# If initialization fails, decrement the counter to keep it in sync
await self._decrement_usage()
raise

async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
await self.aclose()
async def _increment_usage(self) -> int:
"""
Helper coroutine to increment the usage counter while holding the lock.
Returns the new value of the usage counter.
"""
async with self._usage_lock:
self._usage_counter += 1
return self._usage_counter

async def _decrement_usage(self) -> int:
"""
Helper coroutine to decrement the usage counter while holding the lock.
Returns the new value of the usage counter.
"""
async with self._usage_lock:
self._usage_counter -= 1
return self._usage_counter

async def __aexit__(self, exc_type, exc_value, traceback):
"""
Async context manager exit. Decrements a usage counter. If this is the
last exit (counter becomes zero), the client closes its connection pool.
"""
current_usage = await asyncio.shield(self._decrement_usage())
if current_usage == 0:
# This was the last active context, so disconnect the pool.
await asyncio.shield(self.aclose())

def __await__(self) -> Generator[Any, None, "RedisCluster"]:
return self.initialize().__await__()
Expand Down
6 changes: 6 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def linters(c):
run("ruff format --check --diff tests redis")
run("vulture redis whitelist.py --min-confidence 80")

@task
def formatters(c):
"""Format code"""
run("black --target-version py37 tests redis")
run("isort tests redis")


@task
def all_tests(c):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_asyncio/test_usage_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import asyncio

import pytest


@pytest.mark.asyncio
async def test_usage_counter(r):
async def dummy_task():
async with r:
await asyncio.sleep(0.01)

tasks = [dummy_task() for _ in range(20)]
await asyncio.gather(*tasks)

# After all tasks have completed, the usage counter should be back to zero.
assert r._usage_counter == 0
Loading