diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88fa4ad32..c633db9b4 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -13,6 +13,7 @@ def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, + max_context: Optional[int] = None, ) -> List[Any]: """ Construct the model's cache for use in generation. @@ -25,6 +26,10 @@ def make_prompt_cache( max_kv_size (Optional[int]): If provided and the model does not have a ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum size of ``max_kv_size`` + max_context (Optional[int]): If provided, pre-allocate the KV cache + buffer to hold this many tokens. Eliminates reallocation and + concatenation during generation. Useful when the maximum context + length is known (e.g., from server configuration). """ if hasattr(model, "make_cache"): return model.make_cache() @@ -35,7 +40,7 @@ def make_prompt_cache( RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) ] else: - return [KVCache() for _ in range(num_layers)] + return [KVCache(max_size=max_context) for _ in range(num_layers)] def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): @@ -230,12 +235,14 @@ def nbytes(self): class QuantizedKVCache(_BaseCache): step = 256 - def __init__(self, group_size: int = 64, bits: int = 8): + def __init__(self, group_size: int = 64, bits: int = 8, step: Optional[int] = None): self.keys = None self.values = None self.offset = 0 self.group_size = group_size self.bits = bits + if step is not None: + self.step = step def update_and_fetch(self, keys, values): B, n_kv_heads, num_steps, k_head_dim = keys.shape @@ -323,19 +330,28 @@ def nbytes(self): class KVCache(_BaseCache): step = 256 - def __init__(self): + def __init__(self, max_size: Optional[int] = None, step: Optional[int] = None): self.keys = None self.values = None self.offset = 0 + self._max_size = max_size + if step is not None: + self.step = step def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: B, n_kv_heads, _, k_head_dim = keys.shape v_head_dim = values.shape[3] - n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) - v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) + if self._max_size is not None and self.keys is None: + # Pre-allocate to max_size — eliminates all future + # boundary reallocations and concatenations. + alloc_size = self._max_size + else: + n_steps = (self.step + keys.shape[2] - 1) // self.step + alloc_size = n_steps * self.step + k_shape = (B, n_kv_heads, alloc_size, k_head_dim) + v_shape = (B, n_kv_heads, alloc_size, v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: