11"""Implementation of a Key Bundle."""
2+
23import copy
34import json
45import logging
56import os
7+ import threading
68import time
79from datetime import datetime
810from functools import cmp_to_key
911from typing import List
1012from typing import Optional
1113
1214import requests
13- from readerwriterlock import rwlock
1415
1516from cryptojwt .jwk .ec import NIST2SEC
1617from cryptojwt .jwk .hmac import new_sym_key
@@ -152,14 +153,6 @@ def ec_init(spec):
152153 return _kb
153154
154155
155- def keys_reader (func ):
156- def wrapper (self , * args , ** kwargs ):
157- with self ._lock_reader :
158- return func (self , * args , ** kwargs )
159-
160- return wrapper
161-
162-
163156def keys_writer (func ):
164157 def wrapper (self , * args , ** kwargs ):
165158 with self ._lock_writer :
@@ -245,9 +238,7 @@ def __init__(
245238 self .source = None
246239 self .time_out = 0
247240
248- self ._lock = rwlock .RWLockFairD ()
249- self ._lock_reader = self ._lock .gen_rlock ()
250- self ._lock_writer = self ._lock .gen_wlock ()
241+ self ._lock_writer = threading .Lock ()
251242
252243 if httpc :
253244 self .httpc = httpc
@@ -260,11 +251,11 @@ def __init__(
260251 self .source = None
261252 if isinstance (keys , dict ):
262253 if "keys" in keys :
263- self ._do_keys (keys ["keys" ])
254+ self ._add_jwk_dicts (keys ["keys" ])
264255 else :
265- self ._do_keys ([keys ])
256+ self ._add_jwk_dicts ([keys ])
266257 else :
267- self ._do_keys (keys )
258+ self ._add_jwk_dicts (keys )
268259 else :
269260 self ._set_source (source , fileformat )
270261 if self .local :
@@ -305,18 +296,34 @@ def _local_update_required(self) -> bool:
305296 self .last_local = stat .st_mtime
306297 return True
307298
308- @keys_writer
309299 def do_keys (self , keys ):
310- return self ._do_keys (keys )
300+ """Compatibility function for add_jwk_dicts()"""
301+ self .add_jwk_dicts (keys )
311302
312- def _do_keys (self , keys ):
303+ @keys_writer
304+ def add_jwk_dicts (self , keys ):
313305 """
314- Go from JWK description to binary keys
306+ Add JWK dictionaries
315307
316- :param keys:
308+ :param keys: List of JWK dictionaries
317309 :return:
318310 """
319- _new_key = []
311+ self ._add_jwk_dicts (keys )
312+
313+ def _add_jwk_dicts (self , keys ):
314+ _new_keys = self .jwk_dicts_as_keys (keys )
315+ if _new_keys :
316+ self ._keys .extend (_new_keys )
317+ self .last_updated = time .time ()
318+
319+ def jwk_dicts_as_keys (self , keys ):
320+ """
321+ Return JWK dictionaries as list of JWK objects
322+
323+ :param keys: List of JWK dictionaries
324+ :return: List of JWK objects
325+ """
326+ _new_keys = []
320327
321328 for inst in keys :
322329 if inst ["kty" ].lower () in K2C :
@@ -360,16 +367,13 @@ def _do_keys(self, keys):
360367 if _key not in self ._keys :
361368 if not _key .kid :
362369 _key .add_kid ()
363- _new_key .append (_key )
370+ _new_keys .append (_key )
364371 _error = ""
365372
366373 if _error :
367374 LOGGER .warning ("While loading keys, %s" , _error )
368375
369- if _new_key :
370- self ._keys .extend (_new_key )
371-
372- self .last_updated = time .time ()
376+ return _new_keys
373377
374378 def _do_local_jwk (self , filename ):
375379 """
@@ -385,9 +389,9 @@ def _do_local_jwk(self, filename):
385389 with open (filename ) as input_file :
386390 _info = json .load (input_file )
387391 if "keys" in _info :
388- self ._do_keys (_info ["keys" ])
392+ self ._add_jwk_dicts (_info ["keys" ])
389393 else :
390- self ._do_keys ([_info ])
394+ self ._add_jwk_dicts ([_info ])
391395 self .last_local = time .time ()
392396 self .time_out = self .last_local + self .cache_time
393397 return True
@@ -423,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
423427 if kid :
424428 key_args ["kid" ] = kid
425429
426- self ._do_keys ([key_args ])
430+ self ._add_jwk_dicts ([key_args ])
427431 self .last_local = time .time ()
428432 self .time_out = self .last_local + self .cache_time
429433 return True
430434
431- def do_remote (self ):
435+ def _do_remote (self ):
432436 """
433437 Load a JWKS from a webpage.
434438
@@ -458,6 +462,7 @@ def do_remote(self):
458462 LOGGER .error (err )
459463 raise UpdateFailed (REMOTE_FAILED .format (self .source , str (err )))
460464
465+ new_keys = None
461466 load_successful = _http_resp .status_code == 200
462467 not_modified = _http_resp .status_code == 304
463468
@@ -470,7 +475,7 @@ def do_remote(self):
470475
471476 LOGGER .debug ("Loaded JWKS: %s from %s" , _http_resp .text , self .source )
472477 try :
473- self ._do_keys (self .imp_jwks ["keys" ])
478+ new_keys = self .jwk_dicts_as_keys (self .imp_jwks ["keys" ])
474479 except KeyError :
475480 LOGGER .error ("No 'keys' keyword in JWKS" )
476481 self .ignore_errors_until = time .time () + self .ignore_errors_period
@@ -491,6 +496,8 @@ def do_remote(self):
491496 self .ignore_errors_until = time .time () + self .ignore_errors_period
492497 raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
493498
499+ if new_keys is not None :
500+ self ._keys = new_keys
494501 self .last_updated = time .time ()
495502 self .ignore_errors_until = None
496503 return load_successful
@@ -547,7 +554,7 @@ def update(self):
547554 elif self .fileformat == "der" :
548555 updated = self ._do_local_der (self .source , self .keytype , self .keyusage )
549556 elif self .remote :
550- updated = self .do_remote ()
557+ updated = self ._do_remote ()
551558 except Exception as err :
552559 LOGGER .error ("Key bundle update failed: %s" , err )
553560 self ._keys = _old_keys # restore
@@ -575,12 +582,11 @@ def get(self, typ="", only_active=True):
575582 """
576583 self ._uptodate ()
577584
578- with self ._lock_reader :
579- if typ :
580- _typs = [typ .lower (), typ .upper ()]
581- _keys = [k for k in self ._keys if k .kty in _typs ]
582- else :
583- _keys = copy .copy (self ._keys )
585+ if typ :
586+ _typs = [typ .lower (), typ .upper ()]
587+ _keys = [k for k in self ._keys [:] if k .kty in _typs ]
588+ else :
589+ _keys = self ._keys [:]
584590
585591 if only_active :
586592 return [k for k in _keys if not k .inactive_since ]
@@ -595,8 +601,7 @@ def keys(self, update: bool = True):
595601 """
596602 if update :
597603 self ._uptodate ()
598- with self ._lock_reader :
599- return copy .copy (self ._keys )
604+ return self ._keys [:]
600605
601606 def active_keys (self ):
602607 """Return the set of active keys."""
@@ -668,7 +673,6 @@ def remove(self, key):
668673 except ValueError :
669674 pass
670675
671- @keys_reader
672676 def __len__ (self ):
673677 """
674678 The number of keys.
@@ -690,18 +694,12 @@ def get_key_with_kid(self, kid):
690694 :return: The key or None
691695 """
692696 self ._uptodate ()
693- with self ._lock_reader :
694- return self ._get_key_with_kid (kid )
697+ return self ._get_key_with_kid (kid )
695698
696699 def _get_key_with_kid (self , kid ):
697700 for key in self ._keys :
698701 if key .kid == kid :
699702 return key
700-
701- for key in self ._keys :
702- if key .kid == kid :
703- return key
704-
705703 return None
706704
707705 def kids (self ):
@@ -723,9 +721,7 @@ def mark_as_inactive(self, kid):
723721 """
724722 k = self ._get_key_with_kid (kid )
725723 if k :
726- self ._keys .remove (k )
727724 k .inactive_since = time .time ()
728- self ._keys .append (k )
729725 return True
730726 else :
731727 return False
@@ -753,30 +749,18 @@ def remove_outdated(self, after, when=0):
753749 before it should be removed.
754750 :param when: To make it easier to test
755751 """
756- if when :
757- now = when
758- else :
759- now = time .time ()
752+ now = when or time .time ()
760753
761754 if not isinstance (after , float ):
762755 after = float (after )
763756
764- _kl = []
765- changed = False
766- for k in self ._keys :
767- if k .inactive_since and k .inactive_since + after < now :
768- changed = True
769- continue
770-
771- _kl .append (k )
772-
773- self ._keys = _kl
774- return changed
757+ self ._keys = [
758+ k for k in self ._keys if not k .inactive_since or k .inactive_since + after > now
759+ ]
775760
776761 def __contains__ (self , key ):
777762 return key in self .keys ()
778763
779- @keys_reader
780764 def copy (self ):
781765 """
782766 Make deep copy of this KeyBundle
@@ -846,7 +830,7 @@ def load(self, spec):
846830 """
847831 _keys = spec .get ("keys" , [])
848832 if _keys :
849- self ._do_keys (_keys )
833+ self ._add_jwk_dicts (_keys )
850834
851835 for attr , default in self .params .items ():
852836 val = spec .get (attr )
0 commit comments