diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..549ab43b8 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -196,6 +196,8 @@ def remove(self, model, tokens): self._lru_checkpoints.remove((model, tokens)) def pop(self): + if not self._lru and not self._lru_checkpoints: + raise IndexError("pop from empty CacheOrder") if len(self._lru) >= len(self._lru_checkpoints): return self._lru.popleft() else: @@ -344,9 +346,13 @@ def trim_to( while len(self._lru) > n_sequences: model, tokens = self._lru.pop() self._delete(model, tokens) - while self._n_bytes > n_bytes: + while self._n_bytes > n_bytes and len(self._lru) > 0: model, tokens = self._lru.pop() self._delete(model, tokens) + if self._n_bytes > n_bytes: + raise RuntimeError( + "LRUPromptCache byte accounting drifted out of sync with cache order" + ) def log_cache_stats(self): ncaches, nbytes = len(self), self.nbytes diff --git a/tests/test_server.py b/tests/test_server.py index c5a815e4f..4474a6f39 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -560,6 +560,30 @@ def test_lru_bytes(self): self.assertEqual(c, None) self.assertEqual(t, [3, 4]) + def test_trim_to_zero_bytes_on_empty_cache(self): + cache = LRUPromptCache(max_size=10) + # Should not raise IndexError on empty cache + cache.trim_to(n_bytes=0) + self.assertEqual(len(cache), 0) + + def test_trim_to_raises_on_inconsistent_byte_accounting(self): + cache = LRUPromptCache(max_size=10) + cache._n_bytes = 1 + + with self.assertRaisesRegex(RuntimeError, "byte accounting"): + cache.trim_to(n_bytes=0) + + def test_trim_to_zero_bytes_evicts_all(self): + cache = LRUPromptCache(max_size=10) + model = ("test", None, None) + cache.insert_cache(model, [1, 2], [MockCache("aaa")]) + cache.insert_cache(model, [3, 4], [MockCache("bbb")]) + self.assertEqual(len(cache), 2) + + cache.trim_to(n_bytes=0) + self.assertEqual(len(cache), 0) + self.assertEqual(cache.nbytes, 0) + if __name__ == "__main__": unittest.main()