Skip to content

make RedisLockManager picklable for ray/dask compat #18018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 70 additions & 13 deletions src/integrations/prefect-redis/prefect_redis/locking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional

from redis import Redis
from redis.asyncio import Redis as AsyncRedis
Expand Down Expand Up @@ -65,21 +65,43 @@ def __init__(
self.username = username
self.password = password
self.ssl = ssl
# Clients are initialized by _init_clients
self.client: Redis
self.async_client: AsyncRedis
self._init_clients() # Initialize clients here
self._locks: dict[str, Lock | AsyncLock] = {}

# ---------- pickling ----------
def __getstate__(self) -> dict[str, Any]:
return {
k: getattr(self, k)
for k in ("host", "port", "db", "username", "password", "ssl")
}

def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
self._init_clients() # Re-initialize clients here
self._locks = {}

# ------------------------------------

def _init_clients(self) -> None:
self.client = Redis(
host=self.host,
port=self.port,
db=self.db,
username=self.username,
password=self.password,
ssl=self.ssl,
)
self.async_client = AsyncRedis(
host=self.host,
port=self.port,
db=self.db,
username=self.username,
password=self.password,
ssl=self.ssl,
)
self._locks: dict[str, Lock | AsyncLock] = {}

@staticmethod
def _lock_name_for_key(key: str) -> str:
Expand All @@ -92,8 +114,21 @@ def acquire_lock(
acquire_timeout: Optional[float] = None,
hold_timeout: Optional[float] = None,
) -> bool:
"""
Acquires a lock synchronously.

Args:
key: Unique identifier for the transaction record.
holder: Unique identifier for the holder of the lock.
acquire_timeout: Maximum time to wait for the lock to be acquired.
hold_timeout: Maximum time to hold the lock.

Returns:
True if the lock was acquired, False otherwise.
"""
lock_name = self._lock_name_for_key(key)
lock = self._locks.get(lock_name)

if lock is not None and self.is_lock_holder(key, holder):
return True
else:
Expand All @@ -112,20 +147,43 @@ async def aacquire_lock(
acquire_timeout: Optional[float] = None,
hold_timeout: Optional[float] = None,
) -> bool:
"""
Acquires a lock asynchronously.

Args:
key: Unique identifier for the transaction record.
holder: Unique identifier for the holder of the lock. Must match the
holder provided when acquiring the lock.
acquire_timeout: Maximum time to wait for the lock to be acquired.
hold_timeout: Maximum time to hold the lock.

Returns:
True if the lock was acquired, False otherwise.
"""
lock_name = self._lock_name_for_key(key)
lock = self._locks.get(lock_name)
if lock is not None and self.is_lock_holder(key, holder):
return True
else:
lock = AsyncLock(

if lock is not None and isinstance(lock, AsyncLock):
if await lock.owned() and lock.local.token == holder.encode():
return True
else:
lock = None

# Handles the case where a lock might have been released during a task retry
# If the lock doesn't exist in Redis at all, this method will succeed even if
# the holder ID doesn't match the original holder.
if lock is None:
new_lock = AsyncLock(
self.async_client, lock_name, timeout=hold_timeout, thread_local=False
)
lock_acquired = await lock.acquire(
token=holder, blocking_timeout=acquire_timeout
)
if lock_acquired:
self._locks[lock_name] = lock
return lock_acquired
lock_acquired = await new_lock.acquire(
token=holder, blocking_timeout=acquire_timeout
)
if lock_acquired:
self._locks[lock_name] = new_lock
return lock_acquired

return False

def release_lock(self, key: str, holder: str) -> None:
"""
Expand All @@ -146,7 +204,6 @@ def release_lock(self, key: str, holder: str) -> None:
lock_name = self._lock_name_for_key(key)
lock = self._locks.get(lock_name)

# If we have a lock object and we're the holder, release it
if lock is not None and self.is_lock_holder(key, holder):
lock.release()
del self._locks[lock_name]
Expand Down
100 changes: 100 additions & 0 deletions src/integrations/prefect-redis/tests/test_locking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import pickle
import queue
import threading
import time
Expand Down Expand Up @@ -254,3 +256,101 @@ async def test_transaction_retry_lock_behavior(self, lock_manager):
# 4b. Now simulate the "release after rollback" case - SHOULD NOT raise
# This is what our fix specifically addresses
store.release_lock(key, holder="transaction1")

async def test_pickle_unpickle_and_use_lock_manager(
self, lock_manager: RedisLockManager
):
"""
Tests that RedisLockManager can be pickled, unpickled, and then used successfully,
ensuring clients are re-initialized correctly on unpickle.
"""
# With the new __init__, clients are initialized immediately.
# So, no initial check for them being None.

# Store original client IDs for comparison after unpickling
original_client_id = id(lock_manager.client)
original_async_client_id = id(lock_manager.async_client)

# Pickle and unpickle
pickled_manager = pickle.dumps(lock_manager)
unpickled_manager: RedisLockManager = pickle.loads(pickled_manager)

# Verify state after unpickling: clients should be NEW instances due to __setstate__ calling _init_clients()
assert unpickled_manager.client is not None, (
"Client should be re-initialized after unpickling, not None"
)
assert unpickled_manager.async_client is not None, (
"Async client should be re-initialized after unpickling, not None"
)

assert id(unpickled_manager.client) != original_client_id, (
"Client should be a NEW instance after unpickling"
)
assert id(unpickled_manager.async_client) != original_async_client_id, (
"Async client should be a NEW instance after unpickling"
)

# _locks should be an empty dict after unpickling due to __setstate__
assert (
hasattr(unpickled_manager, "_locks")
and isinstance(getattr(unpickled_manager, "_locks"), dict)
and not getattr(unpickled_manager, "_locks")
), "_locks should be an empty dict after unpickling"

# Test synchronous operations (clients are already initialized)
sync_key = "test_sync_pickle_key"
sync_holder = "sync_pickle_holder"

acquired_sync = unpickled_manager.acquire_lock(
sync_key, holder=sync_holder, acquire_timeout=1, hold_timeout=5
)
assert acquired_sync, "Should acquire sync lock after unpickling"
assert unpickled_manager.client is not None, (
"Sync client should be initialized after use"
)
assert unpickled_manager.is_lock_holder(sync_key, sync_holder), (
"Should be sync lock holder"
)
unpickled_manager.release_lock(sync_key, sync_holder)
assert not unpickled_manager.is_locked(sync_key), "Sync lock should be released"

# Test asynchronous operations (should trigger _ensure_clients for async_client)
async_key = "test_async_pickle_key"
async_holder = "async_pickle_holder"
hold_timeout_seconds = 0.2 # Use a short hold timeout for testing expiration

acquired_async = await unpickled_manager.aacquire_lock(
async_key,
holder=async_holder,
acquire_timeout=1,
hold_timeout=hold_timeout_seconds,
)
assert acquired_async, "Should acquire async lock after unpickling"
assert unpickled_manager.async_client is not None, (
"Async client should be initialized after async use"
)

# Verify holder by re-acquiring (should succeed as it's the same holder and lock is fresh)
assert await unpickled_manager.aacquire_lock(
async_key,
holder=async_holder,
acquire_timeout=1,
hold_timeout=hold_timeout_seconds,
), "Re-acquiring same async lock should succeed"

# Wait for the lock to expire based on hold_timeout
await asyncio.sleep(
hold_timeout_seconds + 0.1
) # Wait a bit longer than the timeout

# Verify it's released (expired) by trying to acquire with a different holder.
new_async_holder = "new_async_pickle_holder"
acquired_by_new = await unpickled_manager.aacquire_lock(
async_key,
holder=new_async_holder,
acquire_timeout=1,
hold_timeout=hold_timeout_seconds,
)
assert acquired_by_new, (
"Should acquire async lock with new holder after original lock expires"
)