@@ -19,7 +19,8 @@ class RedisLockManager(LockManager):
19
19
username: The username to use when connecting to the Redis server
20
20
password: The password to use when connecting to the Redis server
21
21
ssl: Whether to use SSL when connecting to the Redis server
22
- # client and async_client are initialized lazily
22
+ client: The Redis client used to communicate with the Redis server
23
+ async_client: The asynchronous Redis client used to communicate with the Redis server
23
24
24
25
Example:
25
26
Use with a cache policy:
@@ -64,11 +65,13 @@ def __init__(
64
65
self .username = username
65
66
self .password = password
66
67
self .ssl = ssl
67
- self .client : Optional [Redis ] = None
68
- self .async_client : Optional [AsyncRedis ] = None
68
+ # Clients are initialized by _init_clients
69
+ self .client : Redis
70
+ self .async_client : AsyncRedis
71
+ self ._init_clients () # Initialize clients here
69
72
self ._locks : dict [str , Lock | AsyncLock ] = {}
70
73
71
- # ---------- pickle helpers ----------
74
+ # ---------- pickling ----------
72
75
def __getstate__ (self ) -> dict [str , Any ]:
73
76
return {
74
77
k : getattr (self , k )
@@ -77,31 +80,28 @@ def __getstate__(self) -> dict[str, Any]:
77
80
78
81
def __setstate__ (self , state : dict [str , Any ]) -> None :
79
82
self .__dict__ .update (state )
80
- self .client = None
81
- self .async_client = None
83
+ self ._init_clients () # Re-initialize clients here
82
84
self ._locks = {}
83
85
84
86
# ------------------------------------
85
87
86
- def _ensure_clients (self ) -> None :
87
- if self .client is None :
88
- self .client = Redis (
89
- host = self .host ,
90
- port = self .port ,
91
- db = self .db ,
92
- username = self .username ,
93
- password = self .password ,
94
- ssl = self .ssl ,
95
- )
96
- if self .async_client is None :
97
- self .async_client = AsyncRedis (
98
- host = self .host ,
99
- port = self .port ,
100
- db = self .db ,
101
- username = self .username ,
102
- password = self .password ,
103
- ssl = self .ssl ,
104
- )
88
+ def _init_clients (self ) -> None :
89
+ self .client = Redis (
90
+ host = self .host ,
91
+ port = self .port ,
92
+ db = self .db ,
93
+ username = self .username ,
94
+ password = self .password ,
95
+ ssl = self .ssl ,
96
+ )
97
+ self .async_client = AsyncRedis (
98
+ host = self .host ,
99
+ port = self .port ,
100
+ db = self .db ,
101
+ username = self .username ,
102
+ password = self .password ,
103
+ ssl = self .ssl ,
104
+ )
105
105
106
106
@staticmethod
107
107
def _lock_name_for_key (key : str ) -> str :
@@ -114,16 +114,24 @@ def acquire_lock(
114
114
acquire_timeout : Optional [float ] = None ,
115
115
hold_timeout : Optional [float ] = None ,
116
116
) -> bool :
117
- self ._ensure_clients ()
117
+ """
118
+ Acquires a lock synchronously.
119
+
120
+ Args:
121
+ key: Unique identifier for the transaction record.
122
+ holder: Unique identifier for the holder of the lock.
123
+ acquire_timeout: Maximum time to wait for the lock to be acquired.
124
+ hold_timeout: Maximum time to hold the lock.
125
+
126
+ Returns:
127
+ True if the lock was acquired, False otherwise.
128
+ """
118
129
lock_name = self ._lock_name_for_key (key )
119
130
lock = self ._locks .get (lock_name )
120
131
121
- if lock is not None and self .is_lock_holder (
122
- key , holder
123
- ): # is_lock_holder will also call _ensure_clients
132
+ if lock is not None and self .is_lock_holder (key , holder ):
124
133
return True
125
134
else :
126
- # If lock is None, or not held by current holder, create/acquire new one.
127
135
lock = Lock (
128
136
self .client , lock_name , timeout = hold_timeout , thread_local = False
129
137
)
@@ -139,7 +147,19 @@ async def aacquire_lock(
139
147
acquire_timeout : Optional [float ] = None ,
140
148
hold_timeout : Optional [float ] = None ,
141
149
) -> bool :
142
- self ._ensure_clients ()
150
+ """
151
+ Acquires a lock asynchronously.
152
+
153
+ Args:
154
+ key: Unique identifier for the transaction record.
155
+ holder: Unique identifier for the holder of the lock. Must match the
156
+ holder provided when acquiring the lock.
157
+ acquire_timeout: Maximum time to wait for the lock to be acquired.
158
+ hold_timeout: Maximum time to hold the lock.
159
+
160
+ Returns:
161
+ True if the lock was acquired, False otherwise.
162
+ """
143
163
lock_name = self ._lock_name_for_key (key )
144
164
lock = self ._locks .get (lock_name )
145
165
@@ -149,8 +169,10 @@ async def aacquire_lock(
149
169
else :
150
170
lock = None
151
171
172
+ # Handles the case where a lock might have been released during a task retry
173
+ # If the lock doesn't exist in Redis at all, this method will succeed even if
174
+ # the holder ID doesn't match the original holder.
152
175
if lock is None :
153
- assert self .async_client is not None , "Async client should be initialized"
154
176
new_lock = AsyncLock (
155
177
self .async_client , lock_name , timeout = hold_timeout , thread_local = False
156
178
)
@@ -164,7 +186,21 @@ async def aacquire_lock(
164
186
return False
165
187
166
188
def release_lock (self , key : str , holder : str ) -> None :
167
- self ._ensure_clients ()
189
+ """
190
+ Releases the lock on the corresponding transaction record.
191
+
192
+ Handles the case where a lock might have been released during a task retry
193
+ If the lock doesn't exist in Redis at all, this method will succeed even if
194
+ the holder ID doesn't match the original holder.
195
+
196
+ Args:
197
+ key: Unique identifier for the transaction record.
198
+ holder: Unique identifier for the holder of the lock. Must match the
199
+ holder provided when acquiring the lock.
200
+
201
+ Raises:
202
+ ValueError: If the lock is held by a different holder.
203
+ """
168
204
lock_name = self ._lock_name_for_key (key )
169
205
lock = self ._locks .get (lock_name )
170
206
@@ -173,67 +209,37 @@ def release_lock(self, key: str, holder: str) -> None:
173
209
del self ._locks [lock_name ]
174
210
return
175
211
176
- if not self .is_locked (key ): # is_locked calls _ensure_clients
212
+ # If the lock doesn't exist in Redis at all, it's already been released
213
+ if not self .is_locked (key ):
177
214
if lock_name in self ._locks :
178
215
del self ._locks [lock_name ]
179
216
return
180
217
218
+ # We have a real conflict - lock exists in Redis but with a different holder
181
219
raise ValueError (f"No lock held by { holder } for transaction with key { key } " )
182
220
183
- async def arelease_lock (self , key : str , holder : str ) -> None : # Added async version
184
- self ._ensure_clients ()
185
- lock_name = self ._lock_name_for_key (key )
186
- lock = self ._locks .get (lock_name )
187
-
188
- if lock is not None and isinstance (
189
- lock , AsyncLock
190
- ): # Still need to check if it *is* an AsyncLock to call await .owned()
191
- if await lock .owned () and lock .local .token == holder .encode ():
192
- await lock .release ()
193
- del self ._locks [lock_name ]
194
- return
195
-
196
- # Check if the lock key exists on the server at all.
197
- if not AsyncLock (self .async_client , lock_name ).locked ():
198
- # If the lock doesn't exist on the server, it's already effectively released.
199
- # Clean up from self._locks if it was there but holder didn't match.
200
- if lock_name in self ._locks :
201
- del self ._locks [lock_name ]
202
- return
203
-
204
- raise ValueError (
205
- f"No lock held by { holder } for transaction with key { key } (async)"
206
- )
207
-
208
221
def wait_for_lock (self , key : str , timeout : Optional [float ] = None ) -> bool :
209
- self ._ensure_clients ()
210
222
lock_name = self ._lock_name_for_key (key )
211
- lock = Lock (self .client , lock_name ) # Create a temporary lock for waiting
223
+ lock = Lock (self .client , lock_name )
212
224
lock_freed = lock .acquire (blocking_timeout = timeout )
213
225
if lock_freed :
214
226
lock .release ()
215
227
return lock_freed
216
228
217
229
async def await_for_lock (self , key : str , timeout : Optional [float ] = None ) -> bool :
218
- self ._ensure_clients ()
219
230
lock_name = self ._lock_name_for_key (key )
220
- assert self .async_client is not None , "Async client should be initialized"
221
- lock = AsyncLock (
222
- self .async_client , lock_name
223
- ) # Create a temporary lock for waiting
231
+ lock = AsyncLock (self .async_client , lock_name )
224
232
lock_freed = await lock .acquire (blocking_timeout = timeout )
225
233
if lock_freed :
226
234
await lock .release ()
227
235
return lock_freed
228
236
229
237
def is_locked (self , key : str ) -> bool :
230
- self ._ensure_clients ()
231
238
lock_name = self ._lock_name_for_key (key )
232
- lock = Lock (self .client , lock_name ) # Create a temporary lock for checking
239
+ lock = Lock (self .client , lock_name )
233
240
return lock .locked ()
234
241
235
242
def is_lock_holder (self , key : str , holder : str ) -> bool :
236
- self ._ensure_clients () # Ensures self.client is available if needed by _locks access logic
237
243
lock_name = self ._lock_name_for_key (key )
238
244
lock = self ._locks .get (lock_name )
239
245
if lock is None :
0 commit comments