@@ -263,6 +263,7 @@ def __init__(
263263 if self .local :
264264 self ._keys = self ._do_local (kid )
265265
266+
266267 def _set_source (self , source , fileformat ):
267268 if source .startswith ("file://" ):
268269 self .source = source [7 :]
@@ -284,10 +285,10 @@ def _set_source(self, source, fileformat):
284285
285286 def _do_local (self , kid ):
286287 if self .fileformat in ["jwks" , "jwk" ]:
287- updated , res = self ._do_local_jwk (self .source )
288+ updated , keys = self ._do_local_jwk (self .source )
288289 elif self .fileformat == "der" :
289- updated , res = self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
290- return res
290+ updated , keys = self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
291+ return keys
291292
292293 def _local_update_required (self ) -> bool :
293294 stat = os .stat (self .source )
@@ -311,14 +312,9 @@ def add_jwk_dicts(self, keys):
311312 :param keys: List of JWK dictionaries
312313 :return:
313314 """
314- self ._add_jwk_dicts ( keys )
315+ self ._keys . extend ( self . jwk_dicts_as_keys ( keys ) )
315316 self .last_updated = time .time ()
316317
317- def _add_jwk_dicts (self , keys ):
318- _new_keys = self .jwk_dicts_as_keys (keys )
319- if _new_keys :
320- self ._keys .extend (_new_keys )
321-
322318 def jwk_dicts_as_keys (self , keys ):
323319 """
324320 Return JWK dictionaries as list of JWK objects
@@ -392,13 +388,13 @@ def _do_local_jwk(self, filename):
392388 with open (filename ) as input_file :
393389 _info = json .load (input_file )
394390 if "keys" in _info :
395- res = self .jwk_dicts_as_keys (_info ["keys" ])
391+ new_keys = self .jwk_dicts_as_keys (_info ["keys" ])
396392 else :
397- res = self .jwk_dicts_as_keys ([_info ])
393+ new_keys = self .jwk_dicts_as_keys ([_info ])
398394
399395 self .last_local = time .time ()
400396 self .time_out = self .last_local + self .cache_time
401- return True , res
397+ return True , new_keys
402398
403399 def _do_local_der (self , filename , keytype , keyusage = None , kid = "" ):
404400 """
@@ -431,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
431427 if kid :
432428 key_args ["kid" ] = kid
433429
434- res = self .jwk_dicts_as_keys ([key_args ])
430+ new_keys = self .jwk_dicts_as_keys ([key_args ])
435431 self .last_local = time .time ()
436432 self .time_out = self .last_local + self .cache_time
437- return True , res
433+ return True , new_keys
438434
439- def _do_remote (self ):
435+ def _do_remote (self , set_keys = True ):
440436 """
441437 Load a JWKS from a webpage.
442438
@@ -451,7 +447,7 @@ def _do_remote(self):
451447 self .source ,
452448 datetime .fromtimestamp (self .ignore_errors_until ),
453449 )
454- return False
450+ return False , None
455451
456452 LOGGER .info ("Reading remote JWKS from %s" , self .source )
457453 try :
@@ -500,11 +496,12 @@ def _do_remote(self):
500496 self .ignore_errors_until = time .time () + self .ignore_errors_period
501497 raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
502498
503- if new_keys is not None :
499+ if set_keys and new_keys :
504500 self ._keys = new_keys
501+
505502 self .last_updated = time .time ()
506503 self .ignore_errors_until = None
507- return load_successful
504+ return load_successful , new_keys
508505
509506 def _parse_remote_response (self , response ):
510507 """
@@ -545,38 +542,31 @@ def update(self):
545542 :return: True if update was ok or False if we encountered an error during update.
546543 """
547544 if self .source :
548- _old_keys = self ._keys # just in case
549-
550- # reread everything
551- self ._keys = []
545+ new_keys = []
552546 updated = None
553547
554548 try :
555549 if self .local :
556550 if self .fileformat in ["jwks" , "jwk" ]:
557551 updated , k = self ._do_local_jwk (self .source )
558- if k :
559- self ._keys .extend (k )
560552 elif self .fileformat == "der" :
561553 updated , k = self ._do_local_der (self .source , self .keytype , self .keyusage )
562- if k :
563- self ._keys .extend (k )
564554 elif self .remote :
565- updated = self ._do_remote ()
555+ updated , k = self ._do_remote (set_keys = False )
556+ if k :
557+ new_keys .extend (k )
566558 except Exception as err :
567559 LOGGER .error ("Key bundle update failed: %s" , err )
568- self ._keys = _old_keys # restore
569560 return False
570561
571562 if updated :
572563 now = time .time ()
573- for _key in _old_keys :
574- if _key not in self . _keys :
564+ for _key in self . _keys :
565+ if _key not in new_keys :
575566 if not _key .inactive_since : # If already marked don't mess
576567 _key .inactive_since = now
577- self ._keys .append (_key )
578- else :
579- self ._keys = _old_keys
568+ new_keys .append (_key )
569+ self ._keys = new_keys
580570
581571 return True
582572
@@ -592,9 +582,9 @@ def get(self, typ="", only_active=True):
592582
593583 if typ :
594584 _typs = [typ .lower (), typ .upper ()]
595- _keys = [k for k in self ._keys [:] if k .kty in _typs ]
585+ _keys = [k for k in self ._keys if k .kty in _typs ]
596586 else :
597- _keys = self ._keys [:]
587+ _keys = self ._keys
598588
599589 if only_active :
600590 return [k for k in _keys if not k .inactive_since ]
@@ -609,7 +599,7 @@ def keys(self, update: bool = True):
609599 """
610600 if update :
611601 self ._uptodate ()
612- return self ._keys [:]
602+ return self ._keys
613603
614604 def active_keys (self ):
615605 """Return the set of active keys."""
@@ -836,9 +826,11 @@ def load(self, spec):
836826 :param spec: Dictionary with attributes and value to populate the instance with
837827 :return: The instance itself
838828 """
829+
839830 _keys = spec .get ("keys" , [])
840831 if _keys :
841- self ._add_jwk_dicts (_keys )
832+ self ._keys .extend (self .jwk_dicts_as_keys (_keys ))
833+ self .last_updated = time .time ()
842834
843835 for attr , default in self .params .items ():
844836 val = spec .get (attr )
0 commit comments