diff --git a/findmy/accessory.py b/findmy/accessory.py index f859568..78b4b3c 100644 --- a/findmy/accessory.py +++ b/findmy/accessory.py @@ -377,6 +377,11 @@ def from_json( class _AccessoryKeyGenerator(KeyGenerator[KeyPair]): """KeyPair generator. Uses the same algorithm internally as FindMy accessories do.""" + # cache enough keys for an entire week. + # every interval'th key is cached. + _CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour + _CACHE_INTERVAL = 10 + def __init__( self, master_key: bytes, @@ -401,8 +406,7 @@ def __init__( self._initial_sk = initial_sk self._key_type = key_type - self._cur_sk = initial_sk - self._cur_sk_ind = 0 + self._sk_cache: dict[int, bytes] = {} self._iter_ind = 0 @@ -426,14 +430,33 @@ def _get_sk(self, ind: int) -> bytes: msg = "The key index must be non-negative" raise ValueError(msg) - if ind < self._cur_sk_ind: # behind us; need to reset :( - self._cur_sk = self._initial_sk - self._cur_sk_ind = 0 + # retrieve from cache + cached_sk = self._sk_cache.get(ind) + if cached_sk is not None: + return cached_sk + + # not in cache: find largest cached index smaller than ind (if exists) + start_ind: int = 0 + cur_sk: bytes = self._initial_sk + for cached_ind in self._sk_cache: + if cached_ind < ind and cached_ind > start_ind: + start_ind = cached_ind + cur_sk = self._sk_cache[cached_ind] + + # compute and update cache + for cur_ind in range(start_ind, ind): + cur_sk = crypto.x963_kdf(cur_sk, b"update", 32) + + # insert intermediate result into cache and evict oldest entry if necessary + if cur_ind % self._CACHE_INTERVAL == 0: + self._sk_cache[cur_ind] = cur_sk + + if len(self._sk_cache) > self._CACHE_SIZE: + # evict oldest entry + oldest_ind = min(self._sk_cache.keys()) + del self._sk_cache[oldest_ind] - for _ in range(self._cur_sk_ind, ind): - self._cur_sk = crypto.x963_kdf(self._cur_sk, b"update", 32) - self._cur_sk_ind += 1 - return self._cur_sk + return cur_sk def _get_keypair(self, ind: int) -> KeyPair: sk = self._get_sk(ind) @@ -449,14 +472,14 @@ def _generate_keys(self, start: int, stop: int | None) -> Generator[KeyPair, Non @override def __iter__(self) -> KeyGenerator: - self._iter_ind = -1 return self @override def __next__(self) -> KeyPair: + key = self._get_keypair(self._iter_ind) self._iter_ind += 1 - return self._get_keypair(self._iter_ind) + return key @overload def __getitem__(self, val: int) -> KeyPair: ...