From 6809f05b28e7f29f87752549a78b40de24af6545 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 13:41:05 -0400 Subject: [PATCH 1/9] coef --- mlx_engine/cache.py | 62 +++++++++++++++++++++++++++ mlx_engine/cache_wrapper.py | 35 ++++++++++----- mlx_engine/generate.py | 2 + mlx_engine/model_kit/model_kit.py | 2 + mlx_engine/utils/prompt_processing.py | 2 + 5 files changed, 93 insertions(+), 10 deletions(-) create mode 100644 mlx_engine/cache.py diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py new file mode 100644 index 00000000..96f89fa2 --- /dev/null +++ b/mlx_engine/cache.py @@ -0,0 +1,62 @@ +from typing import List, Optional, Any + +from mlx_lm.models.cache import RotatingKVCache, KVCache +import mlx.core as mx +import mlx.nn as nn + + +class ShiftingKVCache(RotatingKVCache): + def trim(self, n) -> int: + # trim must not respect keep + n = min(self.offset, n) + if n <= 0: + return 0 + + # put us back into the state before the circular buffer is full + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + new_length = max(self.keys.shape[2] - n, 0) + self.keys = self.keys[..., :new_length, :] + self.values = self.values[..., :new_length, :] + + self.offset = new_length + self._idx = new_length + return n + + def set_keep(self, keep): + # kv must be in temporal order, else we will keep the wrong thing + if self.keys is not None: + self.keys = self._temporal_order(self.keys) + if self.values is not None: + self.values = self._temporal_order(self.values) + self.keep = keep + + def is_trimmable(self) -> bool: + return True + + +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, + keep: int = 4, +) -> List[Any]: + """ + Construct the model's cache for use in generation. + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``ShiftingKVCache`` is used with a maximum + size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + num_layers = len(model.layers) + if max_kv_size is not None: + return [ + ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers) + ] + else: + return [KVCache() for _ in range(num_layers)] \ No newline at end of file diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 36498fab..66f07f9a 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -1,11 +1,8 @@ from typing import List, Optional, Any from mlx_engine.logging import log_info, log_warn, log_error -from mlx_lm.models.cache import ( - make_prompt_cache, - trim_prompt_cache, - can_trim_prompt_cache, -) +from mlx_engine.cache import make_prompt_cache +from mlx_lm.models.cache import trim_prompt_cache, can_trim_prompt_cache from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache import mlx.core as mx import mlx.nn as nn @@ -26,6 +23,7 @@ def __init__( kv_bits: Optional[int] = None, kv_group_size: Optional[int] = None, quantized_kv_start: Optional[int] = None, + keep: int = 4, ): """ Initialize the CacheWrapper. @@ -36,7 +34,8 @@ def __init__( """ # utilize a simple ordered list of tokens processed so far for cache invalidation checking self.tokens: Optional[mx.array] = None - self.cache: List[Any] = make_prompt_cache(model, max_kv_size) + self.keep = keep + self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep) self.model = model self.draft_model: Optional[nn.Module] = None self.max_kv_size = max_kv_size @@ -115,7 +114,7 @@ def _get_unprocessed_tokens( message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: " "Cache is not trimmable. Clearing the cache instead.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) self.tokens = prompt_tokens return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) @@ -126,7 +125,7 @@ def _get_unprocessed_tokens( message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected " " ({num_tokens_to_trim}). Clearing the cache.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) self.tokens = prompt_tokens return self.tokens log_info( @@ -221,9 +220,9 @@ def set_draft_model(self, draft_model: nn.Module): message="Clearing current prompt cache and adding draft model to the cache", ) self.tokens = None - self.cache: List[Any] = make_prompt_cache(self.model) + self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep) if draft_model is not None: - self.cache += make_prompt_cache(draft_model) + self.cache += make_prompt_cache(draft_model, keep=self.keep) self.draft_model = draft_model def unset_draft_model(self): @@ -239,6 +238,7 @@ def update_cache( prompt_progress_callback, *, num_tokens_to_exclude: int = 1, + keep: int = 4, ) -> mx.array: """ Set up the KV cache for the next generation. @@ -248,6 +248,7 @@ def update_cache( prompt_tokens (mx.array): The prompt tokens. prompt_progress_callback (Callable): A callback function to report prompt processing progress. num_tokens_to_exclude (int): The number of tokens that should not be added to the cache. + keep (int): The number of tokens to always keep in the prefix of the prompt cache. Returns: mx.array: The prompt tokens to be used for the next generation. @@ -256,6 +257,12 @@ def update_cache( def prompt_progress_callback(x): return None + + # update keep tracking + self.keep = keep + for cache in self.cache: + if hasattr(cache, "set_keep"): + cache.set_keep(keep) num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( @@ -296,5 +303,13 @@ def prompt_progress_callback(x): def record_generated_token(self, token): """ Add the generated token to the token list, so that we can map the token to the KV cache. + + Also loop when the cache does so that we accurately track what's in cache. """ + # this behavior is common to rolling window (n_keep = 0) and truncate middle + # (n_keep > 0), and we should never get here with stop at max + if len(self.tokens) >= self.max_kv_size: + self.tokens = mx.concat( + [self.tokens[: self.keep], self.tokens[self.keep + 1 :]] + ) self.tokens = mx.concat([self.tokens, mx.array([token])]) diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index 6019cf16..1a80eec1 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -137,6 +137,7 @@ def create_generator( max_tokens: Optional[int] = 10000000, speculative_decoding_toggle: Optional[bool] = None, num_draft_tokens: Optional[int] = None, + keep: int = 4, ) -> Iterator[GenerationResult]: """ Create a generator that streams text generation results from the model. @@ -218,6 +219,7 @@ def create_generator( prompt_progress_callback, generate_args, speculative_decoding_toggle, + keep=keep, ) if draft_model is None: # input embeddings not yet supported for speculative decoding in mlx-lm diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index a389dc6d..1e364725 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -141,6 +141,7 @@ def process_prompt( prompt_progress_callback, generate_args, speculative_decoding_toggle: Optional[bool] = None, + keep: int = 4, ) -> Tuple[mx.array, Optional[mx.array]]: ### TEXT-ONLY PROCESS_PROMPT ### is_text_only_processing = images_b64 is None or len(images_b64) == 0 @@ -160,6 +161,7 @@ def process_prompt( self.draft_model, speculative_decoding_toggle, prompt_progress_callback, + keep=keep ), None ### WITH IMAGES PROMPT PROCESSING ###s if self.vision_add_on is None: diff --git a/mlx_engine/utils/prompt_processing.py b/mlx_engine/utils/prompt_processing.py index 78a687d1..380cf243 100644 --- a/mlx_engine/utils/prompt_processing.py +++ b/mlx_engine/utils/prompt_processing.py @@ -13,6 +13,7 @@ def process_prompt_text_only( draft_model: Optional[nn.Module] = None, speculative_decoding_toggle: Optional[bool] = None, prompt_progress_callback: Optional[Callable[[float], None]] = None, + keep: int = 4, ): if cache_wrapper is None: raise ValueError("Cache wrapper is not initialized, cannot process prompt") @@ -38,6 +39,7 @@ def process_prompt_text_only( prompt_tokens = cache_wrapper.update_cache( prompt_tokens, prompt_progress_callback, + keep=keep, ) generate_args["prompt_cache"] = cache_wrapper.cache return prompt_tokens From 5e7834d495c45c947b99d88298b7c71b81247e7b Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 13:42:32 -0400 Subject: [PATCH 2/9] tests --- tests/test_cache_generic.py | 36 ++++++++ tests/test_cache_shift.py | 180 ++++++++++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 tests/test_cache_generic.py create mode 100644 tests/test_cache_shift.py diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py new file mode 100644 index 00000000..e836d7ec --- /dev/null +++ b/tests/test_cache_generic.py @@ -0,0 +1,36 @@ +import unittest +import mlx.core as mx +from copy import deepcopy +from mlx_engine.cache import ShiftingKVCache + + +class TestCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + cls.kv_head_dim = 4 + cls.bsz = 1 + cls.n_kv_heads = 1 + + @classmethod + def make_random_kv(cls, seqlen: int): + """Helper method to make a random key/value tensor of the right shape""" + return mx.random.normal( + (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), + scale=1.0, + dtype=mx.float32, + ) + + def assertArrEqual(self, a: mx.array, b: mx.array): + """Assert that two tensors are equal over the sequence length dimension""" + self.assertEqual(a.shape, b.shape) + self.assertTrue(mx.allclose(a, b), "Tensors are not equal") + + def add_random_to_cache(self, cache: ShiftingKVCache, seqlen: int) -> mx.array: + """Add random values to the cache and return them""" + base_kv = self.make_random_kv(seqlen) + # base_kv is *assigned* to cache.keys/cache.values so returning base_kv + # would return a reference to cache.keys, which is pointless. so copy it + reference = deepcopy(base_kv) + cache.update_and_fetch(base_kv, base_kv) + return reference \ No newline at end of file diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py new file mode 100644 index 00000000..6dc7ce06 --- /dev/null +++ b/tests/test_cache_shift.py @@ -0,0 +1,180 @@ +import unittest +import mlx.core as mx +from mlx_engine.cache import ShiftingKVCache +from tests.test_cache_generic import TestCache + + +def idx(v: mx.array, i: int): + """Helper function to index into a 4D tensor at the sequence length dimension""" + return v[:, :, i : i + 1, :] + + +class TestShiftingKVCache(TestCache): + def test_overwriting(self): + """Test overwriting when the cache reaches max_size""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 123 + reference = self.add_random_to_cache(cache, 3) + self.assertEqual(cache.offset, 3) + + # attempt to write another element 4 -> 143 + overwrite = self.add_random_to_cache(cache, 1) + # access k/v as cache.state[0]/[1] due to possibly empty buffer + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), idx(reference, 2)) + self.assertEqual(cache.offset, 4) + + def test_ensure_update_increases_offset_indefinitely(self): + """Test single-token updates that should increase offset""" + cache = ShiftingKVCache(max_size=3, keep=1) + + for i in range(10): + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset - 1, i) + + def test_ensure_reasonable_size_and_shift(self): + """Test behavior when the cache gets a KV batch-written that is much larger + than max_size. The default behavior of the cache is to write the entire thing, + then trim it back down when the next KV is written. + """ + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 0123456789 + reference = self.add_random_to_cache(cache, 10) + keys = cache.state[0] + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) + self.assertEqual(cache.offset, 10) + + # trigger trim -> 0X9 -> (rope) 021 + overwrite = self.add_random_to_cache(cache, 1) + keys = cache.state[0] + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) + self.assertEqual(cache.offset, 11) + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), idx(reference, 9)) + + # make sure pos embs are right + cache.keys = cache._temporal_order(cache.keys) + cache.values = cache._temporal_order(cache.values) + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), idx(reference, 9)) + self.assertArrEqual(idx(keys, 2), overwrite) + self.assertEqual(cache.offset, 11) + + # ensure offset keeps increasing + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 12) + + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 13) + + def test_update_keep_on_the_fly(self): + """Test changing the keep value on the fly""" + cache = ShiftingKVCache(max_size=4, keep=1) + + # fill cache -> 1234 + reference = self.add_random_to_cache(cache, 4) + + # attempt to write another element 5 -> 1534 + overwrite = self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 5) + + # update keep -> 1345 -> 1234 implicitly + # and attempt to write another element 5 -> 1254 + # offset updates after set_keep (anytime we reorder/rope shift) + cache.set_keep(2) + self.assertEqual(cache.offset, 5) + overwrite2 = self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 6) + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), idx(reference, 2)) + self.assertArrEqual(idx(keys, 2), overwrite2) + self.assertArrEqual(idx(keys, 3), overwrite) + + def test_trim_before_full(self): + """Test trimming from the end before the cache is full""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 12 + reference = self.add_random_to_cache(cache, 2) + + # trim 1 from end -> 1 + cache.trim(1) + keys = cache.state[0] + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertEqual(cache.offset, 1) + + # ensure adding another value works fine + new_kv = self.add_random_to_cache(cache, 1) + keys = cache.state[0] + self.assertEqual(cache.offset, 2) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), new_kv) + self.assertEqual(cache.offset, 2) + + def test_trim_after_overwrite(self): + """Test trimming from the end when we've written past the cache max""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 123 + reference = self.add_random_to_cache(cache, 3) + self.assertEqual(cache.offset, 3) + + # overwrite so offset goes over max_size -> 143 + base_kv = self.make_random_kv(1) + cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 4) + + # trim 1 from end -> 13 -> 12 (rope), ideally + cache.trim(1) + keys = cache.state[0] + + should_be_kv = mx.concatenate( + [reference[:, :, :1, :], reference[:, :, 2:3, :]], axis=2 + ) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, should_be_kv) + self.assertEqual(cache.offset, 2) + + def test_trim_after_full(self): + """Test trimming from the end when the cache is oversize""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache oversize already -> 1234 + reference = self.add_random_to_cache(cache, 4) + self.assertEqual(cache.offset, 4) + + # trim 2 from end -> 12 + cache.trim(2) + keys = cache.state[0] + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, reference[:, :, :2, :]) + self.assertEqual(cache.offset, 2) + + # ensure adding more values works fine + new_kv = self.add_random_to_cache(cache, 2) + keys = cache.state[0] + self.assertEqual(cache.offset, 4) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) + self.assertArrEqual(keys[:, :, :2, :], reference[:, :, :2, :]) + self.assertArrEqual(keys[:, :, 2:, :], new_kv) + + +if __name__ == "__main__": + unittest.main(verbosity=2, failfast=True) From 27fae1c6a423ec152b4d9153054d83949978b69d Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 13:58:19 -0400 Subject: [PATCH 3/9] please the linter --- mlx_engine/cache.py | 3 +-- mlx_engine/cache_wrapper.py | 12 ++++++++---- mlx_engine/model_kit/model_kit.py | 2 +- tests/test_cache_generic.py | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 96f89fa2..c4ecc560 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -1,7 +1,6 @@ from typing import List, Optional, Any from mlx_lm.models.cache import RotatingKVCache, KVCache -import mlx.core as mx import mlx.nn as nn @@ -59,4 +58,4 @@ def make_prompt_cache( ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers) ] else: - return [KVCache() for _ in range(num_layers)] \ No newline at end of file + return [KVCache() for _ in range(num_layers)] diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 66f07f9a..e0427a6c 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -114,7 +114,9 @@ def _get_unprocessed_tokens( message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: " "Cache is not trimmable. Clearing the cache instead.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) + self.cache = make_prompt_cache( + self.model, self.max_kv_size, keep=self.keep + ) self.tokens = prompt_tokens return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) @@ -125,7 +127,9 @@ def _get_unprocessed_tokens( message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected " " ({num_tokens_to_trim}). Clearing the cache.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) + self.cache = make_prompt_cache( + self.model, self.max_kv_size, keep=self.keep + ) self.tokens = prompt_tokens return self.tokens log_info( @@ -257,7 +261,7 @@ def update_cache( def prompt_progress_callback(x): return None - + # update keep tracking self.keep = keep for cache in self.cache: @@ -303,7 +307,7 @@ def prompt_progress_callback(x): def record_generated_token(self, token): """ Add the generated token to the token list, so that we can map the token to the KV cache. - + Also loop when the cache does so that we accurately track what's in cache. """ # this behavior is common to rolling window (n_keep = 0) and truncate middle diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 1e364725..4f93437b 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -161,7 +161,7 @@ def process_prompt( self.draft_model, speculative_decoding_toggle, prompt_progress_callback, - keep=keep + keep=keep, ), None ### WITH IMAGES PROMPT PROCESSING ###s if self.vision_add_on is None: diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py index e836d7ec..4a7b21f4 100644 --- a/tests/test_cache_generic.py +++ b/tests/test_cache_generic.py @@ -33,4 +33,4 @@ def add_random_to_cache(self, cache: ShiftingKVCache, seqlen: int) -> mx.array: # would return a reference to cache.keys, which is pointless. so copy it reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) - return reference \ No newline at end of file + return reference From cce592762f52d60f83d8d2b4954eca6b028b2cda Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 15:44:49 -0400 Subject: [PATCH 4/9] test that trim stops nuking --- tests/test_text_models.py | 74 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/test_text_models.py b/tests/test_text_models.py index 67dba024..b116ab8f 100644 --- a/tests/test_text_models.py +++ b/tests/test_text_models.py @@ -209,6 +209,80 @@ def generate(text_accumulator: list) -> None: self.assertEqual(generated_text_1, generated_text_2) + def test_cache_nuke_qwen2_5(self): + model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") + model_kit = load_model(model_path=model_path, max_kv_size=32) + prompt = f"""<|im_start|>user +Explain how the universe works. What was the Big Bang? What's redshifting? +<|im_end|> +<|im_start|>assistant +""" + prompt_tokens = tokenize(model_kit, prompt) + log_info( + prefix="test_cache_nuke", + message=f"Generation 1 number of prompt tokens: {len(prompt_tokens)}", + ) + generated_text_list_1 = [] + prompt_progress_callback_times_called = 0 + + def prompt_progress_callback(progress: float) -> None: + nonlocal prompt_progress_callback_times_called + prompt_progress_callback_times_called += 1 + print(f"Prompt Progress: {progress:.2f}") + + # accumulating to list allows pass by reference + def generate(text_accumulator: list) -> None: + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + seed=0, + max_tokens=100, + temp=0.0, + prompt_progress_callback=prompt_progress_callback, + ): + print(result.text, end="", flush=True) + text_accumulator.append(result.text) + if result.stop_condition: + break + print("\n", flush=True) + + ### Generation 1 - fills cache + generate(text_accumulator=generated_text_list_1) + generated_text_1 = "".join(generated_text_list_1) + self.assertEqual(prompt_progress_callback_times_called, 2) + self.assertGreater( + len(generated_text_1), 0, "Model failed to generate any text" + ) + gen1_cache_layer0 = model_kit.cache_wrapper.cache[0] + + ### Generation 2 - trims cache + prompt = f"""<|im_start|>user +Explain how the universe works. What was the Big Bang? +<|im_end|> +<|im_start|>assistant +""" + prompt_tokens = tokenize(model_kit, prompt) + log_info( + prefix="test_cache_nuke", + message=f"Generation 2 number of prompt tokens: {len(prompt_tokens)}", + ) + generated_text_list_2 = [] + prompt_progress_callback_times_called = 0 + generate(text_accumulator=generated_text_list_2) + generated_text_2 = "".join(generated_text_list_2) + # Expect prompt cache to be intact for the first half of the file_content, so we should get 1 + # intermediate callback this time + self.assertEqual(prompt_progress_callback_times_called, 2) + self.assertGreater( + len(generated_text_2), 0, "Model failed to generate any text" + ) + gen2_cache_layer0 = model_kit.cache_wrapper.cache[0] + + # if we nuked cache, these will reference different locations in memory + # if we didn't, they'll refer to the same object + self.assertTrue(gen1_cache_layer0 is gen2_cache_layer0) + + class TestStructuredGen(unittest.TestCase): def setUp(self): self.prompt = "List three colors and their hex codes." From 0a71d2a126756e54dc88b60effa514c9471920e8 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 15:51:30 -0400 Subject: [PATCH 5/9] address comments --- mlx_engine/cache.py | 28 ++++++++++++++++--- tests/test_cache_generic.py | 6 ++-- ...test_cache_shift.py => test_cache_trim.py} | 18 ++++++------ tests/test_text_models.py | 5 ++-- 4 files changed, 39 insertions(+), 18 deletions(-) rename tests/{test_cache_shift.py => test_cache_trim.py} (92%) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index c4ecc560..cae49020 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -4,9 +4,23 @@ import mlx.nn as nn -class ShiftingKVCache(RotatingKVCache): +class AlwaysTrimmableKVCache(RotatingKVCache): + """A KV cache that can always be trimmed. + + The MLX-LM implementation of the RotatingKVCache does not allow trimming + the cache once the maximum KV size has been exceeded, which results in + the cache being nuked every time this happens. This forces the entire context + to be reprocessed regularly, which is not ideal for performance. This KV cache + allows trimming the cache at any time, which circumvents this issue. + See https://github.com/lmstudio-ai/mlx-engine/issues/177 for more details. + """ + def trim(self, n) -> int: - # trim must not respect keep + # trim must not respect keep: we always receive some value for keep, but + # when initially processing the prompt, it may be that the common prefix + # is shorter than keep. in that case we must trim to the common prefix length, + # which violates keep. keep is only used for the cache rotation when exceeding + # the context length mid-generation to ensure we don't lose the common prefix. n = min(self.offset, n) if n <= 0: return 0 @@ -44,10 +58,15 @@ def make_prompt_cache( Construct the model's cache for use in generation. This function will defer the cache construction to the model if it has a ``make_cache`` method, otherwise it will make a default KV cache. + + See https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L10 + for the MLX-LM implementation. This is a temporary extension to support trimming + in the RotatingKVCache, which is not supported in the original MLX-LM implementation. + Args: model (nn.Module): The language model. max_kv_size (Optional[int]): If provided and the model does not have a - ``make_cache`` method, a ``ShiftingKVCache`` is used with a maximum + ``make_cache`` method, a ``AlwaysTrimmableKVCache`` is used with a maximum size of ``max_kv_size`` """ if hasattr(model, "make_cache"): @@ -55,7 +74,8 @@ def make_prompt_cache( num_layers = len(model.layers) if max_kv_size is not None: return [ - ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers) + AlwaysTrimmableKVCache(max_size=max_kv_size, keep=keep) + for _ in range(num_layers) ] else: return [KVCache() for _ in range(num_layers)] diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py index 4a7b21f4..bb5c905a 100644 --- a/tests/test_cache_generic.py +++ b/tests/test_cache_generic.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx from copy import deepcopy -from mlx_engine.cache import ShiftingKVCache +from mlx_engine.cache import AlwaysTrimmableKVCache class TestCache(unittest.TestCase): @@ -26,7 +26,9 @@ def assertArrEqual(self, a: mx.array, b: mx.array): self.assertEqual(a.shape, b.shape) self.assertTrue(mx.allclose(a, b), "Tensors are not equal") - def add_random_to_cache(self, cache: ShiftingKVCache, seqlen: int) -> mx.array: + def add_random_to_cache( + self, cache: AlwaysTrimmableKVCache, seqlen: int + ) -> mx.array: """Add random values to the cache and return them""" base_kv = self.make_random_kv(seqlen) # base_kv is *assigned* to cache.keys/cache.values so returning base_kv diff --git a/tests/test_cache_shift.py b/tests/test_cache_trim.py similarity index 92% rename from tests/test_cache_shift.py rename to tests/test_cache_trim.py index 6dc7ce06..ae0301f7 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_trim.py @@ -1,6 +1,6 @@ import unittest import mlx.core as mx -from mlx_engine.cache import ShiftingKVCache +from mlx_engine.cache import AlwaysTrimmableKVCache from tests.test_cache_generic import TestCache @@ -9,10 +9,10 @@ def idx(v: mx.array, i: int): return v[:, :, i : i + 1, :] -class TestShiftingKVCache(TestCache): +class TestAlwaysTrimmableKVCache(TestCache): def test_overwriting(self): """Test overwriting when the cache reaches max_size""" - cache = ShiftingKVCache(max_size=3, keep=1) + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -30,7 +30,7 @@ def test_overwriting(self): def test_ensure_update_increases_offset_indefinitely(self): """Test single-token updates that should increase offset""" - cache = ShiftingKVCache(max_size=3, keep=1) + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) for i in range(10): self.add_random_to_cache(cache, 1) @@ -41,7 +41,7 @@ def test_ensure_reasonable_size_and_shift(self): than max_size. The default behavior of the cache is to write the entire thing, then trim it back down when the next KV is written. """ - cache = ShiftingKVCache(max_size=3, keep=1) + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) # fill cache -> 0123456789 reference = self.add_random_to_cache(cache, 10) @@ -78,7 +78,7 @@ def test_ensure_reasonable_size_and_shift(self): def test_update_keep_on_the_fly(self): """Test changing the keep value on the fly""" - cache = ShiftingKVCache(max_size=4, keep=1) + cache = AlwaysTrimmableKVCache(max_size=4, keep=1) # fill cache -> 1234 reference = self.add_random_to_cache(cache, 4) @@ -103,7 +103,7 @@ def test_update_keep_on_the_fly(self): def test_trim_before_full(self): """Test trimming from the end before the cache is full""" - cache = ShiftingKVCache(max_size=3, keep=1) + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) # fill cache -> 12 reference = self.add_random_to_cache(cache, 2) @@ -128,7 +128,7 @@ def test_trim_before_full(self): def test_trim_after_overwrite(self): """Test trimming from the end when we've written past the cache max""" - cache = ShiftingKVCache(max_size=3, keep=1) + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -152,7 +152,7 @@ def test_trim_after_overwrite(self): def test_trim_after_full(self): """Test trimming from the end when the cache is oversize""" - cache = ShiftingKVCache(max_size=3, keep=1) + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) # fill cache oversize already -> 1234 reference = self.add_random_to_cache(cache, 4) diff --git a/tests/test_text_models.py b/tests/test_text_models.py index b116ab8f..1f4dd942 100644 --- a/tests/test_text_models.py +++ b/tests/test_text_models.py @@ -208,11 +208,10 @@ def generate(text_accumulator: list) -> None: ) self.assertEqual(generated_text_1, generated_text_2) - def test_cache_nuke_qwen2_5(self): model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") model_kit = load_model(model_path=model_path, max_kv_size=32) - prompt = f"""<|im_start|>user + prompt = """<|im_start|>user Explain how the universe works. What was the Big Bang? What's redshifting? <|im_end|> <|im_start|>assistant @@ -256,7 +255,7 @@ def generate(text_accumulator: list) -> None: gen1_cache_layer0 = model_kit.cache_wrapper.cache[0] ### Generation 2 - trims cache - prompt = f"""<|im_start|>user + prompt = """<|im_start|>user Explain how the universe works. What was the Big Bang? <|im_end|> <|im_start|>assistant From 71f1d534f85c5a8e23be002ee76540409a87c415 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 15:59:07 -0400 Subject: [PATCH 6/9] fix draft models --- mlx_engine/cache_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index e0427a6c..778a8039 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -312,7 +312,7 @@ def record_generated_token(self, token): """ # this behavior is common to rolling window (n_keep = 0) and truncate middle # (n_keep > 0), and we should never get here with stop at max - if len(self.tokens) >= self.max_kv_size: + if self.max_kv_size is not None and len(self.tokens) >= self.max_kv_size: self.tokens = mx.concat( [self.tokens[: self.keep], self.tokens[self.keep + 1 :]] ) From 559d0546ab1c43144b9d1ab7b5c22ced5efe7ded Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 16:01:49 -0400 Subject: [PATCH 7/9] permalink to rotation --- mlx_engine/cache_wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 778a8039..abe92082 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -308,7 +308,9 @@ def record_generated_token(self, token): """ Add the generated token to the token list, so that we can map the token to the KV cache. - Also loop when the cache does so that we accurately track what's in cache. + Also loop when the cache does so that we accurately track what's in cache, if we're using + a RotatingKVCache or subclass of such. See the rotation implemented by MLX-LM here: + https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L371 """ # this behavior is common to rolling window (n_keep = 0) and truncate middle # (n_keep > 0), and we should never get here with stop at max From 395cafb8d20d1829eba8e16ad8b163a8789e1a59 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 16:36:35 -0400 Subject: [PATCH 8/9] gate logic --- mlx_engine/cache_wrapper.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index abe92082..3fb13dec 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -2,7 +2,11 @@ from mlx_engine.logging import log_info, log_warn, log_error from mlx_engine.cache import make_prompt_cache -from mlx_lm.models.cache import trim_prompt_cache, can_trim_prompt_cache +from mlx_lm.models.cache import ( + trim_prompt_cache, + can_trim_prompt_cache, + RotatingKVCache, +) from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache import mlx.core as mx import mlx.nn as nn @@ -36,6 +40,7 @@ def __init__( self.tokens: Optional[mx.array] = None self.keep = keep self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep) + self.is_rotating = all(isinstance(c, RotatingKVCache) for c in self.cache) self.model = model self.draft_model: Optional[nn.Module] = None self.max_kv_size = max_kv_size @@ -115,7 +120,9 @@ def _get_unprocessed_tokens( "Cache is not trimmable. Clearing the cache instead.", ) self.cache = make_prompt_cache( - self.model, self.max_kv_size, keep=self.keep + self.model, + max_kv_size=self.max_kv_size if self.is_rotating else None, + keep=self.keep, ) self.tokens = prompt_tokens return self.tokens @@ -128,7 +135,9 @@ def _get_unprocessed_tokens( " ({num_tokens_to_trim}). Clearing the cache.", ) self.cache = make_prompt_cache( - self.model, self.max_kv_size, keep=self.keep + self.model, + max_kv_size=self.max_kv_size if self.is_rotating else None, + keep=self.keep, ) self.tokens = prompt_tokens return self.tokens @@ -225,6 +234,8 @@ def set_draft_model(self, draft_model: nn.Module): ) self.tokens = None self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep) + # the above will never return a rotating cache since there is no max_kv_size set + self.is_rotating = False if draft_model is not None: self.cache += make_prompt_cache(draft_model, keep=self.keep) self.draft_model = draft_model @@ -314,7 +325,12 @@ def record_generated_token(self, token): """ # this behavior is common to rolling window (n_keep = 0) and truncate middle # (n_keep > 0), and we should never get here with stop at max - if self.max_kv_size is not None and len(self.tokens) >= self.max_kv_size: + if ( + self.max_kv_size is not None + and self.is_rotating + and len(self.tokens) >= self.max_kv_size + ): + # rotate the token tracking buffer self.tokens = mx.concat( [self.tokens[: self.keep], self.tokens[self.keep + 1 :]] ) From d8240d716033c23ce60157a5e4cf0fb4211b3c9f Mon Sep 17 00:00:00 2001 From: christian-lms Date: Mon, 14 Jul 2025 13:12:36 -0400 Subject: [PATCH 9/9] Update mlx_engine/cache.py Co-authored-by: Matt Clayton <156335168+mattjcly@users.noreply.github.com> --- mlx_engine/cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index cae49020..f22bc616 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -60,8 +60,8 @@ def make_prompt_cache( ``make_cache`` method, otherwise it will make a default KV cache. See https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L10 - for the MLX-LM implementation. This is a temporary extension to support trimming - in the RotatingKVCache, which is not supported in the original MLX-LM implementation. + for the MLX-LM implementation. This is a temporary extension to support more flexible + trimming than MLX-LM's original RotatingKVCache. Args: model (nn.Module): The language model.