From 9acad8df8d945cecaaf2737220790156cfbc963e Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 21 Mar 2026 19:28:56 -0500 Subject: [PATCH] feat: configurable KVCache step size and pre-allocation Add optional parameters to KVCache and QuantizedKVCache: - `step` (int): Override the class-level step size (default 256). Larger values reduce the number of boundary reallocations and GPU sync points during generation. For example, step=1024 reduces reallocation frequency 4x. - `max_size` (int, KVCache only): Pre-allocate the buffer to hold this many tokens on first use. Eliminates ALL subsequent boundary reallocations and concatenations. When the maximum context length is known (e.g., from server configuration), this avoids the repeated allocate-concatenate-free cycle that causes transient memory spikes and GPU sync points. Also adds `max_context` parameter to `make_prompt_cache()` to pass through to KVCache constructors. All parameters are optional with backward-compatible defaults. Existing code calling `KVCache()` or `make_prompt_cache(model)` is unaffected. Motivation: On M2 Ultra 128GB with a 122B MoE model (~82 GB weights), the repeated KV cache boundary reallocations (every 256 tokens across 12 attention layers) create transient memory spikes of ~2x the cache size at each boundary. With step=256, a 4000-token generation crosses 15 boundaries, each requiring 24 concatenation operations (12 layers x K+V). Pre-allocation eliminates this entirely. --- mlx_lm/models/cache.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88fa4ad32..c633db9b4 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -13,6 +13,7 @@ def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, + max_context: Optional[int] = None, ) -> List[Any]: """ Construct the model's cache for use in generation. @@ -25,6 +26,10 @@ def make_prompt_cache( max_kv_size (Optional[int]): If provided and the model does not have a ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum size of ``max_kv_size`` + max_context (Optional[int]): If provided, pre-allocate the KV cache + buffer to hold this many tokens. Eliminates reallocation and + concatenation during generation. Useful when the maximum context + length is known (e.g., from server configuration). """ if hasattr(model, "make_cache"): return model.make_cache() @@ -35,7 +40,7 @@ def make_prompt_cache( RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) ] else: - return [KVCache() for _ in range(num_layers)] + return [KVCache(max_size=max_context) for _ in range(num_layers)] def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): @@ -230,12 +235,14 @@ def nbytes(self): class QuantizedKVCache(_BaseCache): step = 256 - def __init__(self, group_size: int = 64, bits: int = 8): + def __init__(self, group_size: int = 64, bits: int = 8, step: Optional[int] = None): self.keys = None self.values = None self.offset = 0 self.group_size = group_size self.bits = bits + if step is not None: + self.step = step def update_and_fetch(self, keys, values): B, n_kv_heads, num_steps, k_head_dim = keys.shape @@ -323,19 +330,28 @@ def nbytes(self): class KVCache(_BaseCache): step = 256 - def __init__(self): + def __init__(self, max_size: Optional[int] = None, step: Optional[int] = None): self.keys = None self.values = None self.offset = 0 + self._max_size = max_size + if step is not None: + self.step = step def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: B, n_kv_heads, _, k_head_dim = keys.shape v_head_dim = values.shape[3] - n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) - v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) + if self._max_size is not None and self.keys is None: + # Pre-allocate to max_size — eliminates all future + # boundary reallocations and concatenations. + alloc_size = self._max_size + else: + n_steps = (self.step + keys.shape[2] - 1) // self.step + alloc_size = n_steps * self.step + k_shape = (B, n_kv_heads, alloc_size, k_head_dim) + v_shape = (B, n_kv_heads, alloc_size, v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: