Skip to content
99 changes: 74 additions & 25 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import functools
import json
import numbers
import sys
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -927,11 +928,6 @@ def _merge_caches(caches):
return batch_cache


def _lazy_extract_cache(cache, i):
# Generators like lambdas are late bound so we can't just use it in the loop
return (c.extract(i) for c in cache)


class BatchGenerator:
@dataclass
class Response:
Expand Down Expand Up @@ -1009,8 +1005,35 @@ def insert(
if max_tokens is None or isinstance(max_tokens, int):
max_tokens = [max_tokens or self.max_tokens] * len(prompts)

if prompt_checkpoints is None or isinstance(prompt_checkpoints, int):
prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts)
if prompt_checkpoints is None:
if self.prompt_checkpoint_callback is not None:
# Preserve the base contract for direct callback consumers:
# omitted prompt_checkpoints still means "checkpoint at the
# last token" unless the caller explicitly passes [None].
prompt_checkpoints = [-1] * len(prompts)
else:
prompt_checkpoints = [None] * len(prompts)
elif isinstance(prompt_checkpoints, int):
prompt_checkpoints = [prompt_checkpoints] * len(prompts)
elif len(prompt_checkpoints) != len(prompts):
raise ValueError("prompt checkpoints must match the number of prompts")

validated_prompt_checkpoints = []
for prompt, checkpoint in zip(prompts, prompt_checkpoints):
if checkpoint is None:
validated_prompt_checkpoints.append(None)
continue
if not isinstance(checkpoint, numbers.Integral) or isinstance(
checkpoint, bool
):
raise ValueError("prompt checkpoint must be an integer or None")
checkpoint = int(checkpoint)
if checkpoint == 0 or checkpoint > len(prompt) or checkpoint < -len(prompt):
raise ValueError(
"prompt checkpoint must be within prompt length and not zero"
)
validated_prompt_checkpoints.append(checkpoint)
prompt_checkpoints = validated_prompt_checkpoints

if caches is None:
caches = [None] * len(prompts)
Expand All @@ -1034,6 +1057,12 @@ def insert(
)
return uids

@staticmethod
def _normalized_prompt_checkpoint_length(length, checkpoint):
if checkpoint is None:
return 1
return length - checkpoint if checkpoint > 0 else -checkpoint

def remove(self, uids: List[int], return_prompt_caches: bool = False):
caches = {}
uids = set(uids)
Expand Down Expand Up @@ -1079,13 +1108,17 @@ def _process_prompts(self, prompts):
max_length = max(lengths)
padding = [max_length - l for l in lengths]

# Get the checkpoint token as an offset from the end of each prompt.
# Then select the largest one so that we perform the checkpoint at
# least `pc` before the end.
prompt_checkpoints = [
(l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints)
]
prompt_checkpoint = max(1, max(prompt_checkpoints))
checkpoint_indices = []
normalized_checkpoints = []
for idx, (length, checkpoint) in enumerate(zip(lengths, prompt_checkpoints)):
normalized_checkpoint = self._normalized_prompt_checkpoint_length(
length, checkpoint
)
if checkpoint is None:
continue
checkpoint_indices.append(idx)
normalized_checkpoints.append(normalized_checkpoint)
prompt_checkpoint = max(1, max(normalized_checkpoints, default=1))

self._stats.prompt_tokens += sum(lengths)

Expand Down Expand Up @@ -1126,8 +1159,8 @@ def _process_prompts(self, prompts):
prompt_cache = _merge_caches(caches)

for c in prompt_cache:
# subtract from lengths since we don't process the last
# `prompt_checkpoint` tokens during prefill
# Subtract the checkpoint span since we keep the tail prompt
# tokens for checkpoint extraction and the first decode step.
c.prepare(
lengths=[l - prompt_checkpoint for l in lengths],
right_padding=padding,
Expand Down Expand Up @@ -1155,19 +1188,17 @@ def _process_prompts(self, prompts):
for c in prompt_cache:
c.finalize()

# We processed L - prompt_checkpoint tokens so call the checkpoint
# callback.
if self.prompt_checkpoint_callback is not None:
if self.prompt_checkpoint_callback is not None and checkpoint_indices:
self.prompt_checkpoint_callback(
[
(uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i))
for i, uid in enumerate(uids)
(uids[i], prompt_checkpoint, [c.extract(i) for c in prompt_cache])
for i in checkpoint_indices
]
)
# Process the remaining prompt_checkpoint-1 tokens
if prompt_checkpoint > 1:
self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
inputs = inputs[:, prompt_checkpoint - 1 :]
mx.clear_cache()

y, logprobs = self._step(
Expand Down Expand Up @@ -1254,10 +1285,28 @@ def _next(self):
self._stats.generation_time += time.perf_counter() - tic
tic = time.perf_counter()

compatible_count = 0
min_length = None
shared_checkpoint = 1
for prompt in prompts:
length = len(prompt[1])
checkpoint = prompt[6]
checkpoint_size = self._normalized_prompt_checkpoint_length(
length, checkpoint
)
new_shared_checkpoint = max(shared_checkpoint, checkpoint_size)
new_min_length = (
length if min_length is None else min(min_length, length)
)
if compatible_count > 0 and new_min_length < new_shared_checkpoint:
break
compatible_count += 1
min_length = new_min_length
shared_checkpoint = new_shared_checkpoint
prompts = prompts[: max(1, compatible_count)]

batch = self._process_prompts(prompts)
self.unprocessed_prompts = self.unprocessed_prompts[
self.prefill_batch_size :
]
self.unprocessed_prompts = self.unprocessed_prompts[len(prompts) :]
prompt_processing = True
# If there was no active batch, set it
if self.active_batch is None:
Expand Down
36 changes: 36 additions & 0 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,22 @@ def empty(self):
"""
raise NotImplementedError("Cache sub-class must implement this.")

def rewind(self, num_to_trim: int) -> bool:
raise NotImplementedError("Cache sub-class must implement rewind.")

def _has_rewind_impl(self):
"""Check whether this cache has a real rewind implementation.

Returns True if the concrete class overrides rewind() beyond the
_BaseCache default. This uses method identity rather than a separate
opt-in flag so that third-party caches that implement rewind()
participate automatically without needing to know about this helper.
"""
try:
return type(self).rewind is not _BaseCache.rewind
except Exception:
return False

@classmethod
def from_state(cls, state, meta_state):
# Create an instance of cls without calling __init__
Expand Down Expand Up @@ -1247,8 +1263,28 @@ def trim(self, n):
self._offset -= n
self._idx -= n
self.offset -= n
if self.rotated:
self.left_padding += n
return n

def can_rewind(self, num_to_trim: int) -> bool:
if num_to_trim <= 0:
return True
if self.keys is None or self.values is None:
return False
if self._idx < 0 or self._idx > self.keys.shape[2]:
return False
if num_to_trim > self._offset or num_to_trim > self._idx:
return False
return True

def rewind(self, num_to_trim: int) -> bool:
if not self.can_rewind(num_to_trim):
return False
if num_to_trim <= 0:
return True
return self.trim(num_to_trim) == num_to_trim

def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("BatchRotatingKVCache Quantization NYI")

Expand Down
Loading