diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ef8dbf7bf..eb6d36314 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -4,6 +4,7 @@ import contextlib import functools import json +import numbers import sys import time from dataclasses import dataclass @@ -927,11 +928,6 @@ def _merge_caches(caches): return batch_cache -def _lazy_extract_cache(cache, i): - # Generators like lambdas are late bound so we can't just use it in the loop - return (c.extract(i) for c in cache) - - class BatchGenerator: @dataclass class Response: @@ -1009,8 +1005,35 @@ def insert( if max_tokens is None or isinstance(max_tokens, int): max_tokens = [max_tokens or self.max_tokens] * len(prompts) - if prompt_checkpoints is None or isinstance(prompt_checkpoints, int): - prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts) + if prompt_checkpoints is None: + if self.prompt_checkpoint_callback is not None: + # Preserve the base contract for direct callback consumers: + # omitted prompt_checkpoints still means "checkpoint at the + # last token" unless the caller explicitly passes [None]. + prompt_checkpoints = [-1] * len(prompts) + else: + prompt_checkpoints = [None] * len(prompts) + elif isinstance(prompt_checkpoints, int): + prompt_checkpoints = [prompt_checkpoints] * len(prompts) + elif len(prompt_checkpoints) != len(prompts): + raise ValueError("prompt checkpoints must match the number of prompts") + + validated_prompt_checkpoints = [] + for prompt, checkpoint in zip(prompts, prompt_checkpoints): + if checkpoint is None: + validated_prompt_checkpoints.append(None) + continue + if not isinstance(checkpoint, numbers.Integral) or isinstance( + checkpoint, bool + ): + raise ValueError("prompt checkpoint must be an integer or None") + checkpoint = int(checkpoint) + if checkpoint == 0 or checkpoint > len(prompt) or checkpoint < -len(prompt): + raise ValueError( + "prompt checkpoint must be within prompt length and not zero" + ) + validated_prompt_checkpoints.append(checkpoint) + prompt_checkpoints = validated_prompt_checkpoints if caches is None: caches = [None] * len(prompts) @@ -1034,6 +1057,12 @@ def insert( ) return uids + @staticmethod + def _normalized_prompt_checkpoint_length(length, checkpoint): + if checkpoint is None: + return 1 + return length - checkpoint if checkpoint > 0 else -checkpoint + def remove(self, uids: List[int], return_prompt_caches: bool = False): caches = {} uids = set(uids) @@ -1079,13 +1108,17 @@ def _process_prompts(self, prompts): max_length = max(lengths) padding = [max_length - l for l in lengths] - # Get the checkpoint token as an offset from the end of each prompt. - # Then select the largest one so that we perform the checkpoint at - # least `pc` before the end. - prompt_checkpoints = [ - (l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints) - ] - prompt_checkpoint = max(1, max(prompt_checkpoints)) + checkpoint_indices = [] + normalized_checkpoints = [] + for idx, (length, checkpoint) in enumerate(zip(lengths, prompt_checkpoints)): + normalized_checkpoint = self._normalized_prompt_checkpoint_length( + length, checkpoint + ) + if checkpoint is None: + continue + checkpoint_indices.append(idx) + normalized_checkpoints.append(normalized_checkpoint) + prompt_checkpoint = max(1, max(normalized_checkpoints, default=1)) self._stats.prompt_tokens += sum(lengths) @@ -1126,8 +1159,8 @@ def _process_prompts(self, prompts): prompt_cache = _merge_caches(caches) for c in prompt_cache: - # subtract from lengths since we don't process the last - # `prompt_checkpoint` tokens during prefill + # Subtract the checkpoint span since we keep the tail prompt + # tokens for checkpoint extraction and the first decode step. c.prepare( lengths=[l - prompt_checkpoint for l in lengths], right_padding=padding, @@ -1155,19 +1188,17 @@ def _process_prompts(self, prompts): for c in prompt_cache: c.finalize() - # We processed L - prompt_checkpoint tokens so call the checkpoint - # callback. - if self.prompt_checkpoint_callback is not None: + if self.prompt_checkpoint_callback is not None and checkpoint_indices: self.prompt_checkpoint_callback( [ - (uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i)) - for i, uid in enumerate(uids) + (uids[i], prompt_checkpoint, [c.extract(i) for c in prompt_cache]) + for i in checkpoint_indices ] ) - # Process the remaining prompt_checkpoint-1 tokens if prompt_checkpoint > 1: self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) + inputs = inputs[:, prompt_checkpoint - 1 :] mx.clear_cache() y, logprobs = self._step( @@ -1254,10 +1285,28 @@ def _next(self): self._stats.generation_time += time.perf_counter() - tic tic = time.perf_counter() + compatible_count = 0 + min_length = None + shared_checkpoint = 1 + for prompt in prompts: + length = len(prompt[1]) + checkpoint = prompt[6] + checkpoint_size = self._normalized_prompt_checkpoint_length( + length, checkpoint + ) + new_shared_checkpoint = max(shared_checkpoint, checkpoint_size) + new_min_length = ( + length if min_length is None else min(min_length, length) + ) + if compatible_count > 0 and new_min_length < new_shared_checkpoint: + break + compatible_count += 1 + min_length = new_min_length + shared_checkpoint = new_shared_checkpoint + prompts = prompts[: max(1, compatible_count)] + batch = self._process_prompts(prompts) - self.unprocessed_prompts = self.unprocessed_prompts[ - self.prefill_batch_size : - ] + self.unprocessed_prompts = self.unprocessed_prompts[len(prompts) :] prompt_processing = True # If there was no active batch, set it if self.active_batch is None: diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88fa4ad32..11ab03f69 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -164,6 +164,22 @@ def empty(self): """ raise NotImplementedError("Cache sub-class must implement this.") + def rewind(self, num_to_trim: int) -> bool: + raise NotImplementedError("Cache sub-class must implement rewind.") + + def _has_rewind_impl(self): + """Check whether this cache has a real rewind implementation. + + Returns True if the concrete class overrides rewind() beyond the + _BaseCache default. This uses method identity rather than a separate + opt-in flag so that third-party caches that implement rewind() + participate automatically without needing to know about this helper. + """ + try: + return type(self).rewind is not _BaseCache.rewind + except Exception: + return False + @classmethod def from_state(cls, state, meta_state): # Create an instance of cls without calling __init__ @@ -1247,8 +1263,28 @@ def trim(self, n): self._offset -= n self._idx -= n self.offset -= n + if self.rotated: + self.left_padding += n return n + def can_rewind(self, num_to_trim: int) -> bool: + if num_to_trim <= 0: + return True + if self.keys is None or self.values is None: + return False + if self._idx < 0 or self._idx > self.keys.shape[2]: + return False + if num_to_trim > self._offset or num_to_trim > self._idx: + return False + return True + + def rewind(self, num_to_trim: int) -> bool: + if not self.can_rewind(num_to_trim): + return False + if num_to_trim <= 0: + return True + return self.trim(num_to_trim) == num_to_trim + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: raise NotImplementedError("BatchRotatingKVCache Quantization NYI") diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..b3d45bbe2 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -2,9 +2,9 @@ import argparse import copy -import heapq import json import logging +import numbers import pickle import platform import socket @@ -36,11 +36,7 @@ from ._version import __version__ from .generate import BatchGenerator, generation_stream, stream_generate -from .models.cache import ( - can_trim_prompt_cache, - make_prompt_cache, - trim_prompt_cache, -) +from .models.cache import _BaseCache, make_prompt_cache from .sample_utils import make_logits_processors, make_sampler from .utils import _parse_size, load, sharded_load @@ -175,19 +171,21 @@ class LRUPromptCache: @dataclass class CacheEntry: prompt_cache: List[Any] + ref_count: int nbytes: int + checkpoint: bool = False class CacheOrder: def __init__(self): - self._lru_checkpoints = deque() self._lru = deque() + self._lru_checkpoints = deque() def __len__(self): return len(self._lru) + len(self._lru_checkpoints) def push(self, model, tokens, checkpoint: bool = False): - c = self._lru_checkpoints if checkpoint else self._lru - c.append((model, tokens)) + queue = self._lru_checkpoints if checkpoint else self._lru + queue.append((model, tokens)) def remove(self, model, tokens): try: @@ -198,15 +196,14 @@ def remove(self, model, tokens): def pop(self): if len(self._lru) >= len(self._lru_checkpoints): return self._lru.popleft() - else: - return self._lru_checkpoints.popleft() + return self._lru_checkpoints.popleft() @dataclass class SearchResult: model: Any exact: List[int] shorter: List[int] - longer: List[int] + longer: List[List[int]] common_prefix: int def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): @@ -229,6 +226,10 @@ def _search(self, model, tokens): return self.SearchResult(model, None, None, None, 0) current = self._cache[model] + if not tokens: + if "cache" in current: + return self.SearchResult(model, tokens, None, None, 0) + return self.SearchResult(model, None, None, None, 0) last_cache_index = -1 index = 0 @@ -244,24 +245,26 @@ def _search(self, model, tokens): # Find the shorter cache shorter = None - if last_cache_index > 0: + if last_cache_index >= 0: shorter = tokens[: last_cache_index + 1] # Check for caches that are longer longer = None common_prefix = index if index > 0: - best = None + candidates = [] stack = [(current, [])] while stack: current, extra = stack.pop() - if "cache" in current: - if best is None or len(extra) < len(best): - best = extra - else: - for tok in current: - stack.append((current[tok], extra + [tok])) - longer = tokens[:index] + best + if "cache" in current and extra: + candidates.append(extra) + for tok in current: + if tok == "cache": + continue + stack.append((current[tok], extra + [tok])) + if candidates: + candidates.sort(key=lambda extra: (len(extra), extra)) + longer = [tokens[:index] + extra for extra in candidates] return self.SearchResult(model, None, shorter, longer, common_prefix) def _get(self, model, tokens): @@ -283,51 +286,179 @@ def _delete(self, model, tokens): break del d_prev[t] + logging.debug(f"[LRUPromptCache] Removed {cache_bytes} bytes from the cache") + + def _extract(self, model, tokens): + cache_entry = self._get(model, tokens) + if cache_entry.checkpoint: + try: + extracted_cache = copy.deepcopy(cache_entry.prompt_cache) + except Exception: + return None + self._lru.remove(model, tokens) + self._lru.push(model, tokens, checkpoint=True) + return self.CacheEntry(extracted_cache, 1, cache_entry.nbytes, True) + + if cache_entry.ref_count == 1: + self._delete(model, tokens) + self._lru.remove(model, tokens) + return cache_entry + + try: + extracted_cache = copy.deepcopy(cache_entry.prompt_cache) + except Exception: + return None + cache_entry.ref_count -= 1 + self._refresh_recency(model, tokens, checkpoint=False) + return self.CacheEntry(extracted_cache, 1, cache_entry.nbytes) + + def _refresh_recency(self, model, tokens, checkpoint: bool): + self._lru.remove(model, tokens) + self._lru.push(model, tokens, checkpoint=checkpoint) + + def _can_rewind_layer_cache(self, layer_cache, num_to_trim): + can_rewind = getattr(layer_cache, "can_rewind", None) + is_trimmable = getattr(layer_cache, "is_trimmable", None) + trim = getattr(layer_cache, "trim", None) + if callable(can_rewind): + if isinstance(layer_cache, _BaseCache): + has_custom_rewind = layer_cache._has_rewind_impl() + else: + has_custom_rewind = callable(getattr(layer_cache, "rewind", None)) + has_execution_path = has_custom_rewind or ( + callable(is_trimmable) and callable(trim) + ) + if not has_execution_path: + return False + try: + can_rewind_result = can_rewind(num_to_trim) + if isinstance(can_rewind_result, bool): + return can_rewind_result + if isinstance(can_rewind_result, numbers.Integral) and not isinstance( + can_rewind_result, bool + ): + return int(can_rewind_result) >= num_to_trim + return False + except Exception: + return False + + # Compatibility fallback for custom caches that only implement the + # legacy is_trimmable()/trim()/rewind() contract. + rewind = getattr(layer_cache, "rewind", None) + if not callable(is_trimmable) or (not callable(trim) and not callable(rewind)): + return False + try: + if not bool(is_trimmable()): + return False + if num_to_trim <= 0: + return True + + # If legacy cache exposes an offset, avoid deepcopy on guaranteed + # misses where trim can never satisfy the requested rewind. + offset = getattr(layer_cache, "offset", None) + if isinstance(offset, numbers.Integral): + return num_to_trim <= offset + return True + except Exception: + return False + + def _can_rewind_prompt_cache(self, cache, num_to_trim): + return all( + self._can_rewind_layer_cache(layer_cache, num_to_trim) + for layer_cache in cache + ) + + def _rewind_layer_cache(self, layer_cache, num_to_trim): + rewind = getattr(layer_cache, "rewind", None) + has_real_rewind = ( + isinstance(layer_cache, _BaseCache) and layer_cache._has_rewind_impl() + ) or (not isinstance(layer_cache, _BaseCache) and callable(rewind)) + if has_real_rewind: + try: + rewind_result = rewind(num_to_trim) + if isinstance(rewind_result, bool): + return rewind_result + if isinstance(rewind_result, numbers.Integral) and not isinstance( + rewind_result, bool + ): + return int(rewind_result) == num_to_trim + return False + except Exception: + return False + + # Compatibility fallback for caches that only implement the + # legacy is_trimmable()/trim() contract. + is_trimmable = getattr(layer_cache, "is_trimmable", None) + trim = getattr(layer_cache, "trim", None) + if not callable(is_trimmable) or not callable(trim): + return False + try: + return bool(is_trimmable()) and trim(num_to_trim) == num_to_trim + except Exception: + return False + + def _rewind_prompt_cache(self, cache, num_to_trim): + return all( + self._rewind_layer_cache(layer_cache, num_to_trim) for layer_cache in cache + ) + def fetch_nearest_cache(self, model, tokens): result = self._search(model, tokens) if result.exact is not None: - cache_entry = self._get(result.model, result.exact) - return copy.deepcopy(cache_entry.prompt_cache), [] - - short_length = len(result.shorter) if result.shorter is not None else 0 - if result.longer is not None and result.common_prefix > short_length: - cache_entry = self._get(result.model, result.longer) - if can_trim_prompt_cache(cache_entry.prompt_cache): - cache = copy.deepcopy(cache_entry.prompt_cache) - prefix = min(len(tokens) - 1, result.common_prefix) - num_to_trim = len(result.longer) - prefix - trim_prompt_cache(cache, num_to_trim) - return cache, tokens[prefix:] - - if short_length > 0: - cache_entry = self._get(result.model, result.shorter) - return copy.deepcopy(cache_entry.prompt_cache), tokens[short_length:] + cache_entry = self._extract(result.model, result.exact) + if cache_entry is None: + return None, tokens + return cache_entry.prompt_cache, [] + + if result.longer is not None: + prefix = min(len(tokens) - 1, result.common_prefix) + for longer_tokens in result.longer: + cache_entry = self._get(result.model, longer_tokens) + num_to_trim = len(longer_tokens) - prefix + + if not self._can_rewind_prompt_cache( + cache_entry.prompt_cache, num_to_trim + ): + continue + try: + cache = copy.deepcopy(cache_entry.prompt_cache) + except Exception: + cache = None + if cache is not None and self._rewind_prompt_cache(cache, num_to_trim): + self._refresh_recency( + result.model, longer_tokens, checkpoint=cache_entry.checkpoint + ) + return cache, tokens[prefix:] + + if result.shorter is not None: + cache_entry = self._extract(result.model, result.shorter) + if cache_entry is None: + return None, tokens + prefix_len = len(result.shorter) + return cache_entry.prompt_cache, tokens[prefix_len:] return None, tokens def insert_cache(self, model, tokens, prompt_cache, checkpoint: bool = False): - is_trimmable = can_trim_prompt_cache(prompt_cache) - if model not in self._cache: self._cache[model] = {} current = self._cache[model] - for i, tok in enumerate(tokens): + for tok in tokens: if tok not in current: current[tok] = {} - if is_trimmable and "cache" in current: - self._n_bytes -= current["cache"].nbytes - del current["cache"] - self._lru.remove(model, tokens[:i]) current = current[tok] if "cache" in current: + current["cache"].ref_count += 1 + current["cache"].checkpoint = current["cache"].checkpoint or checkpoint self._lru.remove(model, tokens) else: cache_bytes = sum(c.nbytes for c in prompt_cache) - current["cache"] = self.CacheEntry(prompt_cache, cache_bytes) + current["cache"] = self.CacheEntry(prompt_cache, 1, cache_bytes, checkpoint) self._n_bytes += cache_bytes + logging.debug(f"[LRUPromptCache] Adding {cache_bytes} to the cache") - self._lru.push(model, tokens, checkpoint=checkpoint) + self._lru.push(model, tokens, checkpoint=current["cache"].checkpoint) if len(self._lru) > self.max_size: model, tokens = self._lru.pop() self._delete(model, tokens) @@ -737,12 +868,14 @@ def _tokenize(self, tokenizer, request, args): def _compute_prompt_checkpoint(self, tokenizer, request, prompt): if request.request_type != "chat": return False, -1 - if request.messages[-1]["role"] != "user": + if not request.messages: + raise ValueError("Chat request messages must be a non-empty list") + last_message = request.messages[-1] + if not isinstance(last_message, dict) or "role" not in last_message: + raise ValueError("Chat request last message must include a role") + if last_message["role"] != "user": return False, -1 - # Save the KV cache at the end of the prompt just before - # the think start token which will likely be removed in the - # next turn. prompt_checkpoint = -1 if tokenizer.has_thinking: for i in range(1, min(11, len(prompt)) - 1, 1): @@ -752,6 +885,33 @@ def _compute_prompt_checkpoint(self, tokenizer, request, prompt): return True, prompt_checkpoint + def _localize_prompt_checkpoint(self, prompt, rest, checkpoint_position): + prompt_len = len(prompt) + rest_offset = prompt_len - len(rest) + checkpoint_prefix = ( + checkpoint_position + if checkpoint_position > 0 + else prompt_len + checkpoint_position + ) + if checkpoint_prefix < rest_offset or checkpoint_prefix >= prompt_len: + return None + return -(prompt_len - checkpoint_prefix) + + def _materialize_prompt_tail_for_generation(self, prompt, cache, rest): + if cache is None or rest or not prompt: + return cache, rest + + # Exact prompt-cache hits need one token outside the cache so generation + # can resume through the normal prefill/decode entry points. + if self.prompt_cache._can_rewind_prompt_cache( + cache, 1 + ) and self.prompt_cache._rewind_prompt_cache(cache, 1): + return cache, prompt[-1:] + + # If the extracted cache cannot be safely rewound, fall back to replaying + # the full prompt instead of forwarding an unusable empty remainder. + return None, prompt + def _is_batchable(self, args): if not self.model_provider.is_batchable: return False @@ -784,12 +944,15 @@ def progress_callback(info): def checkpoint_callback(prompts): for uid, prompt_end, cache in prompts: - rs = batch_results[uid] - if not rs["checkpoint"]: + result = batch_results.get(uid) + if result is None or not result["checkpoint"]: + continue + cache_key = result["cache_key"][:-prompt_end] + if not cache_key: continue self.prompt_cache.insert_cache( current_model_key, - rs["cache_key"][:-prompt_end], + cache_key, list(cache), checkpoint=True, ) @@ -820,6 +983,11 @@ def checkpoint_callback(prompts): ): try: prompt = self._tokenize(current_tokenizer, request, args) + do_checkpoint, checkpoint_position = ( + self._compute_prompt_checkpoint( + current_tokenizer, request, prompt + ) + ) except Exception as e: rqueue.put(e) continue @@ -842,17 +1010,22 @@ def checkpoint_callback(prompts): ) rqueue.put(ctx) - self.prompt_cache.log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( current_model_key, prompt ) + cache, rest = self._materialize_prompt_tail_for_generation( + prompt, cache, rest + ) ctx.prompt_cache_count = len(prompt) - len(rest) if cache is None: cache = make_prompt_cache(self.model_provider.model) - do_checkpoint, checkpoint_position = ( - self._compute_prompt_checkpoint(tokenizer, request, prompt) - ) + localized_checkpoint = None + if do_checkpoint: + localized_checkpoint = self._localize_prompt_checkpoint( + prompt, rest, checkpoint_position + ) + do_checkpoint = localized_checkpoint is not None (uid,) = batch_generator.insert( [rest], @@ -860,7 +1033,7 @@ def checkpoint_callback(prompts): caches=[cache], samplers=[_make_sampler(args, tokenizer)], logits_processors=[_make_logits_processors(args)], - prompt_checkpoints=[checkpoint_position], + prompt_checkpoints=[localized_checkpoint], ) batch_results[uid] = { "ctx": ctx, @@ -1027,6 +1200,9 @@ def progress(tokens_processed, tokens_total): cache, rest = self.prompt_cache.fetch_nearest_cache( self.model_provider.model_key, prompt ) + cache, rest = self._materialize_prompt_tail_for_generation( + prompt, cache, rest + ) ctx.prompt_cache_count = len(prompt) - len(rest) cache_key = prompt[:] if cache is None: diff --git a/tests/prompt_cache_test_utils.py b/tests/prompt_cache_test_utils.py new file mode 100644 index 000000000..f425f775f --- /dev/null +++ b/tests/prompt_cache_test_utils.py @@ -0,0 +1,183 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx + +from mlx_lm.models.cache import RotatingKVCache + + +def make_tiny_step3p5_model(): + from mlx_lm.models import step3p5 + + # Keep this config minimal and centralized so schema churn in one place + # does not ripple through multiple cache behavior assertions. + args = step3p5.ModelArgs.from_dict( + { + "model_type": "step3p5", + "hidden_size": 128, + "num_hidden_layers": 4, + "vocab_size": 256, + "num_attention_heads": 4, + "num_attention_groups": 2, + "head_dim": 32, + "intermediate_size": 256, + "rms_norm_eps": 1e-5, + "rope_theta": [10000.0, 10000.0, 10000.0, 10000.0], + "sliding_window": 4, + "layer_types": [ + "full_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + ], + "partial_rotary_factors": [1.0, 1.0, 1.0, 1.0], + "attention_other_setting": { + "num_attention_heads": 4, + "num_attention_groups": 2, + }, + "use_head_wise_attn_gate": True, + "moe_num_experts": 4, + "moe_top_k": 2, + "moe_intermediate_size": 128, + "share_expert_dim": 128, + "moe_layers_enum": "1,2,3", + } + ) + return step3p5.Model(args) + + +def build_real_rotating_cache(*, max_size=4, total_tokens=4): + cache = RotatingKVCache(max_size=max_size) + kv = mx.arange(total_tokens, dtype=mx.float32).reshape(1, 1, total_tokens, 1) + cache.update_and_fetch(kv, kv) + mx.eval(cache.keys, cache.values) + return cache + + +def snapshot_cache_arrays(cache): + keys = mx.array(cache.keys) if cache.keys is not None else None + values = mx.array(cache.values) if cache.values is not None else None + if keys is not None: + mx.eval(keys) + if values is not None: + mx.eval(values) + return keys, values + + +class RewindRecorderLayer: + def __init__( + self, + *, + max_rewind=4, + rewind_result=True, + offset=None, + rewind_calls=None, + can_rewind_calls=None, + ): + self.max_rewind = max_rewind + self.rewind_result = rewind_result + self.rewind_calls = rewind_calls if rewind_calls is not None else [] + self.can_rewind_calls = can_rewind_calls if can_rewind_calls is not None else [] + self.offset = max_rewind if offset is None else offset + + @property + def nbytes(self): + return 1 + + def can_rewind(self, n): + self.can_rewind_calls.append(n) + return n <= self.max_rewind + + def rewind(self, n): + self.rewind_calls.append(n) + if self.rewind_result: + self.offset = max(0, self.offset - n) + return self.rewind_result + + def __deepcopy__(self, memo): + return type(self)( + max_rewind=self.max_rewind, + rewind_result=self.rewind_result, + offset=self.offset, + rewind_calls=self.rewind_calls, + can_rewind_calls=self.can_rewind_calls, + ) + + +class LegacyTrimLayer: + def __init__( + self, + *, + offset=4, + trim_shortfall=0, + trim_calls=None, + deepcopy_calls=None, + ): + self.offset = offset + self.trim_shortfall = trim_shortfall + self.trim_calls = trim_calls if trim_calls is not None else [] + self.deepcopy_calls = deepcopy_calls if deepcopy_calls is not None else [] + + @property + def nbytes(self): + return 1 + + def is_trimmable(self): + return True + + def trim(self, n): + self.trim_calls.append(n) + trimmed = max(0, n - self.trim_shortfall) + self.offset = max(0, self.offset - trimmed) + return trimmed + + def __deepcopy__(self, memo): + self.deepcopy_calls.append("deepcopy") + return type(self)( + offset=self.offset, + trim_shortfall=self.trim_shortfall, + trim_calls=self.trim_calls, + deepcopy_calls=self.deepcopy_calls, + ) + + +class UnknownNonTrimmableLayer: + @property + def nbytes(self): + return 1 + + +class UnknownNonTrimmableNoDeepcopy: + @property + def nbytes(self): + return 1 + + def __deepcopy__(self, memo): + raise AssertionError( + "deepcopy should be skipped for unknown non-trimmable layers" + ) + + +class UnknownLayerWithoutLegacyHooks: + @property + def nbytes(self): + return 1 + + def __deepcopy__(self, memo): + raise AssertionError( + "deepcopy should be skipped when layer lacks rewind and trim contracts" + ) + + +class DeepcopyShouldNotRunLayer: + @property + def nbytes(self): + return 1 + + def can_rewind(self, n): + return True + + def rewind(self, n): + return True + + def __deepcopy__(self, memo): + raise AssertionError("deepcopy should be skipped on known-safe miss") diff --git a/tests/test_generate.py b/tests/test_generate.py index fee5801a6..cc866ef42 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -633,5 +633,457 @@ def test_batch_continued_generation_gated_delta(self): self._continued_generation_test_helper(model) +class TestBatchGeneratorPromptCheckpoints(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH) + cls.model.set_dtype(mx.float32) + + def test_batch_generator_rejects_prompt_checkpoint_outside_prompt_length(self): + prompt = self.tokenizer.encode( + "Write a short paragraph about caches and rewind behavior." + ) + for checkpoint in (0, len(prompt) + 1, -(len(prompt) + 1)): + with self.subTest(checkpoint=checkpoint): + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + ) + try: + with self.assertRaisesRegex(ValueError, "prompt checkpoint"): + gen.insert([prompt], prompt_checkpoints=checkpoint) + finally: + gen.close() + + def test_batch_generator_rejects_list_prompt_checkpoint_outside_prompt_length(self): + prompts = [ + self.tokenizer.encode("Brief note about caches."), + self.tokenizer.encode( + "Write a short paragraph about caches and rewind behavior." + ), + ] + invalid_checkpoints = [ + [0, -2], + [-2, -(len(prompts[1]) + 1)], + [2, len(prompts[1]) + 1], + ] + for prompt_checkpoints in invalid_checkpoints: + with self.subTest(prompt_checkpoints=prompt_checkpoints): + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + ) + try: + with self.assertRaisesRegex(ValueError, "prompt checkpoint"): + gen.insert(prompts, prompt_checkpoints=prompt_checkpoints) + finally: + gen.close() + + def test_batch_generator_callback_without_prompt_checkpoints_defaults_to_last_token_checkpoint( + self, + ): + prompt = self.tokenizer.encode( + "Write a short paragraph about caches and rewind behavior." + ) + checkpoint_batches = [] + + def checkpoint_callback(records): + checkpoint_batches.append(list(records)) + + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + prompt_checkpoint_callback=checkpoint_callback, + ) + + (uid,) = gen.insert([prompt]) + batch_responses = {} + while responses := gen.next(): + for response in responses: + batch_responses[response.uid] = response + + self.assertEqual(len(checkpoint_batches), 1) + self.assertEqual(len(checkpoint_batches[0]), 1) + checkpoint_uid, checkpoint_size, checkpoint_cache_iter = checkpoint_batches[0][ + 0 + ] + self.assertEqual(checkpoint_uid, uid) + self.assertEqual(checkpoint_size, 1) + + checkpoint_cache = list(checkpoint_cache_iter) + resumed = next( + generate_step( + prompt=mx.array(prompt[-checkpoint_size:], dtype=mx.uint32), + model=self.model, + prompt_cache=checkpoint_cache, + max_tokens=1, + ) + ) + self.assertEqual(batch_responses[uid].token, resumed[0]) + self.assertTrue( + mx.allclose(batch_responses[uid].logprobs, resumed[1], rtol=1e-4, atol=1e-4) + ) + + def test_batch_generator_negative_prompt_checkpoint_resume_matches_full_prompt( + self, + ): + prompt = self.tokenizer.encode( + "Write a short paragraph about caches and rewind behavior." + ) + checkpoint_batches = [] + + def checkpoint_callback(records): + checkpoint_batches.append(list(records)) + + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + prompt_checkpoint_callback=checkpoint_callback, + ) + + (uid,) = gen.insert([prompt], prompt_checkpoints=-2) + batch_responses = {} + while responses := gen.next(): + for response in responses: + batch_responses[response.uid] = response + + self.assertEqual(len(checkpoint_batches), 1) + self.assertEqual(len(checkpoint_batches[0]), 1) + checkpoint_uid, checkpoint_size, checkpoint_cache_iter = checkpoint_batches[0][ + 0 + ] + self.assertEqual(checkpoint_uid, uid) + self.assertEqual(checkpoint_size, 2) + + checkpoint_cache = list(checkpoint_cache_iter) + self.assertEqual(len(checkpoint_cache), len(self.model.layers)) + resumed = next( + generate_step( + prompt=mx.array(prompt[-checkpoint_size:], dtype=mx.uint32), + model=self.model, + prompt_cache=checkpoint_cache, + max_tokens=1, + ) + ) + self.assertEqual(batch_responses[uid].token, resumed[0]) + self.assertTrue( + mx.allclose(batch_responses[uid].logprobs, resumed[1], rtol=1e-4, atol=1e-4) + ) + + def test_batch_generator_prompt_checkpoints_use_shared_batch_checkpoint_and_resume( + self, + ): + prompts = [ + self.tokenizer.encode("Tell me something about the moon."), + self.tokenizer.encode("Tell me something about the sun."), + ] + checkpoint_batches = [] + + def checkpoint_callback(records): + checkpoint_batches.append(list(records)) + + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + prompt_checkpoint_callback=checkpoint_callback, + ) + + uids = gen.insert(prompts, prompt_checkpoints=[-1, -3]) + batch_responses = {} + while responses := gen.next(): + for response in responses: + batch_responses[response.uid] = response + + self.assertEqual(len(checkpoint_batches), 1) + checkpoint_records = checkpoint_batches[0] + self.assertEqual(len(checkpoint_records), len(uids)) + self.assertEqual({record[0] for record in checkpoint_records}, set(uids)) + self.assertEqual({record[1] for record in checkpoint_records}, {3}) + records_by_uid = {record[0]: record for record in checkpoint_records} + for uid, prompt in zip(uids, prompts): + _uid, checkpoint_size, checkpoint_cache_iter = records_by_uid[uid] + checkpoint_cache = list(checkpoint_cache_iter) + self.assertEqual(len(checkpoint_cache), len(self.model.layers)) + resumed = next( + generate_step( + prompt=mx.array(prompt[-checkpoint_size:], dtype=mx.uint32), + model=self.model, + prompt_cache=checkpoint_cache, + max_tokens=1, + ) + ) + self.assertEqual(batch_responses[uid].token, resumed[0]) + self.assertTrue( + mx.allclose( + batch_responses[uid].logprobs, + resumed[1], + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_batch_generator_positive_prompt_checkpoint_is_length_normalized_and_resumable( + self, + ): + prompt = self.tokenizer.encode( + "Write a short paragraph about the mountains and the clouds." + ) + expected_checkpoint_size = max(1, len(prompt) - 2) + checkpoint_batches = [] + + def checkpoint_callback(records): + checkpoint_batches.append(list(records)) + + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + prompt_checkpoint_callback=checkpoint_callback, + ) + + (uid,) = gen.insert([prompt], prompt_checkpoints=2) + batch_responses = {} + while responses := gen.next(): + for response in responses: + batch_responses[response.uid] = response + + self.assertEqual(len(checkpoint_batches), 1) + self.assertEqual(len(checkpoint_batches[0]), 1) + checkpoint_uid, checkpoint_size, checkpoint_cache_iter = checkpoint_batches[0][ + 0 + ] + self.assertEqual(checkpoint_uid, uid) + self.assertEqual(checkpoint_size, expected_checkpoint_size) + + checkpoint_cache = list(checkpoint_cache_iter) + resumed = next( + generate_step( + prompt=mx.array(prompt[-checkpoint_size:], dtype=mx.uint32), + model=self.model, + prompt_cache=checkpoint_cache, + max_tokens=1, + ) + ) + self.assertEqual(batch_responses[uid].token, resumed[0]) + self.assertTrue( + mx.allclose(batch_responses[uid].logprobs, resumed[1], rtol=1e-4, atol=1e-4) + ) + + def test_batch_generator_checkpoint_resume_matches_merge_path_continuation(self): + prompts_a = [ + self.tokenizer.encode("A short warmup prompt about the ocean."), + self.tokenizer.encode("A short warmup prompt about the forest."), + ] + prompts_b = [ + self.tokenizer.encode("Now continue with one sentence about tides."), + self.tokenizer.encode("Now continue with one sentence about moss."), + ] + + base_gen = BatchGenerator( + self.model, stop_tokens=self.tokenizer.eos_token_ids, max_tokens=1 + ) + base_uids = base_gen.insert(prompts_a) + base_caches = {uid: None for uid in base_uids} + base_tokens = {} + while responses := base_gen.next(): + for response in responses: + if response.finish_reason is not None: + base_caches[response.uid] = response.prompt_cache + base_tokens[response.uid] = response.token + caches = [base_caches[uid] for uid in base_uids] + for cache_value in caches: + self.assertIsNotNone(cache_value) + + checkpoint_batches = [] + + def checkpoint_callback(records): + checkpoint_batches.append(list(records)) + + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + prompt_checkpoint_callback=checkpoint_callback, + ) + uids = gen.insert(prompts_b, caches=caches, prompt_checkpoints=-2) + batch_responses = {} + while responses := gen.next(): + for response in responses: + batch_responses[response.uid] = response + + self.assertEqual(len(checkpoint_batches), 1) + checkpoint_records = checkpoint_batches[0] + self.assertEqual(len(checkpoint_records), len(uids)) + self.assertEqual({record[0] for record in checkpoint_records}, set(uids)) + self.assertEqual({record[1] for record in checkpoint_records}, {2}) + records_by_uid = {record[0]: record for record in checkpoint_records} + for uid, base_uid, prompt_a, prompt_b in zip( + uids, base_uids, prompts_a, prompts_b + ): + _uid, checkpoint_size, checkpoint_cache_iter = records_by_uid[uid] + checkpoint_cache = list(checkpoint_cache_iter) + expected = next( + generate_step( + prompt=mx.array( + prompt_a + [base_tokens[base_uid]] + prompt_b, + dtype=mx.uint32, + ), + model=self.model, + max_tokens=1, + ) + ) + resumed = next( + generate_step( + prompt=mx.array(prompt_b[-checkpoint_size:], dtype=mx.uint32), + model=self.model, + prompt_cache=checkpoint_cache, + max_tokens=1, + ) + ) + self.assertEqual(batch_responses[uid].token, expected[0]) + self.assertTrue( + mx.allclose( + batch_responses[uid].logprobs, + expected[1], + rtol=1e-4, + atol=1e-4, + ) + ) + self.assertEqual(resumed[0], expected[0]) + self.assertTrue( + mx.allclose( + resumed[1], + expected[1], + rtol=1e-4, + atol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + batch_responses[uid].logprobs, + resumed[1], + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_batch_generator_merge_path_allows_mixed_none_and_longer_checkpoint_tails( + self, + ): + prompts_a = [ + self.tokenizer.encode("A short warmup prompt about the ocean."), + self.tokenizer.encode("A short warmup prompt about the forest."), + ] + prompts_b = [ + self.tokenizer.encode("Brief.")[:1], + self.tokenizer.encode("Now continue with one sentence about moss."), + ] + self.assertLess(len(prompts_b[0]), 3) + + base_gen = BatchGenerator( + self.model, stop_tokens=self.tokenizer.eos_token_ids, max_tokens=1 + ) + base_uids = base_gen.insert(prompts_a) + base_caches = {uid: None for uid in base_uids} + base_tokens = {} + while responses := base_gen.next(): + for response in responses: + if response.finish_reason is not None: + base_caches[response.uid] = response.prompt_cache + base_tokens[response.uid] = response.token + caches = [base_caches[uid] for uid in base_uids] + for cache_value in caches: + self.assertIsNotNone(cache_value) + + checkpoint_batches = [] + + def checkpoint_callback(records): + checkpoint_batches.append(list(records)) + + gen = BatchGenerator( + self.model, + stop_tokens=self.tokenizer.eos_token_ids, + max_tokens=1, + prompt_checkpoint_callback=checkpoint_callback, + ) + uids = gen.insert(prompts_b, caches=caches, prompt_checkpoints=[None, -3]) + batch_responses = {} + while responses := gen.next(): + for response in responses: + batch_responses[response.uid] = response + + self.assertEqual(set(batch_responses), set(uids)) + expected_short = next( + generate_step( + prompt=mx.array( + prompts_a[0] + [base_tokens[base_uids[0]]] + prompts_b[0], + dtype=mx.uint32, + ), + model=self.model, + max_tokens=1, + ) + ) + self.assertEqual(batch_responses[uids[0]].token, expected_short[0]) + self.assertTrue( + mx.allclose( + batch_responses[uids[0]].logprobs, + expected_short[1], + rtol=1e-4, + atol=1e-4, + ) + ) + self.assertEqual(len(checkpoint_batches), 1) + self.assertEqual(len(checkpoint_batches[0]), 1) + checkpoint_uid, checkpoint_size, checkpoint_cache_iter = checkpoint_batches[0][ + 0 + ] + self.assertEqual(checkpoint_uid, uids[1]) + self.assertEqual(checkpoint_size, 3) + + checkpoint_cache = list(checkpoint_cache_iter) + expected = next( + generate_step( + prompt=mx.array( + prompts_a[1] + [base_tokens[base_uids[1]]] + prompts_b[1], + dtype=mx.uint32, + ), + model=self.model, + max_tokens=1, + ) + ) + resumed = next( + generate_step( + prompt=mx.array(prompts_b[1][-checkpoint_size:], dtype=mx.uint32), + model=self.model, + prompt_cache=checkpoint_cache, + max_tokens=1, + ) + ) + self.assertEqual(batch_responses[uids[1]].token, expected[0]) + self.assertEqual(resumed[0], expected[0]) + self.assertTrue( + mx.allclose( + batch_responses[uids[1]].logprobs, + expected[1], + rtol=1e-4, + atol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + resumed[1], + expected[1], + rtol=1e-4, + atol=1e-4, + ) + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index 05dcd7dc4..02b0db0cb 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -1,6 +1,7 @@ # Copyright © 2024 Apple Inc. import copy +import gc import os import tempfile import unittest @@ -28,6 +29,71 @@ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" +class TestBatchRotatingKVCacheState(unittest.TestCase): + + @staticmethod + def _run_batch_rotating_memory_loop(*, step_size: int, force_padding_eval: bool): + gc.collect() + mx.clear_cache() + + cache = BatchRotatingKVCache(max_size=4, left_padding=[2, 0]) + prompt_kv = mx.zeros((2, 1, 4, 8)) + prompt_k, prompt_v = cache.update_and_fetch(prompt_kv, prompt_kv) + mx.eval(prompt_k, prompt_v) + + step_kv = mx.zeros((2, 1, step_size, 8)) + mx.reset_peak_memory() + for _ in range(120): + keys, values = cache.update_and_fetch(step_kv, step_kv) + output = keys.sum() + values.sum() + if force_padding_eval: + mx.eval(output, cache.left_padding, cache.offset) + else: + mx.eval(output) + + return mx.get_peak_memory(), mx.get_active_memory() + + def test_update_eval_preserves_pre_decode_mask_snapshot(self): + cache = BatchRotatingKVCache(max_size=4, left_padding=[2, 0]) + + prompt_kv = mx.zeros((2, 1, 4, 8)) + prompt_k, prompt_v = cache.update_and_fetch(prompt_kv, prompt_kv) + mx.eval(prompt_k, prompt_v) + + pre_decode_mask = cache.make_mask(1) + + decode_kv = mx.zeros((2, 1, 1, 8)) + decode_k, decode_v = cache.update_and_fetch(decode_kv, decode_kv) + mx.eval(decode_k, decode_v) + + self.assertEqual(pre_decode_mask.shape, (2, 1, 1, 4)) + self.assertEqual(pre_decode_mask[0, 0, 0].tolist(), [True, False, True, True]) + self.assertEqual(pre_decode_mask[1, 0, 0].tolist(), [True, True, True, True]) + + post_decode_mask = cache.make_mask(1) + mx.eval(post_decode_mask) + self.assertEqual(post_decode_mask.shape, (2, 1, 1, 4)) + self.assertEqual(post_decode_mask[0, 0, 0].tolist(), [True, True, True, True]) + self.assertEqual(post_decode_mask[1, 0, 0].tolist(), [True, True, True, True]) + + def test_returned_tensor_eval_keeps_batch_rotating_memory_close_to_explicit_padding_eval( + self, + ): + for path_name, step_size in (("decode", 1), ("prompt", 4)): + with self.subTest(path=path_name): + baseline_peak, baseline_active = self._run_batch_rotating_memory_loop( + step_size=step_size, + force_padding_eval=True, + ) + trial_peak, trial_active = self._run_batch_rotating_memory_loop( + step_size=step_size, + force_padding_eval=False, + ) + + self.assertLessEqual(trial_peak, int(baseline_peak * 2.0)) + self.assertLessEqual(trial_active, int(baseline_active * 2.0)) + + class TestPromptCache(unittest.TestCase): @classmethod diff --git a/tests/test_prompt_cache_server_behavior.py b/tests/test_prompt_cache_server_behavior.py new file mode 100644 index 000000000..9d8f8db31 --- /dev/null +++ b/tests/test_prompt_cache_server_behavior.py @@ -0,0 +1,299 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +from mlx_lm.server import LRUPromptCache +from tests.prompt_cache_test_utils import RewindRecorderLayer + + +class MockCache: + def __init__(self, value): + self.value = value + + @property + def nbytes(self): + return len(self.value) + + def __eq__(self, other): + return other.value == self.value + + +class TestLRUPromptCacheBehavior(unittest.TestCase): + def test_regular_refcounted_hit_refreshes_regular_lru_recency(self): + cache = LRUPromptCache(max_size=2) + model = ("regular-hit-refresh", None, None) + + cache.insert_cache(model, [1], [MockCache("test1")]) + cache.insert_cache(model, [1], [MockCache("test1")]) + cache.insert_cache(model, [2], [MockCache("test2")]) + + c, t = cache.fetch_nearest_cache(model, [1]) + self.assertEqual(c, [MockCache("test1")]) + self.assertEqual(t, []) + + cache.insert_cache(model, [3], [MockCache("test3")]) + + c, t = cache.fetch_nearest_cache(model, [1]) + self.assertEqual(c, [MockCache("test1")]) + self.assertEqual(t, []) + c, t = cache.fetch_nearest_cache(model, [2]) + self.assertIsNone(c) + self.assertEqual(t, [2]) + c, t = cache.fetch_nearest_cache(model, [3]) + self.assertEqual(c, [MockCache("test3")]) + self.assertEqual(t, []) + + def test_checkpoint_hit_refreshes_checkpoint_lru_recency(self): + cache = LRUPromptCache(max_size=2) + model = ("checkpoint-hit-refresh", None, None) + + cache.insert_cache(model, [1], [MockCache("test1")], checkpoint=True) + cache.insert_cache(model, [2], [MockCache("test2")], checkpoint=True) + + c, t = cache.fetch_nearest_cache(model, [1, 99]) + self.assertEqual(c, [MockCache("test1")]) + self.assertEqual(t, [99]) + + cache.insert_cache(model, [3], [MockCache("test3")], checkpoint=True) + + c, t = cache.fetch_nearest_cache(model, [1, 98]) + self.assertEqual(c, [MockCache("test1")]) + self.assertEqual(t, [98]) + c, t = cache.fetch_nearest_cache(model, [2, 77]) + self.assertIsNone(c) + self.assertEqual(t, [2, 77]) + c, t = cache.fetch_nearest_cache(model, [3, 55]) + self.assertEqual(c, [MockCache("test3")]) + self.assertEqual(t, [55]) + + def test_farther_rewindable_prefix_outranks_nearer_non_rewindable_longer_prefix( + self, + ): + lru = LRUPromptCache(max_size=10) + model = ("farther-rewindable-prefix", None, None) + requested_tokens = [1, 2, 5] + + shorter_cache = RewindRecorderLayer(max_rewind=0, offset=1) + nearer_longer_cache = RewindRecorderLayer(max_rewind=0, offset=3) + farther_longer_cache = RewindRecorderLayer(max_rewind=2, offset=4) + lru.insert_cache(model, [1], [shorter_cache]) + lru.insert_cache(model, [1, 2, 9], [nearer_longer_cache]) + lru.insert_cache(model, [1, 2, 3, 4], [farther_longer_cache]) + + reused_cache, remaining = lru.fetch_nearest_cache(model, requested_tokens) + self.assertIsNotNone(reused_cache) + self.assertEqual(remaining, [5]) + self.assertEqual(reused_cache[0].offset, 2) + self.assertEqual(farther_longer_cache.can_rewind_calls, [2]) + self.assertEqual(farther_longer_cache.rewind_calls, [2]) + + def test_longer_rewindable_prefix_outranks_shorter_stored_prefix(self): + lru = LRUPromptCache(max_size=10) + model = ("longer-rewindable-prefix", None, None) + shorter_tokens = [1] + longer_tokens = [1, 2, 9] + requested_tokens = [1, 2, 5] + + shorter_cache = RewindRecorderLayer(max_rewind=0, offset=1) + longer_cache = RewindRecorderLayer(max_rewind=1, offset=3) + lru.insert_cache(model, shorter_tokens, [shorter_cache]) + lru.insert_cache(model, longer_tokens, [longer_cache]) + + reused_cache, remaining = lru.fetch_nearest_cache(model, requested_tokens) + self.assertIsNotNone(reused_cache) + self.assertEqual(remaining, [5]) + self.assertEqual(reused_cache[0].offset, 2) + self.assertEqual(longer_cache.can_rewind_calls, [1]) + self.assertEqual(longer_cache.rewind_calls, [1]) + + exact_cache, exact_remaining = lru.fetch_nearest_cache(model, longer_tokens) + self.assertIsNotNone(exact_cache) + self.assertEqual(exact_remaining, []) + self.assertEqual(exact_cache[0].offset, 3) + + def test_longer_path_reuse_refreshes_recency_for_regular_and_checkpoint_entries( + self, + ): + for checkpoint in (False, True): + with self.subTest(checkpoint=checkpoint): + lru = LRUPromptCache(max_size=2) + model = ("longer-path-refresh", checkpoint, None) + reused_tokens = [1, 2, 9] + sibling_tokens = [1, 3, 9] + fresh_tokens = [4, 5, 6] + + reused_entry = RewindRecorderLayer(max_rewind=1, offset=3) + sibling_entry = RewindRecorderLayer(max_rewind=1, offset=3) + lru.insert_cache( + model, reused_tokens, [reused_entry], checkpoint=checkpoint + ) + lru.insert_cache( + model, sibling_tokens, [sibling_entry], checkpoint=checkpoint + ) + + reused_cache, remaining = lru.fetch_nearest_cache(model, [1, 2, 5]) + self.assertIsNotNone(reused_cache) + self.assertEqual(remaining, [5]) + self.assertEqual(reused_cache[0].offset, 2) + + lru.insert_cache( + model, + fresh_tokens, + [MockCache("fresh")], + checkpoint=checkpoint, + ) + + cache_entry, rest = lru.fetch_nearest_cache(model, reused_tokens) + self.assertIsNotNone(cache_entry) + self.assertEqual(rest, []) + cache_entry, rest = lru.fetch_nearest_cache(model, sibling_tokens) + self.assertIsNone(cache_entry) + self.assertEqual(rest, sibling_tokens) + + # -- Extraction and ref_count semantics -- + + def test_checkpoint_extract_persists_through_multiple_fetches(self): + """Checkpoint entries are persistent: extraction always deepcopies and + never removes the entry.""" + cache = LRUPromptCache(max_size=10) + model = ("checkpoint-persist", None, None) + + cache.insert_cache(model, [1, 2], [MockCache("ckpt")], checkpoint=True) + + c1, t1 = cache.fetch_nearest_cache(model, [1, 2]) + self.assertIsNotNone(c1) + self.assertEqual(t1, []) + + c2, t2 = cache.fetch_nearest_cache(model, [1, 2]) + self.assertIsNotNone(c2) + self.assertEqual(t2, []) + + # Entry is still alive. + result = cache._search(model, [1, 2]) + self.assertIsNotNone(result.exact) + self.assertEqual(len(cache), 1) + + def test_regular_entry_promoted_to_checkpoint_becomes_persistent(self): + """A regular entry promoted to checkpoint via a subsequent checkpoint + insert becomes persistent (extract no longer consumes it).""" + cache = LRUPromptCache(max_size=10) + model = ("promote-checkpoint", None, None) + + cache.insert_cache(model, [1, 2], [MockCache("reg")]) + cache.insert_cache(model, [1, 2], [MockCache("reg")], checkpoint=True) + + c1, _ = cache.fetch_nearest_cache(model, [1, 2]) + self.assertIsNotNone(c1) + c2, _ = cache.fetch_nearest_cache(model, [1, 2]) + self.assertIsNotNone(c2) + self.assertEqual(len(cache), 1) + + def test_insert_existing_key_keeps_original_cache(self): + """Re-inserting the same token key increments ref_count but keeps the + original prompt_cache list (the new one is silently dropped).""" + cache = LRUPromptCache(max_size=10) + model = ("reinsert-keeps-original", None, None) + + cache.insert_cache(model, [1, 2], [MockCache("original")]) + cache.insert_cache(model, [1, 2], [MockCache("different")]) + + # First extract: deepcopy (ref_count 2 → 1), returns original value. + c, t = cache.fetch_nearest_cache(model, [1, 2]) + self.assertEqual(t, []) + self.assertEqual(c, [MockCache("original")]) + + # Second extract: ref_count==1 ownership transfer, still original. + c2, t2 = cache.fetch_nearest_cache(model, [1, 2]) + self.assertEqual(t2, []) + self.assertEqual(c2, [MockCache("original")]) + + # Now fully consumed. + c3, t3 = cache.fetch_nearest_cache(model, [1, 2]) + self.assertIsNone(c3) + self.assertEqual(t3, [1, 2]) + + def test_deepcopy_failure_on_refcounted_entry_does_not_decrement_ref_count(self): + """When deepcopy fails on a refcounted (ref_count > 1) non-checkpoint entry, + _extract returns None without decrementing the ref_count.""" + + class FailDeepCopy: + @property + def nbytes(self): + return 1 + + def __deepcopy__(self, memo): + raise RuntimeError("deepcopy fails") + + cache = LRUPromptCache(max_size=10) + model = ("deepcopy-fail-refcount", None, None) + + cache.insert_cache(model, [1, 2], [FailDeepCopy()]) + cache.insert_cache(model, [1, 2], [FailDeepCopy()]) # ref_count -> 2 + + c, t = cache.fetch_nearest_cache(model, [1, 2]) + self.assertIsNone(c) + self.assertEqual(t, [1, 2]) + + # Entry still alive — ref_count was not decremented. + result = cache._search(model, [1, 2]) + self.assertIsNotNone(result.exact) + self.assertEqual(len(cache), 1) + + # -- Rewind safety -- + + def test_partial_rewind_on_longer_hit_discards_copy_preserves_original(self): + """When rewind succeeds on some layers but fails on others in a longer + cache hit, the corrupted deepcopy is discarded and the original entry + is unmodified.""" + cache = LRUPromptCache(max_size=10) + model = ("partial-rewind-discard", None, None) + + good_layer = RewindRecorderLayer(max_rewind=2, offset=4) + bad_layer = RewindRecorderLayer(max_rewind=2, offset=4, rewind_result=False) + cache.insert_cache(model, [1, 2, 3, 4], [good_layer, bad_layer]) + + # Request [1, 2, 5] — longer candidate needs to rewind 2. + c, t = cache.fetch_nearest_cache(model, [1, 2, 5]) + + # Falls through to no match. + self.assertIsNone(c) + self.assertEqual(t, [1, 2, 5]) + + # Original entry intact. + result = cache._search(model, [1, 2, 3, 4]) + self.assertIsNotNone(result.exact) + original = cache._get(model, [1, 2, 3, 4]) + self.assertEqual(original.prompt_cache[0].offset, 4) + self.assertEqual(original.prompt_cache[1].offset, 4) + + # -- Search behavior changes from upstream -- + + def test_single_token_shorter_match_is_valid(self): + """A single-token prefix cache is returned as a valid shorter match.""" + cache = LRUPromptCache(max_size=10) + model = ("single-token-shorter", None, None) + + cache.insert_cache(model, [1], [MockCache("one")]) + + c, t = cache.fetch_nearest_cache(model, [1, 2, 3]) + self.assertEqual(c, [MockCache("one")]) + self.assertEqual(t, [2, 3]) + + def test_shorter_prefix_not_evicted_by_longer_insert(self): + """Inserting a longer token sequence does not evict a shorter prefix + entry.""" + cache = LRUPromptCache(max_size=10) + model = ("no-prefix-eviction", None, None) + + cache.insert_cache(model, [1, 2], [MockCache("short")]) + cache.insert_cache(model, [1, 2, 3, 4], [MockCache("long")]) + + self.assertEqual(len(cache), 2) + + c, t = cache.fetch_nearest_cache(model, [1, 2]) + self.assertEqual(c, [MockCache("short")]) + self.assertEqual(t, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prompt_cache_server_rewind_internal.py b/tests/test_prompt_cache_server_rewind_internal.py new file mode 100644 index 000000000..8d71f38d1 --- /dev/null +++ b/tests/test_prompt_cache_server_rewind_internal.py @@ -0,0 +1,185 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx + +from mlx_lm.models.cache import BatchRotatingKVCache, _BaseCache +from mlx_lm.server import LRUPromptCache + + +class TestLRUPromptCacheRewindInternals(unittest.TestCase): + def test_batch_rotating_rewind_after_rotation_restores_pre_step_behavior(self): + prompt_kv = mx.zeros((2, 1, 4, 1), dtype=mx.float32) + decode_kv = mx.array([[[[11.0]]], [[[22.0]]]], dtype=mx.float32) + + cache = BatchRotatingKVCache(max_size=4, left_padding=[2, 0]) + cache.update_and_fetch(prompt_kv, prompt_kv) + cache.update_and_fetch(decode_kv, decode_kv) + mx.eval(cache.keys, cache.values) + + self.assertTrue(cache.rotated) + self.assertEqual(cache._offset, 5) + self.assertEqual(cache._idx, 1) + + pre_rewind_offsets = mx.array(cache.offset) + self.assertTrue(cache.rewind(1)) + self.assertEqual(cache._offset, 4) + self.assertEqual(cache._idx, 0) + self.assertTrue(mx.array_equal(cache.offset, pre_rewind_offsets - 1)) + self.assertFalse(cache.can_rewind(1)) + + rewind_mask = cache.make_mask(1) + mx.eval(rewind_mask) + self.assertEqual(rewind_mask.shape, (2, 1, 1, 4)) + self.assertEqual(rewind_mask[0, 0, 0].tolist(), [True, False, True, True]) + self.assertEqual(rewind_mask[1, 0, 0].tolist(), [True, True, True, True]) + + +class TestHasRewindImpl(unittest.TestCase): + def test_base_cache_has_no_rewind_impl(self): + """_BaseCache itself should not report a rewind implementation.""" + base = _BaseCache.__new__(_BaseCache) + self.assertFalse(base._has_rewind_impl()) + + def test_subclass_without_rewind_override_has_no_impl(self): + """A _BaseCache subclass that does not override rewind() should not + report a rewind implementation.""" + + class NoRewind(_BaseCache): + pass + + cache = NoRewind.__new__(NoRewind) + self.assertFalse(cache._has_rewind_impl()) + + def test_subclass_with_rewind_override_has_impl(self): + """A _BaseCache subclass that overrides rewind() should be recognized + as having a rewind implementation — this is the contract that lets + third-party caches participate without an explicit opt-in flag.""" + + class CustomRewind(_BaseCache): + def rewind(self, num_to_trim): + return True + + cache = CustomRewind.__new__(CustomRewind) + self.assertTrue(cache._has_rewind_impl()) + + def test_batch_rotating_has_rewind_impl(self): + """BatchRotatingKVCache should be recognized as having rewind.""" + cache = BatchRotatingKVCache(max_size=4, left_padding=[0]) + self.assertTrue(cache._has_rewind_impl()) + + +class _LegacyTrimCache: + """Non-_BaseCache cache with only the legacy is_trimmable/trim contract. + No can_rewind, no rewind. Exercises the legacy fallback in both + _can_rewind_layer_cache and _rewind_layer_cache.""" + + def __init__(self, offset=10): + self._offset = offset + + @property + def offset(self): + return self._offset + + def is_trimmable(self): + return True + + def trim(self, num_to_trim): + if num_to_trim > self._offset: + return 0 + self._offset -= num_to_trim + return num_to_trim + + +class _BaseCacheWithTrim(_BaseCache): + """_BaseCache subclass with is_trimmable/trim but no rewind override. + This is the shape of KVCache / RotatingKVCache — inherits the base-class + rewind() stub. The rewind path must NOT be attempted; the legacy trim + fallback must be used instead.""" + + def __init__(self, offset=10): + self._offset = offset + + @property + def offset(self): + return self._offset + + def is_trimmable(self): + return True + + def trim(self, num_to_trim): + if num_to_trim > self._offset: + return 0 + self._offset -= num_to_trim + return num_to_trim + + @property + def state(self): + raise NotImplementedError + + def is_empty(self): + return self._offset == 0 + + +class TestCanRewindAndRewindAgreement(unittest.TestCase): + """Verify that _can_rewind_layer_cache and _rewind_layer_cache agree: + if _can_rewind says yes, _rewind must succeed (not waste a deepcopy).""" + + def setUp(self): + self.lru = LRUPromptCache(max_size=10) + + def test_legacy_non_basecache_can_rewind(self): + """Legacy cache (no can_rewind) should be rewindable via the + is_trimmable/trim fallback.""" + cache = _LegacyTrimCache(offset=10) + self.assertTrue(self.lru._can_rewind_layer_cache(cache, 3)) + + def test_legacy_non_basecache_rewind_succeeds(self): + """If _can_rewind says yes, _rewind must actually succeed.""" + cache = _LegacyTrimCache(offset=10) + can = self.lru._can_rewind_layer_cache(cache, 3) + did = self.lru._rewind_layer_cache(cache, 3) + self.assertTrue(can) + self.assertTrue(did) + + def test_basecache_with_trim_no_rewind_override_can_rewind(self): + """_BaseCache subclass with trim but no rewind override should still + be rewindable via legacy fallback — not via the base-class stub.""" + cache = _BaseCacheWithTrim(offset=10) + self.assertTrue(self.lru._can_rewind_layer_cache(cache, 3)) + + def test_basecache_with_trim_no_rewind_override_rewind_succeeds(self): + """The critical regression: _rewind must use the trim fallback, not + call the base-class rewind() stub that raises NotImplementedError.""" + cache = _BaseCacheWithTrim(offset=10) + can = self.lru._can_rewind_layer_cache(cache, 3) + did = self.lru._rewind_layer_cache(cache, 3) + self.assertTrue(can) + self.assertTrue(did) + self.assertEqual(cache.offset, 7) + + def test_basecache_with_trim_rewind_beyond_offset_fails(self): + """Rewinding more than available offset should fail gracefully.""" + cache = _BaseCacheWithTrim(offset=2) + self.assertFalse(self.lru._can_rewind_layer_cache(cache, 5)) + self.assertFalse(self.lru._rewind_layer_cache(cache, 5)) + + def test_can_rewind_and_rewind_agree_for_batch_rotating(self): + """BatchRotatingKVCache uses the real rewind path — sanity check + that the agreement holds here too.""" + kv = mx.zeros((1, 1, 4, 1), dtype=mx.float32) + decode = mx.array([[[[1.0]]]], dtype=mx.float32) + cache = BatchRotatingKVCache(max_size=4, left_padding=[0]) + cache.update_and_fetch(kv, kv) + cache.update_and_fetch(decode, decode) + mx.eval(cache.keys, cache.values) + + can = self.lru._can_rewind_layer_cache(cache, 1) + did = self.lru._rewind_layer_cache(cache, 1) + self.assertTrue(can) + self.assertTrue(did) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_server.py b/tests/test_server.py index c5a815e4f..1fd5b9ae3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,12 +5,24 @@ import json import threading import unittest +from queue import Empty, Queue +from unittest.mock import patch import mlx.core as mx import requests +from mlx_lm.generate import BatchGenerator from mlx_lm.models.cache import KVCache -from mlx_lm.server import APIHandler, LRUPromptCache, ResponseGenerator +from mlx_lm.server import ( + APIHandler, + CompletionRequest, + GenerationArguments, + LogitsProcessorArguments, + LRUPromptCache, + ModelDescription, + ResponseGenerator, + SamplingArguments, +) from mlx_lm.utils import load @@ -446,15 +458,14 @@ def get_kv(n): c[0].update_and_fetch(*get_kv(24)) cache.insert_cache(model, t, c) - # Fetching a cache that is strictly a prefix doesn't remove it from the - # lru cache + # Fetching a strict shorter-prefix hit consumes the only stored entry. tokens = tokens + [20] * 5 c, t = cache.fetch_nearest_cache(model, tokens) k, v = c[0].state self.assertTrue((k == v).all().item()) self.assertTrue((k.flatten() == mx.arange(24)).all().item()) self.assertEqual(t, [20] * 5) - self.assertEqual(len(cache), 1) + self.assertEqual(len(cache), 0) # Inserting a trimmable cache with shared prefix removes the prefixes tokens = tokens + [30] * 3 @@ -482,23 +493,27 @@ def test_lru(self): cache = LRUPromptCache(max_size=2) model = ("test", None, None) cache.insert_cache(model, [1, 2], [MockCache("test1")]) - cache.insert_cache(model, [2, 3], [MockCache("test2")]) + cache.insert_cache(model, [1, 2], [MockCache("test1")]) c, t = cache.fetch_nearest_cache(model, [1, 2]) self.assertEqual(c, [MockCache("test1")]) self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [1]) + c, t = cache.fetch_nearest_cache(model, [1, 2]) self.assertEqual(c, [MockCache("test1")]) - self.assertEqual(t, [1]) - c, t = cache.fetch_nearest_cache(model, [1, 3, 4]) + self.assertEqual(t, []) + c, t = cache.fetch_nearest_cache(model, [1, 2]) + self.assertIsNone(c) + self.assertEqual(t, [1, 2]) + + cache.insert_cache(model, [1, 2], [MockCache("test1")]) + cache.insert_cache(model, [2, 3], [MockCache("test2")]) + + c, t = cache.fetch_nearest_cache(model, [1, 2]) self.assertEqual(c, [MockCache("test1")]) - self.assertEqual(t, [3, 4]) + self.assertEqual(t, []) c, t = cache.fetch_nearest_cache(model, [2, 3, 4]) self.assertEqual(c, [MockCache("test2")]) self.assertEqual(t, [4]) - c, t = cache.fetch_nearest_cache(model, [2, 4, 5]) - self.assertEqual(c, [MockCache("test2")]) - self.assertEqual(t, [4, 5]) cache.insert_cache(model, [1, 2], [MockCache("test1")]) cache.insert_cache(model, [2, 3], [MockCache("test2")]) @@ -519,8 +534,8 @@ def test_lru(self): self.assertEqual(c, None) self.assertEqual(t, [2, 3]) c, t = cache.fetch_nearest_cache(model, [3, 4]) - self.assertEqual(c, [MockCache("test3")]) - self.assertEqual(t, []) + self.assertEqual(c, None) + self.assertEqual(t, [3, 4]) c, t = cache.fetch_nearest_cache(model, [4, 5]) self.assertEqual(c, [MockCache("test4")]) self.assertEqual(t, []) @@ -561,5 +576,760 @@ def test_lru_bytes(self): self.assertEqual(t, [3, 4]) +class TestResponseGeneratorBatchPromptCheckpoints(unittest.TestCase): + @staticmethod + def _generation_args(): + return GenerationArguments( + model=ModelDescription("default_model", None, None), + sampling=SamplingArguments(0.0, 1.0, 0, 0.0, 0.0, 0.0), + logits=LogitsProcessorArguments(None, 1.0, 20, 0.0, 20, 0.0, 20), + stop_words=[], + max_tokens=2, + num_draft_tokens=3, + logprobs=False, + top_logprobs=0, + seed=None, + chat_template_kwargs=None, + ) + + @staticmethod + def _make_text_request(prompt="hello"): + return CompletionRequest( + request_type="text", + prompt=prompt, + messages=[], + tools=None, + role_mapping=None, + ) + + def _build_response_generator(self): + class FakeModel: + def make_cache(self): + return [KVCache()] + + class FakeTokenizer: + has_tool_calling = False + tool_call_start = "" + tool_call_end = "" + tool_parser = staticmethod(lambda text, _: {}) + detokenizer = None + has_thinking = False + think_start_id = 0 + think_end_id = 0 + think_end = "" + eos_token_id = 0 + eos_token_ids = set() + + def encode(self, text, add_special_tokens=False): + return [1, 2, 3] + + class FakeProvider: + is_batchable = True + + def __init__(self): + self.cli_args = type( + "obj", + (object,), + { + "decode_concurrency": 4, + "prompt_concurrency": 2, + "prefill_step_size": 77, + "prompt_cache_bytes": None, + }, + ) + self.model = FakeModel() + self.tokenizer = FakeTokenizer() + self.draft_model = None + self.model_key = ("fake-model", None, None) + + def load(self, model, adapter=None, draft_model=None): + return self.model, self.tokenizer + + generator = ResponseGenerator.__new__(ResponseGenerator) + generator.model_provider = FakeProvider() + generator.prompt_cache = LRUPromptCache(max_size=10) + generator.requests = Queue() + generator._is_distributed = False + generator._rank = 0 + generator._stop = False + generator._time_budget = [] + return generator + + def _run_batch_checkpoint_probe( + self, + *, + request, + tokenized_prompt, + callback_prompt_end=None, + has_thinking=False, + think_start_id=0, + seeded_entries=None, + ): + generator = self._build_response_generator() + generator._time_budget = [None] + generator.model_provider.tokenizer.detokenizer = type( + "FakeDetokenizer", + (), + { + "last_segment": "", + "add_token": lambda self, token: None, + }, + )() + generator.model_provider.tokenizer.has_thinking = has_thinking + generator.model_provider.tokenizer.think_start_id = think_start_id + for entry in seeded_entries or []: + if len(entry) == 3: + tokens, prompt_cache, checkpoint = entry + else: + tokens, prompt_cache = entry + checkpoint = False + generator.prompt_cache.insert_cache( + generator.model_provider.model_key, + tokens, + prompt_cache, + checkpoint=checkpoint, + ) + + request_queue = Queue() + request_args = self._generation_args() + request_seen = False + captured = {"request_queue": request_queue} + + def next_request(timeout=None): + nonlocal request_seen + if request_seen: + return None + request_seen = True + return (request_queue, request, request_args) + + class FakeBatchGenerator: + prompt_cache_nbytes = 0 + + def __init__(self, *args, **kwargs): + captured["constructor_kwargs"] = kwargs + captured["prompt_checkpoint_callback"] = kwargs.get( + "prompt_checkpoint_callback" + ) + + def insert( + self, + prompts, + max_tokens=None, + caches=None, + samplers=None, + logits_processors=None, + prompt_checkpoints=None, + ): + captured["insert_prompts"] = prompts + captured["insert_caches"] = caches + captured["insert_prompt_checkpoints"] = prompt_checkpoints + prompt_checkpoint = None + if prompt_checkpoints is not None: + prompt_checkpoint = prompt_checkpoints[0] + if callback_prompt_end is not None: + captured["effective_prompt_end"] = callback_prompt_end + elif prompt_checkpoint is None: + captured["effective_prompt_end"] = 1 + elif prompt_checkpoint > 0: + captured["effective_prompt_end"] = max( + 1, len(prompts[0]) - prompt_checkpoint + ) + else: + captured["effective_prompt_end"] = -prompt_checkpoint + return [123] + + def next(self): + checkpoint_callback = captured.get("prompt_checkpoint_callback") + if ( + checkpoint_callback is not None + and captured.get("insert_prompt_checkpoints") is not None + ): + checkpoint_callback( + [ + ( + 123, + captured["effective_prompt_end"], + iter([MockCache("checkpoint")]), + ) + ] + ) + generator._stop = True + return [ + type( + "BatchResponse", + (), + { + "uid": 123, + "token": 0, + "logprobs": mx.array([0.0], dtype=mx.float32), + "finish_reason": "stop", + "prompt_cache": [MockCache("final")], + }, + )() + ] + + generator._next_request = next_request + with patch.object(generator, "_tokenize", return_value=tokenized_prompt): + with patch("mlx_lm.server.BatchGenerator", FakeBatchGenerator): + generator._generate() + + return generator, captured + + def _run_malformed_then_valid_batch_probe(self, malformed_request): + generator = self._build_response_generator() + generator._time_budget = [None] + generator.model_provider.tokenizer.detokenizer = type( + "FakeDetokenizer", + (), + { + "last_segment": "", + "add_token": lambda self, token: None, + }, + )() + malformed_queue = Queue() + valid_queue = Queue() + request_args = self._generation_args() + valid_request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + requests = [ + (malformed_queue, malformed_request, request_args), + (valid_queue, valid_request, request_args), + ] + captured = {"insert_prompts": []} + + def next_request(timeout=None): + if requests: + return requests.pop(0) + generator._stop = True + return None + + class FakeBatchGenerator: + prompt_cache_nbytes = 0 + + def __init__(self, *args, **kwargs): + self._done = False + + def insert( + self, + prompts, + max_tokens=None, + caches=None, + samplers=None, + logits_processors=None, + prompt_checkpoints=None, + ): + captured["insert_prompts"].append(prompts) + return [123] + + def next(self): + if self._done: + return [] + self._done = True + return [ + type( + "BatchResponse", + (), + { + "uid": 123, + "token": 0, + "logprobs": mx.array([0.0], dtype=mx.float32), + "finish_reason": "stop", + "prompt_cache": [MockCache("final")], + }, + )() + ] + + generator._next_request = next_request + with patch.object( + generator, "_tokenize", side_effect=[[11, 12, 13, 14], [11, 12, 13, 14]] + ): + with patch("mlx_lm.server.BatchGenerator", FakeBatchGenerator): + generator._generate() + + return malformed_queue, valid_queue, captured + + def test_generate_batch_mode_forwards_checkpoint_callback_and_prompt_checkpoints( + self, + ): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 13, 14], + has_thinking=True, + think_start_id=99, + ) + + self.assertIn("prompt_checkpoint_callback", captured["constructor_kwargs"]) + self.assertEqual(captured["insert_prompt_checkpoints"], [-1]) + self.assertEqual(len(generator.prompt_cache), 2) + + checkpoint_cache, rest = generator.prompt_cache.fetch_nearest_cache( + generator.model_provider.model_key, + [11, 12, 13], + ) + self.assertEqual(rest, []) + self.assertEqual([cache.value for cache in checkpoint_cache], ["checkpoint"]) + + def test_generate_batch_mode_non_thinking_model_stores_checkpoint(self): + """Non-thinking models still save a checkpoint at -1 (last token). + + This is important for models with non-trimmable caches (ArraysCache) + where the completion entry can't be rewound, but a checkpoint entry + at the prompt boundary enables reuse via the shorter-cache path. + """ + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 13, 14], + ) + + self.assertIn("prompt_checkpoint_callback", captured["constructor_kwargs"]) + self.assertEqual(captured["insert_prompt_checkpoints"], [-1]) + self.assertEqual(len(generator.prompt_cache), 2) + + checkpoint_cache, rest = generator.prompt_cache.fetch_nearest_cache( + generator.model_provider.model_key, + [11, 12, 13], + ) + self.assertEqual(rest, []) + self.assertEqual([cache.value for cache in checkpoint_cache], ["checkpoint"]) + + def test_generate_batch_mode_does_not_store_checkpoint_for_non_user_terminal_chat( + self, + ): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ], + tools=None, + role_mapping=None, + ) + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 13, 14], + ) + + self.assertIn("prompt_checkpoint_callback", captured["constructor_kwargs"]) + self.assertEqual(captured["insert_prompt_checkpoints"], [None]) + self.assertEqual(len(generator.prompt_cache), 1) + + self.assertIsNone( + generator.prompt_cache._search( + generator.model_provider.model_key, [11, 12, 13] + ).exact + ) + + def test_generate_batch_mode_uses_think_start_checkpoint_offset(self): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 99, 13, 14], + callback_prompt_end=4, + has_thinking=True, + think_start_id=99, + ) + + self.assertIn("prompt_checkpoint_callback", captured["constructor_kwargs"]) + self.assertEqual(captured["insert_prompt_checkpoints"], [-4]) + self.assertEqual(len(generator.prompt_cache), 2) + + checkpoint_cache, rest = generator.prompt_cache.fetch_nearest_cache( + generator.model_provider.model_key, + [11], + ) + self.assertEqual(rest, []) + self.assertEqual([cache.value for cache in checkpoint_cache], ["checkpoint"]) + + def test_generate_batch_mode_does_not_store_empty_key_checkpoint_entry(self): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 13, 14], + callback_prompt_end=4, + has_thinking=True, + think_start_id=99, + ) + + self.assertIn("prompt_checkpoint_callback", captured["constructor_kwargs"]) + self.assertEqual(captured["insert_prompt_checkpoints"], [-1]) + self.assertEqual(len(generator.prompt_cache), 1) + + root_cache, root_rest = generator.prompt_cache.fetch_nearest_cache( + generator.model_provider.model_key, [] + ) + self.assertIsNone(root_cache) + self.assertEqual(root_rest, []) + + final_cache, final_rest = generator.prompt_cache.fetch_nearest_cache( + generator.model_provider.model_key, + [11, 12, 13, 14, 0], + ) + self.assertEqual(final_rest, []) + self.assertEqual(final_cache, [MockCache("final")]) + + def test_generate_batch_mode_does_not_store_checkpoint_for_text_requests(self): + request = self._make_text_request(prompt="hello world") + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 13, 14], + ) + + self.assertIn("prompt_checkpoint_callback", captured["constructor_kwargs"]) + self.assertEqual(captured["insert_prompt_checkpoints"], [None]) + self.assertEqual(len(generator.prompt_cache), 1) + + self.assertIsNone( + generator.prompt_cache._search( + generator.model_provider.model_key, [11, 12, 13] + ).exact + ) + + def test_generate_batch_mode_empty_chat_messages_reports_request_error(self): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[], + tools=None, + role_mapping=None, + ) + error_queue, valid_queue, captured = self._run_malformed_then_valid_batch_probe( + request + ) + + error = error_queue.get_nowait() + self.assertIsInstance(error, ValueError) + self.assertEqual(str(error), "Chat request messages must be a non-empty list") + with self.assertRaises(Empty): + error_queue.get_nowait() + self.assertEqual(captured["insert_prompts"], [[[11, 12, 13, 14]]]) + valid_ctx = valid_queue.get_nowait() + valid_response = valid_queue.get_nowait() + self.assertFalse(isinstance(valid_ctx, Exception)) + self.assertTrue(hasattr(valid_ctx, "prompt")) + self.assertEqual(valid_response.finish_reason, "stop") + self.assertIsNone(valid_queue.get_nowait()) + + def test_generate_batch_mode_missing_last_role_reports_request_error(self): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"content": "hello"}], + tools=None, + role_mapping=None, + ) + error_queue, valid_queue, captured = self._run_malformed_then_valid_batch_probe( + request + ) + + error = error_queue.get_nowait() + self.assertIsInstance(error, ValueError) + self.assertEqual(str(error), "Chat request last message must include a role") + with self.assertRaises(Empty): + error_queue.get_nowait() + self.assertEqual(captured["insert_prompts"], [[[11, 12, 13, 14]]]) + valid_ctx = valid_queue.get_nowait() + valid_response = valid_queue.get_nowait() + self.assertFalse(isinstance(valid_ctx, Exception)) + self.assertTrue(hasattr(valid_ctx, "prompt")) + self.assertEqual(valid_response.finish_reason, "stop") + self.assertIsNone(valid_queue.get_nowait()) + + def test_generate_batch_mode_does_not_forward_impossible_checkpoint_with_warm_cache( + self, + ): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 99, 13, 14], + callback_prompt_end=4, + has_thinking=True, + think_start_id=99, + seeded_entries=[([11, 12], [MockCache("seed")])], + ) + + self.assertIn("prompt_checkpoint_callback", captured["constructor_kwargs"]) + self.assertEqual(captured["insert_prompts"], [[99, 13, 14]]) + self.assertEqual(captured["insert_prompt_checkpoints"], [None]) + + def test_localize_prompt_checkpoint_at_prompt_start_suppressed_with_warm_cache( + self, + ): + """When checkpoint_position = -len(prompt) (position 0 = start of + prompt) and a warm cache covers a prefix, _localize_prompt_checkpoint + returns None because the checkpoint falls before the reused region.""" + generator = self._build_response_generator() + prompt = [11, 12, 13, 14] + rest = [13, 14] # cache hit covered [11, 12] + + # checkpoint_position = -4 means position 0 (start of prompt). + # rest_offset = 4 - 2 = 2, checkpoint_prefix = 0 < 2 → suppressed. + result = generator._localize_prompt_checkpoint(prompt, rest, -len(prompt)) + self.assertIsNone(result) + + # Also verify that -1 (last token) is NOT suppressed in same scenario. + result2 = generator._localize_prompt_checkpoint(prompt, rest, -1) + self.assertIsNotNone(result2) + self.assertEqual(result2, -1) + + def test_generate_batch_mode_real_generator_stores_checkpoint_cache(self): + class DeterministicBatchModel: + layers = [object()] + + def make_cache(self): + return [KVCache()] + + def __call__(self, input_tokens, cache=None, input_embeddings=None): + if cache is not None: + for layer_cache in cache: + kv = mx.zeros( + (input_tokens.shape[0], 1, input_tokens.shape[1], 1), + dtype=mx.float32, + ) + layer_cache.update_and_fetch(kv, kv) + batch, seq_len = input_tokens.shape + vocab_size = 4 + logits = -1000.0 * mx.ones((vocab_size,), dtype=mx.float32) + logits = logits + (2000.0 * (mx.arange(vocab_size) == 0)) + return mx.broadcast_to(logits, (batch, seq_len, vocab_size)) + + generator = self._build_response_generator() + generator._time_budget = [None] + generator.model_provider.model = DeterministicBatchModel() + generator.model_provider.tokenizer.has_thinking = True + generator.model_provider.tokenizer.think_start_id = 99 + generator.model_provider.tokenizer.detokenizer = type( + "FakeDetokenizer", + (), + { + "last_segment": "", + "add_token": lambda self, token: None, + }, + )() + request_queue = Queue() + request_args = self._generation_args() + request_args.max_tokens = 1 + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + request_seen = False + original_next = BatchGenerator.next + + def next_request(timeout=None): + nonlocal request_seen + if request_seen: + return None + request_seen = True + return (request_queue, request, request_args) + + def stopping_next(batch_generator): + responses = original_next(batch_generator) + if responses and all(r.finish_reason is not None for r in responses): + generator._stop = True + return responses + + generator._next_request = next_request + with patch.object(generator, "_tokenize", return_value=[11, 12, 13, 14]): + with patch.object(BatchGenerator, "next", new=stopping_next): + generator._generate() + + self.assertEqual(len(generator.prompt_cache), 2) + checkpoint_cache, rest = generator.prompt_cache.fetch_nearest_cache( + generator.model_provider.model_key, + [11, 12, 13], + ) + self.assertEqual(rest, []) + self.assertIsNotNone(checkpoint_cache) + self.assertEqual([layer.offset for layer in checkpoint_cache], [3]) + self.assertEqual(len(generator.prompt_cache), 2) + + def test_generate_batch_mode_exact_checkpoint_hit_does_not_forward_empty_prompt( + self, + ): + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + generator, captured = self._run_batch_checkpoint_probe( + request=request, + tokenized_prompt=[11, 12, 13], + seeded_entries=[([11, 12, 13], [MockCache("checkpoint")], True)], + ) + + self.assertTrue(captured["insert_prompts"][0]) + request_queue = captured["request_queue"] + ctx = request_queue.get_nowait() + response = request_queue.get_nowait() + self.assertFalse(isinstance(ctx, Exception)) + self.assertEqual(response.finish_reason, "stop") + self.assertIsNone(request_queue.get_nowait()) + + def test_serve_single_exact_checkpoint_hit_does_not_forward_empty_prompt(self): + generator = self._build_response_generator() + generator.prompt_cache.insert_cache( + generator.model_provider.model_key, + [11, 12, 13], + [MockCache("checkpoint")], + checkpoint=True, + ) + request_queue = Queue() + gen_result = type( + "GenResult", + (), + { + "text": "x", + "token": 0, + "logprobs": mx.array([0.0], dtype=mx.float32), + "finish_reason": "stop", + }, + )() + + def stream_generate_probe(**stream_kwargs): + self.assertTrue(stream_kwargs["prompt"]) + yield gen_result + + with patch("mlx_lm.server.stream_generate", side_effect=stream_generate_probe): + with patch.object(generator, "_tokenize", return_value=[11, 12, 13]): + generator._serve_single( + (request_queue, self._make_text_request(), self._generation_args()) + ) + + ctx = request_queue.get_nowait() + response = request_queue.get_nowait() + self.assertFalse(isinstance(ctx, Exception)) + self.assertFalse(isinstance(response, Exception)) + self.assertEqual(response.finish_reason, "stop") + self.assertIsNone(request_queue.get_nowait()) + + def test_generate_batch_mode_real_generator_suppresses_impossible_warm_cache_checkpoint( + self, + ): + class DeterministicBatchModel: + layers = [object()] + + def make_cache(self): + return [KVCache()] + + def __call__(self, input_tokens, cache=None, input_embeddings=None): + if cache is not None: + for layer_cache in cache: + kv = mx.zeros( + (input_tokens.shape[0], 1, input_tokens.shape[1], 1), + dtype=mx.float32, + ) + layer_cache.update_and_fetch(kv, kv) + batch, seq_len = input_tokens.shape + vocab_size = 4 + logits = -1000.0 * mx.ones((vocab_size,), dtype=mx.float32) + logits = logits + (2000.0 * (mx.arange(vocab_size) == 0)) + return mx.broadcast_to(logits, (batch, seq_len, vocab_size)) + + generator = self._build_response_generator() + generator._time_budget = [None] + generator.model_provider.model = DeterministicBatchModel() + generator.model_provider.tokenizer.detokenizer = type( + "FakeDetokenizer", + (), + { + "last_segment": "", + "add_token": lambda self, token: None, + }, + )() + generator.model_provider.tokenizer.has_thinking = True + generator.model_provider.tokenizer.think_start_id = 99 + + seeded_cache = generator.model_provider.model.make_cache() + generator.model_provider.model( + mx.array([[11, 12]], dtype=mx.uint32), cache=seeded_cache + ) + generator.prompt_cache.insert_cache( + generator.model_provider.model_key, + [11, 12], + seeded_cache, + ) + + request_queue = Queue() + request_args = self._generation_args() + request_args.max_tokens = 1 + request = CompletionRequest( + request_type="chat", + prompt="", + messages=[{"role": "user", "content": "hello"}], + tools=None, + role_mapping=None, + ) + request_seen = False + original_next = BatchGenerator.next + + def next_request(timeout=None): + nonlocal request_seen + if request_seen: + return None + request_seen = True + return (request_queue, request, request_args) + + def stopping_next(batch_generator): + responses = original_next(batch_generator) + if responses and all(r.finish_reason is not None for r in responses): + generator._stop = True + return responses + + generator._next_request = next_request + with patch.object(generator, "_tokenize", return_value=[11, 12, 99, 13, 14]): + with patch.object(BatchGenerator, "next", new=stopping_next): + generator._generate() + + self.assertEqual(len(generator.prompt_cache), 1) + self.assertIsNone( + generator.prompt_cache._search( + generator.model_provider.model_key, [11] + ).exact + ) + + if __name__ == "__main__": unittest.main()