Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
max_context: Optional[int] = None,
) -> List[Any]:
"""
Construct the model's cache for use in generation.
Expand All @@ -25,6 +26,10 @@ def make_prompt_cache(
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
max_context (Optional[int]): If provided, pre-allocate the KV cache
buffer to hold this many tokens. Eliminates reallocation and
concatenation during generation. Useful when the maximum context
length is known (e.g., from server configuration).
"""
if hasattr(model, "make_cache"):
return model.make_cache()
Expand All @@ -35,7 +40,7 @@ def make_prompt_cache(
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
return [KVCache(max_size=max_context) for _ in range(num_layers)]


def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
Expand Down Expand Up @@ -230,12 +235,14 @@ def nbytes(self):
class QuantizedKVCache(_BaseCache):
step = 256

def __init__(self, group_size: int = 64, bits: int = 8):
def __init__(self, group_size: int = 64, bits: int = 8, step: Optional[int] = None):
self.keys = None
self.values = None
self.offset = 0
self.group_size = group_size
self.bits = bits
if step is not None:
self.step = step

def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
Expand Down Expand Up @@ -323,19 +330,28 @@ def nbytes(self):
class KVCache(_BaseCache):
step = 256

def __init__(self):
def __init__(self, max_size: Optional[int] = None, step: Optional[int] = None):
self.keys = None
self.values = None
self.offset = 0
self._max_size = max_size
if step is not None:
self.step = step

def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
if self._max_size is not None and self.keys is None:
# Pre-allocate to max_size — eliminates all future
# boundary reallocations and concatenations.
alloc_size = self._max_size
else:
n_steps = (self.step + keys.shape[2] - 1) // self.step
alloc_size = n_steps * self.step
k_shape = (B, n_kv_heads, alloc_size, k_head_dim)
v_shape = (B, n_kv_heads, alloc_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
Expand Down