From f11604b15835de6aff917d79db885272b7d55a7a Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 10:53:32 -0400 Subject: [PATCH 01/39] loop record generated token --- mlx_engine/cache_wrapper.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 36498fab..28abdd68 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -296,5 +296,15 @@ def prompt_progress_callback(x): def record_generated_token(self, token): """ Add the generated token to the token list, so that we can map the token to the KV cache. + + Also loop when the cache does so that we accurately track what's in cache. """ + # TODO(christian-lms): ensure that this works as intended when over length + # TODO(christian-lms): verify rolling window and truncate middle have n_keep as below + # TODO(christian-lms): this won't work until we pipe in keep from generate + n_keep = self.cache[0].keep + # this behavior is common to rolling window (n_keep = 0) and truncate middle + # (n_keep > 0), and we should never get here with stop at max + if len(self.tokens) >= n_keep: + self.tokens = mx.concat([self.tokens[:n_keep], self.tokens[n_keep+1:]]) self.tokens = mx.concat([self.tokens, mx.array([token])]) From 64113b35ade614f9b313915419b0e6a3116e90a9 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 12:21:49 -0400 Subject: [PATCH 02/39] shifting kv cache --- mlx_engine/cache_wrapper.py | 182 +++++++++++++++++++++++++++++++++++- 1 file changed, 181 insertions(+), 1 deletion(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 28abdd68..44a3710e 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -2,9 +2,10 @@ from mlx_engine.logging import log_info, log_warn, log_error from mlx_lm.models.cache import ( - make_prompt_cache, trim_prompt_cache, can_trim_prompt_cache, + RotatingKVCache, + KVCache ) from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache import mlx.core as mx @@ -12,6 +13,183 @@ import sys +# TODO(christian-lms) DO NOT HARDCODE ME (or at least move it somewhere else) +MAYBE_ATTN_NAMES = ["self_attn", "attention", "attn", "mixer", "norm_attn_norm"] +MAYBE_ROPE_NAMES = ["rope", "rotary_emb"] + + +def _maybe_get_rope(layer: nn.Module) -> Optional[nn.Module]: + for maybe_rope_name in MAYBE_ROPE_NAMES: + if hasattr(layer, maybe_rope_name): + # found it + return getattr(layer, maybe_rope_name) + for maybe_attn_name in MAYBE_ATTN_NAMES: + if hasattr(layer, maybe_attn_name): + # move down one level + return _maybe_get_rope(getattr(layer, maybe_attn_name)) + # no dice + return None + + +def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: + """Attempt to find the RoPE module from a layer of an MLX-LM LLM. + + Args: + model (nn.Module): The LLM to search for the RoPE modules of. + layer_idx (int): The layer of the LLM to get the RoPE module from. + + Returns: + Optional[nn.Module]: The RoPE module if found, else None + """ + # we can assume model has attribute layers because make_prompt_cache does + if layer_idx > len(model.layers): + # TODO(christian-lms): fail silently or throw here? + return None + layer = model.layers[layer_idx] + if not isinstance(layer, nn.Module): + return None + return _maybe_get_rope(layer) + + +class ShiftingKVCache(RotatingKVCache): + def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): + self.rope = rope + self.reuse_offset = 0 + self.reuse_queue = [] + super().__init__(self, max_size, keep, step) + + def is_trimmable(self) -> bool: + return True + + def _temporal_order(self, v) -> mx.array: + """ + Rearrange the cache into temporal order, slicing off the end if unused. + """ + if self._idx == v.shape[2]: + return v + elif self._idx < self.offset: + shift_by = self.keep - self.idx + return mx.concatenate( + [ + v[..., : self.keep, :], + # TODO(christian-lms): verify that i work + # TODO(christian-lms): can you do this in 1 call to self.rope? + # N.B. this implicitly assumes the generation has not gone over twice + # the size of the rotating section of the cache, in which case the + # rotating section would be off by a multiple of (max_kv_size - keep) + # depending on how many times it rolled over. I feel like it's pretty + # safe to assume that this is a rare case + self.rope(v[..., self._idx :, :], shift_by), + self.rope(v[..., self.keep : self._idx, :], shift_by), + ], + axis=2, + ) + else: + return v[..., : self._idx, :] + + def reuse_section(self, write_start_idx: int, reuse_start_idx: int, reuse_length: int) -> None: + # offset indices to account for the fact that we move cache elements around + write_start_idx -= self.reuse_offset + reuse_start_idx -= self.reuse_offset + + # update position offsets for future reuse sections + shift_by = write_start_idx - reuse_start_idx + self.reuse_offset += shift_by + + # queue for reuse: everything is done in one pass at the end in do_reuse + self.reuse_queue.append((write_start_idx, reuse_start_idx, reuse_length)) + + def do_reuse(self) -> None: + last_i: int = len(self.reuse_queue) - 1 + for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate(self.reuse_queue): + shift_by: int = write_start_idx - reuse_start_idx + reuse_end_idx: int = reuse_start_idx + reuse_length + + keys_to_shift = self.keys[..., reuse_start_idx : reuse_end_idx, :] + values_to_shift = self.values[..., reuse_start_idx : reuse_end_idx, :] + + # perform rope shift + # N.B. we can also go back to the MLX-native "don't rope shift" method + # by + shifted_keys = self.rope(keys_to_shift, shift_by) + shifted_values = self.rope(values_to_shift, shift_by) + + # restructure cache with mx.concat + # TODO(christian-lms): maybe it would be better to use inplace ops. + # look into the mlx docs if that's even a thing + keycat = [ + self.keys[..., : write_start_idx, :], + shifted_keys + ] + valcat = [ + self.values[..., : write_start_idx, :], + shifted_values + ] + + # by not re-appending the end at the last one, we truncate the leftovers + if i != last_i: + keycat.append(self.keys[..., reuse_end_idx : , :]) + valcat.append(self.values[..., reuse_end_idx : , :]) + + self.keys = mx.concat(keycat, axis=2) + self.values = mx.concat(valcat, axis=2) + + self.offset -= shift_by + self.reuse_offset = 0 + self.reuse_queue = [] + # TODO(christian-lms): dunno if this number is correct/reasonable/whatever + self._idx = self.keys.shape[2] + + def trim(self, n) -> int: + # TODO(christian-lms): fix me + n = min(self.offset, n) + if n == 0: + return 0 + + if self.offset >= self.max_size: + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + n = n % (self.max_size - self.keep) + + # do trim: put us back into the state before the circular buffer is full + new_length = self.keys.shape[2] - n + self.keys = self.keys[..., :new_length, :] + self.values = self.values[..., :new_length, :] + + self.offset -= n + self._idx = new_length + return n + +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, +) -> List[Any]: + """ + Construct the model's cache for use in generation. + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``TrimmableRotatingKVCache`` is used + with a maximum size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + num_layers = len(model.layers) + if max_kv_size is not None: + cache = [] + for layer in range(num_layers): + rope = maybe_get_rope(model, layer) + if rope is None: + return [KVCache() for _ in range(num_layers)] + # TODO(christian-lms): change keep on the fly, must be setattr elsewhere + cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=4)) + return cache + else: + return [KVCache() for _ in range(num_layers)] + + class CacheWrapper: """ Wrapper class for the MLX LM cache to maintain an in-memory cache @@ -256,6 +434,8 @@ def update_cache( def prompt_progress_callback(x): return None + + # TODO(christian-lms): truncation logic goes here now num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( From 19cde88a8d0999c812d8653ddfd2f6583de805cd Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 12:25:57 -0400 Subject: [PATCH 03/39] override another method to use rope shift --- mlx_engine/cache_wrapper.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 44a3710e..ffed7fa1 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -61,6 +61,17 @@ def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): def is_trimmable(self) -> bool: return True + def _trim(self, trim_size, v, append=None): + to_cat = [] + shift_by = -trim_size + if trim_size > 0: + to_cat = [v[..., : self.keep, :], self.rope(v[..., trim_size + self.keep :, :], shift_by)] + else: + to_cat = [v] + if append is not None: + to_cat.append(append) + return mx.concatenate(to_cat, axis=2) + def _temporal_order(self, v) -> mx.array: """ Rearrange the cache into temporal order, slicing off the end if unused. @@ -102,7 +113,7 @@ def reuse_section(self, write_start_idx: int, reuse_start_idx: int, reuse_length def do_reuse(self) -> None: last_i: int = len(self.reuse_queue) - 1 for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate(self.reuse_queue): - shift_by: int = write_start_idx - reuse_start_idx + shift_by: int = write_start_idx - reuse_start_idx # < 0 reuse_end_idx: int = reuse_start_idx + reuse_length keys_to_shift = self.keys[..., reuse_start_idx : reuse_end_idx, :] From f4822f06f1c745b03bda3c2127cb164f04f1ca82 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 12:30:01 -0400 Subject: [PATCH 04/39] add testing asserts --- mlx_engine/cache_wrapper.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index ffed7fa1..e142471f 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -64,6 +64,7 @@ def is_trimmable(self) -> bool: def _trim(self, trim_size, v, append=None): to_cat = [] shift_by = -trim_size + assert shift_by <= 0 if trim_size > 0: to_cat = [v[..., : self.keep, :], self.rope(v[..., trim_size + self.keep :, :], shift_by)] else: @@ -80,6 +81,7 @@ def _temporal_order(self, v) -> mx.array: return v elif self._idx < self.offset: shift_by = self.keep - self.idx + assert shift_by <= 0 return mx.concatenate( [ v[..., : self.keep, :], @@ -112,8 +114,10 @@ def reuse_section(self, write_start_idx: int, reuse_start_idx: int, reuse_length def do_reuse(self) -> None: last_i: int = len(self.reuse_queue) - 1 + for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate(self.reuse_queue): - shift_by: int = write_start_idx - reuse_start_idx # < 0 + shift_by: int = write_start_idx - reuse_start_idx + assert shift_by <= 0 reuse_end_idx: int = reuse_start_idx + reuse_length keys_to_shift = self.keys[..., reuse_start_idx : reuse_end_idx, :] @@ -121,7 +125,7 @@ def do_reuse(self) -> None: # perform rope shift # N.B. we can also go back to the MLX-native "don't rope shift" method - # by + # by removing RoPE here and removing the overrides for trim, temporal order shifted_keys = self.rope(keys_to_shift, shift_by) shifted_values = self.rope(values_to_shift, shift_by) @@ -137,6 +141,7 @@ def do_reuse(self) -> None: shifted_values ] + # TODO(christian-lms): surely there is a better way to do this? # by not re-appending the end at the last one, we truncate the leftovers if i != last_i: keycat.append(self.keys[..., reuse_end_idx : , :]) @@ -152,9 +157,8 @@ def do_reuse(self) -> None: self._idx = self.keys.shape[2] def trim(self, n) -> int: - # TODO(christian-lms): fix me n = min(self.offset, n) - if n == 0: + if n <= 0: return 0 if self.offset >= self.max_size: @@ -164,10 +168,11 @@ def trim(self, n) -> int: # do trim: put us back into the state before the circular buffer is full new_length = self.keys.shape[2] - n - self.keys = self.keys[..., :new_length, :] - self.values = self.values[..., :new_length, :] + self.keys = self.keys[..., : new_length, :] + self.values = self.values[..., : new_length, :] self.offset -= n + # TODO(christian-lms): verify that this is reasonable self._idx = new_length return n From e297092ad812b83c6715b55813362944114fb2d0 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 12:32:43 -0400 Subject: [PATCH 05/39] warn --- mlx_engine/cache_wrapper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index e142471f..49b9ae9b 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -43,7 +43,6 @@ def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: """ # we can assume model has attribute layers because make_prompt_cache does if layer_idx > len(model.layers): - # TODO(christian-lms): fail silently or throw here? return None layer = model.layers[layer_idx] if not isinstance(layer, nn.Module): @@ -197,7 +196,12 @@ def make_prompt_cache( cache = [] for layer in range(num_layers): rope = maybe_get_rope(model, layer) + # TODO(christian-lms): it is known that this will fail for some models + # like llama4 which has no rope module for every fourth layer. + # this will be figured out Later(tm) once the initial functionality works if rope is None: + log_warn("Attempted to build a KV cache of shiftable caches, but found" + f"None at layer {layer} of model {model}") return [KVCache() for _ in range(num_layers)] # TODO(christian-lms): change keep on the fly, must be setattr elsewhere cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=4)) From ee316db945653f730405ddeae4367dc01e1c9fcf Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 12:38:26 -0400 Subject: [PATCH 06/39] move cache into a separate file --- mlx_engine/cache.py | 206 ++++++++++++++++++++++++++++++++++++ mlx_engine/cache_wrapper.py | 202 +---------------------------------- 2 files changed, 208 insertions(+), 200 deletions(-) create mode 100644 mlx_engine/cache.py diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py new file mode 100644 index 00000000..444ab8e1 --- /dev/null +++ b/mlx_engine/cache.py @@ -0,0 +1,206 @@ +from typing import List, Optional, Any + +from mlx_engine.logging import log_info, log_warn, log_error +from mlx_lm.models.cache import ( + RotatingKVCache, + KVCache +) +import mlx.core as mx +import mlx.nn as nn + + +# TODO(christian-lms) DO NOT HARDCODE ME (or at least move it somewhere else) +MAYBE_ATTN_NAMES = ["self_attn", "attention", "attn", "mixer", "norm_attn_norm"] +MAYBE_ROPE_NAMES = ["rope", "rotary_emb"] + + +def _maybe_get_rope(layer: nn.Module) -> Optional[nn.Module]: + for maybe_rope_name in MAYBE_ROPE_NAMES: + if hasattr(layer, maybe_rope_name): + # found it + return getattr(layer, maybe_rope_name) + for maybe_attn_name in MAYBE_ATTN_NAMES: + if hasattr(layer, maybe_attn_name): + # move down one level + return _maybe_get_rope(getattr(layer, maybe_attn_name)) + # no dice + return None + + +def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: + """Attempt to find the RoPE module from a layer of an MLX-LM LLM. + + Args: + model (nn.Module): The LLM to search for the RoPE modules of. + layer_idx (int): The layer of the LLM to get the RoPE module from. + + Returns: + Optional[nn.Module]: The RoPE module if found, else None + """ + # we can assume model has attribute layers because make_prompt_cache does + if layer_idx > len(model.layers): + return None + layer = model.layers[layer_idx] + if not isinstance(layer, nn.Module): + return None + return _maybe_get_rope(layer) + + +class ShiftingKVCache(RotatingKVCache): + def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): + self.rope = rope + self.reuse_offset = 0 + self.reuse_queue = [] + super().__init__(self, max_size, keep, step) + + def is_trimmable(self) -> bool: + return True + + def _trim(self, trim_size, v, append=None): + to_cat = [] + shift_by = -trim_size + assert shift_by <= 0 + if trim_size > 0: + to_cat = [v[..., : self.keep, :], self.rope(v[..., trim_size + self.keep :, :], shift_by)] + else: + to_cat = [v] + if append is not None: + to_cat.append(append) + return mx.concatenate(to_cat, axis=2) + + def _temporal_order(self, v) -> mx.array: + """ + Rearrange the cache into temporal order, slicing off the end if unused. + """ + if self._idx == v.shape[2]: + return v + elif self._idx < self.offset: + shift_by = self.keep - self.idx + assert shift_by <= 0 + return mx.concatenate( + [ + v[..., : self.keep, :], + # TODO(christian-lms): verify that i work + # TODO(christian-lms): can you do this in 1 call to self.rope? + # N.B. this implicitly assumes the generation has not gone over twice + # the size of the rotating section of the cache, in which case the + # rotating section would be off by a multiple of (max_kv_size - keep) + # depending on how many times it rolled over. I feel like it's pretty + # safe to assume that this is a rare case + self.rope(v[..., self._idx :, :], shift_by), + self.rope(v[..., self.keep : self._idx, :], shift_by), + ], + axis=2, + ) + else: + return v[..., : self._idx, :] + + def reuse_section(self, write_start_idx: int, reuse_start_idx: int, reuse_length: int) -> None: + # offset indices to account for the fact that we move cache elements around + write_start_idx -= self.reuse_offset + reuse_start_idx -= self.reuse_offset + + # update position offsets for future reuse sections + shift_by = write_start_idx - reuse_start_idx + self.reuse_offset += shift_by + + # queue for reuse: everything is done in one pass at the end in do_reuse + self.reuse_queue.append((write_start_idx, reuse_start_idx, reuse_length)) + + def do_reuse(self) -> None: + last_i: int = len(self.reuse_queue) - 1 + + for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate(self.reuse_queue): + shift_by: int = write_start_idx - reuse_start_idx + assert shift_by <= 0 + reuse_end_idx: int = reuse_start_idx + reuse_length + + keys_to_shift = self.keys[..., reuse_start_idx : reuse_end_idx, :] + values_to_shift = self.values[..., reuse_start_idx : reuse_end_idx, :] + + # perform rope shift + # N.B. we can also go back to the MLX-native "don't rope shift" method + # by removing RoPE here and removing the overrides for trim, temporal order + shifted_keys = self.rope(keys_to_shift, shift_by) + shifted_values = self.rope(values_to_shift, shift_by) + + # restructure cache with mx.concat + # TODO(christian-lms): maybe it would be better to use inplace ops. + # look into the mlx docs if that's even a thing + keycat = [ + self.keys[..., : write_start_idx, :], + shifted_keys + ] + valcat = [ + self.values[..., : write_start_idx, :], + shifted_values + ] + + # TODO(christian-lms): surely there is a better way to do this? + # by not re-appending the end at the last one, we truncate the leftovers + if i != last_i: + keycat.append(self.keys[..., reuse_end_idx : , :]) + valcat.append(self.values[..., reuse_end_idx : , :]) + + self.keys = mx.concat(keycat, axis=2) + self.values = mx.concat(valcat, axis=2) + + self.offset -= shift_by + self.reuse_offset = 0 + self.reuse_queue = [] + # TODO(christian-lms): dunno if this number is correct/reasonable/whatever + self._idx = self.keys.shape[2] + + def trim(self, n) -> int: + n = min(self.offset, n) + if n <= 0: + return 0 + + if self.offset >= self.max_size: + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + n = n % (self.max_size - self.keep) + + # do trim: put us back into the state before the circular buffer is full + new_length = self.keys.shape[2] - n + self.keys = self.keys[..., : new_length, :] + self.values = self.values[..., : new_length, :] + + self.offset -= n + # TODO(christian-lms): verify that this is reasonable + self._idx = new_length + return n + +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, +) -> List[Any]: + """ + Construct the model's cache for use in generation. + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``TrimmableRotatingKVCache`` is used + with a maximum size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + num_layers = len(model.layers) + if max_kv_size is not None: + cache = [] + for layer in range(num_layers): + rope = maybe_get_rope(model, layer) + # TODO(christian-lms): it is known that this will fail for some models + # like llama4 which has no rope module for every fourth layer. + # this will be figured out Later(tm) once the initial functionality works + if rope is None: + log_warn("Attempted to build a KV cache of shiftable caches, but found" + f"None at layer {layer} of model {model}") + return [KVCache() for _ in range(num_layers)] + # TODO(christian-lms): change keep on the fly, must be setattr elsewhere + cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=4)) + return cache + else: + return [KVCache() for _ in range(num_layers)] \ No newline at end of file diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 49b9ae9b..04b5b1cb 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -1,11 +1,10 @@ from typing import List, Optional, Any from mlx_engine.logging import log_info, log_warn, log_error +from mlx_engine.cache import make_prompt_cache from mlx_lm.models.cache import ( trim_prompt_cache, - can_trim_prompt_cache, - RotatingKVCache, - KVCache + can_trim_prompt_cache ) from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache import mlx.core as mx @@ -13,203 +12,6 @@ import sys -# TODO(christian-lms) DO NOT HARDCODE ME (or at least move it somewhere else) -MAYBE_ATTN_NAMES = ["self_attn", "attention", "attn", "mixer", "norm_attn_norm"] -MAYBE_ROPE_NAMES = ["rope", "rotary_emb"] - - -def _maybe_get_rope(layer: nn.Module) -> Optional[nn.Module]: - for maybe_rope_name in MAYBE_ROPE_NAMES: - if hasattr(layer, maybe_rope_name): - # found it - return getattr(layer, maybe_rope_name) - for maybe_attn_name in MAYBE_ATTN_NAMES: - if hasattr(layer, maybe_attn_name): - # move down one level - return _maybe_get_rope(getattr(layer, maybe_attn_name)) - # no dice - return None - - -def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: - """Attempt to find the RoPE module from a layer of an MLX-LM LLM. - - Args: - model (nn.Module): The LLM to search for the RoPE modules of. - layer_idx (int): The layer of the LLM to get the RoPE module from. - - Returns: - Optional[nn.Module]: The RoPE module if found, else None - """ - # we can assume model has attribute layers because make_prompt_cache does - if layer_idx > len(model.layers): - return None - layer = model.layers[layer_idx] - if not isinstance(layer, nn.Module): - return None - return _maybe_get_rope(layer) - - -class ShiftingKVCache(RotatingKVCache): - def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): - self.rope = rope - self.reuse_offset = 0 - self.reuse_queue = [] - super().__init__(self, max_size, keep, step) - - def is_trimmable(self) -> bool: - return True - - def _trim(self, trim_size, v, append=None): - to_cat = [] - shift_by = -trim_size - assert shift_by <= 0 - if trim_size > 0: - to_cat = [v[..., : self.keep, :], self.rope(v[..., trim_size + self.keep :, :], shift_by)] - else: - to_cat = [v] - if append is not None: - to_cat.append(append) - return mx.concatenate(to_cat, axis=2) - - def _temporal_order(self, v) -> mx.array: - """ - Rearrange the cache into temporal order, slicing off the end if unused. - """ - if self._idx == v.shape[2]: - return v - elif self._idx < self.offset: - shift_by = self.keep - self.idx - assert shift_by <= 0 - return mx.concatenate( - [ - v[..., : self.keep, :], - # TODO(christian-lms): verify that i work - # TODO(christian-lms): can you do this in 1 call to self.rope? - # N.B. this implicitly assumes the generation has not gone over twice - # the size of the rotating section of the cache, in which case the - # rotating section would be off by a multiple of (max_kv_size - keep) - # depending on how many times it rolled over. I feel like it's pretty - # safe to assume that this is a rare case - self.rope(v[..., self._idx :, :], shift_by), - self.rope(v[..., self.keep : self._idx, :], shift_by), - ], - axis=2, - ) - else: - return v[..., : self._idx, :] - - def reuse_section(self, write_start_idx: int, reuse_start_idx: int, reuse_length: int) -> None: - # offset indices to account for the fact that we move cache elements around - write_start_idx -= self.reuse_offset - reuse_start_idx -= self.reuse_offset - - # update position offsets for future reuse sections - shift_by = write_start_idx - reuse_start_idx - self.reuse_offset += shift_by - - # queue for reuse: everything is done in one pass at the end in do_reuse - self.reuse_queue.append((write_start_idx, reuse_start_idx, reuse_length)) - - def do_reuse(self) -> None: - last_i: int = len(self.reuse_queue) - 1 - - for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate(self.reuse_queue): - shift_by: int = write_start_idx - reuse_start_idx - assert shift_by <= 0 - reuse_end_idx: int = reuse_start_idx + reuse_length - - keys_to_shift = self.keys[..., reuse_start_idx : reuse_end_idx, :] - values_to_shift = self.values[..., reuse_start_idx : reuse_end_idx, :] - - # perform rope shift - # N.B. we can also go back to the MLX-native "don't rope shift" method - # by removing RoPE here and removing the overrides for trim, temporal order - shifted_keys = self.rope(keys_to_shift, shift_by) - shifted_values = self.rope(values_to_shift, shift_by) - - # restructure cache with mx.concat - # TODO(christian-lms): maybe it would be better to use inplace ops. - # look into the mlx docs if that's even a thing - keycat = [ - self.keys[..., : write_start_idx, :], - shifted_keys - ] - valcat = [ - self.values[..., : write_start_idx, :], - shifted_values - ] - - # TODO(christian-lms): surely there is a better way to do this? - # by not re-appending the end at the last one, we truncate the leftovers - if i != last_i: - keycat.append(self.keys[..., reuse_end_idx : , :]) - valcat.append(self.values[..., reuse_end_idx : , :]) - - self.keys = mx.concat(keycat, axis=2) - self.values = mx.concat(valcat, axis=2) - - self.offset -= shift_by - self.reuse_offset = 0 - self.reuse_queue = [] - # TODO(christian-lms): dunno if this number is correct/reasonable/whatever - self._idx = self.keys.shape[2] - - def trim(self, n) -> int: - n = min(self.offset, n) - if n <= 0: - return 0 - - if self.offset >= self.max_size: - self.keys = self._temporal_order(self.keys) - self.values = self._temporal_order(self.values) - n = n % (self.max_size - self.keep) - - # do trim: put us back into the state before the circular buffer is full - new_length = self.keys.shape[2] - n - self.keys = self.keys[..., : new_length, :] - self.values = self.values[..., : new_length, :] - - self.offset -= n - # TODO(christian-lms): verify that this is reasonable - self._idx = new_length - return n - -def make_prompt_cache( - model: nn.Module, - max_kv_size: Optional[int] = None, -) -> List[Any]: - """ - Construct the model's cache for use in generation. - This function will defer the cache construction to the model if it has a - ``make_cache`` method, otherwise it will make a default KV cache. - Args: - model (nn.Module): The language model. - max_kv_size (Optional[int]): If provided and the model does not have a - ``make_cache`` method, a ``TrimmableRotatingKVCache`` is used - with a maximum size of ``max_kv_size`` - """ - if hasattr(model, "make_cache"): - return model.make_cache() - num_layers = len(model.layers) - if max_kv_size is not None: - cache = [] - for layer in range(num_layers): - rope = maybe_get_rope(model, layer) - # TODO(christian-lms): it is known that this will fail for some models - # like llama4 which has no rope module for every fourth layer. - # this will be figured out Later(tm) once the initial functionality works - if rope is None: - log_warn("Attempted to build a KV cache of shiftable caches, but found" - f"None at layer {layer} of model {model}") - return [KVCache() for _ in range(num_layers)] - # TODO(christian-lms): change keep on the fly, must be setattr elsewhere - cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=4)) - return cache - else: - return [KVCache() for _ in range(num_layers)] - - class CacheWrapper: """ Wrapper class for the MLX LM cache to maintain an in-memory cache From 3df5d31529cd6e6fa31f6544c7c7742b6cd36463 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 13:27:45 -0400 Subject: [PATCH 07/39] begin raw unit tests --- tests/test_cache_shift.py | 56 +++++++++++++++++++++++++++++++++++++ tests/test_cache_wrapper.py | 3 ++ 2 files changed, 59 insertions(+) create mode 100644 tests/test_cache_shift.py diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py new file mode 100644 index 00000000..991c1620 --- /dev/null +++ b/tests/test_cache_shift.py @@ -0,0 +1,56 @@ +import unittest +import mlx.core as mx +import mlx.nn as nn +from mlx_engine.cache import ShiftingKVCache + +def idx(v: mx.array, i: int): + return v[:, :, i, :] + + +class ShiftingCacheTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + cls.rope = nn.RoPE(dims=8, traditional=False, base=10000, scale=1.0) + cls.bsz = 1 + cls.n_kv_heads = 1 + cls.kv_head_dim = 4 + + def make_random_kv(self, seqlen: int): + """Helper method to make a random key/value tensor of the right shape""" + return mx.random.normal((self.bsz, self.n_kv_heads, seqlen, self.kv_head_dim), dtype=mx.float16) + + # TODO: you can test to make sure that it's RoPEing right in the model overall by getting + # the post-shift value, then shifting it back to position 0 and checking the layer 0 kv + # matches the raw token embedding + + def test_overwriting(self): + cache = ShiftingKVCache(self.rope, max_size=3, keep=0) + base_kv = self.make_random_kv(3) + overwrite = self.make_random_kv(1) + overwrite_posemb_4 = self.rope(overwrite, 4) + cache.update_and_fetch(base_kv, base_kv) + keys, _ = cache.update_and_fetch(overwrite, overwrite) + self.assertEqual(overwrite_posemb_4, keys[:, :, 0, :]) + + def test_temporal_order_shift(self): + cache = ShiftingKVCache(self.rope, max_size=3, keep=0) + base_kv = self.make_random_kv(3) + overwrite = self.make_random_kv(1) + overwrite_posemb_3 = self.rope(overwrite, 3) + cache.update_and_fetch(base_kv, base_kv) + cache.update_and_fetch(overwrite, overwrite) + cache.keys = cache._temporal_order(cache.keys) + self.assertEqual(overwrite_posemb_3, cache.keys) + + def test_trim_internal(self): + pass + + def test_trim_before_full(self): + pass + + def test_trim_after_full(self): + pass + + def test_reuse(self): + pass \ No newline at end of file diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index d1d6456c..9acad9ac 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -37,6 +37,9 @@ def test_find_common_prefix_all_match(self): self.assertEqual( result, 4 ) # Should find 4 matching tokens (5-1 due to num_tokens_to_exclude) + + # TODO(christian-lms): write tests for cache shifting, which is high-level + # implemented in cachewrapper and so belongs here if __name__ == "__main__": From 514b6c53cf1d420eea29397ae869d3f4860f9255 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 15:59:22 -0400 Subject: [PATCH 08/39] initial uncommented tests --- mlx_engine/cache.py | 32 ++++++++-- tests/test_cache_shift.py | 129 ++++++++++++++++++++++++++++++++------ 2 files changed, 137 insertions(+), 24 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 444ab8e1..a25e10d9 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -48,10 +48,26 @@ def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: class ShiftingKVCache(RotatingKVCache): def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): - self.rope = rope + self._rope = rope self.reuse_offset = 0 self.reuse_queue = [] - super().__init__(self, max_size, keep, step) + super().__init__(max_size, keep, step) + + def rope(self, v: mx.array, shift_by: int) -> mx.array: + # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl + if shift_by == 0: + return v + + # apply RoPE to each token individually with the same offset + shifted_tokens = [] + seq_len = v.shape[2] # sequence dimension + + for i in range(seq_len): + token = v[:, :, i:i+1, :] # shape [batch, heads, 1, head_dim] + shifted_token = self._rope(token, shift_by) + shifted_tokens.append(shifted_token) + + return mx.concatenate(shifted_tokens, axis=2) def is_trimmable(self) -> bool: return True @@ -59,7 +75,6 @@ def is_trimmable(self) -> bool: def _trim(self, trim_size, v, append=None): to_cat = [] shift_by = -trim_size - assert shift_by <= 0 if trim_size > 0: to_cat = [v[..., : self.keep, :], self.rope(v[..., trim_size + self.keep :, :], shift_by)] else: @@ -75,7 +90,7 @@ def _temporal_order(self, v) -> mx.array: if self._idx == v.shape[2]: return v elif self._idx < self.offset: - shift_by = self.keep - self.idx + shift_by = self.keep - self._idx assert shift_by <= 0 return mx.concatenate( [ @@ -152,14 +167,21 @@ def do_reuse(self) -> None: self._idx = self.keys.shape[2] def trim(self, n) -> int: + # TODO(christian-lms): should trim respect keep? currently, no n = min(self.offset, n) if n <= 0: return 0 + # TODO(christian-lms): so you used to need to wrap around because the code + # didn't know how much it was trying to trim, so it would go over the maximum allowed. + # but i think this was in large part due to improperly tracking the tokens that were + # actually in the cache, so this should not be an issue anymore. therefore this trim code + # will trim exactly n off the end wthout any wrapping around. but you can uncomment the line + # if it turns out that this assumption is faulty if self.offset >= self.max_size: self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - n = n % (self.max_size - self.keep) + # n = n % (self.max_size - self.keep) # do trim: put us back into the state before the circular buffer is full new_length = self.keys.shape[2] - n diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 991c1620..65443d3b 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -3,54 +3,145 @@ import mlx.nn as nn from mlx_engine.cache import ShiftingKVCache + def idx(v: mx.array, i: int): - return v[:, :, i, :] + """Helper function to index into a 4D tensor at the sequence length dimension""" + return v[:, :, i:i+1, :] class ShiftingCacheTest(unittest.TestCase): @classmethod def setUpClass(cls): """Set up test resources that will be shared across all test methods""" - cls.rope = nn.RoPE(dims=8, traditional=False, base=10000, scale=1.0) + cls.kv_head_dim = 4 cls.bsz = 1 cls.n_kv_heads = 1 - cls.kv_head_dim = 4 + # TODO: this won't work.............. nn.RoPE decides that it will increase the offset for each position + cls._rope = nn.RoPE(dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0) + + @classmethod + def rope(cls, v: mx.array, shift_by: int = 0) -> mx.array: + """Apply RoPE to the input tensor with an optional shift""" + # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl + if shift_by == 0: + return v + # Apply RoPE to each token individually with the same offset + shifted_tokens = [] + assert len(v.shape) == 4, "Expected input tensor to have 4 dimensions: [batch, heads, seq_len, head_dim]" + seq_len = v.shape[2] + + for i in range(seq_len): + shifted_token = cls._rope(idx(v, i), shift_by) + shifted_tokens.append(shifted_token) - def make_random_kv(self, seqlen: int): - """Helper method to make a random key/value tensor of the right shape""" - return mx.random.normal((self.bsz, self.n_kv_heads, seqlen, self.kv_head_dim), dtype=mx.float16) + return mx.concatenate(shifted_tokens, axis=2) + @classmethod + def make_random_kv(cls, seqlen: int): + """Helper method to make a random key/value tensor of the right shape""" + return mx.random.normal((cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), scale=1.0, dtype=mx.float32) + + def assertArrEqual(self, a: mx.array, b: mx.array): + """Assert that two tensors are equal over the sequence length dimension""" + self.assertEqual(a.shape, b.shape) + self.assertTrue(mx.allclose(a, b), "Tensors are not equal") + # TODO: you can test to make sure that it's RoPEing right in the model overall by getting # the post-shift value, then shifting it back to position 0 and checking the layer 0 kv # matches the raw token embedding - + def test_overwriting(self): - cache = ShiftingKVCache(self.rope, max_size=3, keep=0) + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) base_kv = self.make_random_kv(3) overwrite = self.make_random_kv(1) - overwrite_posemb_4 = self.rope(overwrite, 4) cache.update_and_fetch(base_kv, base_kv) keys, _ = cache.update_and_fetch(overwrite, overwrite) - self.assertEqual(overwrite_posemb_4, keys[:, :, 0, :]) + + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), idx(base_kv, 2)) def test_temporal_order_shift(self): - cache = ShiftingKVCache(self.rope, max_size=3, keep=0) + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) base_kv = self.make_random_kv(3) overwrite = self.make_random_kv(1) - overwrite_posemb_3 = self.rope(overwrite, 3) + overwrite_roped = self.rope(overwrite, -1) cache.update_and_fetch(base_kv, base_kv) cache.update_and_fetch(overwrite, overwrite) - cache.keys = cache._temporal_order(cache.keys) - self.assertEqual(overwrite_posemb_3, cache.keys) + print(base_kv) + print(cache.keys) + keys = cache._temporal_order(cache.keys) + + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), self.rope(idx(base_kv, 2), -1)) + self.assertArrEqual(idx(keys, 2), overwrite_roped) - def test_trim_internal(self): - pass + def test_trim_internal_shift(self): + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + base_kv = self.make_random_kv(3) + cache.update_and_fetch(base_kv, base_kv) + + keys = cache._trim(1, cache.keys) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), self.rope(idx(base_kv, 2), -1)) + + def test_ensure_reasonable_size_and_shift(self): + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + base_kv = self.make_random_kv(10) + keys, _ = cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) + overwrite = self.make_random_kv(1) + keys, _ = cache.update_and_fetch(overwrite, overwrite) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) + + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), self.rope(idx(base_kv, 9), -7)) def test_trim_before_full(self): - pass + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + base_kv = self.make_random_kv(2) + cache.update_and_fetch(base_kv, base_kv) + + cache.trim(1) + keys = cache.keys + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + + new_kv = self.make_random_kv(1) + keys, _ = cache.update_and_fetch(new_kv, new_kv) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), new_kv) def test_trim_after_full(self): - pass + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + base_kv = self.make_random_kv(4) + cache.update_and_fetch(base_kv, base_kv) + + cache.trim(2) + keys = cache.keys + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, base_kv[:,:,:2,:]) + + new_kv = self.make_random_kv(2) + keys, _ = cache.update_and_fetch(new_kv, new_kv) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) + self.assertArrEqual(keys[:,:,:2,:], base_kv[:,:,:2,:]) + self.assertArrEqual(keys[:,:,2:,:], new_kv) def test_reuse(self): - pass \ No newline at end of file + cache = ShiftingKVCache(self._rope, max_size=6, keep=1) + base_kv = self.make_random_kv(8) + original_prompt_cache = base_kv[:, :, :6, :] + cache.update_and_fetch(base_kv, base_kv) + new_prompt_cache = mx.concatenate([base_kv[:,:,:3,:], self.rope(base_kv[:,:,4:,:], -1)], axis=2) + # here we know what to reuse so hardcode it, dynamic reuse is in test_cache_wrapper + cache.reuse_section(3, 4, 2) + cache.do_reuse() + keys = cache.keys + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) + self.assertArrEqual(keys, new_prompt_cache[:,:,:5,:]) From 3f6456613d04038a8a9faa5b6f1f2cc285e1f606 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Mon, 7 Jul 2025 17:41:16 -0400 Subject: [PATCH 09/39] manual rope stuff --- mlx_engine/cache.py | 16 +----- mlx_engine/rope.py | 107 ++++++++++++++++++++++++++++++++++++++ tests/test_cache_shift.py | 14 +---- 3 files changed, 110 insertions(+), 27 deletions(-) create mode 100644 mlx_engine/rope.py diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index a25e10d9..a7d3a3ea 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -55,20 +55,8 @@ def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): def rope(self, v: mx.array, shift_by: int) -> mx.array: # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl - if shift_by == 0: - return v - - # apply RoPE to each token individually with the same offset - shifted_tokens = [] - seq_len = v.shape[2] # sequence dimension - - for i in range(seq_len): - token = v[:, :, i:i+1, :] # shape [batch, heads, 1, head_dim] - shifted_token = self._rope(token, shift_by) - shifted_tokens.append(shifted_token) - - return mx.concatenate(shifted_tokens, axis=2) - + return mx.concatenate([self._rope(v[:,:,i:i+1,:], shift_by) for i in range(v.shape[2])], axis=2) + def is_trimmable(self) -> bool: return True diff --git a/mlx_engine/rope.py b/mlx_engine/rope.py new file mode 100644 index 00000000..eccf7331 --- /dev/null +++ b/mlx_engine/rope.py @@ -0,0 +1,107 @@ +"""So... + +...this isn't optimized yet. It turns out that at small sequence lengths literally just naively +applying RoPE to individual tokens is faster than doing the matrix multiplication, even with the overhead +introduced by the for loop and the theoretical optimization of the matrix multiplication. It does +begin to tip in favor of this shifting method @ larger seqlens (tested at [1,8,1000,128]) but the +overhead converting between MLX and torch still makes it slower overall than MLX-native. + +- the MLX rope shift is weird anyway, but it apparently still works: TODO ask awni why +- this implementation is naive and doesn't leverage the sparsity of the RoPE matrix +- honestly it's probably easier to just use the MLX RoPE shift directly in the model + because this allows us to not have to write custom modules for YaRN and llama3 and + what have you, but i'll leave this here for now in case it becomes useful later +""" + +import torch +import mlx.core as mx +from mlx_lm import load +import numpy as np + +def mlx_rope_shift(x, shift_amount, theta=10000.0, scale=1.0, traditional=False): + """ + MLX-compatible RoPE implementation using matrix multiplication. + Creates a rotation matrix and applies it via matmul to shift all positions by shift_amount. + + Args: + x: Input tensor of shape [bsz, n_kv_heads, seqlen, kv_head_dim] + shift_amount: Number of positions to shift (D) + theta: Base frequency for RoPE (default: 10000.0) + scale: Scaling factor for frequencies (default: 1.0) + traditional: If True, use traditional RoPE pairing (0,1), (2,3), ... + If False, use MLX-style pairing (0,d/2), (1,d/2+1), ... (default: False) + + Returns: + Rotated tensor of same shape as input + """ + bsz, n_heads, seqlen, head_dim = x.shape + device = x.device + + assert head_dim % 2 == 0, "Head dimension must be even" + dim_pairs = head_dim // 2 + + if traditional: + # traditional RoPE: pair adjacent dimensions (0,1), (2,3), (4,5), ... + frequencies = 1.0 / (theta ** (torch.arange(0, dim_pairs, dtype=torch.float32, device=device) * 2.0 / head_dim)) + else: + # MLX-style RoPE: pair first half with second half (0,d/2), (1,d/2+1), ... + frequencies = 1.0 / (theta ** (torch.arange(0, dim_pairs, dtype=torch.float32, device=device) * 2.0 / head_dim)) + + frequencies = frequencies * scale + angles = shift_amount * frequencies # shape: [dim_pairs] + cos_vals = torch.cos(angles) # shape: [dim_pairs] + sin_vals = torch.sin(angles) # shape: [dim_pairs] + + rotation_matrix = torch.eye(head_dim, device=device, dtype=x.dtype) + + if traditional: + for i in range(dim_pairs): + even_idx = i * 2 + odd_idx = i * 2 + 1 + + rotation_matrix[even_idx, even_idx] = cos_vals[i] + rotation_matrix[even_idx, odd_idx] = -sin_vals[i] + rotation_matrix[odd_idx, even_idx] = sin_vals[i] + rotation_matrix[odd_idx, odd_idx] = cos_vals[i] + else: + for i in range(dim_pairs): + first_idx = i + second_idx = i + dim_pairs + + cos_val = cos_vals[i] + sin_val = sin_vals[i] + + rotation_matrix[first_idx, first_idx] = cos_val + rotation_matrix[first_idx, second_idx] = -sin_val + rotation_matrix[second_idx, first_idx] = sin_val + rotation_matrix[second_idx, second_idx] = cos_val + + rotated = x @ rotation_matrix.T + + return rotated + + +def stupid_rope(r, v, shift_by: int = 0): + return mx.concatenate([r(v[:,:,i:i+1,:], shift_by) for i in range(v.shape[2])], axis=2) + +def main(): + model, _ = load("mlx-community/Qwen3-0.6B-bf16") + + v = mx.random.normal((1, 8, 10, 128), scale=1.0, dtype=mx.float32) + + import time + start_time = time.time() + silly = stupid_rope(model.layers[0].self_attn.rope, v, 7) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"MLX RoPE shift took {elapsed_time:.6f} seconds") + converted = torch.from_numpy(np.array(v)) + start_time = time.time() + eff = mlx_rope_shift(converted, 7, theta=1000000.0, scale=1.0, traditional=False) + end_time = time.time() + elapsed_time2 = end_time - start_time + print(f"Torch RoPE shift took {elapsed_time2:.6f} seconds") + print(torch.allclose(torch.from_numpy(np.array(silly)), eff, atol=1e-5)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 65443d3b..54c28d24 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -22,19 +22,7 @@ def setUpClass(cls): @classmethod def rope(cls, v: mx.array, shift_by: int = 0) -> mx.array: """Apply RoPE to the input tensor with an optional shift""" - # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl - if shift_by == 0: - return v - # Apply RoPE to each token individually with the same offset - shifted_tokens = [] - assert len(v.shape) == 4, "Expected input tensor to have 4 dimensions: [batch, heads, seq_len, head_dim]" - seq_len = v.shape[2] - - for i in range(seq_len): - shifted_token = cls._rope(idx(v, i), shift_by) - shifted_tokens.append(shifted_token) - - return mx.concatenate(shifted_tokens, axis=2) + return mx.concatenate([cls._rope(v[:,:,i:i+1,:], shift_by) for i in range(v.shape[2])], axis=2) @classmethod def make_random_kv(cls, seqlen: int): From 22d8a833ffc1b3276b61b29c139e2969a7fdf799 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 09:59:15 -0400 Subject: [PATCH 10/39] pipe in n_keep --- mlx_engine/cache.py | 5 ++--- mlx_engine/cache_wrapper.py | 19 ++++++++++++------- mlx_engine/generate.py | 4 ++++ mlx_engine/model_kit/model_kit.py | 2 ++ mlx_engine/utils/prompt_processing.py | 2 ++ 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index a7d3a3ea..f07e54bb 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -83,8 +83,6 @@ def _temporal_order(self, v) -> mx.array: return mx.concatenate( [ v[..., : self.keep, :], - # TODO(christian-lms): verify that i work - # TODO(christian-lms): can you do this in 1 call to self.rope? # N.B. this implicitly assumes the generation has not gone over twice # the size of the rotating section of the cache, in which case the # rotating section would be off by a multiple of (max_kv_size - keep) @@ -184,6 +182,7 @@ def trim(self, n) -> int: def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, + keep: int = 4, ) -> List[Any]: """ Construct the model's cache for use in generation. @@ -210,7 +209,7 @@ def make_prompt_cache( f"None at layer {layer} of model {model}") return [KVCache() for _ in range(num_layers)] # TODO(christian-lms): change keep on the fly, must be setattr elsewhere - cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=4)) + cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=keep)) return cache else: return [KVCache() for _ in range(num_layers)] \ No newline at end of file diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 04b5b1cb..00d0486e 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -26,6 +26,7 @@ def __init__( kv_bits: Optional[int] = None, kv_group_size: Optional[int] = None, quantized_kv_start: Optional[int] = None, + keep: int = 4, ): """ Initialize the CacheWrapper. @@ -36,7 +37,8 @@ def __init__( """ # utilize a simple ordered list of tokens processed so far for cache invalidation checking self.tokens: Optional[mx.array] = None - self.cache: List[Any] = make_prompt_cache(model, max_kv_size) + self.keep = keep + self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep) self.model = model self.draft_model: Optional[nn.Module] = None self.max_kv_size = max_kv_size @@ -85,7 +87,7 @@ def _find_common_prefix( return common_length def _get_unprocessed_tokens( - self, prompt_tokens: mx.array, num_tokens_to_exclude: int + self, prompt_tokens: mx.array, num_tokens_to_exclude: int, keep: int = 4 ): """ Get the unprocessed tokens from the prompt. @@ -115,7 +117,7 @@ def _get_unprocessed_tokens( message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: " "Cache is not trimmable. Clearing the cache instead.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) self.tokens = prompt_tokens return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) @@ -126,7 +128,7 @@ def _get_unprocessed_tokens( message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected " " ({num_tokens_to_trim}). Clearing the cache.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) self.tokens = prompt_tokens return self.tokens log_info( @@ -221,9 +223,9 @@ def set_draft_model(self, draft_model: nn.Module): message="Clearing current prompt cache and adding draft model to the cache", ) self.tokens = None - self.cache: List[Any] = make_prompt_cache(self.model) + self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep) if draft_model is not None: - self.cache += make_prompt_cache(draft_model) + self.cache += make_prompt_cache(draft_model, keep=self.keep) self.draft_model = draft_model def unset_draft_model(self): @@ -239,6 +241,7 @@ def update_cache( prompt_progress_callback, *, num_tokens_to_exclude: int = 1, + keep: int = 4 ) -> mx.array: """ Set up the KV cache for the next generation. @@ -248,15 +251,17 @@ def update_cache( prompt_tokens (mx.array): The prompt tokens. prompt_progress_callback (Callable): A callback function to report prompt processing progress. num_tokens_to_exclude (int): The number of tokens that should not be added to the cache. + keep (int): The number of tokens to always keep in the prefix of the prompt cache. Returns: mx.array: The prompt tokens to be used for the next generation. """ if prompt_progress_callback is None: - def prompt_progress_callback(x): return None + self.keep = keep + # TODO(christian-lms): truncation logic goes here now num_tokens_to_exclude = max(num_tokens_to_exclude, 1) diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index 6019cf16..54c581ad 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -137,6 +137,7 @@ def create_generator( max_tokens: Optional[int] = 10000000, speculative_decoding_toggle: Optional[bool] = None, num_draft_tokens: Optional[int] = None, + keep: Optional[int] = 4, ) -> Iterator[GenerationResult]: """ Create a generator that streams text generation results from the model. @@ -171,6 +172,8 @@ def create_generator( if a draft model is loaded. If set to true, draft model must be loaded or else error. If set to false, speculative decoding is disabled even if a draft model is loaded. num_draft_tokens (Optional[int]): Number of tokens to draft when using speculative decoding + keep (Optional[int]): Number of tokens to always keep in the prefix of the prompt cache. + Defaults to 4, which is the minimum number of tokens needed for a valid prompt. Yields: GenerationResult: A named tuple containing: @@ -218,6 +221,7 @@ def create_generator( prompt_progress_callback, generate_args, speculative_decoding_toggle, + keep=keep, ) if draft_model is None: # input embeddings not yet supported for speculative decoding in mlx-lm diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index a389dc6d..4f93437b 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -141,6 +141,7 @@ def process_prompt( prompt_progress_callback, generate_args, speculative_decoding_toggle: Optional[bool] = None, + keep: int = 4, ) -> Tuple[mx.array, Optional[mx.array]]: ### TEXT-ONLY PROCESS_PROMPT ### is_text_only_processing = images_b64 is None or len(images_b64) == 0 @@ -160,6 +161,7 @@ def process_prompt( self.draft_model, speculative_decoding_toggle, prompt_progress_callback, + keep=keep, ), None ### WITH IMAGES PROMPT PROCESSING ###s if self.vision_add_on is None: diff --git a/mlx_engine/utils/prompt_processing.py b/mlx_engine/utils/prompt_processing.py index 78a687d1..380cf243 100644 --- a/mlx_engine/utils/prompt_processing.py +++ b/mlx_engine/utils/prompt_processing.py @@ -13,6 +13,7 @@ def process_prompt_text_only( draft_model: Optional[nn.Module] = None, speculative_decoding_toggle: Optional[bool] = None, prompt_progress_callback: Optional[Callable[[float], None]] = None, + keep: int = 4, ): if cache_wrapper is None: raise ValueError("Cache wrapper is not initialized, cannot process prompt") @@ -38,6 +39,7 @@ def process_prompt_text_only( prompt_tokens = cache_wrapper.update_cache( prompt_tokens, prompt_progress_callback, + keep=keep, ) generate_args["prompt_cache"] = cache_wrapper.cache return prompt_tokens From d165867b8f1bfcfd6090344bd9b5bebba22eb617 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 10:14:27 -0400 Subject: [PATCH 11/39] a few more cache wrapper tests --- mlx_engine/cache_wrapper.py | 66 +++++++++++++++++------------- tests/test_cache_wrapper.py | 81 +++++++++++++++++++++++++++++++++---- 2 files changed, 112 insertions(+), 35 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 00d0486e..03a3dc65 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -50,41 +50,51 @@ def __init__( ) @staticmethod - def _find_common_prefix( - current_tokens: mx.array, prompt_tokens: mx.array, num_tokens_to_exclude: int + def _find_matching_sequence_length( + tokens1: mx.array, + tokens2: mx.array, + start1: int = 0, + start2: int = 0, + num_tokens_to_exclude: int = 0 ) -> int: """ - Determine the common prefix length between the current tokens and the prompt tokens. - + Find the length of matching token sequence between two token arrays. + Args: - current_tokens (mx.array): The cached tokens (self.tokens). - prompt_tokens (mx.array): The prompt tokens. - num_tokens_to_exclude (int): The minimum length of the remaining prompt tokens array. - + tokens1: First token array + start1: Starting position in first array + tokens2: Second token array + start2: Starting position in second array + num_tokens_to_exclude (int): The minimum length of the leftover non-matching + segment of tokens1, to be excluded from the match length. + Returns: - int: The length of the common prefix. + int: Length of matching sequence """ - prompt_tokens = prompt_tokens - current_tokens = current_tokens - # Find the minimum length between the two arrays - min_length = min(len(current_tokens), len(prompt_tokens)) - - # Compare elements up to the minimum length - mask = prompt_tokens[:min_length] == current_tokens[:min_length] - - # Find the index where the first mismatch occurs + # Calculate actual bounds + max_len1 = len(tokens1) - start1 + max_len2 = len(tokens2) - start2 + min_length = int(min(max_len1, max_len2)) + + # Extract subsequences to compare + seq1 = tokens1[start1 : start1 + min_length] + seq2 = tokens2[start2 : start2 + min_length] + + # Find first mismatch + mask = seq1 == seq2 if mx.any(mask == False): # noqa E712 - common_length = int(mx.argmax(mask == False)) # noqa E712 + match_length = int(mx.argmax(mask == False)) # noqa E712 else: - common_length = int(min_length) - - # Ensure that the prompt is at least num_tokens_to_exclude long - uncached_prompt_tokens_length = len(prompt_tokens[common_length:]) + match_length = min_length + + # Ensure that the leftover non-matching segment of tokens1 + # is at least num_tokens_to_exclude long + leftover_tokens1_length = len(tokens1[match_length:]) length_adjustment = max( - 0, num_tokens_to_exclude - uncached_prompt_tokens_length + 0, num_tokens_to_exclude - leftover_tokens1_length ) - common_length = max(common_length - length_adjustment, 0) - return common_length + match_length = max(match_length - length_adjustment, 0) + return match_length def _get_unprocessed_tokens( self, prompt_tokens: mx.array, num_tokens_to_exclude: int, keep: int = 4 @@ -104,8 +114,8 @@ def _get_unprocessed_tokens( return self.tokens # Find common KV between the last generation and the current prompt - common_prefix = self._find_common_prefix( - self.tokens, prompt_tokens, num_tokens_to_exclude + common_prefix = self._find_matching_sequence_length( + self.tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude ) # Trim the cache if the common prefix is shorter than the current cache diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 9acad9ac..eabd8705 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -4,7 +4,7 @@ class TestCacheWrapper(unittest.TestCase): - def test_find_common_prefix_with_mismatch(self): + def test_find_matching_sequence_length_with_mismatch(self): """Test when there's a mismatch in the tokens""" # Create two arrays with a known common prefix [1, 2, 3] current_tokens = mx.array([1, 2, 3, 4, 5]) @@ -15,12 +15,12 @@ def test_find_common_prefix_with_mismatch(self): print(f"current_tokens: {current_tokens}") print(f"prompt_tokens: {prompt_tokens}") - result = CacheWrapper._find_common_prefix( - current_tokens, prompt_tokens, num_tokens_to_exclude + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude ) self.assertEqual(result, 3) # Should find 3 matching tokens - def test_find_common_prefix_all_match(self): + def test_find_matching_sequence_length_all_match(self): """Test when all tokens match""" # Create two identical arrays current_tokens = mx.array([1, 2, 3, 4, 5]) @@ -31,13 +31,80 @@ def test_find_common_prefix_all_match(self): print(f"current_tokens: {current_tokens}") print(f"prompt_tokens: {prompt_tokens}") - result = CacheWrapper._find_common_prefix( - current_tokens, prompt_tokens, num_tokens_to_exclude + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude ) self.assertEqual( result, 4 ) # Should find 4 matching tokens (5-1 due to num_tokens_to_exclude) - + + def test_find_matching_sequence_length_no_match(self): + """Test when no tokens match""" + # Create two arrays with no common prefix + current_tokens = mx.array([1, 2, 3, 4, 5]) + prompt_tokens = mx.array([6, 7, 8, 9, 10]) + num_tokens_to_exclude = 1 + + print("\nTest with no matching tokens:") + print(f"current_tokens: {current_tokens}") + print(f"prompt_tokens: {prompt_tokens}") + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude + ) + self.assertEqual(result, 0) # No matching tokens should return 0 + + def test_find_matching_sequence_length_offset_starts(self): + """Test when the current tokens start with a different offset""" + # Create two arrays where the current tokens start with a different offset + current_tokens = mx.array([2, 3, 4, 5]) + prompt_tokens = mx.array([1, 2, 3, 4, 5]) + num_tokens_to_exclude = 1 + + print("\nTest with offset starts:") + print(f"current_tokens: {current_tokens}") + print(f"prompt_tokens: {prompt_tokens}") + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, + prompt_tokens, + start2=1, + num_tokens_to_exclude=num_tokens_to_exclude, + ) + self.assertEqual(result, 3) + + def test_find_matching_sequence_length_more_offsets(self): + """Test when the current tokens have more offsets""" + # Create two arrays where the current tokens have more offsets + current_tokens = mx.array([1, 2, 3, 4, 5, 6]) + prompt_tokens = mx.array([0, 9, 10, 3, 4, 7, 8]) + + print("\nTest with more offsets:") + print(f"current_tokens: {current_tokens}") + print(f"prompt_tokens: {prompt_tokens}") + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens + ) + self.assertEqual(result, 0) + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, + prompt_tokens, + start1=2, + start2=3, + ) + self.assertEqual(result, 2) + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, + prompt_tokens, + start1=2, + start2=3, + num_tokens_to_exclude=1, + ) + self.assertEqual(result, 2) # there are leftovers anyway + # TODO(christian-lms): write tests for cache shifting, which is high-level # implemented in cachewrapper and so belongs here From 199d2319bdaf1ee989a8cae8ba901e800bd324d8 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 10:14:46 -0400 Subject: [PATCH 12/39] ruff formatting lmao --- mlx_engine/cache.py | 82 +++++++++++++++++++------------------ mlx_engine/cache_wrapper.py | 48 +++++++++++----------- tests/test_cache_shift.py | 56 +++++++++++++------------ 3 files changed, 96 insertions(+), 90 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index f07e54bb..51dec7b2 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -1,10 +1,7 @@ from typing import List, Optional, Any -from mlx_engine.logging import log_info, log_warn, log_error -from mlx_lm.models.cache import ( - RotatingKVCache, - KVCache -) +from mlx_engine.logging import log_warn +from mlx_lm.models.cache import RotatingKVCache, KVCache import mlx.core as mx import mlx.nn as nn @@ -31,7 +28,7 @@ def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: """Attempt to find the RoPE module from a layer of an MLX-LM LLM. Args: - model (nn.Module): The LLM to search for the RoPE modules of. + model (nn.Module): The LLM to search for the RoPE modules of. layer_idx (int): The layer of the LLM to get the RoPE module from. Returns: @@ -55,22 +52,28 @@ def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): def rope(self, v: mx.array, shift_by: int) -> mx.array: # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl - return mx.concatenate([self._rope(v[:,:,i:i+1,:], shift_by) for i in range(v.shape[2])], axis=2) + return mx.concatenate( + [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], + axis=2, + ) def is_trimmable(self) -> bool: return True - + def _trim(self, trim_size, v, append=None): to_cat = [] shift_by = -trim_size if trim_size > 0: - to_cat = [v[..., : self.keep, :], self.rope(v[..., trim_size + self.keep :, :], shift_by)] + to_cat = [ + v[..., : self.keep, :], + self.rope(v[..., trim_size + self.keep :, :], shift_by), + ] else: to_cat = [v] if append is not None: to_cat.append(append) return mx.concatenate(to_cat, axis=2) - + def _temporal_order(self, v) -> mx.array: """ Rearrange the cache into temporal order, slicing off the end if unused. @@ -95,69 +98,67 @@ def _temporal_order(self, v) -> mx.array: ) else: return v[..., : self._idx, :] - - def reuse_section(self, write_start_idx: int, reuse_start_idx: int, reuse_length: int) -> None: + + def reuse_section( + self, write_start_idx: int, reuse_start_idx: int, reuse_length: int + ) -> None: # offset indices to account for the fact that we move cache elements around write_start_idx -= self.reuse_offset reuse_start_idx -= self.reuse_offset - + # update position offsets for future reuse sections shift_by = write_start_idx - reuse_start_idx self.reuse_offset += shift_by # queue for reuse: everything is done in one pass at the end in do_reuse self.reuse_queue.append((write_start_idx, reuse_start_idx, reuse_length)) - + def do_reuse(self) -> None: last_i: int = len(self.reuse_queue) - 1 - for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate(self.reuse_queue): + for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate( + self.reuse_queue + ): shift_by: int = write_start_idx - reuse_start_idx assert shift_by <= 0 reuse_end_idx: int = reuse_start_idx + reuse_length - keys_to_shift = self.keys[..., reuse_start_idx : reuse_end_idx, :] - values_to_shift = self.values[..., reuse_start_idx : reuse_end_idx, :] + keys_to_shift = self.keys[..., reuse_start_idx:reuse_end_idx, :] + values_to_shift = self.values[..., reuse_start_idx:reuse_end_idx, :] # perform rope shift # N.B. we can also go back to the MLX-native "don't rope shift" method # by removing RoPE here and removing the overrides for trim, temporal order shifted_keys = self.rope(keys_to_shift, shift_by) shifted_values = self.rope(values_to_shift, shift_by) - + # restructure cache with mx.concat # TODO(christian-lms): maybe it would be better to use inplace ops. # look into the mlx docs if that's even a thing - keycat = [ - self.keys[..., : write_start_idx, :], - shifted_keys - ] - valcat = [ - self.values[..., : write_start_idx, :], - shifted_values - ] - + keycat = [self.keys[..., :write_start_idx, :], shifted_keys] + valcat = [self.values[..., :write_start_idx, :], shifted_values] + # TODO(christian-lms): surely there is a better way to do this? # by not re-appending the end at the last one, we truncate the leftovers if i != last_i: - keycat.append(self.keys[..., reuse_end_idx : , :]) - valcat.append(self.values[..., reuse_end_idx : , :]) + keycat.append(self.keys[..., reuse_end_idx:, :]) + valcat.append(self.values[..., reuse_end_idx:, :]) self.keys = mx.concat(keycat, axis=2) self.values = mx.concat(valcat, axis=2) - + self.offset -= shift_by self.reuse_offset = 0 self.reuse_queue = [] # TODO(christian-lms): dunno if this number is correct/reasonable/whatever self._idx = self.keys.shape[2] - + def trim(self, n) -> int: # TODO(christian-lms): should trim respect keep? currently, no n = min(self.offset, n) if n <= 0: return 0 - + # TODO(christian-lms): so you used to need to wrap around because the code # didn't know how much it was trying to trim, so it would go over the maximum allowed. # but i think this was in large part due to improperly tracking the tokens that were @@ -171,14 +172,15 @@ def trim(self, n) -> int: # do trim: put us back into the state before the circular buffer is full new_length = self.keys.shape[2] - n - self.keys = self.keys[..., : new_length, :] - self.values = self.values[..., : new_length, :] - + self.keys = self.keys[..., :new_length, :] + self.values = self.values[..., :new_length, :] + self.offset -= n # TODO(christian-lms): verify that this is reasonable self._idx = new_length return n - + + def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, @@ -205,11 +207,13 @@ def make_prompt_cache( # like llama4 which has no rope module for every fourth layer. # this will be figured out Later(tm) once the initial functionality works if rope is None: - log_warn("Attempted to build a KV cache of shiftable caches, but found" - f"None at layer {layer} of model {model}") + log_warn( + "Attempted to build a KV cache of shiftable caches, but found" + f"None at layer {layer} of model {model}" + ) return [KVCache() for _ in range(num_layers)] # TODO(christian-lms): change keep on the fly, must be setattr elsewhere cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=keep)) return cache else: - return [KVCache() for _ in range(num_layers)] \ No newline at end of file + return [KVCache() for _ in range(num_layers)] diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 03a3dc65..77ca74c5 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -2,10 +2,7 @@ from mlx_engine.logging import log_info, log_warn, log_error from mlx_engine.cache import make_prompt_cache -from mlx_lm.models.cache import ( - trim_prompt_cache, - can_trim_prompt_cache -) +from mlx_lm.models.cache import trim_prompt_cache, can_trim_prompt_cache from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache import mlx.core as mx import mlx.nn as nn @@ -51,23 +48,23 @@ def __init__( @staticmethod def _find_matching_sequence_length( - tokens1: mx.array, - tokens2: mx.array, + tokens1: mx.array, + tokens2: mx.array, start1: int = 0, start2: int = 0, - num_tokens_to_exclude: int = 0 + num_tokens_to_exclude: int = 0, ) -> int: """ Find the length of matching token sequence between two token arrays. - + Args: tokens1: First token array start1: Starting position in first array - tokens2: Second token array + tokens2: Second token array start2: Starting position in second array num_tokens_to_exclude (int): The minimum length of the leftover non-matching segment of tokens1, to be excluded from the match length. - + Returns: int: Length of matching sequence """ @@ -75,24 +72,22 @@ def _find_matching_sequence_length( max_len1 = len(tokens1) - start1 max_len2 = len(tokens2) - start2 min_length = int(min(max_len1, max_len2)) - + # Extract subsequences to compare seq1 = tokens1[start1 : start1 + min_length] seq2 = tokens2[start2 : start2 + min_length] - + # Find first mismatch mask = seq1 == seq2 if mx.any(mask == False): # noqa E712 match_length = int(mx.argmax(mask == False)) # noqa E712 else: match_length = min_length - + # Ensure that the leftover non-matching segment of tokens1 # is at least num_tokens_to_exclude long leftover_tokens1_length = len(tokens1[match_length:]) - length_adjustment = max( - 0, num_tokens_to_exclude - leftover_tokens1_length - ) + length_adjustment = max(0, num_tokens_to_exclude - leftover_tokens1_length) match_length = max(match_length - length_adjustment, 0) return match_length @@ -127,7 +122,9 @@ def _get_unprocessed_tokens( message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: " "Cache is not trimmable. Clearing the cache instead.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) + self.cache = make_prompt_cache( + self.model, self.max_kv_size, keep=self.keep + ) self.tokens = prompt_tokens return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) @@ -138,7 +135,9 @@ def _get_unprocessed_tokens( message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected " " ({num_tokens_to_trim}). Clearing the cache.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size, keep=self.keep) + self.cache = make_prompt_cache( + self.model, self.max_kv_size, keep=self.keep + ) self.tokens = prompt_tokens return self.tokens log_info( @@ -251,7 +250,7 @@ def update_cache( prompt_progress_callback, *, num_tokens_to_exclude: int = 1, - keep: int = 4 + keep: int = 4, ) -> mx.array: """ Set up the KV cache for the next generation. @@ -267,11 +266,12 @@ def update_cache( mx.array: The prompt tokens to be used for the next generation. """ if prompt_progress_callback is None: + def prompt_progress_callback(x): return None - + self.keep = keep - + # TODO(christian-lms): truncation logic goes here now num_tokens_to_exclude = max(num_tokens_to_exclude, 1) @@ -313,8 +313,8 @@ def prompt_progress_callback(x): def record_generated_token(self, token): """ Add the generated token to the token list, so that we can map the token to the KV cache. - - Also loop when the cache does so that we accurately track what's in cache. + + Also loop when the cache does so that we accurately track what's in cache. """ # TODO(christian-lms): ensure that this works as intended when over length # TODO(christian-lms): verify rolling window and truncate middle have n_keep as below @@ -323,5 +323,5 @@ def record_generated_token(self, token): # this behavior is common to rolling window (n_keep = 0) and truncate middle # (n_keep > 0), and we should never get here with stop at max if len(self.tokens) >= n_keep: - self.tokens = mx.concat([self.tokens[:n_keep], self.tokens[n_keep+1:]]) + self.tokens = mx.concat([self.tokens[:n_keep], self.tokens[n_keep + 1 :]]) self.tokens = mx.concat([self.tokens, mx.array([token])]) diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 54c28d24..7ceedcc9 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -6,7 +6,7 @@ def idx(v: mx.array, i: int): """Helper function to index into a 4D tensor at the sequence length dimension""" - return v[:, :, i:i+1, :] + return v[:, :, i : i + 1, :] class ShiftingCacheTest(unittest.TestCase): @@ -16,18 +16,19 @@ def setUpClass(cls): cls.kv_head_dim = 4 cls.bsz = 1 cls.n_kv_heads = 1 - # TODO: this won't work.............. nn.RoPE decides that it will increase the offset for each position - cls._rope = nn.RoPE(dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0) + # cannot be used raw: must be wrapped in the cache.rope workaround impl + cls._rope = nn.RoPE( + dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0 + ) - @classmethod - def rope(cls, v: mx.array, shift_by: int = 0) -> mx.array: - """Apply RoPE to the input tensor with an optional shift""" - return mx.concatenate([cls._rope(v[:,:,i:i+1,:], shift_by) for i in range(v.shape[2])], axis=2) - @classmethod def make_random_kv(cls, seqlen: int): """Helper method to make a random key/value tensor of the right shape""" - return mx.random.normal((cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), scale=1.0, dtype=mx.float32) + return mx.random.normal( + (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), + scale=1.0, + dtype=mx.float32, + ) def assertArrEqual(self, a: mx.array, b: mx.array): """Assert that two tensors are equal over the sequence length dimension""" @@ -36,7 +37,7 @@ def assertArrEqual(self, a: mx.array, b: mx.array): # TODO: you can test to make sure that it's RoPEing right in the model overall by getting # the post-shift value, then shifting it back to position 0 and checking the layer 0 kv - # matches the raw token embedding + # matches the raw token embedding def test_overwriting(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) @@ -48,12 +49,12 @@ def test_overwriting(self): self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), overwrite) self.assertArrEqual(idx(keys, 2), idx(base_kv, 2)) - + def test_temporal_order_shift(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) base_kv = self.make_random_kv(3) overwrite = self.make_random_kv(1) - overwrite_roped = self.rope(overwrite, -1) + overwrite_roped = cache.rope(overwrite, -1) cache.update_and_fetch(base_kv, base_kv) cache.update_and_fetch(overwrite, overwrite) print(base_kv) @@ -61,19 +62,19 @@ def test_temporal_order_shift(self): keys = cache._temporal_order(cache.keys) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), self.rope(idx(base_kv, 2), -1)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) self.assertArrEqual(idx(keys, 2), overwrite_roped) - + def test_trim_internal_shift(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) base_kv = self.make_random_kv(3) cache.update_and_fetch(base_kv, base_kv) - + keys = cache._trim(1, cache.keys) - + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), self.rope(idx(base_kv, 2), -1)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) def test_ensure_reasonable_size_and_shift(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) @@ -86,13 +87,13 @@ def test_ensure_reasonable_size_and_shift(self): self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), overwrite) - self.assertArrEqual(idx(keys, 2), self.rope(idx(base_kv, 9), -7)) + self.assertArrEqual(idx(keys, 2), cache.rope(idx(base_kv, 9), -7)) def test_trim_before_full(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) base_kv = self.make_random_kv(2) cache.update_and_fetch(base_kv, base_kv) - + cache.trim(1) keys = cache.keys @@ -104,7 +105,7 @@ def test_trim_before_full(self): self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), new_kv) - + def test_trim_after_full(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) base_kv = self.make_random_kv(4) @@ -113,23 +114,24 @@ def test_trim_after_full(self): cache.trim(2) keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(keys, base_kv[:,:,:2,:]) + self.assertArrEqual(keys, base_kv[:, :, :2, :]) new_kv = self.make_random_kv(2) keys, _ = cache.update_and_fetch(new_kv, new_kv) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) - self.assertArrEqual(keys[:,:,:2,:], base_kv[:,:,:2,:]) - self.assertArrEqual(keys[:,:,2:,:], new_kv) - + self.assertArrEqual(keys[:, :, :2, :], base_kv[:, :, :2, :]) + self.assertArrEqual(keys[:, :, 2:, :], new_kv) + def test_reuse(self): cache = ShiftingKVCache(self._rope, max_size=6, keep=1) base_kv = self.make_random_kv(8) - original_prompt_cache = base_kv[:, :, :6, :] cache.update_and_fetch(base_kv, base_kv) - new_prompt_cache = mx.concatenate([base_kv[:,:,:3,:], self.rope(base_kv[:,:,4:,:], -1)], axis=2) + new_prompt_cache = mx.concatenate( + [base_kv[:, :, :3, :], cache.rope(base_kv[:, :, 4:, :], -1)], axis=2 + ) # here we know what to reuse so hardcode it, dynamic reuse is in test_cache_wrapper cache.reuse_section(3, 4, 2) cache.do_reuse() keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) - self.assertArrEqual(keys, new_prompt_cache[:,:,:5,:]) + self.assertArrEqual(keys, new_prompt_cache[:, :, :5, :]) From 8dcb82d69a850e59558c6a606630630e1ae845e4 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 10:23:37 -0400 Subject: [PATCH 13/39] record_generated_token test and fix --- mlx_engine/cache_wrapper.py | 14 ++---- tests/test_cache_wrapper.py | 25 ++++++++++ tests/utils.py | 95 +++++++++++++++++++------------------ 3 files changed, 80 insertions(+), 54 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 77ca74c5..09d96529 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -92,7 +92,7 @@ def _find_matching_sequence_length( return match_length def _get_unprocessed_tokens( - self, prompt_tokens: mx.array, num_tokens_to_exclude: int, keep: int = 4 + self, prompt_tokens: mx.array, num_tokens_to_exclude: int ): """ Get the unprocessed tokens from the prompt. @@ -113,6 +113,8 @@ def _get_unprocessed_tokens( self.tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude ) + # TODO reuse logic goes here + # Trim the cache if the common prefix is shorter than the current cache num_tokens_to_trim = self.cache[0].offset - common_prefix if num_tokens_to_trim > 0: @@ -272,8 +274,6 @@ def prompt_progress_callback(x): self.keep = keep - # TODO(christian-lms): truncation logic goes here now - num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( prompt_tokens, num_tokens_to_exclude @@ -316,12 +316,8 @@ def record_generated_token(self, token): Also loop when the cache does so that we accurately track what's in cache. """ - # TODO(christian-lms): ensure that this works as intended when over length - # TODO(christian-lms): verify rolling window and truncate middle have n_keep as below - # TODO(christian-lms): this won't work until we pipe in keep from generate - n_keep = self.cache[0].keep # this behavior is common to rolling window (n_keep = 0) and truncate middle # (n_keep > 0), and we should never get here with stop at max - if len(self.tokens) >= n_keep: - self.tokens = mx.concat([self.tokens[:n_keep], self.tokens[n_keep + 1 :]]) + if len(self.tokens) >= self.max_kv_size: + self.tokens = mx.concat([self.tokens[:self.keep], self.tokens[self.keep + 1 :]]) self.tokens = mx.concat([self.tokens, mx.array([token])]) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index eabd8705..3807b0ba 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -1,6 +1,7 @@ import unittest import mlx.core as mx from mlx_engine.cache_wrapper import CacheWrapper +from tests.utils import DummyModel class TestCacheWrapper(unittest.TestCase): @@ -105,9 +106,33 @@ def test_find_matching_sequence_length_more_offsets(self): ) self.assertEqual(result, 2) # there are leftovers anyway + def test_record_generated_token_loops(self): + cache = CacheWrapper( + model=DummyModel(), + max_kv_size=5, + keep=2, + ) + cache.tokens = mx.array([]) + cache.record_generated_token(1) + cache.record_generated_token(2) + cache.record_generated_token(3) + cache.record_generated_token(4) + cache.record_generated_token(5) + self.assertListEqual( + cache.tokens.tolist(), + [1, 2, 3, 4, 5], + ) + cache.record_generated_token(6) + self.assertListEqual( + cache.tokens.tolist(), + [1, 2, 4, 5, 6], + ) + # TODO(christian-lms): write tests for cache shifting, which is high-level # implemented in cachewrapper and so belongs here + # TODO(christian-lms): write tests for record_generated_token looping + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/utils.py b/tests/utils.py index 6d1c95dd..dd31698a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,62 +2,67 @@ import sys import subprocess -from mlx_engine.generate import load_model, load_draft_model, tokenize +# from mlx_engine.generate import load_model, load_draft_model, tokenize -def model_getter(model_name: str): - """Helper method to get a model, prompt user to download if not found""" +class DummyModel: + """Dummy model class for testing""" + layers = [0] - with open(Path("~/.lmstudio-home-pointer").expanduser().resolve(), "r") as f: - lmstudio_home = Path(f.read().strip()) - model_path = lmstudio_home / "models" / model_name - # Check if model exists, if not prompt user to download - if not model_path.exists(): - print(f"\nModel {model_name} not found at {model_path}") +# def model_getter(model_name: str): +# """Helper method to get a model, prompt user to download if not found""" - def greenify(text): - return f"\033[92m{text}\033[0m" +# with open(Path("~/.lmstudio-home-pointer").expanduser().resolve(), "r") as f: +# lmstudio_home = Path(f.read().strip()) +# model_path = lmstudio_home / "models" / model_name - response = input( - f"Would you like to download the model {greenify(model_name)}? (y/N): " - ) - if response.lower() == "y": - print(f"Downloading model with command: lms get {model_name}") - subprocess.run(["lms", "get", model_name], check=True) - else: - print(f"Model {model_name} not found") - sys.exit(1) +# # Check if model exists, if not prompt user to download +# if not model_path.exists(): +# print(f"\nModel {model_name} not found at {model_path}") - return model_path +# def greenify(text): +# return f"\033[92m{text}\033[0m" +# response = input( +# f"Would you like to download the model {greenify(model_name)}? (y/N): " +# ) +# if response.lower() == "y": +# print(f"Downloading model with command: lms get {model_name}") +# subprocess.run(["lms", "get", model_name], check=True) +# else: +# print(f"Model {model_name} not found") +# sys.exit(1) -def model_load_and_tokenize_prompt( - model_name: str, - prompt: str, - max_kv_size=4096, - trust_remote_code=False, - draft_model_name=None, -): - """Helper method to test a model""" - print(f"Testing model {model_name}") +# return model_path - # Check if model exists, if not prompt user to download - model_path = model_getter(model_name) - # Load the model - model_kit = load_model( - model_path=model_path, - max_kv_size=max_kv_size, - trust_remote_code=trust_remote_code, - ) +# def model_load_and_tokenize_prompt( +# model_name: str, +# prompt: str, +# max_kv_size=4096, +# trust_remote_code=False, +# draft_model_name=None, +# ): +# """Helper method to test a model""" +# print(f"Testing model {model_name}") - # Load the draft model if any - if draft_model_name is not None: - draft_model_path = model_getter(draft_model_name) - load_draft_model(model_kit, draft_model_path) +# # Check if model exists, if not prompt user to download +# model_path = model_getter(model_name) - # Tokenize the prompt - prompt_tokens = tokenize(model_kit, prompt) +# # Load the model +# model_kit = load_model( +# model_path=model_path, +# max_kv_size=max_kv_size, +# trust_remote_code=trust_remote_code, +# ) - return model_kit, prompt_tokens +# # Load the draft model if any +# if draft_model_name is not None: +# draft_model_path = model_getter(draft_model_name) +# load_draft_model(model_kit, draft_model_path) + +# # Tokenize the prompt +# prompt_tokens = tokenize(model_kit, prompt) + +# return model_kit, prompt_tokens From 56d34e01f5c9ebbd081c5ee9441fa01288539306 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 10:36:05 -0400 Subject: [PATCH 14/39] prelim reuse code --- mlx_engine/cache_wrapper.py | 65 +++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 09d96529..3f4aba28 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -90,6 +90,59 @@ def _find_matching_sequence_length( length_adjustment = max(0, num_tokens_to_exclude - leftover_tokens1_length) match_length = max(match_length - length_adjustment, 0) return match_length + + def _truncate_cache( + self, + prompt_tokens: mx.array, + common_prefix_len: int, + non_prefix_reuse_min_seq_len: int = 256 + ) -> tuple[mx.array, int]: + cache_size = len(self.tokens) + prompt_size = len(prompt_tokens) + + # start scanning from after the common prefix + cache_head_idx = common_prefix_len + prompt_head_idx = common_prefix_len + total_reused = 0 + + if self.verbose: + print(f"Looking for non-prefix sequences of length >= {non_prefix_reuse_min_seq_len}", file=sys.stderr) + + while cache_head_idx < cache_size and prompt_head_idx < prompt_size: + match_length = self._find_matching_sequence_length( + self.tokens, cache_head_idx, + prompt_tokens, prompt_head_idx + ) + + if match_length < non_prefix_reuse_min_seq_len: + # sequence too short - advance cache pointer to find next potential match + cache_head_idx += 1 + else: + if self.verbose: + print(f"Reusing {match_length} tokens from cache", file=sys.stderr) + + # found reusable sequence - shift cache content + self.cache.reuse_section( + source_pos=cache_head_idx, + target_pos=prompt_head_idx, + length=match_length + ) + + # update the tokens to reflect the reused sequence + for i in range(match_length): + self.tokens[prompt_head_idx + i] = self.tokens[cache_head_idx + i] + + # advance pointers + cache_head_idx += match_length + prompt_head_idx += match_length + total_reused += match_length + + self.cache.do_reuse() + # TODO(christian-lms): ensure that this works + self.tokens = self.tokens[:common_prefix_len + total_reused] + prompt_tokens = prompt_tokens[total_reused:] + + return prompt_tokens, total_reused def _get_unprocessed_tokens( self, prompt_tokens: mx.array, num_tokens_to_exclude: int @@ -113,6 +166,18 @@ def _get_unprocessed_tokens( self.tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude ) + if hasattr(self.cache, "reuse_section"): + prompt_tokens, n_reused_tokens = self._truncate_cache( + prompt_tokens, + common_prefix, + ) + if n_reused_tokens > 0: + log_info( + prefix="CacheWrapper", + message=f"Reused {n_reused_tokens} tokens from the cache" + ) + common_prefix += n_reused_tokens + # TODO reuse logic goes here # Trim the cache if the common prefix is shorter than the current cache From 374150947d987a966eda1a987b264debdf945dcf Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 10:58:36 -0400 Subject: [PATCH 15/39] maybe reuse unit test --- mlx_engine/cache_wrapper.py | 13 ++++---- tests/test_cache_wrapper.py | 65 ++++++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 3f4aba28..0d262041 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -96,7 +96,7 @@ def _truncate_cache( prompt_tokens: mx.array, common_prefix_len: int, non_prefix_reuse_min_seq_len: int = 256 - ) -> tuple[mx.array, int]: + ) -> int: cache_size = len(self.tokens) prompt_size = len(prompt_tokens) @@ -140,9 +140,8 @@ def _truncate_cache( self.cache.do_reuse() # TODO(christian-lms): ensure that this works self.tokens = self.tokens[:common_prefix_len + total_reused] - prompt_tokens = prompt_tokens[total_reused:] - return prompt_tokens, total_reused + return total_reused def _get_unprocessed_tokens( self, prompt_tokens: mx.array, num_tokens_to_exclude: int @@ -163,11 +162,13 @@ def _get_unprocessed_tokens( # Find common KV between the last generation and the current prompt common_prefix = self._find_matching_sequence_length( + # TODO(christian-lms): BLOCKING: num_tokens_to_exclude must be moved after reuse + # this means you should move it out of _find_matching_sequence_length self.tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude ) - + if hasattr(self.cache, "reuse_section"): - prompt_tokens, n_reused_tokens = self._truncate_cache( + n_reused_tokens = self._truncate_cache( prompt_tokens, common_prefix, ) @@ -178,8 +179,6 @@ def _get_unprocessed_tokens( ) common_prefix += n_reused_tokens - # TODO reuse logic goes here - # Trim the cache if the common prefix is shorter than the current cache num_tokens_to_trim = self.cache[0].offset - common_prefix if num_tokens_to_trim > 0: diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 3807b0ba..8165d70f 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -1,10 +1,37 @@ import unittest import mlx.core as mx +import mlx.nn as nn from mlx_engine.cache_wrapper import CacheWrapper +from mlx_engine.cache import ShiftingKVCache from tests.utils import DummyModel class TestCacheWrapper(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + cls.kv_head_dim = 4 + cls.bsz = 1 + cls.n_kv_heads = 1 + # cannot be used raw: must be wrapped in the cache.rope workaround impl + cls._rope = nn.RoPE( + dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0 + ) + + @classmethod + def make_random_kv(cls, seqlen: int): + """Helper method to make a random key/value tensor of the right shape""" + return mx.random.normal( + (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), + scale=1.0, + dtype=mx.float32, + ) + + def assertArrEqual(self, a: mx.array, b: mx.array): + """Assert that two tensors are equal over the sequence length dimension""" + self.assertEqual(a.shape, b.shape) + self.assertTrue(mx.allclose(a, b), "Tensors are not equal") + def test_find_matching_sequence_length_with_mismatch(self): """Test when there's a mismatch in the tokens""" # Create two arrays with a known common prefix [1, 2, 3] @@ -104,7 +131,7 @@ def test_find_matching_sequence_length_more_offsets(self): start2=3, num_tokens_to_exclude=1, ) - self.assertEqual(result, 2) # there are leftovers anyway + self.assertEqual(result, 2) # there are leftovers anyway def test_record_generated_token_loops(self): cache = CacheWrapper( @@ -128,10 +155,40 @@ def test_record_generated_token_loops(self): [1, 2, 4, 5, 6], ) - # TODO(christian-lms): write tests for cache shifting, which is high-level - # implemented in cachewrapper and so belongs here + def test_cache_reuse(self): + cache = CacheWrapper(DummyModel(), 10) + cache.cache = ShiftingKVCache(self._rope, max_size=10, keep=2) + + # set up pretend cache + cached_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + cache_kv = self.make_random_kv(10) + cache.tokens = cached_tokens + cache.cache.update_and_fetch(cache_kv, cache_kv) + + prompt_tokens = mx.array([1, 2, 4, 7, 8, 9, 11]) + prefix_len = cache._find_matching_sequence_length( + cached_tokens, prompt_tokens, 0 + ) + self.assertEqual(prefix_len, 2) + + total_reused = cache._truncate_cache( + prompt_tokens=prompt_tokens, + common_prefix_len=prefix_len, + non_prefix_reuse_min_seq_len=1, + ) + + should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) + + def idx(v, a, b): + return v[:, :, a:b, :] + + should_be_kv = mx.concatenate( + [idx(cache_kv, 0, 2), idx(cache_kv, 3, 4), idx(cache_kv, 6, 9)] + ) - # TODO(christian-lms): write tests for record_generated_token looping + self.assertEqual(total_reused, 4) + self.assertArrEqual(cache.tokens, should_be_tokens) + self.assertArrEqual(cache.cache.keys, should_be_kv) if __name__ == "__main__": From a677f861f214f55f26065fdf5c4e3fa1dcc839e2 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 11:00:15 -0400 Subject: [PATCH 16/39] code reuse! --- tests/test_cache_generic.py | 30 ++++++++++++++++++++++++++++++ tests/test_cache_shift.py | 37 ++++++------------------------------- tests/test_cache_wrapper.py | 29 ++--------------------------- 3 files changed, 38 insertions(+), 58 deletions(-) create mode 100644 tests/test_cache_generic.py diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py new file mode 100644 index 00000000..0d0d70b0 --- /dev/null +++ b/tests/test_cache_generic.py @@ -0,0 +1,30 @@ +import unittest +import mlx.core as mx +import mlx.nn as nn + + +class TestCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + cls.kv_head_dim = 4 + cls.bsz = 1 + cls.n_kv_heads = 1 + # cannot be used raw: must be wrapped in the cache.rope workaround impl + cls._rope = nn.RoPE( + dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0 + ) + + @classmethod + def make_random_kv(cls, seqlen: int): + """Helper method to make a random key/value tensor of the right shape""" + return mx.random.normal( + (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), + scale=1.0, + dtype=mx.float32, + ) + + def assertArrEqual(self, a: mx.array, b: mx.array): + """Assert that two tensors are equal over the sequence length dimension""" + self.assertEqual(a.shape, b.shape) + self.assertTrue(mx.allclose(a, b), "Tensors are not equal") \ No newline at end of file diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 7ceedcc9..bc7442c7 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx -import mlx.nn as nn from mlx_engine.cache import ShiftingKVCache +from tests.test_cache_generic import TestCache def idx(v: mx.array, i: int): @@ -9,36 +9,7 @@ def idx(v: mx.array, i: int): return v[:, :, i : i + 1, :] -class ShiftingCacheTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - """Set up test resources that will be shared across all test methods""" - cls.kv_head_dim = 4 - cls.bsz = 1 - cls.n_kv_heads = 1 - # cannot be used raw: must be wrapped in the cache.rope workaround impl - cls._rope = nn.RoPE( - dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0 - ) - - @classmethod - def make_random_kv(cls, seqlen: int): - """Helper method to make a random key/value tensor of the right shape""" - return mx.random.normal( - (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), - scale=1.0, - dtype=mx.float32, - ) - - def assertArrEqual(self, a: mx.array, b: mx.array): - """Assert that two tensors are equal over the sequence length dimension""" - self.assertEqual(a.shape, b.shape) - self.assertTrue(mx.allclose(a, b), "Tensors are not equal") - - # TODO: you can test to make sure that it's RoPEing right in the model overall by getting - # the post-shift value, then shifting it back to position 0 and checking the layer 0 kv - # matches the raw token embedding - +class TestShiftingKVCache(TestCache): def test_overwriting(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) base_kv = self.make_random_kv(3) @@ -135,3 +106,7 @@ def test_reuse(self): keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) self.assertArrEqual(keys, new_prompt_cache[:, :, :5, :]) + + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 8165d70f..1697e037 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -1,37 +1,12 @@ import unittest import mlx.core as mx -import mlx.nn as nn from mlx_engine.cache_wrapper import CacheWrapper from mlx_engine.cache import ShiftingKVCache +from tests.test_cache_generic import TestCache from tests.utils import DummyModel -class TestCacheWrapper(unittest.TestCase): - @classmethod - def setUpClass(cls): - """Set up test resources that will be shared across all test methods""" - cls.kv_head_dim = 4 - cls.bsz = 1 - cls.n_kv_heads = 1 - # cannot be used raw: must be wrapped in the cache.rope workaround impl - cls._rope = nn.RoPE( - dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0 - ) - - @classmethod - def make_random_kv(cls, seqlen: int): - """Helper method to make a random key/value tensor of the right shape""" - return mx.random.normal( - (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), - scale=1.0, - dtype=mx.float32, - ) - - def assertArrEqual(self, a: mx.array, b: mx.array): - """Assert that two tensors are equal over the sequence length dimension""" - self.assertEqual(a.shape, b.shape) - self.assertTrue(mx.allclose(a, b), "Tensors are not equal") - +class TestCacheWrapper(TestCache): def test_find_matching_sequence_length_with_mismatch(self): """Test when there's a mismatch in the tokens""" # Create two arrays with a known common prefix [1, 2, 3] From 77e523ce462143f73ecc6745529501750d067cc4 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 11:16:17 -0400 Subject: [PATCH 17/39] cache shift test comments --- mlx_engine/cache.py | 1 + tests/test_cache_shift.py | 70 ++++++++++++++++++++++++++++++------- tests/test_cache_wrapper.py | 6 +++- 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 51dec7b2..9e082e4d 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -50,6 +50,7 @@ def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): self.reuse_queue = [] super().__init__(max_size, keep, step) + # TODO(christian-lms): stop rope shifting your values!!!!!!!!!!!!! def rope(self, v: mx.array, shift_by: int) -> mx.array: # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl return mx.concatenate( diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index bc7442c7..476aac52 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -11,10 +11,15 @@ def idx(v: mx.array, i: int): class TestShiftingKVCache(TestCache): def test_overwriting(self): + """Test overwriting when the cache reaches max_size""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 123 base_kv = self.make_random_kv(3) - overwrite = self.make_random_kv(1) cache.update_and_fetch(base_kv, base_kv) + + # attempt to write another element 4 -> 143 + overwrite = self.make_random_kv(1) keys, _ = cache.update_and_fetch(overwrite, overwrite) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) @@ -22,90 +27,131 @@ def test_overwriting(self): self.assertArrEqual(idx(keys, 2), idx(base_kv, 2)) def test_temporal_order_shift(self): + """Test the RoPE shift in _temporal_order""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 123 base_kv = self.make_random_kv(3) - overwrite = self.make_random_kv(1) - overwrite_roped = cache.rope(overwrite, -1) cache.update_and_fetch(base_kv, base_kv) + + # attempt to write another element 4 -> 143 + overwrite = self.make_random_kv(1) cache.update_and_fetch(overwrite, overwrite) - print(base_kv) - print(cache.keys) + + # put the cache in temporal order -> 134 -> 123 (rope shift) keys = cache._temporal_order(cache.keys) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) - self.assertArrEqual(idx(keys, 2), overwrite_roped) + self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) def test_trim_internal_shift(self): + """Test the RoPE shift in _trim (internal method)""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 123 base_kv = self.make_random_kv(3) cache.update_and_fetch(base_kv, base_kv) + # trim 1 from middle -> 13 keys = cache._trim(1, cache.keys) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) def test_ensure_reasonable_size_and_shift(self): + """Test behavior when the cache gets a KV batch-written that is much larger + than max_size. The default behavior of the cache is to write the entire thing, + then trim it back down when the next KV is written. + """ cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 0123456789 base_kv = self.make_random_kv(10) keys, _ = cache.update_and_fetch(base_kv, base_kv) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) + + # trigger trim -> 0X9 -> (rope) 021 overwrite = self.make_random_kv(1) keys, _ = cache.update_and_fetch(overwrite, overwrite) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + # TODO(christian-lms): this should also be rope unshifted because it's coming in + # w/ pos emb @ position X and then being sent to 2. figure out where this goes self.assertArrEqual(idx(keys, 1), overwrite) + # TODO(christian-lms): is this position 2 or 1? it should be 1 self.assertArrEqual(idx(keys, 2), cache.rope(idx(base_kv, 9), -7)) def test_trim_before_full(self): + """Test trimming from the end before the cache is full""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 12 base_kv = self.make_random_kv(2) cache.update_and_fetch(base_kv, base_kv) + # trim 1 from end -> 1 cache.trim(1) keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + # ensure adding another value works fine new_kv = self.make_random_kv(1) keys, _ = cache.update_and_fetch(new_kv, new_kv) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), new_kv) + # TODO(christian-lms): this doesn't actually test the overwriting, for that you + # need to fill it to 3 first then add 1 then try trim def test_trim_after_full(self): + """Test trimming from the end when the cache is oversize""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache oversize already -> 1234 base_kv = self.make_random_kv(4) cache.update_and_fetch(base_kv, base_kv) + # trim 2 from end -> 12 cache.trim(2) keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(keys, base_kv[:, :, :2, :]) + # ensure adding more values works fines new_kv = self.make_random_kv(2) keys, _ = cache.update_and_fetch(new_kv, new_kv) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) self.assertArrEqual(keys[:, :, :2, :], base_kv[:, :, :2, :]) self.assertArrEqual(keys[:, :, 2:, :], new_kv) def test_reuse(self): - cache = ShiftingKVCache(self._rope, max_size=6, keep=1) + """Test basic reuse APIs""" + cache = ShiftingKVCache(self._rope, max_size=8, keep=1) + + # fill cache -> 12345678 base_kv = self.make_random_kv(8) cache.update_and_fetch(base_kv, base_kv) - new_prompt_cache = mx.concatenate( - [base_kv[:, :, :3, :], cache.rope(base_kv[:, :, 4:, :], -1)], axis=2 - ) - # here we know what to reuse so hardcode it, dynamic reuse is in test_cache_wrapper + + # reuse a specific section (hardcoded), dynamic reuse is in test_cache_wrapper cache.reuse_section(3, 4, 2) cache.do_reuse() keys = cache.keys + + # this is what the remaining cache should look like + should_be_keys = mx.concatenate( + [base_kv[:, :, :3, :], cache.rope(base_kv[:, :, 4:6, :], -1)], axis=2 + ) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) - self.assertArrEqual(keys, new_prompt_cache[:, :, :5, :]) + self.assertArrEqual(keys, should_be_keys) if __name__ == "__main__": diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 1697e037..99180a00 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -158,7 +158,11 @@ def idx(v, a, b): return v[:, :, a:b, :] should_be_kv = mx.concatenate( - [idx(cache_kv, 0, 2), idx(cache_kv, 3, 4), idx(cache_kv, 6, 9)] + [ + idx(cache_kv, 0, 2), + cache.cache.rope(idx(cache_kv, 3, 4), -1), + cache.cache.rope(idx(cache_kv, 6, 9), -3), + ] ) self.assertEqual(total_reused, 4) From 2a8855aceea986c1818072263d1ecd0de132a0cf Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 11:25:32 -0400 Subject: [PATCH 18/39] stop rope shifting values and set keep --- mlx_engine/cache.py | 89 +++++++++++++++++++++++++++++++------ mlx_engine/cache_wrapper.py | 2 + 2 files changed, 78 insertions(+), 13 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 9e082e4d..0df03dca 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -43,6 +43,7 @@ def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: return _maybe_get_rope(layer) +# TODO(christian-lms): you end up basically overriding EVERYTHING so maybe decouple class ShiftingKVCache(RotatingKVCache): def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): self._rope = rope @@ -50,24 +51,26 @@ def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): self.reuse_queue = [] super().__init__(max_size, keep, step) - # TODO(christian-lms): stop rope shifting your values!!!!!!!!!!!!! def rope(self, v: mx.array, shift_by: int) -> mx.array: # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl return mx.concatenate( [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], axis=2, ) + + def rope_if(self, v: mx.array, shift_by: int, do: bool = False) -> mx.array: + return self.rope(v, shift_by) if do else v def is_trimmable(self) -> bool: return True - def _trim(self, trim_size, v, append=None): + def _trim(self, trim_size, v, append=None, is_key=False): to_cat = [] shift_by = -trim_size if trim_size > 0: to_cat = [ v[..., : self.keep, :], - self.rope(v[..., trim_size + self.keep :, :], shift_by), + self.rope_if(v[..., trim_size + self.keep :, :], shift_by, do=is_key), ] else: to_cat = [v] @@ -75,7 +78,7 @@ def _trim(self, trim_size, v, append=None): to_cat.append(append) return mx.concatenate(to_cat, axis=2) - def _temporal_order(self, v) -> mx.array: + def _temporal_order(self, v, is_key=False) -> mx.array: """ Rearrange the cache into temporal order, slicing off the end if unused. """ @@ -92,13 +95,75 @@ def _temporal_order(self, v) -> mx.array: # rotating section would be off by a multiple of (max_kv_size - keep) # depending on how many times it rolled over. I feel like it's pretty # safe to assume that this is a rare case - self.rope(v[..., self._idx :, :], shift_by), - self.rope(v[..., self.keep : self._idx, :], shift_by), + self.rope_if(v[..., self._idx :, :], shift_by, do=is_key), + self.rope_if(v[..., self.keep : self._idx, :], shift_by, do=is_key), ], axis=2, ) else: return v[..., : self._idx, :] + + def _update_concat(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + # Put the keys/values in temporal order to + # preserve context + self.keys = self._temporal_order(self.keys, is_key=True) + self.values = self._temporal_order(self.values, is_key=False) + + # The largest size is self.max_size + S to ensure + # every token gets at least self.max_size context + trim_size = self._idx - self.max_size + self.keys = self._trim(trim_size, self.keys, keys, is_key=True) + self.values = self._trim(trim_size, self.values, values, is_key=False) + self.offset += keys.shape[2] + self._idx = self.keys.shape[2] + return self.keys, self.values + + def _update_in_place(self, keys, values): + # May not have hit the max size yet, so potentially + # keep growing the cache + B, n_kv_heads, S, k_head_dim = keys.shape + prev = self.offset + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size + ): + v_head_dim = values.shape[3] + new_size = min(self.step, self.max_size - prev) + k_shape = (B, n_kv_heads, new_size, k_head_dim) + v_shape = (B, n_kv_heads, new_size, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + self._idx = prev + + # Trim if needed + trim_size = self.keys.shape[2] - self.max_size + if trim_size > 0: + self.keys = self._trim(trim_size, self.keys, is_key=True) + self.values = self._trim(trim_size, self.values, is_key=False) + self._idx = self.max_size + + # Rotate + if self._idx == self.max_size: + self._idx = self.keep + + # Assign + self.keys[..., self._idx : self._idx + S, :] = keys + self.values[..., self._idx : self._idx + S, :] = values + self.offset += S + self._idx += S + + # If the buffer is not full, slice off the end + if self.offset < self.max_size: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + return self.keys, self.values def reuse_section( self, write_start_idx: int, reuse_start_idx: int, reuse_length: int @@ -131,13 +196,12 @@ def do_reuse(self) -> None: # N.B. we can also go back to the MLX-native "don't rope shift" method # by removing RoPE here and removing the overrides for trim, temporal order shifted_keys = self.rope(keys_to_shift, shift_by) - shifted_values = self.rope(values_to_shift, shift_by) # restructure cache with mx.concat # TODO(christian-lms): maybe it would be better to use inplace ops. # look into the mlx docs if that's even a thing keycat = [self.keys[..., :write_start_idx, :], shifted_keys] - valcat = [self.values[..., :write_start_idx, :], shifted_values] + valcat = [self.values[..., :write_start_idx, :], values_to_shift] # TODO(christian-lms): surely there is a better way to do this? # by not re-appending the end at the last one, we truncate the leftovers @@ -145,8 +209,8 @@ def do_reuse(self) -> None: keycat.append(self.keys[..., reuse_end_idx:, :]) valcat.append(self.values[..., reuse_end_idx:, :]) - self.keys = mx.concat(keycat, axis=2) - self.values = mx.concat(valcat, axis=2) + self.keys = mx.concatenate(keycat, axis=2) + self.values = mx.concatenate(valcat, axis=2) self.offset -= shift_by self.reuse_offset = 0 @@ -167,8 +231,8 @@ def trim(self, n) -> int: # will trim exactly n off the end wthout any wrapping around. but you can uncomment the line # if it turns out that this assumption is faulty if self.offset >= self.max_size: - self.keys = self._temporal_order(self.keys) - self.values = self._temporal_order(self.values) + self.keys = self._temporal_order(self.keys, is_key=True) + self.values = self._temporal_order(self.values, is_key=False) # n = n % (self.max_size - self.keep) # do trim: put us back into the state before the circular buffer is full @@ -213,7 +277,6 @@ def make_prompt_cache( f"None at layer {layer} of model {model}" ) return [KVCache() for _ in range(num_layers)] - # TODO(christian-lms): change keep on the fly, must be setattr elsewhere cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=keep)) return cache else: diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 0d262041..05d93113 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -336,7 +336,9 @@ def update_cache( def prompt_progress_callback(x): return None + # update keep tracking self.keep = keep + setattr(self.cache, "keep", keep) num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( From 447a134eb4df76339e183ab4f37f9c984916394a Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 11:47:31 -0400 Subject: [PATCH 19/39] cache is a list, and exclude tokens in the right place --- mlx_engine/cache_wrapper.py | 80 ++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 05d93113..c6f81a64 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -52,7 +52,6 @@ def _find_matching_sequence_length( tokens2: mx.array, start1: int = 0, start2: int = 0, - num_tokens_to_exclude: int = 0, ) -> int: """ Find the length of matching token sequence between two token arrays. @@ -62,8 +61,6 @@ def _find_matching_sequence_length( start1: Starting position in first array tokens2: Second token array start2: Starting position in second array - num_tokens_to_exclude (int): The minimum length of the leftover non-matching - segment of tokens1, to be excluded from the match length. Returns: int: Length of matching sequence @@ -79,67 +76,61 @@ def _find_matching_sequence_length( # Find first mismatch mask = seq1 == seq2 - if mx.any(mask == False): # noqa E712 - match_length = int(mx.argmax(mask == False)) # noqa E712 - else: - match_length = min_length - - # Ensure that the leftover non-matching segment of tokens1 - # is at least num_tokens_to_exclude long - leftover_tokens1_length = len(tokens1[match_length:]) - length_adjustment = max(0, num_tokens_to_exclude - leftover_tokens1_length) - match_length = max(match_length - length_adjustment, 0) - return match_length - + return int(mx.argmax(mask == False)) if mx.any(mask == False) else min_length # noqa E712 + def _truncate_cache( self, - prompt_tokens: mx.array, + prompt_tokens: mx.array, common_prefix_len: int, - non_prefix_reuse_min_seq_len: int = 256 + non_prefix_reuse_min_seq_len: int = 256, ) -> int: cache_size = len(self.tokens) prompt_size = len(prompt_tokens) - + # start scanning from after the common prefix cache_head_idx = common_prefix_len prompt_head_idx = common_prefix_len total_reused = 0 if self.verbose: - print(f"Looking for non-prefix sequences of length >= {non_prefix_reuse_min_seq_len}", file=sys.stderr) - + print( + f"Looking for non-prefix sequences of length >= {non_prefix_reuse_min_seq_len}", + file=sys.stderr, + ) + while cache_head_idx < cache_size and prompt_head_idx < prompt_size: match_length = self._find_matching_sequence_length( - self.tokens, cache_head_idx, - prompt_tokens, prompt_head_idx + self.tokens, cache_head_idx, prompt_tokens, prompt_head_idx ) - + if match_length < non_prefix_reuse_min_seq_len: # sequence too short - advance cache pointer to find next potential match cache_head_idx += 1 else: if self.verbose: print(f"Reusing {match_length} tokens from cache", file=sys.stderr) - + # found reusable sequence - shift cache content - self.cache.reuse_section( - source_pos=cache_head_idx, - target_pos=prompt_head_idx, - length=match_length - ) + for cache in self.cache: + cache.reuse_section( + source_pos=cache_head_idx, + target_pos=prompt_head_idx, + length=match_length, + ) # update the tokens to reflect the reused sequence for i in range(match_length): self.tokens[prompt_head_idx + i] = self.tokens[cache_head_idx + i] - + # advance pointers cache_head_idx += match_length prompt_head_idx += match_length total_reused += match_length - - self.cache.do_reuse() + + for cache in self.cache: + cache.do_reuse() # TODO(christian-lms): ensure that this works - self.tokens = self.tokens[:common_prefix_len + total_reused] + self.tokens = self.tokens[: common_prefix_len + total_reused] return total_reused @@ -162,12 +153,12 @@ def _get_unprocessed_tokens( # Find common KV between the last generation and the current prompt common_prefix = self._find_matching_sequence_length( - # TODO(christian-lms): BLOCKING: num_tokens_to_exclude must be moved after reuse - # this means you should move it out of _find_matching_sequence_length - self.tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude + self.tokens, + prompt_tokens, ) - - if hasattr(self.cache, "reuse_section"): + + # do reuse but only if the cache has it + if hasattr(self.cache[0], "reuse_section"): n_reused_tokens = self._truncate_cache( prompt_tokens, common_prefix, @@ -175,10 +166,14 @@ def _get_unprocessed_tokens( if n_reused_tokens > 0: log_info( prefix="CacheWrapper", - message=f"Reused {n_reused_tokens} tokens from the cache" + message=f"Reused {n_reused_tokens} tokens from the cache", ) common_prefix += n_reused_tokens + # exclude some tokens from end, e.g. for kicking off generation + if common_prefix >= len(prompt_tokens) - num_tokens_to_exclude: + common_prefix = len(prompt_tokens) - num_tokens_to_exclude + # Trim the cache if the common prefix is shorter than the current cache num_tokens_to_trim = self.cache[0].offset - common_prefix if num_tokens_to_trim > 0: @@ -338,7 +333,8 @@ def prompt_progress_callback(x): # update keep tracking self.keep = keep - setattr(self.cache, "keep", keep) + for cache in self.cache: + setattr(cache, "keep", keep) num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( @@ -385,5 +381,7 @@ def record_generated_token(self, token): # this behavior is common to rolling window (n_keep = 0) and truncate middle # (n_keep > 0), and we should never get here with stop at max if len(self.tokens) >= self.max_kv_size: - self.tokens = mx.concat([self.tokens[:self.keep], self.tokens[self.keep + 1 :]]) + self.tokens = mx.concat( + [self.tokens[: self.keep], self.tokens[self.keep + 1 :]] + ) self.tokens = mx.concat([self.tokens, mx.array([token])]) From 16fc7a12cd1dc9e825ad1025cb01119bdb848eb4 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 11:48:49 -0400 Subject: [PATCH 20/39] same for tests --- tests/test_cache_wrapper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 99180a00..6c0a8575 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -132,13 +132,13 @@ def test_record_generated_token_loops(self): def test_cache_reuse(self): cache = CacheWrapper(DummyModel(), 10) - cache.cache = ShiftingKVCache(self._rope, max_size=10, keep=2) + cache.cache[0] = ShiftingKVCache(self._rope, max_size=10, keep=2) # set up pretend cache cached_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) cache_kv = self.make_random_kv(10) cache.tokens = cached_tokens - cache.cache.update_and_fetch(cache_kv, cache_kv) + cache.cache[0].update_and_fetch(cache_kv, cache_kv) prompt_tokens = mx.array([1, 2, 4, 7, 8, 9, 11]) prefix_len = cache._find_matching_sequence_length( @@ -160,14 +160,14 @@ def idx(v, a, b): should_be_kv = mx.concatenate( [ idx(cache_kv, 0, 2), - cache.cache.rope(idx(cache_kv, 3, 4), -1), - cache.cache.rope(idx(cache_kv, 6, 9), -3), + cache.cache[0].rope(idx(cache_kv, 3, 4), -1), + cache.cache[0].rope(idx(cache_kv, 6, 9), -3), ] ) self.assertEqual(total_reused, 4) self.assertArrEqual(cache.tokens, should_be_tokens) - self.assertArrEqual(cache.cache.keys, should_be_kv) + self.assertArrEqual(cache.cache[0].keys, should_be_kv) if __name__ == "__main__": From 85f2241899ce9126a0fb9e9538e28a6c25cd9a13 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 11:55:12 -0400 Subject: [PATCH 21/39] apply that to tests too oops --- tests/test_cache_wrapper.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 6c0a8575..e617d523 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -12,14 +12,13 @@ def test_find_matching_sequence_length_with_mismatch(self): # Create two arrays with a known common prefix [1, 2, 3] current_tokens = mx.array([1, 2, 3, 4, 5]) prompt_tokens = mx.array([1, 2, 3, 6, 7]) # Mismatch at index 3 - num_tokens_to_exclude = 1 print("\nTest with mismatch:") print(f"current_tokens: {current_tokens}") print(f"prompt_tokens: {prompt_tokens}") result = CacheWrapper._find_matching_sequence_length( - current_tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude + current_tokens, prompt_tokens ) self.assertEqual(result, 3) # Should find 3 matching tokens @@ -28,14 +27,13 @@ def test_find_matching_sequence_length_all_match(self): # Create two identical arrays current_tokens = mx.array([1, 2, 3, 4, 5]) prompt_tokens = mx.array([1, 2, 3, 4, 5]) # All tokens match - num_tokens_to_exclude = 1 print("\nTest with all matching:") print(f"current_tokens: {current_tokens}") print(f"prompt_tokens: {prompt_tokens}") result = CacheWrapper._find_matching_sequence_length( - current_tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude + current_tokens, prompt_tokens ) self.assertEqual( result, 4 @@ -46,14 +44,13 @@ def test_find_matching_sequence_length_no_match(self): # Create two arrays with no common prefix current_tokens = mx.array([1, 2, 3, 4, 5]) prompt_tokens = mx.array([6, 7, 8, 9, 10]) - num_tokens_to_exclude = 1 print("\nTest with no matching tokens:") print(f"current_tokens: {current_tokens}") print(f"prompt_tokens: {prompt_tokens}") result = CacheWrapper._find_matching_sequence_length( - current_tokens, prompt_tokens, num_tokens_to_exclude=num_tokens_to_exclude + current_tokens, prompt_tokens ) self.assertEqual(result, 0) # No matching tokens should return 0 @@ -62,7 +59,6 @@ def test_find_matching_sequence_length_offset_starts(self): # Create two arrays where the current tokens start with a different offset current_tokens = mx.array([2, 3, 4, 5]) prompt_tokens = mx.array([1, 2, 3, 4, 5]) - num_tokens_to_exclude = 1 print("\nTest with offset starts:") print(f"current_tokens: {current_tokens}") @@ -72,9 +68,8 @@ def test_find_matching_sequence_length_offset_starts(self): current_tokens, prompt_tokens, start2=1, - num_tokens_to_exclude=num_tokens_to_exclude, ) - self.assertEqual(result, 3) + self.assertEqual(result, 4) def test_find_matching_sequence_length_more_offsets(self): """Test when the current tokens have more offsets""" @@ -99,15 +94,6 @@ def test_find_matching_sequence_length_more_offsets(self): ) self.assertEqual(result, 2) - result = CacheWrapper._find_matching_sequence_length( - current_tokens, - prompt_tokens, - start1=2, - start2=3, - num_tokens_to_exclude=1, - ) - self.assertEqual(result, 2) # there are leftovers anyway - def test_record_generated_token_loops(self): cache = CacheWrapper( model=DummyModel(), From 6ac8d2fbda6af6ab12435bc06d84ffc777ed4b46 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 11:59:53 -0400 Subject: [PATCH 22/39] decouple from rotatingkvcache since so much of it was rewritten anyway (had to pipe in is_key) --- mlx_engine/cache.py | 179 ++++++++++++++++++++++++++------------------ 1 file changed, 108 insertions(+), 71 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 0df03dca..805915c8 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -1,12 +1,12 @@ from typing import List, Optional, Any from mlx_engine.logging import log_warn -from mlx_lm.models.cache import RotatingKVCache, KVCache +from mlx_lm.models.cache import _BaseCache, KVCache import mlx.core as mx import mlx.nn as nn -# TODO(christian-lms) DO NOT HARDCODE ME (or at least move it somewhere else) +# unfortunate that this is hardcoded but what else is one to do MAYBE_ATTN_NAMES = ["self_attn", "attention", "attn", "mixer", "norm_attn_norm"] MAYBE_ROPE_NAMES = ["rope", "rotary_emb"] @@ -43,13 +43,18 @@ def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: return _maybe_get_rope(layer) -# TODO(christian-lms): you end up basically overriding EVERYTHING so maybe decouple -class ShiftingKVCache(RotatingKVCache): - def __init__(self, rope: nn.Module, max_size=None, keep=0, step=256): +class ShiftingKVCache(_BaseCache): + def __init__(self, rope: nn.Module, max_size=256, keep=0, step=256): + self.keep = keep + self.keys = None + self.values = None + self.offset = 0 + self.max_size = max_size + self.step = step + self._idx = 0 self._rope = rope self.reuse_offset = 0 self.reuse_queue = [] - super().__init__(max_size, keep, step) def rope(self, v: mx.array, shift_by: int) -> mx.array: # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl @@ -61,9 +66,6 @@ def rope(self, v: mx.array, shift_by: int) -> mx.array: def rope_if(self, v: mx.array, shift_by: int, do: bool = False) -> mx.array: return self.rope(v, shift_by) if do else v - def is_trimmable(self) -> bool: - return True - def _trim(self, trim_size, v, append=None, is_key=False): to_cat = [] shift_by = -trim_size @@ -102,68 +104,6 @@ def _temporal_order(self, v, is_key=False) -> mx.array: ) else: return v[..., : self._idx, :] - - def _update_concat(self, keys, values): - if self.keys is None: - self.keys = keys - self.values = values - else: - # Put the keys/values in temporal order to - # preserve context - self.keys = self._temporal_order(self.keys, is_key=True) - self.values = self._temporal_order(self.values, is_key=False) - - # The largest size is self.max_size + S to ensure - # every token gets at least self.max_size context - trim_size = self._idx - self.max_size - self.keys = self._trim(trim_size, self.keys, keys, is_key=True) - self.values = self._trim(trim_size, self.values, values, is_key=False) - self.offset += keys.shape[2] - self._idx = self.keys.shape[2] - return self.keys, self.values - - def _update_in_place(self, keys, values): - # May not have hit the max size yet, so potentially - # keep growing the cache - B, n_kv_heads, S, k_head_dim = keys.shape - prev = self.offset - if self.keys is None or ( - prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size - ): - v_head_dim = values.shape[3] - new_size = min(self.step, self.max_size - prev) - k_shape = (B, n_kv_heads, new_size, k_head_dim) - v_shape = (B, n_kv_heads, new_size, v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - self._idx = prev - - # Trim if needed - trim_size = self.keys.shape[2] - self.max_size - if trim_size > 0: - self.keys = self._trim(trim_size, self.keys, is_key=True) - self.values = self._trim(trim_size, self.values, is_key=False) - self._idx = self.max_size - - # Rotate - if self._idx == self.max_size: - self._idx = self.keep - - # Assign - self.keys[..., self._idx : self._idx + S, :] = keys - self.values[..., self._idx : self._idx + S, :] = values - self.offset += S - self._idx += S - - # If the buffer is not full, slice off the end - if self.offset < self.max_size: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - return self.keys, self.values def reuse_section( self, write_start_idx: int, reuse_start_idx: int, reuse_length: int @@ -244,6 +184,103 @@ def trim(self, n) -> int: # TODO(christian-lms): verify that this is reasonable self._idx = new_length return n + + def _update_concat(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + # Put the keys/values in temporal order to + # preserve context + self.keys = self._temporal_order(self.keys, is_key=True) + self.values = self._temporal_order(self.values, is_key=False) + + # The largest size is self.max_size + S to ensure + # every token gets at least self.max_size context + trim_size = self._idx - self.max_size + self.keys = self._trim(trim_size, self.keys, keys, is_key=True) + self.values = self._trim(trim_size, self.values, values, is_key=False) + self.offset += keys.shape[2] + self._idx = self.keys.shape[2] + return self.keys, self.values + + def _update_in_place(self, keys, values): + # May not have hit the max size yet, so potentially + # keep growing the cache + B, n_kv_heads, S, k_head_dim = keys.shape + prev = self.offset + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size + ): + v_head_dim = values.shape[3] + new_size = min(self.step, self.max_size - prev) + k_shape = (B, n_kv_heads, new_size, k_head_dim) + v_shape = (B, n_kv_heads, new_size, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + self._idx = prev + + # Trim if needed + trim_size = self.keys.shape[2] - self.max_size + if trim_size > 0: + self.keys = self._trim(trim_size, self.keys, is_key=True) + self.values = self._trim(trim_size, self.values, is_key=False) + self._idx = self.max_size + + # Rotate + if self._idx == self.max_size: + self._idx = self.keep + + # Assign + self.keys[..., self._idx : self._idx + S, :] = keys + self.values[..., self._idx : self._idx + S, :] = values + self.offset += S + self._idx += S + + # If the buffer is not full, slice off the end + if self.offset < self.max_size: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + return self.keys, self.values + + def update_and_fetch(self, keys, values): + if keys.shape[2] == 1: + return self._update_in_place(keys, values) + return self._update_concat(keys, values) + + @property + def state(self): + if self.offset < self.keys.shape[2]: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + else: + return self.keys, self.values + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple( + map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) + ) + + @meta_state.setter + def meta_state(self, v): + self.keep, self.max_size, self.step, self.offset, self._idx = map( + int, + v, + ) + + def is_trimmable(self) -> bool: + return True + + def to_quantized(self, group_size: int = 64, bits: int = 4) -> Any: + raise NotImplementedError("ShiftingKVCache Quantization NYI") def make_prompt_cache( From d124c0e6ded89b00914c237d41e16d9fa7094740 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 13:38:13 -0400 Subject: [PATCH 23/39] working reuse test --- mlx_engine/cache.py | 92 ++++++++++++++++++++----------------- mlx_engine/cache_wrapper.py | 7 ++- tests/test_cache_shift.py | 43 +++++++++++++++-- tests/test_cache_wrapper.py | 29 ++++++++---- 4 files changed, 111 insertions(+), 60 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 805915c8..db2f0e27 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -53,7 +53,6 @@ def __init__(self, rope: nn.Module, max_size=256, keep=0, step=256): self.step = step self._idx = 0 self._rope = rope - self.reuse_offset = 0 self.reuse_queue = [] def rope(self, v: mx.array, shift_by: int) -> mx.array: @@ -108,54 +107,47 @@ def _temporal_order(self, v, is_key=False) -> mx.array: def reuse_section( self, write_start_idx: int, reuse_start_idx: int, reuse_length: int ) -> None: - # offset indices to account for the fact that we move cache elements around - write_start_idx -= self.reuse_offset - reuse_start_idx -= self.reuse_offset - - # update position offsets for future reuse sections - shift_by = write_start_idx - reuse_start_idx - self.reuse_offset += shift_by - # queue for reuse: everything is done in one pass at the end in do_reuse self.reuse_queue.append((write_start_idx, reuse_start_idx, reuse_length)) def do_reuse(self) -> None: - last_i: int = len(self.reuse_queue) - 1 - - for i, (write_start_idx, reuse_start_idx, reuse_length) in enumerate( - self.reuse_queue - ): - shift_by: int = write_start_idx - reuse_start_idx - assert shift_by <= 0 - reuse_end_idx: int = reuse_start_idx + reuse_length - - keys_to_shift = self.keys[..., reuse_start_idx:reuse_end_idx, :] - values_to_shift = self.values[..., reuse_start_idx:reuse_end_idx, :] - - # perform rope shift - # N.B. we can also go back to the MLX-native "don't rope shift" method - # by removing RoPE here and removing the overrides for trim, temporal order - shifted_keys = self.rope(keys_to_shift, shift_by) - - # restructure cache with mx.concat - # TODO(christian-lms): maybe it would be better to use inplace ops. - # look into the mlx docs if that's even a thing - keycat = [self.keys[..., :write_start_idx, :], shifted_keys] - valcat = [self.values[..., :write_start_idx, :], values_to_shift] + if not self.reuse_queue: + return - # TODO(christian-lms): surely there is a better way to do this? - # by not re-appending the end at the last one, we truncate the leftovers - if i != last_i: - keycat.append(self.keys[..., reuse_end_idx:, :]) - valcat.append(self.values[..., reuse_end_idx:, :]) - - self.keys = mx.concatenate(keycat, axis=2) - self.values = mx.concatenate(valcat, axis=2) - - self.offset -= shift_by - self.reuse_offset = 0 + # just in case, sort in write order + self.reuse_queue.sort(key=lambda x: x[0]) + + key_segments = [] + value_segments = [] + current_pos = 0 + + for write_start_idx, reuse_start_idx, reuse_length in self.reuse_queue: + # add any gap before this write position + if current_pos < write_start_idx: + key_segments.append(self.keys[..., current_pos:write_start_idx, :]) + value_segments.append(self.values[..., current_pos:write_start_idx, :]) + + # add the reused segment with RoPE shift + shift_by = write_start_idx - reuse_start_idx # intentionally negative!!! + reuse_end_idx = reuse_start_idx + reuse_length + + keys_to_reuse = self.keys[..., reuse_start_idx:reuse_end_idx, :] + values_to_reuse = self.values[..., reuse_start_idx:reuse_end_idx, :] + + # only keys require rope + shifted_keys = self.rope(keys_to_reuse, shift_by) + + key_segments.append(shifted_keys) + value_segments.append(values_to_reuse) + + current_pos = write_start_idx + reuse_length + self.offset += shift_by + + self.keys = mx.concatenate(key_segments, axis=2) + self.values = mx.concatenate(value_segments, axis=2) + + # clean up self.reuse_queue = [] - # TODO(christian-lms): dunno if this number is correct/reasonable/whatever self._idx = self.keys.shape[2] def trim(self, n) -> int: @@ -214,7 +206,10 @@ def _update_in_place(self, keys, values): ): v_head_dim = values.shape[3] new_size = min(self.step, self.max_size - prev) + print(self.max_size) + print(prev) k_shape = (B, n_kv_heads, new_size, k_head_dim) + print(k_shape) v_shape = (B, n_kv_heads, new_size, v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) @@ -299,6 +294,17 @@ def make_prompt_cache( with a maximum size of ``max_kv_size`` """ if hasattr(model, "make_cache"): + # TODO(christian-lms): gah what are you gonna do about models that do this + # afm7 baichuan_m1 cohere2 gemma3(+friends) llama4 mamba plamo2 recurrent_gemma + # m1 mamba plamo2 recurrent_gemma are hybrid + # - afm7 is trivially overridable + # - cohere2 is swa on some layers but can probably be overridden + # - gemma3 see cohere2 + # - llama4 uses chunked kv on some layers but can maybe be overridden + # though these layers don't have rope modules + + # try to get the model name from model.args.model_type but i suppose this will + # not always work. that or literally model.__name__ hopefully return model.make_cache() num_layers = len(model.layers) if max_kv_size is not None: diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index c6f81a64..cdd5cfb2 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -100,7 +100,7 @@ def _truncate_cache( while cache_head_idx < cache_size and prompt_head_idx < prompt_size: match_length = self._find_matching_sequence_length( - self.tokens, cache_head_idx, prompt_tokens, prompt_head_idx + prompt_tokens, self.tokens, prompt_head_idx, cache_head_idx ) if match_length < non_prefix_reuse_min_seq_len: @@ -109,13 +109,12 @@ def _truncate_cache( else: if self.verbose: print(f"Reusing {match_length} tokens from cache", file=sys.stderr) + print(f"idx {prompt_head_idx} {cache_head_idx}") # found reusable sequence - shift cache content for cache in self.cache: cache.reuse_section( - source_pos=cache_head_idx, - target_pos=prompt_head_idx, - length=match_length, + prompt_head_idx, cache_head_idx, match_length ) # update the tokens to reflect the reused sequence diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 476aac52..60d4cf1f 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -26,7 +26,7 @@ def test_overwriting(self): self.assertArrEqual(idx(keys, 1), overwrite) self.assertArrEqual(idx(keys, 2), idx(base_kv, 2)) - def test_temporal_order_shift(self): + def test_temporal_order_shift_rope(self): """Test the RoPE shift in _temporal_order""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) @@ -39,13 +39,32 @@ def test_temporal_order_shift(self): cache.update_and_fetch(overwrite, overwrite) # put the cache in temporal order -> 134 -> 123 (rope shift) - keys = cache._temporal_order(cache.keys) + keys = cache._temporal_order(cache.keys, is_key=True) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) - def test_trim_internal_shift(self): + def test_temporal_order_shift_no_rope(self): + """Test putting the cache in temporal order""" + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 123 + base_kv = self.make_random_kv(3) + cache.update_and_fetch(base_kv, base_kv) + + # attempt to write another element 4 -> 143 + overwrite = self.make_random_kv(1) + cache.update_and_fetch(overwrite, overwrite) + + # put the cache in temporal order -> 134 (no rope shift) + keys = cache._temporal_order(cache.keys, is_key=False) + + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), idx(base_kv, 2)) + self.assertArrEqual(idx(keys, 2), overwrite) + + def test_trim_internal_shift_rope(self): """Test the RoPE shift in _trim (internal method)""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) @@ -54,13 +73,27 @@ def test_trim_internal_shift(self): cache.update_and_fetch(base_kv, base_kv) # trim 1 from middle -> 13 - keys = cache._trim(1, cache.keys) - + keys = cache._trim(1, cache.keys, is_key=True) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) + def test_trim_internal_shift_no_rope(self): + """Test the RoPE shift in _trim (internal method)""" + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 123 + base_kv = self.make_random_kv(3) + cache.update_and_fetch(base_kv, base_kv) + + # trim 1 from middle -> 13 + keys = cache._trim(1, cache.keys, is_key=False) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), idx(base_kv, 2)) + def test_ensure_reasonable_size_and_shift(self): """Test behavior when the cache gets a KV batch-written that is much larger than max_size. The default behavior of the cache is to write the entire thing, diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index e617d523..ba57af5d 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -35,9 +35,7 @@ def test_find_matching_sequence_length_all_match(self): result = CacheWrapper._find_matching_sequence_length( current_tokens, prompt_tokens ) - self.assertEqual( - result, 4 - ) # Should find 4 matching tokens (5-1 due to num_tokens_to_exclude) + self.assertEqual(result, 5) # Should find 5 matching tokens def test_find_matching_sequence_length_no_match(self): """Test when no tokens match""" @@ -116,7 +114,7 @@ def test_record_generated_token_loops(self): [1, 2, 4, 5, 6], ) - def test_cache_reuse(self): + def test_cache_reuse_heavy(self): cache = CacheWrapper(DummyModel(), 10) cache.cache[0] = ShiftingKVCache(self._rope, max_size=10, keep=2) @@ -126,7 +124,9 @@ def test_cache_reuse(self): cache.tokens = cached_tokens cache.cache[0].update_and_fetch(cache_kv, cache_kv) + # set up pretend prompt prompt_tokens = mx.array([1, 2, 4, 7, 8, 9, 11]) + prefix_len = cache._find_matching_sequence_length( cached_tokens, prompt_tokens, 0 ) @@ -138,22 +138,35 @@ def test_cache_reuse(self): non_prefix_reuse_min_seq_len=1, ) - should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) - + # prepare references def idx(v, a, b): return v[:, :, a:b, :] + should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) should_be_kv = mx.concatenate( [ idx(cache_kv, 0, 2), cache.cache[0].rope(idx(cache_kv, 3, 4), -1), cache.cache[0].rope(idx(cache_kv, 6, 9), -3), - ] + ], + axis=2, ) - + self.assertEqual(total_reused, 4) self.assertArrEqual(cache.tokens, should_be_tokens) self.assertArrEqual(cache.cache[0].keys, should_be_kv) + + # ensure updating works as intended + new_kv = self.make_random_kv(1) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) + + # ensure batch concat works as intended + new_kv = self.make_random_kv(2) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) if __name__ == "__main__": From 8ac2baecc409e80b09e9db7f98ae3e538e56e492 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 14:26:39 -0400 Subject: [PATCH 24/39] cache offsets ooooooooooooooooooooops --- mlx_engine/cache.py | 18 +++++++- mlx_engine/cache_wrapper.py | 3 +- tests/test_cache_shift.py | 83 ++++++++++++++++++++++++++++++++++--- 3 files changed, 97 insertions(+), 7 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index db2f0e27..8b0e17d0 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -64,6 +64,9 @@ def rope(self, v: mx.array, shift_by: int) -> mx.array: def rope_if(self, v: mx.array, shift_by: int, do: bool = False) -> mx.array: return self.rope(v, shift_by) if do else v + + # TODO(christian-lms): maybe the solution to the below is + # to make these fns operate on both k/v at once def _trim(self, trim_size, v, append=None, is_key=False): to_cat = [] @@ -77,6 +80,9 @@ def _trim(self, trim_size, v, append=None, is_key=False): to_cat = [v] if append is not None: to_cat.append(append) + # TODO(christian-lms): necessary? stupid hack anyway + if is_key and trim_size > 0: + self.offset -= trim_size return mx.concatenate(to_cat, axis=2) def _temporal_order(self, v, is_key=False) -> mx.array: @@ -86,8 +92,11 @@ def _temporal_order(self, v, is_key=False) -> mx.array: if self._idx == v.shape[2]: return v elif self._idx < self.offset: - shift_by = self.keep - self._idx + shift_by = self.keep - self._idx # intentionally negative!!! assert shift_by <= 0 + # TODO(christian-lms): necessary? stupid hack anyway + if is_key: + self.offset += shift_by return mx.concatenate( [ v[..., : self.keep, :], @@ -149,6 +158,7 @@ def do_reuse(self) -> None: # clean up self.reuse_queue = [] self._idx = self.keys.shape[2] + self.offset = self.keys.shape[2] def trim(self, n) -> int: # TODO(christian-lms): should trim respect keep? currently, no @@ -246,6 +256,12 @@ def update_and_fetch(self, keys, values): if keys.shape[2] == 1: return self._update_in_place(keys, values) return self._update_concat(keys, values) + + def set_keep(self, keep): + # kv must be in temporal order, else we will keep the wrong thing + self.keys = self._temporal_order(self.keys, is_key=True) + self.values = self._temporal_order(self.values, is_key=False) + self.keep = keep @property def state(self): diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index cdd5cfb2..49f7a046 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -333,7 +333,8 @@ def prompt_progress_callback(x): # update keep tracking self.keep = keep for cache in self.cache: - setattr(cache, "keep", keep) + if hasattr(cache, "set_keep"): + cache.set_keep(keep) num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 60d4cf1f..b3611e11 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -17,6 +17,7 @@ def test_overwriting(self): # fill cache -> 123 base_kv = self.make_random_kv(3) cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 3) # attempt to write another element 4 -> 143 overwrite = self.make_random_kv(1) @@ -25,6 +26,16 @@ def test_overwriting(self): self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), overwrite) self.assertArrEqual(idx(keys, 2), idx(base_kv, 2)) + self.assertEqual(cache.offset, 4) + + def test_ensure_update_increases_offset_indefinitely(self): + """Test single-token updates that should increase offset""" + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + for i in range(10): + kv = self.make_random_kv(1) + cache.update_and_fetch(kv, kv) + self.assertEqual(cache.offset - 1, i) def test_temporal_order_shift_rope(self): """Test the RoPE shift in _temporal_order""" @@ -33,10 +44,12 @@ def test_temporal_order_shift_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 3) # attempt to write another element 4 -> 143 overwrite = self.make_random_kv(1) cache.update_and_fetch(overwrite, overwrite) + self.assertEqual(cache.offset, 4) # put the cache in temporal order -> 134 -> 123 (rope shift) keys = cache._temporal_order(cache.keys, is_key=True) @@ -44,6 +57,7 @@ def test_temporal_order_shift_rope(self): self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) + self.assertEqual(cache.offset, 3) def test_temporal_order_shift_no_rope(self): """Test putting the cache in temporal order""" @@ -52,10 +66,12 @@ def test_temporal_order_shift_no_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 3) # attempt to write another element 4 -> 143 overwrite = self.make_random_kv(1) cache.update_and_fetch(overwrite, overwrite) + self.assertEqual(cache.offset, 4) # put the cache in temporal order -> 134 (no rope shift) keys = cache._temporal_order(cache.keys, is_key=False) @@ -63,6 +79,7 @@ def test_temporal_order_shift_no_rope(self): self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), idx(base_kv, 2)) self.assertArrEqual(idx(keys, 2), overwrite) + self.assertEqual(cache.offset, 4) def test_trim_internal_shift_rope(self): """Test the RoPE shift in _trim (internal method)""" @@ -71,6 +88,7 @@ def test_trim_internal_shift_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 3) # trim 1 from middle -> 13 keys = cache._trim(1, cache.keys, is_key=True) @@ -78,6 +96,8 @@ def test_trim_internal_shift_rope(self): self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) + # trim should trigger offset change with is_key=True + self.assertEqual(cache.offset, 2) def test_trim_internal_shift_no_rope(self): """Test the RoPE shift in _trim (internal method)""" @@ -86,13 +106,15 @@ def test_trim_internal_shift_no_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 3) - # trim 1 from middle -> 13 + # trim 1 from middle -> 13 -> 12 keys = cache._trim(1, cache.keys, is_key=False) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), idx(base_kv, 2)) + self.assertEqual(cache.offset, 3) def test_ensure_reasonable_size_and_shift(self): """Test behavior when the cache gets a KV batch-written that is much larger @@ -105,19 +127,64 @@ def test_ensure_reasonable_size_and_shift(self): base_kv = self.make_random_kv(10) keys, _ = cache.update_and_fetch(base_kv, base_kv) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) + self.assertEqual(cache.offset, 10) # trigger trim -> 0X9 -> (rope) 021 overwrite = self.make_random_kv(1) keys, _ = cache.update_and_fetch(overwrite, overwrite) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) + # this should be 4 since this mimics autoregression + self.assertEqual(cache.offset, 4) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - # TODO(christian-lms): this should also be rope unshifted because it's coming in - # w/ pos emb @ position X and then being sent to 2. figure out where this goes self.assertArrEqual(idx(keys, 1), overwrite) - # TODO(christian-lms): is this position 2 or 1? it should be 1 self.assertArrEqual(idx(keys, 2), cache.rope(idx(base_kv, 9), -7)) + # make sure pos embs are right + keys = cache._temporal_order(keys, is_key=True) + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 9), -8)) + self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) + self.assertEqual(cache.offset, 3) + + # ensure offset keeps increasing + overwrite = self.make_random_kv(1) + cache.update_and_fetch(overwrite, overwrite) + self.assertEqual(cache.offset, 4) + + overwrite = self.make_random_kv(1) + cache.update_and_fetch(overwrite, overwrite) + self.assertEqual(cache.offset, 5) + + def test_update_keep_on_the_fly(self): + """Test changing the keep value on the fly""" + cache = ShiftingKVCache(self._rope, max_size=4, keep=1) + + # fill cache -> 1234 + base_kv = self.make_random_kv(4) + cache.update_and_fetch(base_kv, base_kv) + + # attempt to write another element 5 -> 1534 + overwrite = self.make_random_kv(1) + cache.update_and_fetch(overwrite, overwrite) + self.assertEqual(cache.offset, 5) + + # update keep -> 1345 -> 1234 implicitly + # and attempt to write another element 5 -> 1254 + # offset updates after set_keep (anytime we reorder/rope shift) + cache.set_keep(2) + overwrite2 = self.make_random_kv(1) + self.assertEqual(cache.offset, 4) + keys, _ = cache.update_and_fetch(overwrite2, overwrite2) + self.assertEqual(cache.offset, 5) + + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) + self.assertArrEqual(idx(keys, 2), overwrite2) + self.assertArrEqual(idx(keys, 3), cache.rope(overwrite, -1)) + + # TODO add offset assertions everywhere to make sure you're good + def test_trim_before_full(self): """Test trimming from the end before the cache is full""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) @@ -132,6 +199,7 @@ def test_trim_before_full(self): self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertEqual(cache.offset, 1) # ensure adding another value works fine new_kv = self.make_random_kv(1) @@ -140,6 +208,7 @@ def test_trim_before_full(self): self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), new_kv) + self.assertEqual(cache.offset, 2) # TODO(christian-lms): this doesn't actually test the overwriting, for that you # need to fill it to 3 first then add 1 then try trim @@ -150,16 +219,19 @@ def test_trim_after_full(self): # fill cache oversize already -> 1234 base_kv = self.make_random_kv(4) cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 4) # trim 2 from end -> 12 cache.trim(2) keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(keys, base_kv[:, :, :2, :]) + self.assertEqual(cache.offset, 2) - # ensure adding more values works fines + # ensure adding more values works fine new_kv = self.make_random_kv(2) keys, _ = cache.update_and_fetch(new_kv, new_kv) + self.assertEqual(cache.offset, 4) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) self.assertArrEqual(keys[:, :, :2, :], base_kv[:, :, :2, :]) @@ -185,6 +257,7 @@ def test_reuse(self): self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) self.assertArrEqual(keys, should_be_keys) + self.assertEqual(cache.offset, 5) if __name__ == "__main__": From e373181fe51c792d30b8b3690fde7da9802a3922 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 14:42:00 -0400 Subject: [PATCH 25/39] refactor trim/temporal order internal interfaces to operate on both k/v at once, which bypasses the former stupid hack --- mlx_engine/cache.py | 129 +++++++++++++++++++------------------- tests/test_cache_shift.py | 34 +++++----- 2 files changed, 85 insertions(+), 78 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 8b0e17d0..12864270 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -61,57 +61,71 @@ def rope(self, v: mx.array, shift_by: int) -> mx.array: [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], axis=2, ) - - def rope_if(self, v: mx.array, shift_by: int, do: bool = False) -> mx.array: - return self.rope(v, shift_by) if do else v - - # TODO(christian-lms): maybe the solution to the below is - # to make these fns operate on both k/v at once - - def _trim(self, trim_size, v, append=None, is_key=False): - to_cat = [] + + def _trim( + self, trim_size, append_k=None, append_v=None + ) -> None: + k = self.keys + v = self.values + assert k.shape == v.shape shift_by = -trim_size if trim_size > 0: - to_cat = [ - v[..., : self.keep, :], - self.rope_if(v[..., trim_size + self.keep :, :], shift_by, do=is_key), + k_cat = [ + k[..., : self.keep, :], + self.rope(k[..., trim_size + self.keep :, :], shift_by), ] - else: - to_cat = [v] - if append is not None: - to_cat.append(append) - # TODO(christian-lms): necessary? stupid hack anyway - if is_key and trim_size > 0: + v_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] self.offset -= trim_size - return mx.concatenate(to_cat, axis=2) - - def _temporal_order(self, v, is_key=False) -> mx.array: + else: + k_cat = [k] + v_cat = [v] + if append_k is not None: + assert append_v is not None + k_cat.append(append_k) + v_cat.append(append_v) + if append_v is not None: + assert append_k is not None + # already done + self.keys, self.values = mx.concatenate(k_cat, axis=2), mx.concatenate(v_cat, axis=2) + + def _temporal_order(self) -> None: """ Rearrange the cache into temporal order, slicing off the end if unused. """ + k = self.keys + v = self.values + assert k.shape == v.shape if self._idx == v.shape[2]: - return v + pass elif self._idx < self.offset: shift_by = self.keep - self._idx # intentionally negative!!! assert shift_by <= 0 # TODO(christian-lms): necessary? stupid hack anyway - if is_key: - self.offset += shift_by - return mx.concatenate( + self.offset += shift_by + kcat = mx.concatenate( [ - v[..., : self.keep, :], + k[..., : self.keep, :], # N.B. this implicitly assumes the generation has not gone over twice # the size of the rotating section of the cache, in which case the # rotating section would be off by a multiple of (max_kv_size - keep) # depending on how many times it rolled over. I feel like it's pretty # safe to assume that this is a rare case - self.rope_if(v[..., self._idx :, :], shift_by, do=is_key), - self.rope_if(v[..., self.keep : self._idx, :], shift_by, do=is_key), + self.rope(k[..., self._idx :, :], shift_by), + self.rope(k[..., self.keep : self._idx, :], shift_by), ], axis=2, ) + vcat = mx.concatenate( + [ + v[..., : self.keep, :], + v[..., self._idx :, :], + v[..., self.keep : self._idx, :], + ], + axis=2, + ) + self.keys, self.values = kcat, vcat else: - return v[..., : self._idx, :] + self.keys, self.values = k[..., : self._idx, :], v[..., : self._idx, :] def reuse_section( self, write_start_idx: int, reuse_start_idx: int, reuse_length: int @@ -123,38 +137,38 @@ def do_reuse(self) -> None: if not self.reuse_queue: return - # just in case, sort in write order + # just in case, sort in write order self.reuse_queue.sort(key=lambda x: x[0]) - + key_segments = [] value_segments = [] current_pos = 0 - + for write_start_idx, reuse_start_idx, reuse_length in self.reuse_queue: # add any gap before this write position if current_pos < write_start_idx: key_segments.append(self.keys[..., current_pos:write_start_idx, :]) value_segments.append(self.values[..., current_pos:write_start_idx, :]) - + # add the reused segment with RoPE shift shift_by = write_start_idx - reuse_start_idx # intentionally negative!!! reuse_end_idx = reuse_start_idx + reuse_length - + keys_to_reuse = self.keys[..., reuse_start_idx:reuse_end_idx, :] values_to_reuse = self.values[..., reuse_start_idx:reuse_end_idx, :] - + # only keys require rope shifted_keys = self.rope(keys_to_reuse, shift_by) - + key_segments.append(shifted_keys) value_segments.append(values_to_reuse) - + current_pos = write_start_idx + reuse_length self.offset += shift_by - + self.keys = mx.concatenate(key_segments, axis=2) self.values = mx.concatenate(value_segments, axis=2) - + # clean up self.reuse_queue = [] self._idx = self.keys.shape[2] @@ -166,18 +180,8 @@ def trim(self, n) -> int: if n <= 0: return 0 - # TODO(christian-lms): so you used to need to wrap around because the code - # didn't know how much it was trying to trim, so it would go over the maximum allowed. - # but i think this was in large part due to improperly tracking the tokens that were - # actually in the cache, so this should not be an issue anymore. therefore this trim code - # will trim exactly n off the end wthout any wrapping around. but you can uncomment the line - # if it turns out that this assumption is faulty - if self.offset >= self.max_size: - self.keys = self._temporal_order(self.keys, is_key=True) - self.values = self._temporal_order(self.values, is_key=False) - # n = n % (self.max_size - self.keep) - # do trim: put us back into the state before the circular buffer is full + self._temporal_order() new_length = self.keys.shape[2] - n self.keys = self.keys[..., :new_length, :] self.values = self.values[..., :new_length, :] @@ -186,7 +190,7 @@ def trim(self, n) -> int: # TODO(christian-lms): verify that this is reasonable self._idx = new_length return n - + def _update_concat(self, keys, values): if self.keys is None: self.keys = keys @@ -194,14 +198,14 @@ def _update_concat(self, keys, values): else: # Put the keys/values in temporal order to # preserve context - self.keys = self._temporal_order(self.keys, is_key=True) - self.values = self._temporal_order(self.values, is_key=False) + self._temporal_order() # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context trim_size = self._idx - self.max_size - self.keys = self._trim(trim_size, self.keys, keys, is_key=True) - self.values = self._trim(trim_size, self.values, values, is_key=False) + self._trim( + trim_size, append_k=keys, append_v=values + ) self.offset += keys.shape[2] self._idx = self.keys.shape[2] return self.keys, self.values @@ -216,10 +220,7 @@ def _update_in_place(self, keys, values): ): v_head_dim = values.shape[3] new_size = min(self.step, self.max_size - prev) - print(self.max_size) - print(prev) k_shape = (B, n_kv_heads, new_size, k_head_dim) - print(k_shape) v_shape = (B, n_kv_heads, new_size, v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) @@ -233,8 +234,9 @@ def _update_in_place(self, keys, values): # Trim if needed trim_size = self.keys.shape[2] - self.max_size if trim_size > 0: - self.keys = self._trim(trim_size, self.keys, is_key=True) - self.values = self._trim(trim_size, self.values, is_key=False) + self._trim( + trim_size, + ) self._idx = self.max_size # Rotate @@ -256,11 +258,10 @@ def update_and_fetch(self, keys, values): if keys.shape[2] == 1: return self._update_in_place(keys, values) return self._update_concat(keys, values) - + def set_keep(self, keep): # kv must be in temporal order, else we will keep the wrong thing - self.keys = self._temporal_order(self.keys, is_key=True) - self.values = self._temporal_order(self.values, is_key=False) + self._temporal_order() self.keep = keep @property @@ -318,7 +319,7 @@ def make_prompt_cache( # - gemma3 see cohere2 # - llama4 uses chunked kv on some layers but can maybe be overridden # though these layers don't have rope modules - + # try to get the model name from model.args.model_type but i suppose this will # not always work. that or literally model.__name__ hopefully return model.make_cache() diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index b3611e11..98b57dcb 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -52,7 +52,8 @@ def test_temporal_order_shift_rope(self): self.assertEqual(cache.offset, 4) # put the cache in temporal order -> 134 -> 123 (rope shift) - keys = cache._temporal_order(cache.keys, is_key=True) + cache._temporal_order() + keys = cache.keys self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) @@ -74,12 +75,13 @@ def test_temporal_order_shift_no_rope(self): self.assertEqual(cache.offset, 4) # put the cache in temporal order -> 134 (no rope shift) - keys = cache._temporal_order(cache.keys, is_key=False) + cache._temporal_order() + values = cache.values - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), idx(base_kv, 2)) - self.assertArrEqual(idx(keys, 2), overwrite) - self.assertEqual(cache.offset, 4) + self.assertArrEqual(idx(values, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(values, 1), idx(base_kv, 2)) + self.assertArrEqual(idx(values, 2), overwrite) + self.assertEqual(cache.offset, 3) def test_trim_internal_shift_rope(self): """Test the RoPE shift in _trim (internal method)""" @@ -91,7 +93,8 @@ def test_trim_internal_shift_rope(self): self.assertEqual(cache.offset, 3) # trim 1 from middle -> 13 - keys = cache._trim(1, cache.keys, is_key=True) + cache._trim(1) + keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) @@ -109,12 +112,13 @@ def test_trim_internal_shift_no_rope(self): self.assertEqual(cache.offset, 3) # trim 1 from middle -> 13 -> 12 - keys = cache._trim(1, cache.keys, is_key=False) - - self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), idx(base_kv, 2)) - self.assertEqual(cache.offset, 3) + cache._trim(1) + values = cache.values + + self.assertEqual(values.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(idx(values, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(values, 1), idx(base_kv, 2)) + self.assertEqual(cache.offset, 2) def test_ensure_reasonable_size_and_shift(self): """Test behavior when the cache gets a KV batch-written that is much larger @@ -141,7 +145,9 @@ def test_ensure_reasonable_size_and_shift(self): self.assertArrEqual(idx(keys, 2), cache.rope(idx(base_kv, 9), -7)) # make sure pos embs are right - keys = cache._temporal_order(keys, is_key=True) + cache._temporal_order() + keys = cache.keys + self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 9), -8)) self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) From 4b938bea8e3d5bea0b857402fbcbae51b521330e Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 14:55:44 -0400 Subject: [PATCH 26/39] more test fixes --- mlx_engine/cache.py | 12 +++--- mlx_engine/cache_wrapper.py | 1 - tests/test_cache_shift.py | 84 ++++++++++++++++++++++++++----------- 3 files changed, 65 insertions(+), 32 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 12864270..66399211 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -56,7 +56,10 @@ def __init__(self, rope: nn.Module, max_size=256, keep=0, step=256): self.reuse_queue = [] def rope(self, v: mx.array, shift_by: int) -> mx.array: - # TODO(christian-lms): this is reeeeeeallllyyyy stupid. spin a proper block impl + # you'd think this is inefficient, but it seems faster than spinning + # a custom implementation somehow. also it allows us to easily use the + # sustk scaled rope/yarn/llama3 rope impls in mlx_lm without having to + # spin a custom implementation for those too (and any future rope variants) return mx.concatenate( [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], axis=2, @@ -100,7 +103,6 @@ def _temporal_order(self) -> None: elif self._idx < self.offset: shift_by = self.keep - self._idx # intentionally negative!!! assert shift_by <= 0 - # TODO(christian-lms): necessary? stupid hack anyway self.offset += shift_by kcat = mx.concatenate( [ @@ -175,7 +177,7 @@ def do_reuse(self) -> None: self.offset = self.keys.shape[2] def trim(self, n) -> int: - # TODO(christian-lms): should trim respect keep? currently, no + # trim does not respect keep and it will stay this way n = min(self.offset, n) if n <= 0: return 0 @@ -187,7 +189,6 @@ def trim(self, n) -> int: self.values = self.values[..., :new_length, :] self.offset -= n - # TODO(christian-lms): verify that this is reasonable self._idx = new_length return n @@ -196,8 +197,7 @@ def _update_concat(self, keys, values): self.keys = keys self.values = values else: - # Put the keys/values in temporal order to - # preserve context + # Put the keys/values in temporal order to preserve context self._temporal_order() # The largest size is self.max_size + S to ensure diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 49f7a046..1fc62d8e 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -128,7 +128,6 @@ def _truncate_cache( for cache in self.cache: cache.do_reuse() - # TODO(christian-lms): ensure that this works self.tokens = self.tokens[: common_prefix_len + total_reused] return total_reused diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 98b57dcb..63496257 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -1,5 +1,6 @@ import unittest import mlx.core as mx +from copy import deepcopy from mlx_engine.cache import ShiftingKVCache from tests.test_cache_generic import TestCache @@ -9,6 +10,9 @@ def idx(v: mx.array, i: int): return v[:, :, i : i + 1, :] +# TODO(christian-lms): helper fn for setup and for concatenate along axis 2 + + class TestShiftingKVCache(TestCache): def test_overwriting(self): """Test overwriting when the cache reaches max_size""" @@ -16,6 +20,7 @@ def test_overwriting(self): # fill cache -> 123 base_kv = self.make_random_kv(3) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) self.assertEqual(cache.offset, 3) @@ -23,9 +28,9 @@ def test_overwriting(self): overwrite = self.make_random_kv(1) keys, _ = cache.update_and_fetch(overwrite, overwrite) - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), overwrite) - self.assertArrEqual(idx(keys, 2), idx(base_kv, 2)) + self.assertArrEqual(idx(keys, 2), idx(reference, 2)) self.assertEqual(cache.offset, 4) def test_ensure_update_increases_offset_indefinitely(self): @@ -43,6 +48,7 @@ def test_temporal_order_shift_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) self.assertEqual(cache.offset, 3) @@ -55,8 +61,8 @@ def test_temporal_order_shift_rope(self): cache._temporal_order() keys = cache.keys - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) self.assertEqual(cache.offset, 3) @@ -66,6 +72,7 @@ def test_temporal_order_shift_no_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) self.assertEqual(cache.offset, 3) @@ -78,8 +85,8 @@ def test_temporal_order_shift_no_rope(self): cache._temporal_order() values = cache.values - self.assertArrEqual(idx(values, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(values, 1), idx(base_kv, 2)) + self.assertArrEqual(idx(values, 0), idx(reference, 0)) + self.assertArrEqual(idx(values, 1), idx(reference, 2)) self.assertArrEqual(idx(values, 2), overwrite) self.assertEqual(cache.offset, 3) @@ -89,6 +96,7 @@ def test_trim_internal_shift_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) self.assertEqual(cache.offset, 3) @@ -97,8 +105,8 @@ def test_trim_internal_shift_rope(self): keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) # trim should trigger offset change with is_key=True self.assertEqual(cache.offset, 2) @@ -108,6 +116,7 @@ def test_trim_internal_shift_no_rope(self): # fill cache -> 123 base_kv = self.make_random_kv(3) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) self.assertEqual(cache.offset, 3) @@ -116,8 +125,8 @@ def test_trim_internal_shift_no_rope(self): values = cache.values self.assertEqual(values.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(idx(values, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(values, 1), idx(base_kv, 2)) + self.assertArrEqual(idx(values, 0), idx(reference, 0)) + self.assertArrEqual(idx(values, 1), idx(reference, 2)) self.assertEqual(cache.offset, 2) def test_ensure_reasonable_size_and_shift(self): @@ -129,6 +138,7 @@ def test_ensure_reasonable_size_and_shift(self): # fill cache -> 0123456789 base_kv = self.make_random_kv(10) + reference = deepcopy(base_kv) keys, _ = cache.update_and_fetch(base_kv, base_kv) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) self.assertEqual(cache.offset, 10) @@ -140,16 +150,16 @@ def test_ensure_reasonable_size_and_shift(self): # this should be 4 since this mimics autoregression self.assertEqual(cache.offset, 4) - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), overwrite) - self.assertArrEqual(idx(keys, 2), cache.rope(idx(base_kv, 9), -7)) + self.assertArrEqual(idx(keys, 2), cache.rope(idx(reference, 9), -7)) # make sure pos embs are right cache._temporal_order() keys = cache.keys - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 9), -8)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 9), -8)) self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) self.assertEqual(cache.offset, 3) @@ -168,6 +178,7 @@ def test_update_keep_on_the_fly(self): # fill cache -> 1234 base_kv = self.make_random_kv(4) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) # attempt to write another element 5 -> 1534 @@ -184,19 +195,18 @@ def test_update_keep_on_the_fly(self): keys, _ = cache.update_and_fetch(overwrite2, overwrite2) self.assertEqual(cache.offset, 5) - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(base_kv, 2), -1)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) self.assertArrEqual(idx(keys, 2), overwrite2) self.assertArrEqual(idx(keys, 3), cache.rope(overwrite, -1)) - # TODO add offset assertions everywhere to make sure you're good - def test_trim_before_full(self): """Test trimming from the end before the cache is full""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) # fill cache -> 12 base_kv = self.make_random_kv(2) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) # trim 1 from end -> 1 @@ -204,7 +214,7 @@ def test_trim_before_full(self): keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertEqual(cache.offset, 1) # ensure adding another value works fine @@ -212,18 +222,41 @@ def test_trim_before_full(self): keys, _ = cache.update_and_fetch(new_kv, new_kv) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(idx(keys, 0), idx(base_kv, 0)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), new_kv) self.assertEqual(cache.offset, 2) + + def test_trim_after_overwrite(self): + """Test trimming from the end when we've written past the cache max""" + cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + + # fill cache -> 123 + base_kv = self.make_random_kv(3) + reference = deepcopy(base_kv) + cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 3) + + # overwrite so offset goes over max_size -> 143 + base_kv = self.make_random_kv(1) + cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 4) + + # trim 1 from end -> 13 -> 12 (rope), ideally + cache.trim(1) + + keys = cache.keys + should_be_kv = mx.concatenate([reference[:, :, :1, :], cache.rope(reference[:, :, 2:3, :], -1)], axis=2) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, should_be_kv) + self.assertEqual(cache.offset, 2) - # TODO(christian-lms): this doesn't actually test the overwriting, for that you - # need to fill it to 3 first then add 1 then try trim def test_trim_after_full(self): """Test trimming from the end when the cache is oversize""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) # fill cache oversize already -> 1234 base_kv = self.make_random_kv(4) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) self.assertEqual(cache.offset, 4) @@ -231,7 +264,7 @@ def test_trim_after_full(self): cache.trim(2) keys = cache.keys self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(keys, base_kv[:, :, :2, :]) + self.assertArrEqual(keys, reference[:, :, :2, :]) self.assertEqual(cache.offset, 2) # ensure adding more values works fine @@ -240,7 +273,7 @@ def test_trim_after_full(self): self.assertEqual(cache.offset, 4) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) - self.assertArrEqual(keys[:, :, :2, :], base_kv[:, :, :2, :]) + self.assertArrEqual(keys[:, :, :2, :], reference[:, :, :2, :]) self.assertArrEqual(keys[:, :, 2:, :], new_kv) def test_reuse(self): @@ -249,6 +282,7 @@ def test_reuse(self): # fill cache -> 12345678 base_kv = self.make_random_kv(8) + reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) # reuse a specific section (hardcoded), dynamic reuse is in test_cache_wrapper @@ -258,7 +292,7 @@ def test_reuse(self): # this is what the remaining cache should look like should_be_keys = mx.concatenate( - [base_kv[:, :, :3, :], cache.rope(base_kv[:, :, 4:6, :], -1)], axis=2 + [reference[:, :, :3, :], cache.rope(reference[:, :, 4:6, :], -1)], axis=2 ) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) From d7c4ce75ee9ed18a2bcafffb872d8a6cb78f4522 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 15:29:30 -0400 Subject: [PATCH 27/39] refactor tests --- mlx_engine/cache.py | 31 ++++--- mlx_engine/cache_wrapper.py | 4 +- tests/test_cache_generic.py | 16 +++- tests/test_cache_shift.py | 163 +++++++++++++++--------------------- tests/test_cache_wrapper.py | 17 ++-- 5 files changed, 107 insertions(+), 124 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 66399211..c00b63e2 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -11,6 +11,11 @@ MAYBE_ROPE_NAMES = ["rope", "rotary_emb"] +def cat(v: mx.array): + """Alias for mx.concatenate(v, axis=2) since that's used all over the place""" + return mx.concatenate(v, axis=2) + + def _maybe_get_rope(layer: nn.Module) -> Optional[nn.Module]: for maybe_rope_name in MAYBE_ROPE_NAMES: if hasattr(layer, maybe_rope_name): @@ -60,14 +65,11 @@ def rope(self, v: mx.array, shift_by: int) -> mx.array: # a custom implementation somehow. also it allows us to easily use the # sustk scaled rope/yarn/llama3 rope impls in mlx_lm without having to # spin a custom implementation for those too (and any future rope variants) - return mx.concatenate( + return cat( [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], - axis=2, ) - def _trim( - self, trim_size, append_k=None, append_v=None - ) -> None: + def _trim(self, trim_size, append_k=None, append_v=None) -> None: k = self.keys v = self.values assert k.shape == v.shape @@ -89,7 +91,7 @@ def _trim( if append_v is not None: assert append_k is not None # already done - self.keys, self.values = mx.concatenate(k_cat, axis=2), mx.concatenate(v_cat, axis=2) + self.keys, self.values = cat(k_cat), cat(v_cat) def _temporal_order(self) -> None: """ @@ -104,7 +106,7 @@ def _temporal_order(self) -> None: shift_by = self.keep - self._idx # intentionally negative!!! assert shift_by <= 0 self.offset += shift_by - kcat = mx.concatenate( + kcat = cat( [ k[..., : self.keep, :], # N.B. this implicitly assumes the generation has not gone over twice @@ -115,15 +117,13 @@ def _temporal_order(self) -> None: self.rope(k[..., self._idx :, :], shift_by), self.rope(k[..., self.keep : self._idx, :], shift_by), ], - axis=2, ) - vcat = mx.concatenate( + vcat = cat( [ v[..., : self.keep, :], v[..., self._idx :, :], v[..., self.keep : self._idx, :], ], - axis=2, ) self.keys, self.values = kcat, vcat else: @@ -168,8 +168,7 @@ def do_reuse(self) -> None: current_pos = write_start_idx + reuse_length self.offset += shift_by - self.keys = mx.concatenate(key_segments, axis=2) - self.values = mx.concatenate(value_segments, axis=2) + self.keys, self.values = cat(key_segments), cat(value_segments) # clean up self.reuse_queue = [] @@ -203,9 +202,7 @@ def _update_concat(self, keys, values): # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context trim_size = self._idx - self.max_size - self._trim( - trim_size, append_k=keys, append_v=values - ) + self._trim(trim_size, append_k=keys, append_v=values) self.offset += keys.shape[2] self._idx = self.keys.shape[2] return self.keys, self.values @@ -225,8 +222,8 @@ def _update_in_place(self, keys, values): new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) + self.keys = cat([self.keys, new_k]) + self.values = cat([self.values, new_v]) else: self.keys, self.values = new_k, new_v self._idx = prev diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 1fc62d8e..4640b9e1 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -113,9 +113,7 @@ def _truncate_cache( # found reusable sequence - shift cache content for cache in self.cache: - cache.reuse_section( - prompt_head_idx, cache_head_idx, match_length - ) + cache.reuse_section(prompt_head_idx, cache_head_idx, match_length) # update the tokens to reflect the reused sequence for i in range(match_length): diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py index 0d0d70b0..a558b1ac 100644 --- a/tests/test_cache_generic.py +++ b/tests/test_cache_generic.py @@ -1,6 +1,8 @@ import unittest import mlx.core as mx import mlx.nn as nn +from copy import deepcopy +from mlx_engine.cache import ShiftingKVCache class TestCache(unittest.TestCase): @@ -27,4 +29,16 @@ def make_random_kv(cls, seqlen: int): def assertArrEqual(self, a: mx.array, b: mx.array): """Assert that two tensors are equal over the sequence length dimension""" self.assertEqual(a.shape, b.shape) - self.assertTrue(mx.allclose(a, b), "Tensors are not equal") \ No newline at end of file + self.assertTrue(mx.allclose(a, b), "Tensors are not equal") + + def add_random_to_cache( + self, cache: ShiftingKVCache, seqlen: int + ) -> mx.array: + """Add random values to the cache and return them""" + base_kv = self.make_random_kv(seqlen) + # base_kv is *assigned* to cache.keys/cache.values so returning base_kv + # would return a reference to cache.keys, which is pointless. so copy it + reference = deepcopy(base_kv) + cache.update_and_fetch(base_kv, base_kv) + return reference + diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 63496257..5ea812f2 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx from copy import deepcopy -from mlx_engine.cache import ShiftingKVCache +from mlx_engine.cache import ShiftingKVCache, cat from tests.test_cache_generic import TestCache @@ -10,36 +10,31 @@ def idx(v: mx.array, i: int): return v[:, :, i : i + 1, :] -# TODO(christian-lms): helper fn for setup and for concatenate along axis 2 - - class TestShiftingKVCache(TestCache): def test_overwriting(self): """Test overwriting when the cache reaches max_size""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) # fill cache -> 123 - base_kv = self.make_random_kv(3) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 3) self.assertEqual(cache.offset, 3) - + # attempt to write another element 4 -> 143 - overwrite = self.make_random_kv(1) - keys, _ = cache.update_and_fetch(overwrite, overwrite) + overwrite = self.add_random_to_cache(cache, 1) + # access k/v as cache.state[0]/[1] due to possibly empty buffer + keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), overwrite) self.assertArrEqual(idx(keys, 2), idx(reference, 2)) self.assertEqual(cache.offset, 4) - + def test_ensure_update_increases_offset_indefinitely(self): """Test single-token updates that should increase offset""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) - + for i in range(10): - kv = self.make_random_kv(1) - cache.update_and_fetch(kv, kv) + self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset - 1, i) def test_temporal_order_shift_rope(self): @@ -47,19 +42,16 @@ def test_temporal_order_shift_rope(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) # fill cache -> 123 - base_kv = self.make_random_kv(3) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 3) self.assertEqual(cache.offset, 3) - + # attempt to write another element 4 -> 143 - overwrite = self.make_random_kv(1) - cache.update_and_fetch(overwrite, overwrite) + overwrite = self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset, 4) # put the cache in temporal order -> 134 -> 123 (rope shift) cache._temporal_order() - keys = cache.keys + keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) @@ -71,19 +63,16 @@ def test_temporal_order_shift_no_rope(self): cache = ShiftingKVCache(self._rope, max_size=3, keep=1) # fill cache -> 123 - base_kv = self.make_random_kv(3) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 3) self.assertEqual(cache.offset, 3) - + # attempt to write another element 4 -> 143 - overwrite = self.make_random_kv(1) - cache.update_and_fetch(overwrite, overwrite) + overwrite = self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset, 4) # put the cache in temporal order -> 134 (no rope shift) cache._temporal_order() - values = cache.values + values = cache.state[1] self.assertArrEqual(idx(values, 0), idx(reference, 0)) self.assertArrEqual(idx(values, 1), idx(reference, 2)) @@ -93,16 +82,14 @@ def test_temporal_order_shift_no_rope(self): def test_trim_internal_shift_rope(self): """Test the RoPE shift in _trim (internal method)""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) - + # fill cache -> 123 - base_kv = self.make_random_kv(3) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 3) self.assertEqual(cache.offset, 3) # trim 1 from middle -> 13 cache._trim(1) - keys = cache.keys + keys = cache.state[0] self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(reference, 0)) @@ -113,17 +100,15 @@ def test_trim_internal_shift_rope(self): def test_trim_internal_shift_no_rope(self): """Test the RoPE shift in _trim (internal method)""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) - + # fill cache -> 123 - base_kv = self.make_random_kv(3) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 3) self.assertEqual(cache.offset, 3) # trim 1 from middle -> 13 -> 12 cache._trim(1) - values = cache.values - + values = cache.state[1] + self.assertEqual(values.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(values, 0), idx(reference, 0)) self.assertArrEqual(idx(values, 1), idx(reference, 2)) @@ -135,17 +120,16 @@ def test_ensure_reasonable_size_and_shift(self): then trim it back down when the next KV is written. """ cache = ShiftingKVCache(self._rope, max_size=3, keep=1) - + # fill cache -> 0123456789 - base_kv = self.make_random_kv(10) - reference = deepcopy(base_kv) - keys, _ = cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 10) + keys = cache.state[0] self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) self.assertEqual(cache.offset, 10) - + # trigger trim -> 0X9 -> (rope) 021 - overwrite = self.make_random_kv(1) - keys, _ = cache.update_and_fetch(overwrite, overwrite) + overwrite = self.add_random_to_cache(cache, 1) + keys = cache.state[0] self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) # this should be 4 since this mimics autoregression self.assertEqual(cache.offset, 4) @@ -156,86 +140,78 @@ def test_ensure_reasonable_size_and_shift(self): # make sure pos embs are right cache._temporal_order() - keys = cache.keys - + keys = cache.state[0] + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 9), -8)) self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) self.assertEqual(cache.offset, 3) - + # ensure offset keeps increasing - overwrite = self.make_random_kv(1) - cache.update_and_fetch(overwrite, overwrite) + self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset, 4) - overwrite = self.make_random_kv(1) - cache.update_and_fetch(overwrite, overwrite) + self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset, 5) - + def test_update_keep_on_the_fly(self): """Test changing the keep value on the fly""" cache = ShiftingKVCache(self._rope, max_size=4, keep=1) # fill cache -> 1234 - base_kv = self.make_random_kv(4) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 4) # attempt to write another element 5 -> 1534 - overwrite = self.make_random_kv(1) - cache.update_and_fetch(overwrite, overwrite) + overwrite = self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset, 5) # update keep -> 1345 -> 1234 implicitly # and attempt to write another element 5 -> 1254 - # offset updates after set_keep (anytime we reorder/rope shift) + # offset updates after set_keep (anytime we reorder/rope shift) cache.set_keep(2) - overwrite2 = self.make_random_kv(1) self.assertEqual(cache.offset, 4) - keys, _ = cache.update_and_fetch(overwrite2, overwrite2) + overwrite2 = self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset, 5) + keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) self.assertArrEqual(idx(keys, 2), overwrite2) self.assertArrEqual(idx(keys, 3), cache.rope(overwrite, -1)) - + def test_trim_before_full(self): """Test trimming from the end before the cache is full""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) - + # fill cache -> 12 - base_kv = self.make_random_kv(2) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 2) # trim 1 from end -> 1 cache.trim(1) - keys = cache.keys + keys = cache.state[0] self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertEqual(cache.offset, 1) # ensure adding another value works fine - new_kv = self.make_random_kv(1) - keys, _ = cache.update_and_fetch(new_kv, new_kv) + new_kv = self.add_random_to_cache(cache, 1) + keys = cache.state[0] + self.assertEqual(cache.offset, 2) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), new_kv) self.assertEqual(cache.offset, 2) - + def test_trim_after_overwrite(self): """Test trimming from the end when we've written past the cache max""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) - + # fill cache -> 123 - base_kv = self.make_random_kv(3) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 3) self.assertEqual(cache.offset, 3) - + # overwrite so offset goes over max_size -> 143 base_kv = self.make_random_kv(1) cache.update_and_fetch(base_kv, base_kv) @@ -243,9 +219,11 @@ def test_trim_after_overwrite(self): # trim 1 from end -> 13 -> 12 (rope), ideally cache.trim(1) + keys = cache.state[0] - keys = cache.keys - should_be_kv = mx.concatenate([reference[:, :, :1, :], cache.rope(reference[:, :, 2:3, :], -1)], axis=2) + should_be_kv = cat( + [reference[:, :, :1, :], cache.rope(reference[:, :, 2:3, :], -1)] + ) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(keys, should_be_kv) self.assertEqual(cache.offset, 2) @@ -253,23 +231,22 @@ def test_trim_after_overwrite(self): def test_trim_after_full(self): """Test trimming from the end when the cache is oversize""" cache = ShiftingKVCache(self._rope, max_size=3, keep=1) - + # fill cache oversize already -> 1234 - base_kv = self.make_random_kv(4) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) + reference = self.add_random_to_cache(cache, 4) self.assertEqual(cache.offset, 4) # trim 2 from end -> 12 cache.trim(2) - keys = cache.keys + keys = cache.state[0] + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(keys, reference[:, :, :2, :]) self.assertEqual(cache.offset, 2) # ensure adding more values works fine - new_kv = self.make_random_kv(2) - keys, _ = cache.update_and_fetch(new_kv, new_kv) + new_kv = self.add_random_to_cache(cache, 2) + keys = cache.state[0] self.assertEqual(cache.offset, 4) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) @@ -279,20 +256,18 @@ def test_trim_after_full(self): def test_reuse(self): """Test basic reuse APIs""" cache = ShiftingKVCache(self._rope, max_size=8, keep=1) - + # fill cache -> 12345678 - base_kv = self.make_random_kv(8) - reference = deepcopy(base_kv) - cache.update_and_fetch(base_kv, base_kv) - + reference = self.add_random_to_cache(cache, 8) + # reuse a specific section (hardcoded), dynamic reuse is in test_cache_wrapper cache.reuse_section(3, 4, 2) cache.do_reuse() - keys = cache.keys + keys = cache.state[0] # this is what the remaining cache should look like - should_be_keys = mx.concatenate( - [reference[:, :, :3, :], cache.rope(reference[:, :, 4:6, :], -1)], axis=2 + should_be_keys = cat( + [reference[:, :, :3, :], cache.rope(reference[:, :, 4:6, :], -1)] ) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) @@ -301,4 +276,4 @@ def test_reuse(self): if __name__ == "__main__": - unittest.main(verbosity=2) \ No newline at end of file + unittest.main(verbosity=2) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index ba57af5d..bb908b81 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx from mlx_engine.cache_wrapper import CacheWrapper -from mlx_engine.cache import ShiftingKVCache +from mlx_engine.cache import ShiftingKVCache, cat from tests.test_cache_generic import TestCache from tests.utils import DummyModel @@ -126,7 +126,7 @@ def test_cache_reuse_heavy(self): # set up pretend prompt prompt_tokens = mx.array([1, 2, 4, 7, 8, 9, 11]) - + prefix_len = cache._find_matching_sequence_length( cached_tokens, prompt_tokens, 0 ) @@ -143,29 +143,28 @@ def idx(v, a, b): return v[:, :, a:b, :] should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) - should_be_kv = mx.concatenate( + should_be_kv = cat( [ idx(cache_kv, 0, 2), cache.cache[0].rope(idx(cache_kv, 3, 4), -1), cache.cache[0].rope(idx(cache_kv, 6, 9), -3), ], - axis=2, ) - + self.assertEqual(total_reused, 4) self.assertArrEqual(cache.tokens, should_be_tokens) self.assertArrEqual(cache.cache[0].keys, should_be_kv) - + # ensure updating works as intended new_kv = self.make_random_kv(1) keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) - should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + should_be_kv = cat([should_be_kv, new_kv]) self.assertArrEqual(keys, should_be_kv) - + # ensure batch concat works as intended new_kv = self.make_random_kv(2) keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) - should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + should_be_kv = cat([should_be_kv, new_kv]) self.assertArrEqual(keys, should_be_kv) From 5d60f1327b5417f7e7eb494f5f393615a6f1bdb5 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Tue, 8 Jul 2025 17:23:25 -0400 Subject: [PATCH 28/39] technically if you ran this it would work --- mlx_engine/cache.py | 14 ++++++++++++-- tests/test_cache_shift.py | 1 - 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index c00b63e2..8ec291ac 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -72,6 +72,8 @@ def rope(self, v: mx.array, shift_by: int) -> mx.array: def _trim(self, trim_size, append_k=None, append_v=None) -> None: k = self.keys v = self.values + if k is None or v is None: + return assert k.shape == v.shape shift_by = -trim_size if trim_size > 0: @@ -99,6 +101,8 @@ def _temporal_order(self) -> None: """ k = self.keys v = self.values + if k is None or v is None: + return assert k.shape == v.shape if self._idx == v.shape[2]: pass @@ -178,17 +182,23 @@ def do_reuse(self) -> None: def trim(self, n) -> int: # trim does not respect keep and it will stay this way n = min(self.offset, n) + print(f"debug: before: shape {self.keys.shape} keep {self.keep} n={n} {self.offset}os {self._idx}idx", file=sys.stderr) if n <= 0: return 0 # do trim: put us back into the state before the circular buffer is full self._temporal_order() - new_length = self.keys.shape[2] - n + print(f"after TO: shape {self.keys.shape} keep {self.keep} n={n} {self.offset}os {self._idx}idx", file=sys.stderr) + + # TODO(christian-lms): stupid hack that belies a bigger problem. mayeb 184 shoudl be min vs. sks2 + new_length = max(self.keys.shape[2] - n, 0) self.keys = self.keys[..., :new_length, :] self.values = self.values[..., :new_length, :] - self.offset -= n + # TODO(christian-lms): maybe this is wrong??? maybe you have bigger problems elsewhere + self.offset = new_length self._idx = new_length + print(self.keys.shape, self.offset, self._idx, file=sys.stderr) return n def _update_concat(self, keys, values): diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 5ea812f2..cd3986a8 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -1,6 +1,5 @@ import unittest import mlx.core as mx -from copy import deepcopy from mlx_engine.cache import ShiftingKVCache, cat from tests.test_cache_generic import TestCache From 9c378e62ef97b211f2043e5d8b3527111976b929 Mon Sep 17 00:00:00 2001 From: christian-lms Date: Thu, 10 Jul 2025 16:10:20 -0400 Subject: [PATCH 29/39] properly works now (i think) --- mlx_engine/cache.py | 14 +++++++------- mlx_engine/cache_wrapper.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 8ec291ac..48ddae03 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -4,6 +4,7 @@ from mlx_lm.models.cache import _BaseCache, KVCache import mlx.core as mx import mlx.nn as nn +import sys # unfortunate that this is hardcoded but what else is one to do @@ -65,9 +66,10 @@ def rope(self, v: mx.array, shift_by: int) -> mx.array: # a custom implementation somehow. also it allows us to easily use the # sustk scaled rope/yarn/llama3 rope impls in mlx_lm without having to # spin a custom implementation for those too (and any future rope variants) - return cat( - [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], - ) + # return cat( + # [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], + # ) + return v def _trim(self, trim_size, append_k=None, append_v=None) -> None: k = self.keys @@ -182,13 +184,11 @@ def do_reuse(self) -> None: def trim(self, n) -> int: # trim does not respect keep and it will stay this way n = min(self.offset, n) - print(f"debug: before: shape {self.keys.shape} keep {self.keep} n={n} {self.offset}os {self._idx}idx", file=sys.stderr) if n <= 0: return 0 # do trim: put us back into the state before the circular buffer is full self._temporal_order() - print(f"after TO: shape {self.keys.shape} keep {self.keep} n={n} {self.offset}os {self._idx}idx", file=sys.stderr) # TODO(christian-lms): stupid hack that belies a bigger problem. mayeb 184 shoudl be min vs. sks2 new_length = max(self.keys.shape[2] - n, 0) @@ -198,7 +198,6 @@ def trim(self, n) -> int: # TODO(christian-lms): maybe this is wrong??? maybe you have bigger problems elsewhere self.offset = new_length self._idx = new_length - print(self.keys.shape, self.offset, self._idx, file=sys.stderr) return n def _update_concat(self, keys, values): @@ -269,6 +268,7 @@ def update_and_fetch(self, keys, values): def set_keep(self, keep): # kv must be in temporal order, else we will keep the wrong thing self._temporal_order() + print(f"setting keep to {keep} with offset {self.offset} and idx {self._idx}", file=sys.stderr) self.keep = keep @property @@ -305,7 +305,7 @@ def to_quantized(self, group_size: int = 64, bits: int = 4) -> Any: def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, - keep: int = 4, + keep: int = 0, ) -> List[Any]: """ Construct the model's cache for use in generation. diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 4640b9e1..f32ab42f 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -98,6 +98,9 @@ def _truncate_cache( file=sys.stderr, ) + print(f"self tokens: {self.tokens.tolist()}", file=sys.stderr) + print(f"prompt tokens: {prompt_tokens.tolist()}", file=sys.stderr) + while cache_head_idx < cache_size and prompt_head_idx < prompt_size: match_length = self._find_matching_sequence_length( prompt_tokens, self.tokens, prompt_head_idx, cache_head_idx @@ -109,7 +112,6 @@ def _truncate_cache( else: if self.verbose: print(f"Reusing {match_length} tokens from cache", file=sys.stderr) - print(f"idx {prompt_head_idx} {cache_head_idx}") # found reusable sequence - shift cache content for cache in self.cache: @@ -128,6 +130,7 @@ def _truncate_cache( cache.do_reuse() self.tokens = self.tokens[: common_prefix_len + total_reused] + print(f"self post tokens: {self.tokens.tolist()}", file=sys.stderr) return total_reused def _get_unprocessed_tokens( @@ -159,12 +162,11 @@ def _get_unprocessed_tokens( prompt_tokens, common_prefix, ) - if n_reused_tokens > 0: - log_info( - prefix="CacheWrapper", - message=f"Reused {n_reused_tokens} tokens from the cache", - ) - common_prefix += n_reused_tokens + log_info( + prefix="CacheWrapper", + message=f"Reused {n_reused_tokens} tokens from the cache", + ) + common_prefix += n_reused_tokens # exclude some tokens from end, e.g. for kicking off generation if common_prefix >= len(prompt_tokens) - num_tokens_to_exclude: From aed67a9784f8e827f3c884bd006b2e7ad134678b Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 16:22:38 -0400 Subject: [PATCH 30/39] try to remove rope --- mlx_engine/cache.py | 110 ++++++------------------------------ mlx_engine/rope.py | 107 ----------------------------------- tests/test_cache_generic.py | 4 -- tests/test_cache_shift.py | 44 +++++++-------- tests/test_cache_wrapper.py | 6 +- 5 files changed, 41 insertions(+), 230 deletions(-) delete mode 100644 mlx_engine/rope.py diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 48ddae03..7ca0bb65 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -1,7 +1,6 @@ from typing import List, Optional, Any -from mlx_engine.logging import log_warn -from mlx_lm.models.cache import _BaseCache, KVCache +from mlx_lm.models.cache import RotatingKVCache, KVCache import mlx.core as mx import mlx.nn as nn import sys @@ -12,45 +11,14 @@ MAYBE_ROPE_NAMES = ["rope", "rotary_emb"] +# TODO(christian-lms): stop doing me def cat(v: mx.array): """Alias for mx.concatenate(v, axis=2) since that's used all over the place""" return mx.concatenate(v, axis=2) -def _maybe_get_rope(layer: nn.Module) -> Optional[nn.Module]: - for maybe_rope_name in MAYBE_ROPE_NAMES: - if hasattr(layer, maybe_rope_name): - # found it - return getattr(layer, maybe_rope_name) - for maybe_attn_name in MAYBE_ATTN_NAMES: - if hasattr(layer, maybe_attn_name): - # move down one level - return _maybe_get_rope(getattr(layer, maybe_attn_name)) - # no dice - return None - - -def maybe_get_rope(model: nn.Module, layer_idx: int) -> Optional[nn.Module]: - """Attempt to find the RoPE module from a layer of an MLX-LM LLM. - - Args: - model (nn.Module): The LLM to search for the RoPE modules of. - layer_idx (int): The layer of the LLM to get the RoPE module from. - - Returns: - Optional[nn.Module]: The RoPE module if found, else None - """ - # we can assume model has attribute layers because make_prompt_cache does - if layer_idx > len(model.layers): - return None - layer = model.layers[layer_idx] - if not isinstance(layer, nn.Module): - return None - return _maybe_get_rope(layer) - - -class ShiftingKVCache(_BaseCache): - def __init__(self, rope: nn.Module, max_size=256, keep=0, step=256): +class ShiftingKVCache(RotatingKVCache): + def __init__(self, max_size=256, keep=0, step=256): self.keep = keep self.keys = None self.values = None @@ -58,32 +26,22 @@ def __init__(self, rope: nn.Module, max_size=256, keep=0, step=256): self.max_size = max_size self.step = step self._idx = 0 - self._rope = rope self.reuse_queue = [] - def rope(self, v: mx.array, shift_by: int) -> mx.array: - # you'd think this is inefficient, but it seems faster than spinning - # a custom implementation somehow. also it allows us to easily use the - # sustk scaled rope/yarn/llama3 rope impls in mlx_lm without having to - # spin a custom implementation for those too (and any future rope variants) - # return cat( - # [self._rope(v[:, :, i : i + 1, :], shift_by) for i in range(v.shape[2])], - # ) - return v - + # TODO(christian-lms): does it matter if you don't change offsets? def _trim(self, trim_size, append_k=None, append_v=None) -> None: k = self.keys v = self.values if k is None or v is None: return assert k.shape == v.shape - shift_by = -trim_size if trim_size > 0: k_cat = [ k[..., : self.keep, :], - self.rope(k[..., trim_size + self.keep :, :], shift_by), + k[..., trim_size + self.keep :, :], ] v_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] + # TODO(christian-lms): try removing me. if it seems fine then revert self.offset -= trim_size else: k_cat = [k] @@ -94,7 +52,6 @@ def _trim(self, trim_size, append_k=None, append_v=None) -> None: v_cat.append(append_v) if append_v is not None: assert append_k is not None - # already done self.keys, self.values = cat(k_cat), cat(v_cat) def _temporal_order(self) -> None: @@ -111,17 +68,13 @@ def _temporal_order(self) -> None: elif self._idx < self.offset: shift_by = self.keep - self._idx # intentionally negative!!! assert shift_by <= 0 + # TODO(christian-lms): try removing me. if it seems fine then revert self.offset += shift_by kcat = cat( [ k[..., : self.keep, :], - # N.B. this implicitly assumes the generation has not gone over twice - # the size of the rotating section of the cache, in which case the - # rotating section would be off by a multiple of (max_kv_size - keep) - # depending on how many times it rolled over. I feel like it's pretty - # safe to assume that this is a rare case - self.rope(k[..., self._idx :, :], shift_by), - self.rope(k[..., self.keep : self._idx, :], shift_by), + k[..., self._idx :, :], + k[..., self.keep : self._idx, :], ], ) vcat = cat( @@ -162,14 +115,8 @@ def do_reuse(self) -> None: shift_by = write_start_idx - reuse_start_idx # intentionally negative!!! reuse_end_idx = reuse_start_idx + reuse_length - keys_to_reuse = self.keys[..., reuse_start_idx:reuse_end_idx, :] - values_to_reuse = self.values[..., reuse_start_idx:reuse_end_idx, :] - - # only keys require rope - shifted_keys = self.rope(keys_to_reuse, shift_by) - - key_segments.append(shifted_keys) - value_segments.append(values_to_reuse) + key_segments.append(self.keys[..., reuse_start_idx:reuse_end_idx, :]) + value_segments.append(self.values[..., reuse_start_idx:reuse_end_idx, :]) current_pos = write_start_idx + reuse_length self.offset += shift_by @@ -190,7 +137,6 @@ def trim(self, n) -> int: # do trim: put us back into the state before the circular buffer is full self._temporal_order() - # TODO(christian-lms): stupid hack that belies a bigger problem. mayeb 184 shoudl be min vs. sks2 new_length = max(self.keys.shape[2] - n, 0) self.keys = self.keys[..., :new_length, :] self.values = self.values[..., :new_length, :] @@ -305,7 +251,7 @@ def to_quantized(self, group_size: int = 64, bits: int = 4) -> Any: def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, - keep: int = 0, + keep: int = 4, ) -> List[Any]: """ Construct the model's cache for use in generation. @@ -314,37 +260,13 @@ def make_prompt_cache( Args: model (nn.Module): The language model. max_kv_size (Optional[int]): If provided and the model does not have a - ``make_cache`` method, a ``TrimmableRotatingKVCache`` is used - with a maximum size of ``max_kv_size`` + ``make_cache`` method, a ``ShiftingKVCache`` is used with a maximum + size of ``max_kv_size`` """ if hasattr(model, "make_cache"): - # TODO(christian-lms): gah what are you gonna do about models that do this - # afm7 baichuan_m1 cohere2 gemma3(+friends) llama4 mamba plamo2 recurrent_gemma - # m1 mamba plamo2 recurrent_gemma are hybrid - # - afm7 is trivially overridable - # - cohere2 is swa on some layers but can probably be overridden - # - gemma3 see cohere2 - # - llama4 uses chunked kv on some layers but can maybe be overridden - # though these layers don't have rope modules - - # try to get the model name from model.args.model_type but i suppose this will - # not always work. that or literally model.__name__ hopefully return model.make_cache() num_layers = len(model.layers) if max_kv_size is not None: - cache = [] - for layer in range(num_layers): - rope = maybe_get_rope(model, layer) - # TODO(christian-lms): it is known that this will fail for some models - # like llama4 which has no rope module for every fourth layer. - # this will be figured out Later(tm) once the initial functionality works - if rope is None: - log_warn( - "Attempted to build a KV cache of shiftable caches, but found" - f"None at layer {layer} of model {model}" - ) - return [KVCache() for _ in range(num_layers)] - cache.append(ShiftingKVCache(rope, max_size=max_kv_size, keep=keep)) - return cache + return [ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers)] else: return [KVCache() for _ in range(num_layers)] diff --git a/mlx_engine/rope.py b/mlx_engine/rope.py deleted file mode 100644 index eccf7331..00000000 --- a/mlx_engine/rope.py +++ /dev/null @@ -1,107 +0,0 @@ -"""So... - -...this isn't optimized yet. It turns out that at small sequence lengths literally just naively -applying RoPE to individual tokens is faster than doing the matrix multiplication, even with the overhead -introduced by the for loop and the theoretical optimization of the matrix multiplication. It does -begin to tip in favor of this shifting method @ larger seqlens (tested at [1,8,1000,128]) but the -overhead converting between MLX and torch still makes it slower overall than MLX-native. - -- the MLX rope shift is weird anyway, but it apparently still works: TODO ask awni why -- this implementation is naive and doesn't leverage the sparsity of the RoPE matrix -- honestly it's probably easier to just use the MLX RoPE shift directly in the model - because this allows us to not have to write custom modules for YaRN and llama3 and - what have you, but i'll leave this here for now in case it becomes useful later -""" - -import torch -import mlx.core as mx -from mlx_lm import load -import numpy as np - -def mlx_rope_shift(x, shift_amount, theta=10000.0, scale=1.0, traditional=False): - """ - MLX-compatible RoPE implementation using matrix multiplication. - Creates a rotation matrix and applies it via matmul to shift all positions by shift_amount. - - Args: - x: Input tensor of shape [bsz, n_kv_heads, seqlen, kv_head_dim] - shift_amount: Number of positions to shift (D) - theta: Base frequency for RoPE (default: 10000.0) - scale: Scaling factor for frequencies (default: 1.0) - traditional: If True, use traditional RoPE pairing (0,1), (2,3), ... - If False, use MLX-style pairing (0,d/2), (1,d/2+1), ... (default: False) - - Returns: - Rotated tensor of same shape as input - """ - bsz, n_heads, seqlen, head_dim = x.shape - device = x.device - - assert head_dim % 2 == 0, "Head dimension must be even" - dim_pairs = head_dim // 2 - - if traditional: - # traditional RoPE: pair adjacent dimensions (0,1), (2,3), (4,5), ... - frequencies = 1.0 / (theta ** (torch.arange(0, dim_pairs, dtype=torch.float32, device=device) * 2.0 / head_dim)) - else: - # MLX-style RoPE: pair first half with second half (0,d/2), (1,d/2+1), ... - frequencies = 1.0 / (theta ** (torch.arange(0, dim_pairs, dtype=torch.float32, device=device) * 2.0 / head_dim)) - - frequencies = frequencies * scale - angles = shift_amount * frequencies # shape: [dim_pairs] - cos_vals = torch.cos(angles) # shape: [dim_pairs] - sin_vals = torch.sin(angles) # shape: [dim_pairs] - - rotation_matrix = torch.eye(head_dim, device=device, dtype=x.dtype) - - if traditional: - for i in range(dim_pairs): - even_idx = i * 2 - odd_idx = i * 2 + 1 - - rotation_matrix[even_idx, even_idx] = cos_vals[i] - rotation_matrix[even_idx, odd_idx] = -sin_vals[i] - rotation_matrix[odd_idx, even_idx] = sin_vals[i] - rotation_matrix[odd_idx, odd_idx] = cos_vals[i] - else: - for i in range(dim_pairs): - first_idx = i - second_idx = i + dim_pairs - - cos_val = cos_vals[i] - sin_val = sin_vals[i] - - rotation_matrix[first_idx, first_idx] = cos_val - rotation_matrix[first_idx, second_idx] = -sin_val - rotation_matrix[second_idx, first_idx] = sin_val - rotation_matrix[second_idx, second_idx] = cos_val - - rotated = x @ rotation_matrix.T - - return rotated - - -def stupid_rope(r, v, shift_by: int = 0): - return mx.concatenate([r(v[:,:,i:i+1,:], shift_by) for i in range(v.shape[2])], axis=2) - -def main(): - model, _ = load("mlx-community/Qwen3-0.6B-bf16") - - v = mx.random.normal((1, 8, 10, 128), scale=1.0, dtype=mx.float32) - - import time - start_time = time.time() - silly = stupid_rope(model.layers[0].self_attn.rope, v, 7) - end_time = time.time() - elapsed_time = end_time - start_time - print(f"MLX RoPE shift took {elapsed_time:.6f} seconds") - converted = torch.from_numpy(np.array(v)) - start_time = time.time() - eff = mlx_rope_shift(converted, 7, theta=1000000.0, scale=1.0, traditional=False) - end_time = time.time() - elapsed_time2 = end_time - start_time - print(f"Torch RoPE shift took {elapsed_time2:.6f} seconds") - print(torch.allclose(torch.from_numpy(np.array(silly)), eff, atol=1e-5)) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py index a558b1ac..6d723547 100644 --- a/tests/test_cache_generic.py +++ b/tests/test_cache_generic.py @@ -12,10 +12,6 @@ def setUpClass(cls): cls.kv_head_dim = 4 cls.bsz = 1 cls.n_kv_heads = 1 - # cannot be used raw: must be wrapped in the cache.rope workaround impl - cls._rope = nn.RoPE( - dims=cls.kv_head_dim, traditional=False, base=100000, scale=1.0 - ) @classmethod def make_random_kv(cls, seqlen: int): diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index cd3986a8..5cdb57b3 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -12,7 +12,7 @@ def idx(v: mx.array, i: int): class TestShiftingKVCache(TestCache): def test_overwriting(self): """Test overwriting when the cache reaches max_size""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -30,7 +30,7 @@ def test_overwriting(self): def test_ensure_update_increases_offset_indefinitely(self): """Test single-token updates that should increase offset""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) for i in range(10): self.add_random_to_cache(cache, 1) @@ -38,7 +38,7 @@ def test_ensure_update_increases_offset_indefinitely(self): def test_temporal_order_shift_rope(self): """Test the RoPE shift in _temporal_order""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -53,13 +53,13 @@ def test_temporal_order_shift_rope(self): keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) - self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) + self.assertArrEqual(idx(keys, 1), idx(reference, 2)) + self.assertArrEqual(idx(keys, 2), overwrite) self.assertEqual(cache.offset, 3) def test_temporal_order_shift_no_rope(self): """Test putting the cache in temporal order""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -80,7 +80,7 @@ def test_temporal_order_shift_no_rope(self): def test_trim_internal_shift_rope(self): """Test the RoPE shift in _trim (internal method)""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -92,13 +92,13 @@ def test_trim_internal_shift_rope(self): self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(idx(keys, 0), idx(reference, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) + self.assertArrEqual(idx(keys, 1), idx(reference, 2)) # trim should trigger offset change with is_key=True self.assertEqual(cache.offset, 2) def test_trim_internal_shift_no_rope(self): """Test the RoPE shift in _trim (internal method)""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -118,7 +118,7 @@ def test_ensure_reasonable_size_and_shift(self): than max_size. The default behavior of the cache is to write the entire thing, then trim it back down when the next KV is written. """ - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 0123456789 reference = self.add_random_to_cache(cache, 10) @@ -135,15 +135,15 @@ def test_ensure_reasonable_size_and_shift(self): self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), overwrite) - self.assertArrEqual(idx(keys, 2), cache.rope(idx(reference, 9), -7)) + self.assertArrEqual(idx(keys, 2), idx(reference, 9)) # make sure pos embs are right cache._temporal_order() keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 9), -8)) - self.assertArrEqual(idx(keys, 2), cache.rope(overwrite, -1)) + self.assertArrEqual(idx(keys, 1), idx(reference, 9)) + self.assertArrEqual(idx(keys, 2), overwrite) self.assertEqual(cache.offset, 3) # ensure offset keeps increasing @@ -155,7 +155,7 @@ def test_ensure_reasonable_size_and_shift(self): def test_update_keep_on_the_fly(self): """Test changing the keep value on the fly""" - cache = ShiftingKVCache(self._rope, max_size=4, keep=1) + cache = ShiftingKVCache(max_size=4, keep=1) # fill cache -> 1234 reference = self.add_random_to_cache(cache, 4) @@ -174,13 +174,13 @@ def test_update_keep_on_the_fly(self): keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) - self.assertArrEqual(idx(keys, 1), cache.rope(idx(reference, 2), -1)) + self.assertArrEqual(idx(keys, 1), idx(reference, 2)) self.assertArrEqual(idx(keys, 2), overwrite2) - self.assertArrEqual(idx(keys, 3), cache.rope(overwrite, -1)) + self.assertArrEqual(idx(keys, 3), overwrite) def test_trim_before_full(self): """Test trimming from the end before the cache is full""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 12 reference = self.add_random_to_cache(cache, 2) @@ -205,7 +205,7 @@ def test_trim_before_full(self): def test_trim_after_overwrite(self): """Test trimming from the end when we've written past the cache max""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache -> 123 reference = self.add_random_to_cache(cache, 3) @@ -221,7 +221,7 @@ def test_trim_after_overwrite(self): keys = cache.state[0] should_be_kv = cat( - [reference[:, :, :1, :], cache.rope(reference[:, :, 2:3, :], -1)] + [reference[:, :, :1, :], reference[:, :, 2:3, :]] ) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(keys, should_be_kv) @@ -229,7 +229,7 @@ def test_trim_after_overwrite(self): def test_trim_after_full(self): """Test trimming from the end when the cache is oversize""" - cache = ShiftingKVCache(self._rope, max_size=3, keep=1) + cache = ShiftingKVCache(max_size=3, keep=1) # fill cache oversize already -> 1234 reference = self.add_random_to_cache(cache, 4) @@ -254,7 +254,7 @@ def test_trim_after_full(self): def test_reuse(self): """Test basic reuse APIs""" - cache = ShiftingKVCache(self._rope, max_size=8, keep=1) + cache = ShiftingKVCache(max_size=8, keep=1) # fill cache -> 12345678 reference = self.add_random_to_cache(cache, 8) @@ -266,7 +266,7 @@ def test_reuse(self): # this is what the remaining cache should look like should_be_keys = cat( - [reference[:, :, :3, :], cache.rope(reference[:, :, 4:6, :], -1)] + [reference[:, :, :3, :], reference[:, :, 4:6, :]] ) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index bb908b81..61eb01c5 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -116,7 +116,7 @@ def test_record_generated_token_loops(self): def test_cache_reuse_heavy(self): cache = CacheWrapper(DummyModel(), 10) - cache.cache[0] = ShiftingKVCache(self._rope, max_size=10, keep=2) + cache.cache[0] = ShiftingKVCache(max_size=10, keep=2) # set up pretend cache cached_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) @@ -146,8 +146,8 @@ def idx(v, a, b): should_be_kv = cat( [ idx(cache_kv, 0, 2), - cache.cache[0].rope(idx(cache_kv, 3, 4), -1), - cache.cache[0].rope(idx(cache_kv, 6, 9), -3), + idx(cache_kv, 3, 4), + idx(cache_kv, 6, 9), ], ) From 1b9af40645da7c0aaf3cca4b48b2dfc7e9dd074a Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 16:36:40 -0400 Subject: [PATCH 31/39] simplify cache again --- mlx_engine/cache.py | 183 ++------------------------------------------ 1 file changed, 7 insertions(+), 176 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 7ca0bb65..7d010262 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -3,90 +3,12 @@ from mlx_lm.models.cache import RotatingKVCache, KVCache import mlx.core as mx import mlx.nn as nn -import sys - - -# unfortunate that this is hardcoded but what else is one to do -MAYBE_ATTN_NAMES = ["self_attn", "attention", "attn", "mixer", "norm_attn_norm"] -MAYBE_ROPE_NAMES = ["rope", "rotary_emb"] - - -# TODO(christian-lms): stop doing me -def cat(v: mx.array): - """Alias for mx.concatenate(v, axis=2) since that's used all over the place""" - return mx.concatenate(v, axis=2) class ShiftingKVCache(RotatingKVCache): def __init__(self, max_size=256, keep=0, step=256): - self.keep = keep - self.keys = None - self.values = None - self.offset = 0 - self.max_size = max_size - self.step = step - self._idx = 0 self.reuse_queue = [] - - # TODO(christian-lms): does it matter if you don't change offsets? - def _trim(self, trim_size, append_k=None, append_v=None) -> None: - k = self.keys - v = self.values - if k is None or v is None: - return - assert k.shape == v.shape - if trim_size > 0: - k_cat = [ - k[..., : self.keep, :], - k[..., trim_size + self.keep :, :], - ] - v_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] - # TODO(christian-lms): try removing me. if it seems fine then revert - self.offset -= trim_size - else: - k_cat = [k] - v_cat = [v] - if append_k is not None: - assert append_v is not None - k_cat.append(append_k) - v_cat.append(append_v) - if append_v is not None: - assert append_k is not None - self.keys, self.values = cat(k_cat), cat(v_cat) - - def _temporal_order(self) -> None: - """ - Rearrange the cache into temporal order, slicing off the end if unused. - """ - k = self.keys - v = self.values - if k is None or v is None: - return - assert k.shape == v.shape - if self._idx == v.shape[2]: - pass - elif self._idx < self.offset: - shift_by = self.keep - self._idx # intentionally negative!!! - assert shift_by <= 0 - # TODO(christian-lms): try removing me. if it seems fine then revert - self.offset += shift_by - kcat = cat( - [ - k[..., : self.keep, :], - k[..., self._idx :, :], - k[..., self.keep : self._idx, :], - ], - ) - vcat = cat( - [ - v[..., : self.keep, :], - v[..., self._idx :, :], - v[..., self.keep : self._idx, :], - ], - ) - self.keys, self.values = kcat, vcat - else: - self.keys, self.values = k[..., : self._idx, :], v[..., : self._idx, :] + super().__init__(max_size=max_size, keep=keep, step=step) def reuse_section( self, write_start_idx: int, reuse_start_idx: int, reuse_length: int @@ -121,7 +43,8 @@ def do_reuse(self) -> None: current_pos = write_start_idx + reuse_length self.offset += shift_by - self.keys, self.values = cat(key_segments), cat(value_segments) + self.keys = mx.concatenate(key_segments, axis=2) + self.values = mx.concatenate(value_segments, axis=2) # clean up self.reuse_queue = [] @@ -135,118 +58,26 @@ def trim(self, n) -> int: return 0 # do trim: put us back into the state before the circular buffer is full - self._temporal_order() + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) new_length = max(self.keys.shape[2] - n, 0) self.keys = self.keys[..., :new_length, :] self.values = self.values[..., :new_length, :] - # TODO(christian-lms): maybe this is wrong??? maybe you have bigger problems elsewhere self.offset = new_length self._idx = new_length return n - def _update_concat(self, keys, values): - if self.keys is None: - self.keys = keys - self.values = values - else: - # Put the keys/values in temporal order to preserve context - self._temporal_order() - - # The largest size is self.max_size + S to ensure - # every token gets at least self.max_size context - trim_size = self._idx - self.max_size - self._trim(trim_size, append_k=keys, append_v=values) - self.offset += keys.shape[2] - self._idx = self.keys.shape[2] - return self.keys, self.values - - def _update_in_place(self, keys, values): - # May not have hit the max size yet, so potentially - # keep growing the cache - B, n_kv_heads, S, k_head_dim = keys.shape - prev = self.offset - if self.keys is None or ( - prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size - ): - v_head_dim = values.shape[3] - new_size = min(self.step, self.max_size - prev) - k_shape = (B, n_kv_heads, new_size, k_head_dim) - v_shape = (B, n_kv_heads, new_size, v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - self.keys = cat([self.keys, new_k]) - self.values = cat([self.values, new_v]) - else: - self.keys, self.values = new_k, new_v - self._idx = prev - - # Trim if needed - trim_size = self.keys.shape[2] - self.max_size - if trim_size > 0: - self._trim( - trim_size, - ) - self._idx = self.max_size - - # Rotate - if self._idx == self.max_size: - self._idx = self.keep - - # Assign - self.keys[..., self._idx : self._idx + S, :] = keys - self.values[..., self._idx : self._idx + S, :] = values - self.offset += S - self._idx += S - - # If the buffer is not full, slice off the end - if self.offset < self.max_size: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - return self.keys, self.values - - def update_and_fetch(self, keys, values): - if keys.shape[2] == 1: - return self._update_in_place(keys, values) - return self._update_concat(keys, values) - def set_keep(self, keep): # kv must be in temporal order, else we will keep the wrong thing - self._temporal_order() - print(f"setting keep to {keep} with offset {self.offset} and idx {self._idx}", file=sys.stderr) + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) self.keep = keep - @property - def state(self): - if self.offset < self.keys.shape[2]: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - else: - return self.keys, self.values - - @state.setter - def state(self, v): - self.keys, self.values = v - - @property - def meta_state(self): - return tuple( - map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) - ) - - @meta_state.setter - def meta_state(self, v): - self.keep, self.max_size, self.step, self.offset, self._idx = map( - int, - v, - ) - def is_trimmable(self) -> bool: return True - def to_quantized(self, group_size: int = 64, bits: int = 4) -> Any: - raise NotImplementedError("ShiftingKVCache Quantization NYI") - def make_prompt_cache( model: nn.Module, From a1521d8dcf84cc13e2b04a94c91a5491d34f73c6 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 16:44:03 -0400 Subject: [PATCH 32/39] more reductionism --- mlx_engine/cache.py | 4 +- tests/test_cache_shift.py | 103 +++++------------------------------- tests/test_cache_wrapper.py | 9 ++-- 3 files changed, 21 insertions(+), 95 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 7d010262..817910f9 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -98,6 +98,8 @@ def make_prompt_cache( return model.make_cache() num_layers = len(model.layers) if max_kv_size is not None: - return [ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers)] + return [ + ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers) + ] else: return [KVCache() for _ in range(num_layers)] diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 5cdb57b3..22e981aa 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -1,6 +1,6 @@ import unittest import mlx.core as mx -from mlx_engine.cache import ShiftingKVCache, cat +from mlx_engine.cache import ShiftingKVCache from tests.test_cache_generic import TestCache @@ -36,83 +36,6 @@ def test_ensure_update_increases_offset_indefinitely(self): self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset - 1, i) - def test_temporal_order_shift_rope(self): - """Test the RoPE shift in _temporal_order""" - cache = ShiftingKVCache(max_size=3, keep=1) - - # fill cache -> 123 - reference = self.add_random_to_cache(cache, 3) - self.assertEqual(cache.offset, 3) - - # attempt to write another element 4 -> 143 - overwrite = self.add_random_to_cache(cache, 1) - self.assertEqual(cache.offset, 4) - - # put the cache in temporal order -> 134 -> 123 (rope shift) - cache._temporal_order() - keys = cache.state[0] - - self.assertArrEqual(idx(keys, 0), idx(reference, 0)) - self.assertArrEqual(idx(keys, 1), idx(reference, 2)) - self.assertArrEqual(idx(keys, 2), overwrite) - self.assertEqual(cache.offset, 3) - - def test_temporal_order_shift_no_rope(self): - """Test putting the cache in temporal order""" - cache = ShiftingKVCache(max_size=3, keep=1) - - # fill cache -> 123 - reference = self.add_random_to_cache(cache, 3) - self.assertEqual(cache.offset, 3) - - # attempt to write another element 4 -> 143 - overwrite = self.add_random_to_cache(cache, 1) - self.assertEqual(cache.offset, 4) - - # put the cache in temporal order -> 134 (no rope shift) - cache._temporal_order() - values = cache.state[1] - - self.assertArrEqual(idx(values, 0), idx(reference, 0)) - self.assertArrEqual(idx(values, 1), idx(reference, 2)) - self.assertArrEqual(idx(values, 2), overwrite) - self.assertEqual(cache.offset, 3) - - def test_trim_internal_shift_rope(self): - """Test the RoPE shift in _trim (internal method)""" - cache = ShiftingKVCache(max_size=3, keep=1) - - # fill cache -> 123 - reference = self.add_random_to_cache(cache, 3) - self.assertEqual(cache.offset, 3) - - # trim 1 from middle -> 13 - cache._trim(1) - keys = cache.state[0] - - self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(idx(keys, 0), idx(reference, 0)) - self.assertArrEqual(idx(keys, 1), idx(reference, 2)) - # trim should trigger offset change with is_key=True - self.assertEqual(cache.offset, 2) - - def test_trim_internal_shift_no_rope(self): - """Test the RoPE shift in _trim (internal method)""" - cache = ShiftingKVCache(max_size=3, keep=1) - - # fill cache -> 123 - reference = self.add_random_to_cache(cache, 3) - self.assertEqual(cache.offset, 3) - - # trim 1 from middle -> 13 -> 12 - cache._trim(1) - values = cache.state[1] - - self.assertEqual(values.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) - self.assertArrEqual(idx(values, 0), idx(reference, 0)) - self.assertArrEqual(idx(values, 1), idx(reference, 2)) - self.assertEqual(cache.offset, 2) - def test_ensure_reasonable_size_and_shift(self): """Test behavior when the cache gets a KV batch-written that is much larger than max_size. The default behavior of the cache is to write the entire thing, @@ -130,28 +53,28 @@ def test_ensure_reasonable_size_and_shift(self): overwrite = self.add_random_to_cache(cache, 1) keys = cache.state[0] self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) - # this should be 4 since this mimics autoregression - self.assertEqual(cache.offset, 4) + self.assertEqual(cache.offset, 11) self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), overwrite) self.assertArrEqual(idx(keys, 2), idx(reference, 9)) # make sure pos embs are right - cache._temporal_order() + cache.keys = cache._temporal_order(cache.keys) + cache.values = cache._temporal_order(cache.values) keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) self.assertArrEqual(idx(keys, 1), idx(reference, 9)) self.assertArrEqual(idx(keys, 2), overwrite) - self.assertEqual(cache.offset, 3) + self.assertEqual(cache.offset, 11) # ensure offset keeps increasing self.add_random_to_cache(cache, 1) - self.assertEqual(cache.offset, 4) + self.assertEqual(cache.offset, 12) self.add_random_to_cache(cache, 1) - self.assertEqual(cache.offset, 5) + self.assertEqual(cache.offset, 13) def test_update_keep_on_the_fly(self): """Test changing the keep value on the fly""" @@ -168,9 +91,9 @@ def test_update_keep_on_the_fly(self): # and attempt to write another element 5 -> 1254 # offset updates after set_keep (anytime we reorder/rope shift) cache.set_keep(2) - self.assertEqual(cache.offset, 4) - overwrite2 = self.add_random_to_cache(cache, 1) self.assertEqual(cache.offset, 5) + overwrite2 = self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 6) keys = cache.state[0] self.assertArrEqual(idx(keys, 0), idx(reference, 0)) @@ -220,8 +143,8 @@ def test_trim_after_overwrite(self): cache.trim(1) keys = cache.state[0] - should_be_kv = cat( - [reference[:, :, :1, :], reference[:, :, 2:3, :]] + should_be_kv = mx.concatenate( + [reference[:, :, :1, :], reference[:, :, 2:3, :]], axis=2 ) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) self.assertArrEqual(keys, should_be_kv) @@ -265,8 +188,8 @@ def test_reuse(self): keys = cache.state[0] # this is what the remaining cache should look like - should_be_keys = cat( - [reference[:, :, :3, :], reference[:, :, 4:6, :]] + should_be_keys = mx.concatenate( + [reference[:, :, :3, :], reference[:, :, 4:6, :]], axis=2 ) self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 61eb01c5..a6da73c4 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx from mlx_engine.cache_wrapper import CacheWrapper -from mlx_engine.cache import ShiftingKVCache, cat +from mlx_engine.cache import ShiftingKVCache from tests.test_cache_generic import TestCache from tests.utils import DummyModel @@ -143,12 +143,13 @@ def idx(v, a, b): return v[:, :, a:b, :] should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) - should_be_kv = cat( + should_be_kv = mx.concatenate( [ idx(cache_kv, 0, 2), idx(cache_kv, 3, 4), idx(cache_kv, 6, 9), ], + axis=2, ) self.assertEqual(total_reused, 4) @@ -158,13 +159,13 @@ def idx(v, a, b): # ensure updating works as intended new_kv = self.make_random_kv(1) keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) - should_be_kv = cat([should_be_kv, new_kv]) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) self.assertArrEqual(keys, should_be_kv) # ensure batch concat works as intended new_kv = self.make_random_kv(2) keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) - should_be_kv = cat([should_be_kv, new_kv]) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) self.assertArrEqual(keys, should_be_kv) From dd205ba45877cedaf49bf56788ddc6a9eaa9f98e Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 16:46:11 -0400 Subject: [PATCH 33/39] remove prints --- mlx_engine/cache.py | 2 +- mlx_engine/cache_wrapper.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 817910f9..f9c36094 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -52,7 +52,7 @@ def do_reuse(self) -> None: self.offset = self.keys.shape[2] def trim(self, n) -> int: - # trim does not respect keep and it will stay this way + # trim does not respect keep, which must be the case n = min(self.offset, n) if n <= 0: return 0 diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index f32ab42f..4f3835b3 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -98,9 +98,6 @@ def _truncate_cache( file=sys.stderr, ) - print(f"self tokens: {self.tokens.tolist()}", file=sys.stderr) - print(f"prompt tokens: {prompt_tokens.tolist()}", file=sys.stderr) - while cache_head_idx < cache_size and prompt_head_idx < prompt_size: match_length = self._find_matching_sequence_length( prompt_tokens, self.tokens, prompt_head_idx, cache_head_idx @@ -130,7 +127,6 @@ def _truncate_cache( cache.do_reuse() self.tokens = self.tokens[: common_prefix_len + total_reused] - print(f"self post tokens: {self.tokens.tolist()}", file=sys.stderr) return total_reused def _get_unprocessed_tokens( From a740335c053317dbe19f1c153ed4f13ede3d9ac1 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 16:48:33 -0400 Subject: [PATCH 34/39] ??? oops --- tests/utils.py | 90 +++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index dd31698a..e3004324 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,7 @@ import sys import subprocess -# from mlx_engine.generate import load_model, load_draft_model, tokenize +from mlx_engine.generate import load_model, load_draft_model, tokenize class DummyModel: @@ -10,59 +10,59 @@ class DummyModel: layers = [0] -# def model_getter(model_name: str): -# """Helper method to get a model, prompt user to download if not found""" +def model_getter(model_name: str): + """Helper method to get a model, prompt user to download if not found""" -# with open(Path("~/.lmstudio-home-pointer").expanduser().resolve(), "r") as f: -# lmstudio_home = Path(f.read().strip()) -# model_path = lmstudio_home / "models" / model_name + with open(Path("~/.lmstudio-home-pointer").expanduser().resolve(), "r") as f: + lmstudio_home = Path(f.read().strip()) + model_path = lmstudio_home / "models" / model_name -# # Check if model exists, if not prompt user to download -# if not model_path.exists(): -# print(f"\nModel {model_name} not found at {model_path}") + # Check if model exists, if not prompt user to download + if not model_path.exists(): + print(f"\nModel {model_name} not found at {model_path}") -# def greenify(text): -# return f"\033[92m{text}\033[0m" + def greenify(text): + return f"\033[92m{text}\033[0m" -# response = input( -# f"Would you like to download the model {greenify(model_name)}? (y/N): " -# ) -# if response.lower() == "y": -# print(f"Downloading model with command: lms get {model_name}") -# subprocess.run(["lms", "get", model_name], check=True) -# else: -# print(f"Model {model_name} not found") -# sys.exit(1) + response = input( + f"Would you like to download the model {greenify(model_name)}? (y/N): " + ) + if response.lower() == "y": + print(f"Downloading model with command: lms get {model_name}") + subprocess.run(["lms", "get", model_name], check=True) + else: + print(f"Model {model_name} not found") + sys.exit(1) -# return model_path + return model_path -# def model_load_and_tokenize_prompt( -# model_name: str, -# prompt: str, -# max_kv_size=4096, -# trust_remote_code=False, -# draft_model_name=None, -# ): -# """Helper method to test a model""" -# print(f"Testing model {model_name}") +def model_load_and_tokenize_prompt( + model_name: str, + prompt: str, + max_kv_size=4096, + trust_remote_code=False, + draft_model_name=None, +): + """Helper method to test a model""" + print(f"Testing model {model_name}") -# # Check if model exists, if not prompt user to download -# model_path = model_getter(model_name) + # Check if model exists, if not prompt user to download + model_path = model_getter(model_name) -# # Load the model -# model_kit = load_model( -# model_path=model_path, -# max_kv_size=max_kv_size, -# trust_remote_code=trust_remote_code, -# ) + # Load the model + model_kit = load_model( + model_path=model_path, + max_kv_size=max_kv_size, + trust_remote_code=trust_remote_code, + ) -# # Load the draft model if any -# if draft_model_name is not None: -# draft_model_path = model_getter(draft_model_name) -# load_draft_model(model_kit, draft_model_path) + # Load the draft model if any + if draft_model_name is not None: + draft_model_path = model_getter(draft_model_name) + load_draft_model(model_kit, draft_model_path) -# # Tokenize the prompt -# prompt_tokens = tokenize(model_kit, prompt) + # Tokenize the prompt + prompt_tokens = tokenize(model_kit, prompt) -# return model_kit, prompt_tokens + return model_kit, prompt_tokens From 349b5a80de22204fe8e0761f4b2605008de83d14 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 16:54:33 -0400 Subject: [PATCH 35/39] final fixes for now --- mlx_engine/cache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index f9c36094..3ea58ac4 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -71,8 +71,10 @@ def trim(self, n) -> int: def set_keep(self, keep): # kv must be in temporal order, else we will keep the wrong thing - self.keys = self._temporal_order(self.keys) - self.values = self._temporal_order(self.values) + if self.keys is not None: + self.keys = self._temporal_order(self.keys) + if self.values is not None: + self.values = self._temporal_order(self.values) self.keep = keep def is_trimmable(self) -> bool: From 1c4cf24bcfad52d6ff7d590d8e412552fd11194e Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 17:15:45 -0400 Subject: [PATCH 36/39] more fixes --- mlx_engine/cache.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index 3ea58ac4..c5057cd7 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -33,16 +33,12 @@ def do_reuse(self) -> None: key_segments.append(self.keys[..., current_pos:write_start_idx, :]) value_segments.append(self.values[..., current_pos:write_start_idx, :]) - # add the reused segment with RoPE shift - shift_by = write_start_idx - reuse_start_idx # intentionally negative!!! reuse_end_idx = reuse_start_idx + reuse_length + current_pos = write_start_idx + reuse_length key_segments.append(self.keys[..., reuse_start_idx:reuse_end_idx, :]) value_segments.append(self.values[..., reuse_start_idx:reuse_end_idx, :]) - current_pos = write_start_idx + reuse_length - self.offset += shift_by - self.keys = mx.concatenate(key_segments, axis=2) self.values = mx.concatenate(value_segments, axis=2) @@ -52,12 +48,12 @@ def do_reuse(self) -> None: self.offset = self.keys.shape[2] def trim(self, n) -> int: - # trim does not respect keep, which must be the case + # trim must not respect keep n = min(self.offset, n) if n <= 0: return 0 - # do trim: put us back into the state before the circular buffer is full + # put us back into the state before the circular buffer is full self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) From bf66e2b3e392506f320cd697b3da482e928de6cd Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 10 Jul 2025 17:22:59 -0400 Subject: [PATCH 37/39] make linter happy --- tests/test_cache_generic.py | 6 +----- tests/utils.py | 1 + 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py index 6d723547..4a7b21f4 100644 --- a/tests/test_cache_generic.py +++ b/tests/test_cache_generic.py @@ -1,6 +1,5 @@ import unittest import mlx.core as mx -import mlx.nn as nn from copy import deepcopy from mlx_engine.cache import ShiftingKVCache @@ -27,9 +26,7 @@ def assertArrEqual(self, a: mx.array, b: mx.array): self.assertEqual(a.shape, b.shape) self.assertTrue(mx.allclose(a, b), "Tensors are not equal") - def add_random_to_cache( - self, cache: ShiftingKVCache, seqlen: int - ) -> mx.array: + def add_random_to_cache(self, cache: ShiftingKVCache, seqlen: int) -> mx.array: """Add random values to the cache and return them""" base_kv = self.make_random_kv(seqlen) # base_kv is *assigned* to cache.keys/cache.values so returning base_kv @@ -37,4 +34,3 @@ def add_random_to_cache( reference = deepcopy(base_kv) cache.update_and_fetch(base_kv, base_kv) return reference - diff --git a/tests/utils.py b/tests/utils.py index e3004324..ad3fd8e7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,6 +7,7 @@ class DummyModel: """Dummy model class for testing""" + layers = [0] From 3d51d580cd25bf56ab61d4bcd7577b2bf6694222 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 11:08:44 -0400 Subject: [PATCH 38/39] fix trim --- mlx_engine/cache_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 4f3835b3..7cbd5c3e 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -169,7 +169,8 @@ def _get_unprocessed_tokens( common_prefix = len(prompt_tokens) - num_tokens_to_exclude # Trim the cache if the common prefix is shorter than the current cache - num_tokens_to_trim = self.cache[0].offset - common_prefix + # state[0] is an alias for keys that accounts for partially filled buffers + num_tokens_to_trim = self.cache[0].state[0].shape[2] - common_prefix if num_tokens_to_trim > 0: if not can_trim_prompt_cache(self.cache): log_warn( From 18b6dc33d7b3ec049fa940f3ea33f8c76c61df01 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 11 Jul 2025 12:42:55 -0400 Subject: [PATCH 39/39] extra tests (in progress) --- mlx_engine/cache.py | 4 ++ tests/test_cache_shift.py | 28 ++++++++ tests/test_cache_wrapper.py | 140 +++++++++++++++++++++++++++++++++++- 3 files changed, 170 insertions(+), 2 deletions(-) diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py index c5057cd7..9f7088a8 100644 --- a/mlx_engine/cache.py +++ b/mlx_engine/cache.py @@ -20,6 +20,10 @@ def do_reuse(self) -> None: if not self.reuse_queue: return + # just in case maybe + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + # just in case, sort in write order self.reuse_queue.sort(key=lambda x: x[0]) diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py index 22e981aa..0ae01004 100644 --- a/tests/test_cache_shift.py +++ b/tests/test_cache_shift.py @@ -196,6 +196,34 @@ def test_reuse(self): self.assertArrEqual(keys, should_be_keys) self.assertEqual(cache.offset, 5) + def test_reuse_after_overwrite(self): + """Test basic reuse APIs after an overwrite""" + cache = ShiftingKVCache(max_size=8, keep=1) + + # fill cache -> 12345678 + reference = self.add_random_to_cache(cache, 8) + news = self.add_random_to_cache(cache, 1) # overwrite to 13456789 after TO + self.assertArrEqual( + cache.state[0], mx.concatenate( + [reference[:, :, :1, :], news, reference[:, :, 2:8, :]], axis=2 + ) + ) + + # suppose the prompt coming in is now 13678 + # reuse from 2 to 4 length 3 + cache.reuse_section(2, 4, 3) + cache.do_reuse() + keys = cache.state[0] + + # the remaining cache should be 13678 + should_be_keys = mx.concatenate( + [reference[:, :, :1, :], reference[:,:,2:3,:], reference[:, :, 5:8, :]], axis=2 + ) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) + self.assertArrEqual(keys, should_be_keys) + self.assertEqual(cache.offset, 5) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index a6da73c4..61149a6b 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -3,7 +3,8 @@ from mlx_engine.cache_wrapper import CacheWrapper from mlx_engine.cache import ShiftingKVCache from tests.test_cache_generic import TestCache -from tests.utils import DummyModel +from tests.utils import DummyModel, model_getter +from mlx_engine.generate import load_model, create_generator class TestCacheWrapper(TestCache): @@ -115,7 +116,7 @@ def test_record_generated_token_loops(self): ) def test_cache_reuse_heavy(self): - cache = CacheWrapper(DummyModel(), 10) + cache = CacheWrapper(DummyModel(), 10, keep=2) cache.cache[0] = ShiftingKVCache(max_size=10, keep=2) # set up pretend cache @@ -168,6 +169,141 @@ def idx(v, a, b): should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) self.assertArrEqual(keys, should_be_kv) + def test_cache_reuse_overwrite_heavy(self): + cache = CacheWrapper(DummyModel(), 10, keep=2) + cache.cache[0] = ShiftingKVCache(max_size=10, keep=2) + + # set up pretend cache + cached_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + cache_kv = self.make_random_kv(10) + for i in range(10): + cache.record_generated_token(cached_tokens[i]) + cache.cache[0].update_and_fetch(cache_kv, cache_kv) + + # append another one to overwrite + cache.record_generated_token(11) + cache_new_kv = self.make_random_kv(1) + cache.cache[0].update_and_fetch(cache_new_kv, cache_new_kv) + + print(cache.tokens) + self.assertArrEqual(cache.tokens, mx.array([1, 2, 4, 5, 6, 7, 8, 9, 10, 11])) + self.assertEqual(cache.cache[0].keys.shape[2], 10) + + # set up pretend prompt + prompt_tokens = mx.array([1, 2, 4, 7, 8, 9, 12]) + + prefix_len = cache._find_matching_sequence_length( + cached_tokens, prompt_tokens, 0 + ) + self.assertEqual(prefix_len, 2) + + # prepare references + def idx(v, a, b): + return v[:, :, a:b, :] + + should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) + should_be_kv = mx.concatenate( + [ + idx(cache_kv, 0, 2), + idx(cache_kv, 3, 4), + idx(cache_kv, 6, 9), + ], + axis=2, + ) + + total_reused = cache._truncate_cache( + prompt_tokens=prompt_tokens, + common_prefix_len=prefix_len, + non_prefix_reuse_min_seq_len=1, + ) + + self.assertEqual(total_reused, 4) + self.assertArrEqual(cache.tokens, should_be_tokens) + self.assertArrEqual(cache.cache[0].keys, should_be_kv) + + # ensure updating works as intended + new_kv = self.make_random_kv(1) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) + + # ensure batch concat works as intended + new_kv = self.make_random_kv(2) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) + + def test_update_cache_heavy(self): + """Test that the cache updates correctly during generation""" + # TODO(christian-lms): you need to pipe in nonprefix reuse min seq len + model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") + model_kit = load_model(model_path=model_path, max_kv_size=10) + + # set up pretend cache + prompt_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + non_prefill_tokens = model_kit.cache_wrapper.update_cache(prompt_tokens, prompt_progress_callback=None, keep=2) + layer_0_cache = model_kit.cache_wrapper.cache[0] + from copy import deepcopy + original_keys = deepcopy(layer_0_cache.state[0]) + + # generate a few tokens + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + seed=0, + max_tokens=2, + temp=0.0, + prompt_progress_callback=None, + keep=2, + ): + print(model_kit.cache_wrapper.tokens.tolist()) + print(result.tokens) + + result_tokens = mx.array([1, 2, 6, 7, 8, 9, 10, 4999, 1725, 1725]) + self.assertArrEqual(model_kit.cache_wrapper.tokens, result_tokens) + + _compA = model_kit.cache_wrapper.cache[0]._temporal_order(model_kit.cache_wrapper.cache[0].state[0]) + compA = _compA[..., :7, :] + print(_compA[0,0,:,:1].tolist()) + compB = mx.concat( + [original_keys[..., :2, :], original_keys[..., 4:, :]], axis=2) + self.assertArrEqual(compA, compB) + print("--- ---") + + new_prompt_tokens = mx.array([1, 2, 8, 9, 10, 4999, 1725, 1725]) + for result in create_generator( + model_kit=model_kit, + prompt_tokens=new_prompt_tokens, + seed=0, + max_tokens=2, + temp=0.0, + prompt_progress_callback=None, + keep=2, + ): + self.assertEqual(len(model_kit.cache_wrapper.tokens), model_kit.cache_wrapper.cache[0].state[0].shape[2]) + print(f"HOASDOSIADN {result.tokens}") + print(model_kit.cache_wrapper.tokens.tolist()) + print(result.tokens) + + print(model_kit.cache_wrapper.tokens.tolist()) + new_result_tokens = mx.array([1, 2, 9, 10, 4999, 1725, 1725, 21002, 1177, 1177]) + self.assertArrEqual(model_kit.cache_wrapper.tokens, new_result_tokens) + + _compC = model_kit.cache_wrapper.cache[0]._temporal_order(model_kit.cache_wrapper.cache[0].state[0]) + compC = _compC[..., :3, :] + print(_compC[0,0,:,:1].tolist()) + print(original_keys[0,0,:,:1].tolist()) + compD = mx.concat( + [original_keys[..., :2, :], original_keys[..., 8:, :]], axis=2) + self.assertArrEqual(compC, compD) + compE = _compC[..., 3:6, :] + compF = _compA[..., 7:, :] + print("--- ---") + print(_compC[0,0,2:5,:1].tolist()) + print(_compA[0,0,7:,:1].tolist()) + self.assertArrEqual(compE, compF) + raise ValueError() + if __name__ == "__main__": unittest.main(verbosity=2)