From a9944448e6c9720da873c2a52618970f4bd77c81 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Sat, 28 Feb 2026 21:32:04 +0100 Subject: [PATCH 01/11] server: return 503 on Metal OOM and harden generation failures --- mlx_lm/server.py | 530 +++++++++++++++++++++++++------------------ tests/test_server.py | 224 +++++++++++++----- 2 files changed, 478 insertions(+), 276 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..c63d54927 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -2,7 +2,6 @@ import argparse import copy -import heapq import json import logging import pickle @@ -42,7 +41,7 @@ trim_prompt_cache, ) from .sample_utils import make_logits_processors, make_sampler -from .utils import _parse_size, load, sharded_load +from .utils import load, sharded_load def get_system_fingerprint(): @@ -50,6 +49,67 @@ def get_system_fingerprint(): return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" +def parse_size(x): + sizes = {"M": 1e6, "G": 1e9, "MB": 1e6, "GB": 1e9, "": 1} + split = 0 + for xi in x: + if not (xi.isdigit() or xi == "."): + break + split += 1 + digits = float(x[:split]) + size = (x[split:]).strip().upper() + return int(digits * sizes[size]) + + +def is_metal_oom_error(error: Exception) -> bool: + text = str(error).lower() + patterns = ( + "out of memory", + "insufficient memory", + "kiogpucommandbuffercallbackerroroutofmemory", + "mps backend out of memory", + ) + return any(pattern in text for pattern in patterns) + + +def projected_kv_bytes(prompt_cache: List[Any], extra_tokens: int) -> int: + cache_bytes = sum(c.nbytes for c in prompt_cache) + if cache_bytes <= 0 or extra_tokens <= 0: + return cache_bytes + + cache_tokens = max( + (c.size() for c in prompt_cache if hasattr(c, "size")), default=0 + ) + if cache_tokens <= 0: + return cache_bytes + + bytes_per_token = cache_bytes / cache_tokens + return cache_bytes + int(bytes_per_token * extra_tokens) + + +def apply_prompt_token_limit( + tokens: List[int], + *, + max_prompt_tokens: Optional[int], + overflow_policy: str, + keep_tokens: int, +) -> List[int]: + if max_prompt_tokens is None or len(tokens) <= max_prompt_tokens: + return tokens + + if overflow_policy == "error": + raise ValueError( + "Prompt exceeds max prompt token limit: " + f"prompt_tokens={len(tokens)}, max_prompt_tokens={max_prompt_tokens}" + ) + + keep_tokens = max(0, min(keep_tokens, max_prompt_tokens)) + tail_tokens = max_prompt_tokens - keep_tokens + if tail_tokens <= 0: + return tokens[:max_prompt_tokens] + return tokens[:keep_tokens] + tokens[-tail_tokens:] + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -172,35 +232,13 @@ def process_message_content(messages): class LRUPromptCache: + @dataclass class CacheEntry: prompt_cache: List[Any] + count: int nbytes: int - class CacheOrder: - def __init__(self): - self._lru_checkpoints = deque() - self._lru = deque() - - def __len__(self): - return len(self._lru) + len(self._lru_checkpoints) - - def push(self, model, tokens, checkpoint: bool = False): - c = self._lru_checkpoints if checkpoint else self._lru - c.append((model, tokens)) - - def remove(self, model, tokens): - try: - self._lru.remove((model, tokens)) - except ValueError: - self._lru_checkpoints.remove((model, tokens)) - - def pop(self): - if len(self._lru) >= len(self._lru_checkpoints): - return self._lru.popleft() - else: - return self._lru_checkpoints.popleft() - @dataclass class SearchResult: model: Any @@ -213,7 +251,7 @@ def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): self.max_size = max_size self.max_bytes = max_bytes self._cache = {} - self._lru = self.CacheOrder() + self._lru = deque() self._n_bytes = 0 def __len__(self): @@ -250,7 +288,7 @@ def _search(self, model, tokens): # Check for caches that are longer longer = None common_prefix = index - if index > 0: + if index > 0 and last_cache_index <= 0: best = None stack = [(current, [])] while stack: @@ -283,14 +321,32 @@ def _delete(self, model, tokens): break del d_prev[t] + logging.debug(f"[LRUPromptCache] Removed {cache_bytes} bytes from the cache") + + def _extract(self, model, tokens): + cache_entry = self._get(model, tokens) + if cache_entry.count == 1: + self._delete(model, tokens) + self._lru.remove((model, tokens)) + return cache_entry + + cache_entry.count -= 1 + return self.CacheEntry( + copy.deepcopy(cache_entry.prompt_cache), 1, cache_entry.nbytes + ) + def fetch_nearest_cache(self, model, tokens): result = self._search(model, tokens) if result.exact is not None: - cache_entry = self._get(result.model, result.exact) - return copy.deepcopy(cache_entry.prompt_cache), [] + cache_entry = self._extract(result.model, result.exact) + return cache_entry.prompt_cache, [] - short_length = len(result.shorter) if result.shorter is not None else 0 - if result.longer is not None and result.common_prefix > short_length: + if result.shorter is not None: + cache_entry = self._extract(result.model, result.shorter) + prefix_len = len(result.shorter) + return cache_entry.prompt_cache, tokens[prefix_len:] + + if result.longer is not None: cache_entry = self._get(result.model, result.longer) if can_trim_prompt_cache(cache_entry.prompt_cache): cache = copy.deepcopy(cache_entry.prompt_cache) @@ -299,40 +355,32 @@ def fetch_nearest_cache(self, model, tokens): trim_prompt_cache(cache, num_to_trim) return cache, tokens[prefix:] - if short_length > 0: - cache_entry = self._get(result.model, result.shorter) - return copy.deepcopy(cache_entry.prompt_cache), tokens[short_length:] - return None, tokens - def insert_cache(self, model, tokens, prompt_cache, checkpoint: bool = False): - is_trimmable = can_trim_prompt_cache(prompt_cache) - + def insert_cache(self, model, tokens, prompt_cache): if model not in self._cache: self._cache[model] = {} current = self._cache[model] - for i, tok in enumerate(tokens): + for tok in tokens: if tok not in current: current[tok] = {} - if is_trimmable and "cache" in current: - self._n_bytes -= current["cache"].nbytes - del current["cache"] - self._lru.remove(model, tokens[:i]) current = current[tok] if "cache" in current: - self._lru.remove(model, tokens) + current["cache"].count += 1 + self._lru.remove((model, tokens)) else: cache_bytes = sum(c.nbytes for c in prompt_cache) - current["cache"] = self.CacheEntry(prompt_cache, cache_bytes) + current["cache"] = self.CacheEntry(prompt_cache, 1, cache_bytes) self._n_bytes += cache_bytes + logging.debug(f"[LRUPromptCache] Adding {cache_bytes} to the cache") - self._lru.push(model, tokens, checkpoint=checkpoint) + self._lru.append((model, tokens)) if len(self._lru) > self.max_size: - model, tokens = self._lru.pop() + model, tokens = self._lru.popleft() self._delete(model, tokens) while self._n_bytes > self.max_bytes and len(self._lru) > 1: - model, tokens = self._lru.pop() + model, tokens = self._lru.popleft() self._delete(model, tokens) def trim_to( @@ -342,23 +390,12 @@ def trim_to( n_bytes = max(0, n_bytes) if n_bytes is not None else 1 << 63 while len(self._lru) > n_sequences: - model, tokens = self._lru.pop() + model, tokens = self._lru.popleft() self._delete(model, tokens) while self._n_bytes > n_bytes: - model, tokens = self._lru.pop() + model, tokens = self._lru.popleft() self._delete(model, tokens) - def log_cache_stats(self): - ncaches, nbytes = len(self), self.nbytes - ntok = ( - len(self._lru._lru_checkpoints[-1][1]) - if len(self._lru._lru_checkpoints) > 0 - else 0 - ) - logging.info( - f"KV Caches: {ncaches} seq, {nbytes / 1e9:.2f} GB, latest user cache {ntok} tokens" - ) - @dataclass class ModelDescription: @@ -382,10 +419,6 @@ class LogitsProcessorArguments: logit_bias: Optional[Dict[int, float]] repetition_penalty: float repetition_context_size: int - presence_penalty: float - presence_context_size: int - frequency_penalty: float - frequency_context_size: int @dataclass @@ -614,10 +647,6 @@ def _make_logits_processors(args): args.logits.logit_bias, args.logits.repetition_penalty, args.logits.repetition_context_size, - args.logits.presence_penalty, - args.logits.presence_context_size, - args.logits.frequency_penalty, - args.logits.frequency_context_size, ) @@ -722,35 +751,35 @@ def _tokenize(self, tokenizer, request, args): if args.chat_template_kwargs: chat_template_args = chat_template_args.copy() chat_template_args.update(args.chat_template_kwargs) - return tokenizer.apply_chat_template( + tokens = tokenizer.apply_chat_template( messages, tools=tools, add_generation_prompt=True, tokenize=True, **chat_template_args, ) + return apply_prompt_token_limit( + tokens, + max_prompt_tokens=self.cli_args.max_prompt_tokens, + overflow_policy=self.cli_args.prompt_overflow_policy, + keep_tokens=self.cli_args.prompt_keep_tokens, + ) else: - return tokenizer.encode(convert_chat(messages, role_mapping)) + tokens = tokenizer.encode(convert_chat(messages, role_mapping)) + return apply_prompt_token_limit( + tokens, + max_prompt_tokens=self.cli_args.max_prompt_tokens, + overflow_policy=self.cli_args.prompt_overflow_policy, + keep_tokens=self.cli_args.prompt_keep_tokens, + ) else: - return tokenizer.encode(request.prompt) - - def _compute_prompt_checkpoint(self, tokenizer, request, prompt): - if request.request_type != "chat": - return False, -1 - if request.messages[-1]["role"] != "user": - return False, -1 - - # Save the KV cache at the end of the prompt just before - # the think start token which will likely be removed in the - # next turn. - prompt_checkpoint = -1 - if tokenizer.has_thinking: - for i in range(1, min(11, len(prompt)) - 1, 1): - if prompt[-i] == tokenizer.think_start_id: - prompt_checkpoint = -i - 1 - break - - return True, prompt_checkpoint + tokens = tokenizer.encode(request.prompt) + return apply_prompt_token_limit( + tokens, + max_prompt_tokens=self.cli_args.max_prompt_tokens, + overflow_policy=self.cli_args.prompt_overflow_policy, + keep_tokens=self.cli_args.prompt_keep_tokens, + ) def _is_batchable(self, args): if not self.model_provider.is_batchable: @@ -760,6 +789,44 @@ def _is_batchable(self, args): return True + def _make_prompt_cache(self, model, draft_model=None): + cache = make_prompt_cache( + model, + max_kv_size=self.model_provider.cli_args.max_kv_size, + ) + if draft_model is not None: + cache += make_prompt_cache( + draft_model, + max_kv_size=self.model_provider.cli_args.max_kv_size, + ) + return cache + + def _memory_admission_error( + self, prompt_cache: List[Any], extra_tokens: int, active_bytes: int = 0 + ) -> Optional[str]: + limit = self.model_provider.cli_args.max_active_kv_bytes + if limit is None: + return None + projected = active_bytes + projected_kv_bytes(prompt_cache, extra_tokens) + if projected <= limit: + return None + return ( + "Projected KV memory usage exceeds configured active KV limit. " + f"projected={projected} bytes, limit={limit} bytes" + ) + + def _check_active_memory_limit(self): + limit = self.model_provider.cli_args.max_active_memory_bytes + if limit is None: + return + active = mx.get_active_memory() + if active > limit: + raise MemoryError( + "Active MLX memory exceeded configured limit: " + f"active={active} bytes, limit={limit} bytes. " + "Consider lowering prompt length or max_tokens." + ) + def _generate(self): current_model = None current_sampling = None @@ -778,22 +845,11 @@ def get_next_request(timeout=None): return self._next_request(timeout) def progress_callback(info): + self._check_active_memory_limit() for uid, processed, total in info: if uid in batch_results: batch_results[uid]["rqueue"].put((min(processed, total), total)) - def checkpoint_callback(prompts): - for uid, prompt_end, cache in prompts: - rs = batch_results[uid] - if not rs["checkpoint"]: - continue - self.prompt_cache.insert_cache( - current_model_key, - rs["cache_key"][:-prompt_end], - list(cache), - checkpoint=True, - ) - if self._is_distributed: seed = mx.distributed.all_sum(mx.random.state[0]).view(mx.uint64).item() mx.random.seed(seed) @@ -842,16 +898,25 @@ def checkpoint_callback(prompts): ) rqueue.put(ctx) - self.prompt_cache.log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( current_model_key, prompt ) ctx.prompt_cache_count = len(prompt) - len(rest) if cache is None: - cache = make_prompt_cache(self.model_provider.model) + cache = self._make_prompt_cache(self.model_provider.model) + + admission_error = self._memory_admission_error( + cache, + len(rest) + args.max_tokens, + batch_generator.prompt_cache_nbytes, + ) + if admission_error is not None: + rqueue.put(MemoryError(admission_error)) + continue - do_checkpoint, checkpoint_position = ( - self._compute_prompt_checkpoint(tokenizer, request, prompt) + ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes + logging.info( + f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB" ) (uid,) = batch_generator.insert( @@ -860,14 +925,12 @@ def checkpoint_callback(prompts): caches=[cache], samplers=[_make_sampler(args, tokenizer)], logits_processors=[_make_logits_processors(args)], - prompt_checkpoints=[checkpoint_position], ) batch_results[uid] = { "ctx": ctx, "cache_key": prompt[:], "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, - "checkpoint": do_checkpoint, } # just making sure we don't leave a reference around del cache @@ -905,7 +968,7 @@ def checkpoint_callback(prompts): prefill_batch_size=self.cli_args.prompt_concurrency, prefill_step_size=self.cli_args.prefill_step_size, prompt_progress_callback=progress_callback, - prompt_checkpoint_callback=checkpoint_callback, + max_kv_size=self.cli_args.max_kv_size, ) unprocessed_requests.append((rqueue, request, args)) continue @@ -931,59 +994,78 @@ def checkpoint_callback(prompts): continue uids_to_remove = [] - for _ in self._time_budget: - responses = batch_generator.next() - if not responses: - break - - for r in responses: - result = batch_results[r.uid] - result["cache_key"].append(r.token) - if r.finish_reason != "stop": - result["detokenizer"].add_token(r.token) - - result["rqueue"].put( - Response( - result["detokenizer"].last_segment, - r.token, - r.logprobs[r.token].item(), - r.finish_reason, - _format_top_logprobs( - r.logprobs, args.top_logprobs, current_tokenizer - ), - ) - ) - - if r.finish_reason is not None: - result["rqueue"].put(None) - self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], r.prompt_cache + try: + for _ in self._time_budget: + responses = batch_generator.next() + if not responses: + break + + for r in responses: + result = batch_results[r.uid] + result["cache_key"].append(r.token) + if r.finish_reason != "stop": + result["detokenizer"].add_token(r.token) + + result["rqueue"].put( + Response( + result["detokenizer"].last_segment, + r.token, + r.logprobs[r.token].item(), + r.finish_reason, + _format_top_logprobs( + r.logprobs, args.top_logprobs, current_tokenizer + ), + ) ) - del batch_results[r.uid] - - if result["ctx"]._should_stop: - uids_to_remove.append(r.uid) - uids_to_remove = self._share_object(uids_to_remove) - if uids_to_remove: - with mx.stream(generation_stream): - caches = batch_generator.remove( - uids_to_remove, return_prompt_caches=True - ) - for uid, prompt_cache in caches.items(): - if uid not in batch_results: - continue - result = batch_results[uid] - self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], prompt_cache + if r.finish_reason is not None: + result["rqueue"].put(None) + self.prompt_cache.insert_cache( + current_model_key, + result["cache_key"], + r.prompt_cache, + ) + del batch_results[r.uid] + + if result["ctx"]._should_stop: + uids_to_remove.append(r.uid) + + uids_to_remove = self._share_object(uids_to_remove) + if uids_to_remove: + with mx.stream(generation_stream): + caches = batch_generator.remove( + uids_to_remove, return_prompt_caches=True ) - del batch_results[uid] + for uid, prompt_cache in caches.items(): + if uid not in batch_results: + continue + result = batch_results[uid] + self.prompt_cache.insert_cache( + current_model_key, result["cache_key"], prompt_cache + ) + del batch_results[uid] + except Exception as e: + logging.exception("Batched generation failed") + if is_metal_oom_error(e): + mx.clear_cache() + for result in batch_results.values(): + result["rqueue"].put(e) + result["rqueue"].put(None) + batch_results = {} + current_model = None + current_sampling = None + current_tokenizer = None + current_model_key = None + batch_generator.close() + batch_generator = None + drain_batch = False def _serve_single(self, request): rqueue, request, args = request # Define the progress callback def progress(tokens_processed, tokens_total): + self._check_active_memory_limit() rqueue.put((tokens_processed, tokens_total)) try: @@ -1023,16 +1105,25 @@ def progress(tokens_processed, tokens_total): logits_processors = _make_logits_processors(args) # Load the KV cache - self.prompt_cache.log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( self.model_provider.model_key, prompt ) ctx.prompt_cache_count = len(prompt) - len(rest) cache_key = prompt[:] if cache is None: - cache = make_prompt_cache(self.model_provider.model) - if self.model_provider.draft_model is not None: - cache += make_prompt_cache(self.model_provider.draft_model) + cache = self._make_prompt_cache( + self.model_provider.model, self.model_provider.draft_model + ) + + admission_error = self._memory_admission_error( + cache, + len(rest) + args.max_tokens, + ) + if admission_error is not None: + raise MemoryError(admission_error) + + ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes + logging.info(f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB") # Process the prompt and generate tokens for gen in stream_generate( @@ -1046,6 +1137,7 @@ def progress(tokens_processed, tokens_total): draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, + max_kv_size=self.cli_args.max_kv_size, prefill_step_size=self.cli_args.prefill_step_size, ): rqueue.put( @@ -1074,6 +1166,8 @@ def progress(tokens_processed, tokens_total): ) except Exception as e: + if is_metal_oom_error(e): + mx.clear_cache() rqueue.put(e) def generate( @@ -1126,13 +1220,7 @@ def __init__( super().__init__(*args, **kwargs) def _set_cors_headers(self): - allowed_origins = self.response_generator.cli_args.allowed_origins - origin = self.headers.get("Origin") - if "*" in allowed_origins: - self.send_header("Access-Control-Allow-Origin", "*") - elif origin in allowed_origins: - self.send_header("Access-Control-Allow-Origin", origin) - self.send_header("Vary", "Origin") + self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Methods", "*") self.send_header("Access-Control-Allow-Headers", "*") @@ -1168,23 +1256,7 @@ def do_POST(self): return # Fetch and parse request body - content_length = self.headers.get("Content-Length") - if content_length is None: - self._set_completion_headers(411) - self.end_headers() - self.wfile.write( - json.dumps({"error": "Content-Length header is required"}).encode() - ) - return - try: - content_length = int(content_length) - except ValueError: - self._set_completion_headers(400) - self.end_headers() - self.wfile.write( - json.dumps({"error": "Invalid Content-Length header"}).encode() - ) - return + content_length = int(self.headers["Content-Length"]) raw_body = self.rfile.read(content_length) try: self.body = json.loads(raw_body.decode()) @@ -1225,10 +1297,6 @@ def do_POST(self): self.min_p = self.body.get("min_p", self.response_generator.cli_args.min_p) self.repetition_penalty = self.body.get("repetition_penalty", 0.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) - self.presence_penalty = self.body.get("presence_penalty", 0.0) - self.presence_context_size = self.body.get("presence_context_size", 20) - self.frequency_penalty = self.body.get("frequency_penalty", 0.0) - self.frequency_context_size = self.body.get("frequency_context_size", 20) self.xtc_probability = self.body.get("xtc_probability", 0.0) self.xtc_threshold = self.body.get("xtc_threshold", 0.0) self.logit_bias = self.body.get("logit_bias", None) @@ -1277,25 +1345,6 @@ def validate_model_parameters(self): or self.repetition_penalty < 0 ): raise ValueError("repetition_penalty must be a non-negative float") - if ( - not isinstance(self.repetition_context_size, int) - or self.repetition_context_size < 0 - ): - raise ValueError("repetition_context_size must be a non-negative integer") - if not isinstance(self.presence_penalty, (float, int)): - raise ValueError("Presence penalty must be must be a float") - if ( - not isinstance(self.presence_context_size, int) - or self.presence_context_size < 0 - ): - raise ValueError("presence_context_size must be a non-negative integer") - if not isinstance(self.frequency_penalty, (float, int)): - raise ValueError("Presence penalty must be must be a float") - if ( - not isinstance(self.frequency_context_size, int) - or self.frequency_context_size < 0 - ): - raise ValueError("frequency_context_size must be a non-negative integer") if not isinstance(self.logprobs, bool): raise ValueError("logprobs must be a boolean") @@ -1305,6 +1354,12 @@ def validate_model_parameters(self): f"top_logprobs must be between 1 and 10 but got {self.top_logprobs:,}" ) + if ( + not isinstance(self.repetition_context_size, int) + or self.repetition_context_size < 0 + ): + raise ValueError("repetition_context_size must be a non-negative integer") + if self.logit_bias is not None: if not isinstance(self.logit_bias, dict): raise ValueError("logit_bias must be a dict of int to float") @@ -1464,10 +1519,6 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): logit_bias=self.logit_bias, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, - presence_penalty=self.presence_penalty, - presence_context_size=self.presence_context_size, - frequency_penalty=self.frequency_penalty, - frequency_context_size=self.frequency_context_size, ), stop_words=stop_words, max_tokens=self.max_tokens, @@ -1502,7 +1553,8 @@ def keepalive_callback(processed_tokens, total_tokens): progress_callback=keepalive_callback, ) except Exception as e: - self._set_completion_headers(404) + status_code = 503 if is_metal_oom_error(e) else 500 + self._set_completion_headers(status_code) self.end_headers() self.wfile.write(json.dumps({"error": f"{e}"}).encode()) return @@ -1884,7 +1936,11 @@ def run( handler_class=APIHandler, ): group = mx.distributed.init() - prompt_cache = LRUPromptCache(model_provider.cli_args.prompt_cache_size) + prompt_cache_bytes = model_provider.cli_args.prompt_cache_bytes + prompt_cache = LRUPromptCache( + model_provider.cli_args.prompt_cache_size, + prompt_cache_bytes if prompt_cache_bytes is not None else (1 << 63), + ) response_generator = ResponseGenerator(model_provider, prompt_cache) if group.rank() == 0: _run_http_server(host, port, response_generator) @@ -1916,12 +1972,6 @@ def main(): default=8080, help="Port for the HTTP server (default: 8080)", ) - parser.add_argument( - "--allowed-origins", - type=lambda x: x.split(","), - default="*", - help="Allowed origins (default: *)", - ) parser.add_argument( "--draft-model", type=str, @@ -2020,9 +2070,53 @@ def main(): ) parser.add_argument( "--prompt-cache-bytes", - type=_parse_size, + type=parse_size, help="Maximum size in bytes of the KV caches", ) + parser.add_argument( + "--max-prompt-tokens", + type=int, + default=None, + help="Maximum prompt token count accepted by the server", + ) + parser.add_argument( + "--prompt-overflow-policy", + type=str, + default="error", + choices=["error", "truncate"], + help="Behavior when prompt exceeds --max-prompt-tokens", + ) + parser.add_argument( + "--prompt-keep-tokens", + type=int, + default=512, + help=( + "When truncating prompts, keep this many tokens from the start and fill " + "the remainder from the end" + ), + ) + parser.add_argument( + "--max-active-kv-bytes", + type=parse_size, + help=( + "Reject requests when projected active KV memory would exceed this limit " + "(bytes or shorthand like 20G)" + ), + ) + parser.add_argument( + "--max-active-memory-bytes", + type=parse_size, + help=( + "Abort requests when current active MLX memory exceeds this limit " + "(bytes or shorthand like 30G)" + ), + ) + parser.add_argument( + "--max-kv-size", + type=int, + default=None, + help="Maximum size of the active KV cache per sequence", + ) parser.add_argument( "--pipeline", action="store_true", diff --git a/tests/test_server.py b/tests/test_server.py index c5a815e4f..34fdf703e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,14 +3,23 @@ import http import io import json +import sys import threading import unittest +from unittest.mock import patch import mlx.core as mx import requests from mlx_lm.models.cache import KVCache -from mlx_lm.server import APIHandler, LRUPromptCache, ResponseGenerator +from mlx_lm.server import ( + APIHandler, + LRUPromptCache, + ResponseGenerator, + apply_prompt_token_limit, + is_metal_oom_error, + projected_kv_bytes, +) from mlx_lm.utils import load @@ -47,7 +56,12 @@ def __init__(self, with_draft=False): "prompt_cache_size": 10, "prompt_cache_bytes": 1 << 63, "prompt_cache_total_bytes": None, - "allowed_origins": ["*"], + "max_prompt_tokens": None, + "prompt_overflow_policy": "error", + "prompt_keep_tokens": 512, + "max_active_kv_bytes": None, + "max_active_memory_bytes": None, + "max_kv_size": None, }, ) @@ -63,24 +77,20 @@ def load(self, model, adapter=None, draft_model=None): class MockCache: - def __init__(self, value, is_trimmable: bool = True): + def __init__(self, value, size=None): self.value = value - self._is_trimmable = is_trimmable + self._size = len(value) if size is None else size @property def nbytes(self): return len(self.value) + def size(self): + return self._size + def __eq__(self, other): return other.value == self.value - def is_trimmable(self): - return self._is_trimmable - - def trim(self, n): - assert self._is_trimmable - return n - class TestServer(unittest.TestCase): @classmethod @@ -446,23 +456,18 @@ def get_kv(n): c[0].update_and_fetch(*get_kv(24)) cache.insert_cache(model, t, c) - # Fetching a cache that is strictly a prefix doesn't remove it from the - # lru cache tokens = tokens + [20] * 5 c, t = cache.fetch_nearest_cache(model, tokens) k, v = c[0].state self.assertTrue((k == v).all().item()) self.assertTrue((k.flatten() == mx.arange(24)).all().item()) self.assertEqual(t, [20] * 5) - self.assertEqual(len(cache), 1) + self.assertEqual(len(cache._lru), 0) - # Inserting a trimmable cache with shared prefix removes the prefixes tokens = tokens + [30] * 3 c[0].update_and_fetch(*get_kv(8)) cache.insert_cache(model, tokens, c) - self.assertEqual(len(cache), 1) - # Fetching a cache with a shared prefix doesn't remove it either tokens = tokens[:26] + [40] * 8 c, t = cache.fetch_nearest_cache(model, tokens) k, v = c[0].state @@ -471,34 +476,23 @@ def get_kv(n): (k.flatten() == mx.concatenate([mx.arange(24), mx.arange(2)])).all().item() ) self.assertEqual(t, [40] * 8) - self.assertEqual(len(cache), 1) - - # Inserting a diverged cache actually creates another entry - c[0].update_and_fetch(*get_kv(8)) - cache.insert_cache(model, tokens, c) - self.assertEqual(len(cache), 2) + self.assertEqual(len(cache._lru), 1) def test_lru(self): cache = LRUPromptCache(max_size=2) model = ("test", None, None) cache.insert_cache(model, [1, 2], [MockCache("test1")]) - cache.insert_cache(model, [2, 3], [MockCache("test2")]) + cache.insert_cache(model, [1, 2], [MockCache("test1")]) c, t = cache.fetch_nearest_cache(model, [1, 2]) self.assertEqual(c, [MockCache("test1")]) self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [1]) - self.assertEqual(c, [MockCache("test1")]) - self.assertEqual(t, [1]) - c, t = cache.fetch_nearest_cache(model, [1, 3, 4]) + c, t = cache.fetch_nearest_cache(model, [1, 2]) self.assertEqual(c, [MockCache("test1")]) - self.assertEqual(t, [3, 4]) - c, t = cache.fetch_nearest_cache(model, [2, 3, 4]) - self.assertEqual(c, [MockCache("test2")]) - self.assertEqual(t, [4]) - c, t = cache.fetch_nearest_cache(model, [2, 4, 5]) - self.assertEqual(c, [MockCache("test2")]) - self.assertEqual(t, [4, 5]) + self.assertEqual(t, []) + c, t = cache.fetch_nearest_cache(model, [1, 2]) + self.assertEqual(c, None) + self.assertEqual(t, [1, 2]) cache.insert_cache(model, [1, 2], [MockCache("test1")]) cache.insert_cache(model, [2, 3], [MockCache("test2")]) @@ -514,29 +508,6 @@ def test_lru(self): self.assertEqual(c, [MockCache("test3")]) self.assertEqual(t, []) - cache.insert_cache(model, [4, 5], [MockCache("test4")], checkpoint=True) - c, t = cache.fetch_nearest_cache(model, [2, 3]) - self.assertEqual(c, None) - self.assertEqual(t, [2, 3]) - c, t = cache.fetch_nearest_cache(model, [3, 4]) - self.assertEqual(c, [MockCache("test3")]) - self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [4, 5]) - self.assertEqual(c, [MockCache("test4")]) - self.assertEqual(t, []) - - cache.insert_cache(model, [5, 6], [MockCache("test5")]) - cache.insert_cache(model, [6, 7], [MockCache("test6")]) - c, t = cache.fetch_nearest_cache(model, [5, 6]) - self.assertEqual(c, None) - self.assertEqual(t, [5, 6]) - c, t = cache.fetch_nearest_cache(model, [6, 7]) - self.assertEqual(c, [MockCache("test6")]) - self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [4, 5]) - self.assertEqual(c, [MockCache("test4")]) - self.assertEqual(t, []) - def test_lru_bytes(self): cache = LRUPromptCache(max_size=100, max_bytes=10) model = ("test", None, None) @@ -561,5 +532,142 @@ def test_lru_bytes(self): self.assertEqual(t, [3, 4]) +class FailingResponseGenerator: + def __init__(self, exc): + self.exc = exc + self.cli_args = type( + "obj", + (object,), + { + "num_draft_tokens": 3, + "max_tokens": 32, + "temp": 0.0, + "top_p": 1.0, + "top_k": 0, + "min_p": 0.0, + "model": None, + }, + ) + + def generate(self, *args, **kwargs): + raise self.exc + + +class TestErrorStatusCodes(unittest.TestCase): + def _run_request(self, exc): + response_generator = FailingResponseGenerator(exc) + httpd = http.server.HTTPServer( + ("localhost", 0), + lambda *args, **kwargs: APIHandler(response_generator, *args, **kwargs), + ) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.daemon = True + server_thread.start() + try: + url = f"http://localhost:{httpd.server_port}/v1/completions" + return requests.post( + url, + json={ + "model": "default_model", + "prompt": "test", + "max_tokens": 2, + }, + ) + finally: + httpd.shutdown() + httpd.server_close() + server_thread.join() + + def test_oom_maps_to_service_unavailable(self): + response = self._run_request( + RuntimeError( + "[METAL] Command buffer execution failed: Insufficient Memory " + "(00000008:kIOGPUCommandBufferCallbackErrorOutOfMemory)" + ) + ) + self.assertEqual(response.status_code, 503) + + def test_non_oom_maps_to_internal_server_error(self): + response = self._run_request(RuntimeError("arbitrary failure")) + self.assertEqual(response.status_code, 500) + + def test_is_metal_oom_error(self): + self.assertTrue(is_metal_oom_error(RuntimeError("out of memory"))) + self.assertTrue( + is_metal_oom_error( + RuntimeError("kIOGPUCommandBufferCallbackErrorOutOfMemory") + ) + ) + self.assertFalse(is_metal_oom_error(RuntimeError("other runtime failure"))) + + +class TestKVBudgeting(unittest.TestCase): + def test_projected_kv_bytes_without_growth(self): + cache = [MockCache("abcd", size=0)] + self.assertEqual(projected_kv_bytes(cache, 10), 4) + + def test_projected_kv_bytes_with_growth(self): + cache = [MockCache("abcdef", size=3)] + # 6 bytes over 3 tokens => 2 bytes/token + self.assertEqual(projected_kv_bytes(cache, 5), 16) + + +class TestCLIValidation(unittest.TestCase): + def test_reject_bad_prompt_overflow_policy(self): + from mlx_lm import server as server_module + + argv = [ + "mlx_lm.server", + "--prompt-overflow-policy", + "invalid", + ] + with patch.object(sys, "argv", argv): + with self.assertRaises(SystemExit): + server_module.main() + + +class TestPromptTokenLimit(unittest.TestCase): + def test_no_limit(self): + tokens = list(range(10)) + self.assertEqual( + apply_prompt_token_limit( + tokens, + max_prompt_tokens=None, + overflow_policy="error", + keep_tokens=0, + ), + tokens, + ) + + def test_error_policy(self): + with self.assertRaisesRegex( + ValueError, "Prompt exceeds max prompt token limit" + ): + apply_prompt_token_limit( + list(range(20)), + max_prompt_tokens=8, + overflow_policy="error", + keep_tokens=0, + ) + + def test_truncate_policy(self): + out = apply_prompt_token_limit( + list(range(20)), + max_prompt_tokens=8, + overflow_policy="truncate", + keep_tokens=3, + ) + self.assertEqual(out, [0, 1, 2, 15, 16, 17, 18, 19]) + + def test_truncate_policy_keep_over_cap(self): + out = apply_prompt_token_limit( + list(range(20)), + max_prompt_tokens=8, + overflow_policy="truncate", + keep_tokens=100, + ) + self.assertEqual(out, list(range(8))) + + if __name__ == "__main__": unittest.main() From 720322661e11a35558a04e42ca82fd913c9c8706 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Sat, 28 Feb 2026 21:35:59 +0100 Subject: [PATCH 02/11] server: add KV limits, admission control, and quantization wiring --- mlx_lm/generate.py | 73 +++++++----------------------------------- mlx_lm/models/cache.py | 7 ---- 2 files changed, 11 insertions(+), 69 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ef8dbf7bf..ab0bdb83e 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -927,11 +927,6 @@ def _merge_caches(caches): return batch_cache -def _lazy_extract_cache(cache, i): - # Generators like lambdas are late bound so we can't just use it in the loop - return (c.extract(i) for c in cache) - - class BatchGenerator: @dataclass class Response: @@ -953,9 +948,6 @@ def __init__( completion_batch_size: int = 32, prefill_batch_size: int = 8, prefill_step_size: int = 2048, - prompt_checkpoint_callback: Optional[ - Callable[[List[Tuple[int, int, List[Any]]]], None] - ] = None, prompt_progress_callback: Optional[ Callable[[List[Tuple[int, int, int]]], None] ] = None, @@ -971,7 +963,6 @@ def __init__( self.prefill_step_size = prefill_step_size self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) - self.prompt_checkpoint_callback = prompt_checkpoint_callback self.prompt_progress_callback = prompt_progress_callback or (lambda *_: None) self._stats = BatchStats() self._next_count = 0 @@ -1002,16 +993,12 @@ def insert( caches=None, samplers: list | None = None, logits_processors: list | None = None, - prompt_checkpoints: list | int | None = None, ): uids = [] if max_tokens is None or isinstance(max_tokens, int): max_tokens = [max_tokens or self.max_tokens] * len(prompts) - if prompt_checkpoints is None or isinstance(prompt_checkpoints, int): - prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts) - if caches is None: caches = [None] * len(prompts) for i in range(len(prompts)): @@ -1021,10 +1008,10 @@ def insert( samplers = samplers or [None] * len(prompts) logits_processors = logits_processors or [self.logits_processors] * len(prompts) - for p, m, c, s, lp, pc in zip( - prompts, max_tokens, caches, samplers, logits_processors, prompt_checkpoints + for p, m, c, s, lp in zip( + prompts, max_tokens, caches, samplers, logits_processors ): - self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp, pc)) + self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp)) uids.append(self.uid_count) self.uid_count += 1 # Sort in ascending order of length @@ -1065,28 +1052,12 @@ def prompt_cache_nbytes(self): return total def _process_prompts(self, prompts): - ( - uids, - inputs, - max_tokens, - caches, - samplers, - logits_processors, - prompt_checkpoints, - ) = zip(*prompts) + uids, inputs, max_tokens, caches, samplers, logits_processors = zip(*prompts) lengths = [len(p) for p in inputs] max_length = max(lengths) padding = [max_length - l for l in lengths] - # Get the checkpoint token as an offset from the end of each prompt. - # Then select the largest one so that we perform the checkpoint at - # least `pc` before the end. - prompt_checkpoints = [ - (l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints) - ] - prompt_checkpoint = max(1, max(prompt_checkpoints)) - self._stats.prompt_tokens += sum(lengths) tokens = [mx.array(inp) for inp in inputs] @@ -1099,10 +1070,8 @@ def _process_prompts(self, prompts): inputs = _left_pad_prompts(inputs, max_length=max_length) prompt_cache = _make_cache(self.model, padding, self.max_kv_size) - while inputs.shape[1] > prompt_checkpoint: - n_to_process = min( - self.prefill_step_size, inputs.shape[1] - prompt_checkpoint - ) + while inputs.shape[1] > 1: + n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1) self.model(inputs[:, :n_to_process], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) inputs = inputs[:, n_to_process:] @@ -1121,22 +1090,16 @@ def _process_prompts(self, prompts): # 2. Process # 3. Finalize the KV caches so they are left padded again else: - last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs]) + last_inputs = mx.array([p[-1:] for p in inputs]) inputs = _right_pad_prompts(inputs, max_length=max_length) prompt_cache = _merge_caches(caches) for c in prompt_cache: - # subtract from lengths since we don't process the last - # `prompt_checkpoint` tokens during prefill - c.prepare( - lengths=[l - prompt_checkpoint for l in lengths], - right_padding=padding, - ) + # subtract one from lengths since we don't process the last token during prefill + c.prepare(lengths=[l - 1 for l in lengths], right_padding=padding) - while inputs.shape[1] > prompt_checkpoint: - n_to_process = min( - self.prefill_step_size, inputs.shape[1] - prompt_checkpoint - ) + while inputs.shape[1] > 1: + n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1) self.model(inputs[:, :n_to_process], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) inputs = inputs[:, n_to_process:] @@ -1154,20 +1117,6 @@ def _process_prompts(self, prompts): for c in prompt_cache: c.finalize() - - # We processed L - prompt_checkpoint tokens so call the checkpoint - # callback. - if self.prompt_checkpoint_callback is not None: - self.prompt_checkpoint_callback( - [ - (uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i)) - for i, uid in enumerate(uids) - ] - ) - # Process the remaining prompt_checkpoint-1 tokens - if prompt_checkpoint > 1: - self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) mx.clear_cache() y, logprobs = self._step( diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88fa4ad32..d8851ffcb 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1124,10 +1124,6 @@ def _update_concat(self, keys, values): self.offset += keys.shape[2] self._offset += keys.shape[2] self._idx = self.keys.shape[2] - - # Make sure left_padding and offset are evaluated - self.keys = mx.depends(self.keys, (self.left_padding, self.offset)) - return self.keys, self.values def _update_in_place(self, keys, values): @@ -1178,9 +1174,6 @@ def _update_in_place(self, keys, values): self.offset += S self._idx += S - # Make sure left_padding and offset are evaluated - self.keys = mx.depends(self.keys, (self.left_padding, self.offset)) - # If the buffer is not full, slice off the end if self._offset < self.max_size: return ( From 3f6ab3ec78889a9ce69e49f738f8d39f3aa96352 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Sun, 1 Mar 2026 12:46:49 +0100 Subject: [PATCH 03/11] docs: add mlx_lm.server memory control options and examples --- README.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/README.md b/README.md index ce71596b3..3415bed63 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,46 @@ requests that use the same context. See the [example](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/examples/chat.py) for more usage details. +### Server Memory Controls + +When using `mlx_lm.server`, these options help prevent OOM during long +multi-turn sessions: + +- `--prompt-cache-bytes`: upper bound for the LRU prompt cache memory. +- `--max-active-kv-bytes`: reject requests if projected active KV usage would + exceed this limit. +- `--max-kv-size`: fixed active KV window (rotating cache), llama.cpp-style. +- `--kv-bits`, `--kv-group-size`, `--quantized-kv-start`: KV cache + quantization controls. + +Examples: + +```bash +# Fixed active-KV window (stable bounded memory, no KV quantization) +mlx_lm.server \ + --model \ + --max-kv-size 8192 \ + --prompt-cache-bytes 2G \ + --max-active-kv-bytes 8G +``` + +```bash +# KV quantization mode (batching disabled while kv-bits is enabled) +mlx_lm.server \ + --model \ + --kv-bits 4 \ + --kv-group-size 64 \ + --quantized-kv-start 0 \ + --prompt-cache-bytes 3G \ + --max-active-kv-bytes 8G +``` + +Notes: + +- `--max-kv-size` and `--kv-bits` are currently mutually exclusive. +- OOM-style failures now return HTTP `503` instead of crashing the server + process. + ### Supported Models `mlx-lm` supports thousands of LLMs available on the Hugging Face Hub. If the From dde9cebca6b2be1f197323373f6abef940fd7bd1 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Sun, 1 Mar 2026 13:34:40 +0100 Subject: [PATCH 04/11] cache: import tree_reduce used by quantized cache nbytes --- mlx_lm/models/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index d8851ffcb..2e77e9954 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_map, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten from .base import create_causal_mask From aa18f56360e0fbf1be0ac236303dfeca5dd9fd2a Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Sun, 1 Mar 2026 13:44:13 +0100 Subject: [PATCH 05/11] cache: make quantized nbytes robust for empty state --- mlx_lm/models/cache.py | 8 +++++++- tests/test_server.py | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 2e77e9954..368761227 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -317,7 +317,13 @@ def empty(self): @property def nbytes(self): - return tree_reduce(lambda a, x: a + x.nbytes, (self.keys, self.values), 0) + if self.keys is None: + return 0 + return tree_reduce( + lambda a, x: a + (x.nbytes if x is not None else 0), + (self.keys, self.values), + 0, + ) class KVCache(_BaseCache): diff --git a/tests/test_server.py b/tests/test_server.py index 34fdf703e..f691f8051 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -11,7 +11,7 @@ import mlx.core as mx import requests -from mlx_lm.models.cache import KVCache +from mlx_lm.models.cache import KVCache, QuantizedKVCache from mlx_lm.server import ( APIHandler, LRUPromptCache, @@ -611,6 +611,10 @@ def test_projected_kv_bytes_with_growth(self): # 6 bytes over 3 tokens => 2 bytes/token self.assertEqual(projected_kv_bytes(cache, 5), 16) + def test_quantized_cache_nbytes_empty(self): + cache = QuantizedKVCache() + self.assertEqual(cache.nbytes, 0) + class TestCLIValidation(unittest.TestCase): def test_reject_bad_prompt_overflow_policy(self): From 041d649db986b9808e6930c00e0bf50ba8b42f0b Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Sun, 1 Mar 2026 14:09:34 +0100 Subject: [PATCH 06/11] server: enforce active memory ceiling during prefill --- README.md | 40 ---------------------------------------- tests/test_server.py | 6 +----- 2 files changed, 1 insertion(+), 45 deletions(-) diff --git a/README.md b/README.md index 3415bed63..ce71596b3 100644 --- a/README.md +++ b/README.md @@ -234,46 +234,6 @@ requests that use the same context. See the [example](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/examples/chat.py) for more usage details. -### Server Memory Controls - -When using `mlx_lm.server`, these options help prevent OOM during long -multi-turn sessions: - -- `--prompt-cache-bytes`: upper bound for the LRU prompt cache memory. -- `--max-active-kv-bytes`: reject requests if projected active KV usage would - exceed this limit. -- `--max-kv-size`: fixed active KV window (rotating cache), llama.cpp-style. -- `--kv-bits`, `--kv-group-size`, `--quantized-kv-start`: KV cache - quantization controls. - -Examples: - -```bash -# Fixed active-KV window (stable bounded memory, no KV quantization) -mlx_lm.server \ - --model \ - --max-kv-size 8192 \ - --prompt-cache-bytes 2G \ - --max-active-kv-bytes 8G -``` - -```bash -# KV quantization mode (batching disabled while kv-bits is enabled) -mlx_lm.server \ - --model \ - --kv-bits 4 \ - --kv-group-size 64 \ - --quantized-kv-start 0 \ - --prompt-cache-bytes 3G \ - --max-active-kv-bytes 8G -``` - -Notes: - -- `--max-kv-size` and `--kv-bits` are currently mutually exclusive. -- OOM-style failures now return HTTP `503` instead of crashing the server - process. - ### Supported Models `mlx-lm` supports thousands of LLMs available on the Hugging Face Hub. If the diff --git a/tests/test_server.py b/tests/test_server.py index f691f8051..34fdf703e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -11,7 +11,7 @@ import mlx.core as mx import requests -from mlx_lm.models.cache import KVCache, QuantizedKVCache +from mlx_lm.models.cache import KVCache from mlx_lm.server import ( APIHandler, LRUPromptCache, @@ -611,10 +611,6 @@ def test_projected_kv_bytes_with_growth(self): # 6 bytes over 3 tokens => 2 bytes/token self.assertEqual(projected_kv_bytes(cache, 5), 16) - def test_quantized_cache_nbytes_empty(self): - cache = QuantizedKVCache() - self.assertEqual(cache.nbytes, 0) - class TestCLIValidation(unittest.TestCase): def test_reject_bad_prompt_overflow_policy(self): From 195d007ba49b0755448ca0cd9ad6876f48a856d8 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Sun, 1 Mar 2026 19:36:38 +0100 Subject: [PATCH 07/11] server: drop KV quantization scope from memory-safety PR --- mlx_lm/models/cache.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 368761227..d8851ffcb 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_unflatten from .base import create_causal_mask @@ -317,13 +317,7 @@ def empty(self): @property def nbytes(self): - if self.keys is None: - return 0 - return tree_reduce( - lambda a, x: a + (x.nbytes if x is not None else 0), - (self.keys, self.values), - 0, - ) + return tree_reduce(lambda a, x: a + x.nbytes, (self.keys, self.values), 0) class KVCache(_BaseCache): From 0482ae434e0dff1665423a6a1f54ddb07e3dd684 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Wed, 4 Mar 2026 00:17:13 +0100 Subject: [PATCH 08/11] docs: final polish --- mlx_lm/SERVER.md | 55 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/mlx_lm/SERVER.md b/mlx_lm/SERVER.md index f38ad3dd4..353dab17c 100644 --- a/mlx_lm/SERVER.md +++ b/mlx_lm/SERVER.md @@ -153,3 +153,58 @@ list contains the following fields: - `id`: The Hugging Face repo id. - `created`: A time-stamp representing the model creation time. + +### Server Memory Controls + +When using `mlx_lm.server`, these options help prevent OOM during long +multi-turn sessions: + +- `--prompt-cache-bytes`: upper limit for the LRU prompt cache memory. +- `--max-prompt-tokens`: max prompt token cap to avoid unbounded memory growth. +- `--prompt-overflow-policy`: `error` (reject) or `truncate` (drop tokens from + the beginning/middle of the prompt). +- `--prompt-keep-tokens`: with `truncate`, keep this many tokens from the beginning + of the prompt. +- `--max-active-kv-bytes`: reject requests if projected active KV usage would + exceed this limit. +- `--max-active-memory-bytes`: abort requests when current MLX active memory is + above this limit. +- `--max-kv-size`: fixed active KV window (rotating cache). This limits per-request + KV growth but can effectively reduce context window. + +Examples: + +```bash +# Fixed active-KV window (stable bounded memory) +mlx_lm.server \ + --model \ + --max-prompt-tokens 8192 \ + --prompt-overflow-policy error \ + --max-kv-size 8192 \ + --prompt-cache-bytes 2G \ + --max-active-kv-bytes 8G \ + --max-active-memory-bytes 28G +``` + +Notes: + +- `--max-prompt-tokens` is the primary control to stop memory creep across long chats. +- `--max-active-kv-bytes`, `--max-active-memory-bytes`, and `--max-kv-size` + work at different levels: + - `--max-active-kv-bytes`: projected KV-only admission control. + - `--max-active-memory-bytes`: runtime limit for all active MLX memory. + - `--max-kv-size`: hard limit on attention window in KV cache (potential quality degradation). +- Practical tuning order: + 1. Set `--max-prompt-tokens` first (for example `8192`). + 2. Set `--max-active-memory-bytes` below total RAM by ~20% to leave room for OS and other apps. + 3. Set `--max-active-kv-bytes` as a subset of that budget (~20-40% of + `--max-active-memory-bytes`). + 4. Only add `--max-kv-size` if memory still growth; start high (for example `8192` + or `16384`) and lower only if required. +- Extensively tested example for GLM-4.7.Flash-5bit running on a 36 GB machine: + - `--max-active-memory-bytes 27G` + - `--max-active-kv-bytes 6G` + - `--max-prompt-tokens 8192` +- OOM-style failures now return HTTP `503` instead of crashing the server + process. + From 7aeabdb631c81f9301001a563b8417c126ff06a4 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Mon, 23 Mar 2026 13:09:30 +0100 Subject: [PATCH 09/11] pr948: drop unintended upstream file reverts --- mlx_lm/SERVER.md | 55 ------------------------------- mlx_lm/generate.py | 73 +++++++++++++++++++++++++++++++++++------- mlx_lm/models/cache.py | 7 ++++ 3 files changed, 69 insertions(+), 66 deletions(-) diff --git a/mlx_lm/SERVER.md b/mlx_lm/SERVER.md index 353dab17c..f38ad3dd4 100644 --- a/mlx_lm/SERVER.md +++ b/mlx_lm/SERVER.md @@ -153,58 +153,3 @@ list contains the following fields: - `id`: The Hugging Face repo id. - `created`: A time-stamp representing the model creation time. - -### Server Memory Controls - -When using `mlx_lm.server`, these options help prevent OOM during long -multi-turn sessions: - -- `--prompt-cache-bytes`: upper limit for the LRU prompt cache memory. -- `--max-prompt-tokens`: max prompt token cap to avoid unbounded memory growth. -- `--prompt-overflow-policy`: `error` (reject) or `truncate` (drop tokens from - the beginning/middle of the prompt). -- `--prompt-keep-tokens`: with `truncate`, keep this many tokens from the beginning - of the prompt. -- `--max-active-kv-bytes`: reject requests if projected active KV usage would - exceed this limit. -- `--max-active-memory-bytes`: abort requests when current MLX active memory is - above this limit. -- `--max-kv-size`: fixed active KV window (rotating cache). This limits per-request - KV growth but can effectively reduce context window. - -Examples: - -```bash -# Fixed active-KV window (stable bounded memory) -mlx_lm.server \ - --model \ - --max-prompt-tokens 8192 \ - --prompt-overflow-policy error \ - --max-kv-size 8192 \ - --prompt-cache-bytes 2G \ - --max-active-kv-bytes 8G \ - --max-active-memory-bytes 28G -``` - -Notes: - -- `--max-prompt-tokens` is the primary control to stop memory creep across long chats. -- `--max-active-kv-bytes`, `--max-active-memory-bytes`, and `--max-kv-size` - work at different levels: - - `--max-active-kv-bytes`: projected KV-only admission control. - - `--max-active-memory-bytes`: runtime limit for all active MLX memory. - - `--max-kv-size`: hard limit on attention window in KV cache (potential quality degradation). -- Practical tuning order: - 1. Set `--max-prompt-tokens` first (for example `8192`). - 2. Set `--max-active-memory-bytes` below total RAM by ~20% to leave room for OS and other apps. - 3. Set `--max-active-kv-bytes` as a subset of that budget (~20-40% of - `--max-active-memory-bytes`). - 4. Only add `--max-kv-size` if memory still growth; start high (for example `8192` - or `16384`) and lower only if required. -- Extensively tested example for GLM-4.7.Flash-5bit running on a 36 GB machine: - - `--max-active-memory-bytes 27G` - - `--max-active-kv-bytes 6G` - - `--max-prompt-tokens 8192` -- OOM-style failures now return HTTP `503` instead of crashing the server - process. - diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ab0bdb83e..ef8dbf7bf 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -927,6 +927,11 @@ def _merge_caches(caches): return batch_cache +def _lazy_extract_cache(cache, i): + # Generators like lambdas are late bound so we can't just use it in the loop + return (c.extract(i) for c in cache) + + class BatchGenerator: @dataclass class Response: @@ -948,6 +953,9 @@ def __init__( completion_batch_size: int = 32, prefill_batch_size: int = 8, prefill_step_size: int = 2048, + prompt_checkpoint_callback: Optional[ + Callable[[List[Tuple[int, int, List[Any]]]], None] + ] = None, prompt_progress_callback: Optional[ Callable[[List[Tuple[int, int, int]]], None] ] = None, @@ -963,6 +971,7 @@ def __init__( self.prefill_step_size = prefill_step_size self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) + self.prompt_checkpoint_callback = prompt_checkpoint_callback self.prompt_progress_callback = prompt_progress_callback or (lambda *_: None) self._stats = BatchStats() self._next_count = 0 @@ -993,12 +1002,16 @@ def insert( caches=None, samplers: list | None = None, logits_processors: list | None = None, + prompt_checkpoints: list | int | None = None, ): uids = [] if max_tokens is None or isinstance(max_tokens, int): max_tokens = [max_tokens or self.max_tokens] * len(prompts) + if prompt_checkpoints is None or isinstance(prompt_checkpoints, int): + prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts) + if caches is None: caches = [None] * len(prompts) for i in range(len(prompts)): @@ -1008,10 +1021,10 @@ def insert( samplers = samplers or [None] * len(prompts) logits_processors = logits_processors or [self.logits_processors] * len(prompts) - for p, m, c, s, lp in zip( - prompts, max_tokens, caches, samplers, logits_processors + for p, m, c, s, lp, pc in zip( + prompts, max_tokens, caches, samplers, logits_processors, prompt_checkpoints ): - self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp)) + self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp, pc)) uids.append(self.uid_count) self.uid_count += 1 # Sort in ascending order of length @@ -1052,12 +1065,28 @@ def prompt_cache_nbytes(self): return total def _process_prompts(self, prompts): - uids, inputs, max_tokens, caches, samplers, logits_processors = zip(*prompts) + ( + uids, + inputs, + max_tokens, + caches, + samplers, + logits_processors, + prompt_checkpoints, + ) = zip(*prompts) lengths = [len(p) for p in inputs] max_length = max(lengths) padding = [max_length - l for l in lengths] + # Get the checkpoint token as an offset from the end of each prompt. + # Then select the largest one so that we perform the checkpoint at + # least `pc` before the end. + prompt_checkpoints = [ + (l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints) + ] + prompt_checkpoint = max(1, max(prompt_checkpoints)) + self._stats.prompt_tokens += sum(lengths) tokens = [mx.array(inp) for inp in inputs] @@ -1070,8 +1099,10 @@ def _process_prompts(self, prompts): inputs = _left_pad_prompts(inputs, max_length=max_length) prompt_cache = _make_cache(self.model, padding, self.max_kv_size) - while inputs.shape[1] > 1: - n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1) + while inputs.shape[1] > prompt_checkpoint: + n_to_process = min( + self.prefill_step_size, inputs.shape[1] - prompt_checkpoint + ) self.model(inputs[:, :n_to_process], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) inputs = inputs[:, n_to_process:] @@ -1090,16 +1121,22 @@ def _process_prompts(self, prompts): # 2. Process # 3. Finalize the KV caches so they are left padded again else: - last_inputs = mx.array([p[-1:] for p in inputs]) + last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs]) inputs = _right_pad_prompts(inputs, max_length=max_length) prompt_cache = _merge_caches(caches) for c in prompt_cache: - # subtract one from lengths since we don't process the last token during prefill - c.prepare(lengths=[l - 1 for l in lengths], right_padding=padding) + # subtract from lengths since we don't process the last + # `prompt_checkpoint` tokens during prefill + c.prepare( + lengths=[l - prompt_checkpoint for l in lengths], + right_padding=padding, + ) - while inputs.shape[1] > 1: - n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1) + while inputs.shape[1] > prompt_checkpoint: + n_to_process = min( + self.prefill_step_size, inputs.shape[1] - prompt_checkpoint + ) self.model(inputs[:, :n_to_process], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) inputs = inputs[:, n_to_process:] @@ -1117,6 +1154,20 @@ def _process_prompts(self, prompts): for c in prompt_cache: c.finalize() + + # We processed L - prompt_checkpoint tokens so call the checkpoint + # callback. + if self.prompt_checkpoint_callback is not None: + self.prompt_checkpoint_callback( + [ + (uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i)) + for i, uid in enumerate(uids) + ] + ) + # Process the remaining prompt_checkpoint-1 tokens + if prompt_checkpoint > 1: + self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) mx.clear_cache() y, logprobs = self._step( diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index d8851ffcb..88fa4ad32 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1124,6 +1124,10 @@ def _update_concat(self, keys, values): self.offset += keys.shape[2] self._offset += keys.shape[2] self._idx = self.keys.shape[2] + + # Make sure left_padding and offset are evaluated + self.keys = mx.depends(self.keys, (self.left_padding, self.offset)) + return self.keys, self.values def _update_in_place(self, keys, values): @@ -1174,6 +1178,9 @@ def _update_in_place(self, keys, values): self.offset += S self._idx += S + # Make sure left_padding and offset are evaluated + self.keys = mx.depends(self.keys, (self.left_padding, self.offset)) + # If the buffer is not full, slice off the end if self._offset < self.max_size: return ( From b095103d579087f2ebcf6276e937ef879e670442 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Mon, 23 Mar 2026 13:36:19 +0100 Subject: [PATCH 10/11] Resolve upstream merge without reverting server updates --- mlx_lm/server.py | 366 ++++++++++++++++++++++++++----------------- tests/test_server.py | 147 ++++++++++------- 2 files changed, 312 insertions(+), 201 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index c63d54927..564d23575 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -2,6 +2,7 @@ import argparse import copy +import heapq import json import logging import pickle @@ -41,7 +42,7 @@ trim_prompt_cache, ) from .sample_utils import make_logits_processors, make_sampler -from .utils import load, sharded_load +from .utils import _parse_size, load, sharded_load def get_system_fingerprint(): @@ -49,18 +50,6 @@ def get_system_fingerprint(): return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" -def parse_size(x): - sizes = {"M": 1e6, "G": 1e9, "MB": 1e6, "GB": 1e9, "": 1} - split = 0 - for xi in x: - if not (xi.isdigit() or xi == "."): - break - split += 1 - digits = float(x[:split]) - size = (x[split:]).strip().upper() - return int(digits * sizes[size]) - - def is_metal_oom_error(error: Exception) -> bool: text = str(error).lower() patterns = ( @@ -103,6 +92,9 @@ def apply_prompt_token_limit( f"prompt_tokens={len(tokens)}, max_prompt_tokens={max_prompt_tokens}" ) + if overflow_policy != "truncate": + raise ValueError(f"Invalid prompt overflow policy: {overflow_policy}") + keep_tokens = max(0, min(keep_tokens, max_prompt_tokens)) tail_tokens = max_prompt_tokens - keep_tokens if tail_tokens <= 0: @@ -232,13 +224,35 @@ def process_message_content(messages): class LRUPromptCache: - @dataclass class CacheEntry: prompt_cache: List[Any] - count: int nbytes: int + class CacheOrder: + def __init__(self): + self._lru_checkpoints = deque() + self._lru = deque() + + def __len__(self): + return len(self._lru) + len(self._lru_checkpoints) + + def push(self, model, tokens, checkpoint: bool = False): + c = self._lru_checkpoints if checkpoint else self._lru + c.append((model, tokens)) + + def remove(self, model, tokens): + try: + self._lru.remove((model, tokens)) + except ValueError: + self._lru_checkpoints.remove((model, tokens)) + + def pop(self): + if len(self._lru) >= len(self._lru_checkpoints): + return self._lru.popleft() + else: + return self._lru_checkpoints.popleft() + @dataclass class SearchResult: model: Any @@ -251,7 +265,7 @@ def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): self.max_size = max_size self.max_bytes = max_bytes self._cache = {} - self._lru = deque() + self._lru = self.CacheOrder() self._n_bytes = 0 def __len__(self): @@ -288,7 +302,7 @@ def _search(self, model, tokens): # Check for caches that are longer longer = None common_prefix = index - if index > 0 and last_cache_index <= 0: + if index > 0: best = None stack = [(current, [])] while stack: @@ -321,32 +335,14 @@ def _delete(self, model, tokens): break del d_prev[t] - logging.debug(f"[LRUPromptCache] Removed {cache_bytes} bytes from the cache") - - def _extract(self, model, tokens): - cache_entry = self._get(model, tokens) - if cache_entry.count == 1: - self._delete(model, tokens) - self._lru.remove((model, tokens)) - return cache_entry - - cache_entry.count -= 1 - return self.CacheEntry( - copy.deepcopy(cache_entry.prompt_cache), 1, cache_entry.nbytes - ) - def fetch_nearest_cache(self, model, tokens): result = self._search(model, tokens) if result.exact is not None: - cache_entry = self._extract(result.model, result.exact) - return cache_entry.prompt_cache, [] - - if result.shorter is not None: - cache_entry = self._extract(result.model, result.shorter) - prefix_len = len(result.shorter) - return cache_entry.prompt_cache, tokens[prefix_len:] + cache_entry = self._get(result.model, result.exact) + return copy.deepcopy(cache_entry.prompt_cache), [] - if result.longer is not None: + short_length = len(result.shorter) if result.shorter is not None else 0 + if result.longer is not None and result.common_prefix > short_length: cache_entry = self._get(result.model, result.longer) if can_trim_prompt_cache(cache_entry.prompt_cache): cache = copy.deepcopy(cache_entry.prompt_cache) @@ -355,32 +351,40 @@ def fetch_nearest_cache(self, model, tokens): trim_prompt_cache(cache, num_to_trim) return cache, tokens[prefix:] + if short_length > 0: + cache_entry = self._get(result.model, result.shorter) + return copy.deepcopy(cache_entry.prompt_cache), tokens[short_length:] + return None, tokens - def insert_cache(self, model, tokens, prompt_cache): + def insert_cache(self, model, tokens, prompt_cache, checkpoint: bool = False): + is_trimmable = can_trim_prompt_cache(prompt_cache) + if model not in self._cache: self._cache[model] = {} current = self._cache[model] - for tok in tokens: + for i, tok in enumerate(tokens): if tok not in current: current[tok] = {} + if is_trimmable and "cache" in current: + self._n_bytes -= current["cache"].nbytes + del current["cache"] + self._lru.remove(model, tokens[:i]) current = current[tok] if "cache" in current: - current["cache"].count += 1 - self._lru.remove((model, tokens)) + self._lru.remove(model, tokens) else: cache_bytes = sum(c.nbytes for c in prompt_cache) - current["cache"] = self.CacheEntry(prompt_cache, 1, cache_bytes) + current["cache"] = self.CacheEntry(prompt_cache, cache_bytes) self._n_bytes += cache_bytes - logging.debug(f"[LRUPromptCache] Adding {cache_bytes} to the cache") - self._lru.append((model, tokens)) + self._lru.push(model, tokens, checkpoint=checkpoint) if len(self._lru) > self.max_size: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) while self._n_bytes > self.max_bytes and len(self._lru) > 1: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) def trim_to( @@ -390,12 +394,23 @@ def trim_to( n_bytes = max(0, n_bytes) if n_bytes is not None else 1 << 63 while len(self._lru) > n_sequences: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) while self._n_bytes > n_bytes: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) + def log_cache_stats(self): + ncaches, nbytes = len(self), self.nbytes + ntok = ( + len(self._lru._lru_checkpoints[-1][1]) + if len(self._lru._lru_checkpoints) > 0 + else 0 + ) + logging.info( + f"KV Caches: {ncaches} seq, {nbytes / 1e9:.2f} GB, latest user cache {ntok} tokens" + ) + @dataclass class ModelDescription: @@ -419,6 +434,10 @@ class LogitsProcessorArguments: logit_bias: Optional[Dict[int, float]] repetition_penalty: float repetition_context_size: int + presence_penalty: float + presence_context_size: int + frequency_penalty: float + frequency_context_size: int @dataclass @@ -647,6 +666,10 @@ def _make_logits_processors(args): args.logits.logit_bias, args.logits.repetition_penalty, args.logits.repetition_context_size, + args.logits.presence_penalty, + args.logits.presence_context_size, + args.logits.frequency_penalty, + args.logits.frequency_context_size, ) @@ -781,6 +804,24 @@ def _tokenize(self, tokenizer, request, args): keep_tokens=self.cli_args.prompt_keep_tokens, ) + def _compute_prompt_checkpoint(self, tokenizer, request, prompt): + if request.request_type != "chat": + return False, -1 + if request.messages[-1]["role"] != "user": + return False, -1 + + # Save the KV cache at the end of the prompt just before + # the think start token which will likely be removed in the + # next turn. + prompt_checkpoint = -1 + if tokenizer.has_thinking: + for i in range(1, min(11, len(prompt)) - 1, 1): + if prompt[-i] == tokenizer.think_start_id: + prompt_checkpoint = -i - 1 + break + + return True, prompt_checkpoint + def _is_batchable(self, args): if not self.model_provider.is_batchable: return False @@ -850,6 +891,18 @@ def progress_callback(info): if uid in batch_results: batch_results[uid]["rqueue"].put((min(processed, total), total)) + def checkpoint_callback(prompts): + for uid, prompt_end, cache in prompts: + rs = batch_results[uid] + if not rs["checkpoint"]: + continue + self.prompt_cache.insert_cache( + current_model_key, + rs["cache_key"][:-prompt_end], + list(cache), + checkpoint=True, + ) + if self._is_distributed: seed = mx.distributed.all_sum(mx.random.state[0]).view(mx.uint64).item() mx.random.seed(seed) @@ -898,25 +951,18 @@ def progress_callback(info): ) rqueue.put(ctx) + self.prompt_cache.log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( current_model_key, prompt ) ctx.prompt_cache_count = len(prompt) - len(rest) if cache is None: - cache = self._make_prompt_cache(self.model_provider.model) - - admission_error = self._memory_admission_error( - cache, - len(rest) + args.max_tokens, - batch_generator.prompt_cache_nbytes, - ) - if admission_error is not None: - rqueue.put(MemoryError(admission_error)) - continue + cache = self._make_prompt_cache( + self.model_provider.model, self.model_provider.draft_model + ) - ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes - logging.info( - f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB" + do_checkpoint, checkpoint_position = ( + self._compute_prompt_checkpoint(tokenizer, request, prompt) ) (uid,) = batch_generator.insert( @@ -925,12 +971,14 @@ def progress_callback(info): caches=[cache], samplers=[_make_sampler(args, tokenizer)], logits_processors=[_make_logits_processors(args)], + prompt_checkpoints=[checkpoint_position], ) batch_results[uid] = { "ctx": ctx, "cache_key": prompt[:], "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, + "checkpoint": do_checkpoint, } # just making sure we don't leave a reference around del cache @@ -968,6 +1016,7 @@ def progress_callback(info): prefill_batch_size=self.cli_args.prompt_concurrency, prefill_step_size=self.cli_args.prefill_step_size, prompt_progress_callback=progress_callback, + prompt_checkpoint_callback=checkpoint_callback, max_kv_size=self.cli_args.max_kv_size, ) unprocessed_requests.append((rqueue, request, args)) @@ -994,71 +1043,53 @@ def progress_callback(info): continue uids_to_remove = [] - try: - for _ in self._time_budget: - responses = batch_generator.next() - if not responses: - break - - for r in responses: - result = batch_results[r.uid] - result["cache_key"].append(r.token) - if r.finish_reason != "stop": - result["detokenizer"].add_token(r.token) - - result["rqueue"].put( - Response( - result["detokenizer"].last_segment, - r.token, - r.logprobs[r.token].item(), - r.finish_reason, - _format_top_logprobs( - r.logprobs, args.top_logprobs, current_tokenizer - ), - ) + for _ in self._time_budget: + responses = batch_generator.next() + if not responses: + break + + for r in responses: + result = batch_results[r.uid] + result["cache_key"].append(r.token) + if r.finish_reason != "stop": + result["detokenizer"].add_token(r.token) + + result["rqueue"].put( + Response( + result["detokenizer"].last_segment, + r.token, + r.logprobs[r.token].item(), + r.finish_reason, + _format_top_logprobs( + r.logprobs, args.top_logprobs, current_tokenizer + ), + ) + ) + + if r.finish_reason is not None: + result["rqueue"].put(None) + self.prompt_cache.insert_cache( + current_model_key, result["cache_key"], r.prompt_cache ) + del batch_results[r.uid] + + if result["ctx"]._should_stop: + uids_to_remove.append(r.uid) - if r.finish_reason is not None: - result["rqueue"].put(None) - self.prompt_cache.insert_cache( - current_model_key, - result["cache_key"], - r.prompt_cache, - ) - del batch_results[r.uid] - - if result["ctx"]._should_stop: - uids_to_remove.append(r.uid) - - uids_to_remove = self._share_object(uids_to_remove) - if uids_to_remove: - with mx.stream(generation_stream): - caches = batch_generator.remove( - uids_to_remove, return_prompt_caches=True + uids_to_remove = self._share_object(uids_to_remove) + if uids_to_remove: + with mx.stream(generation_stream): + caches = batch_generator.remove( + uids_to_remove, return_prompt_caches=True + ) + for uid, prompt_cache in caches.items(): + if uid not in batch_results: + continue + result = batch_results[uid] + self.prompt_cache.insert_cache( + current_model_key, result["cache_key"], prompt_cache ) - for uid, prompt_cache in caches.items(): - if uid not in batch_results: - continue - result = batch_results[uid] - self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], prompt_cache - ) - del batch_results[uid] - except Exception as e: - logging.exception("Batched generation failed") - if is_metal_oom_error(e): - mx.clear_cache() - for result in batch_results.values(): - result["rqueue"].put(e) - result["rqueue"].put(None) - batch_results = {} - current_model = None - current_sampling = None - current_tokenizer = None - current_model_key = None - batch_generator.close() - batch_generator = None - drain_batch = False + del batch_results[uid] def _serve_single(self, request): rqueue, request, args = request @@ -1105,6 +1136,7 @@ def progress(tokens_processed, tokens_total): logits_processors = _make_logits_processors(args) # Load the KV cache + self.prompt_cache.log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( self.model_provider.model_key, prompt ) @@ -1122,9 +1154,6 @@ def progress(tokens_processed, tokens_total): if admission_error is not None: raise MemoryError(admission_error) - ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes - logging.info(f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB") - # Process the prompt and generate tokens for gen in stream_generate( model=model, @@ -1137,8 +1166,8 @@ def progress(tokens_processed, tokens_total): draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, - max_kv_size=self.cli_args.max_kv_size, prefill_step_size=self.cli_args.prefill_step_size, + max_kv_size=self.cli_args.max_kv_size, ): rqueue.put( Response( @@ -1166,8 +1195,6 @@ def progress(tokens_processed, tokens_total): ) except Exception as e: - if is_metal_oom_error(e): - mx.clear_cache() rqueue.put(e) def generate( @@ -1220,7 +1247,13 @@ def __init__( super().__init__(*args, **kwargs) def _set_cors_headers(self): - self.send_header("Access-Control-Allow-Origin", "*") + allowed_origins = self.response_generator.cli_args.allowed_origins + origin = self.headers.get("Origin") + if "*" in allowed_origins: + self.send_header("Access-Control-Allow-Origin", "*") + elif origin in allowed_origins: + self.send_header("Access-Control-Allow-Origin", origin) + self.send_header("Vary", "Origin") self.send_header("Access-Control-Allow-Methods", "*") self.send_header("Access-Control-Allow-Headers", "*") @@ -1256,7 +1289,23 @@ def do_POST(self): return # Fetch and parse request body - content_length = int(self.headers["Content-Length"]) + content_length = self.headers.get("Content-Length") + if content_length is None: + self._set_completion_headers(411) + self.end_headers() + self.wfile.write( + json.dumps({"error": "Content-Length header is required"}).encode() + ) + return + try: + content_length = int(content_length) + except ValueError: + self._set_completion_headers(400) + self.end_headers() + self.wfile.write( + json.dumps({"error": "Invalid Content-Length header"}).encode() + ) + return raw_body = self.rfile.read(content_length) try: self.body = json.loads(raw_body.decode()) @@ -1297,6 +1346,10 @@ def do_POST(self): self.min_p = self.body.get("min_p", self.response_generator.cli_args.min_p) self.repetition_penalty = self.body.get("repetition_penalty", 0.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) + self.presence_penalty = self.body.get("presence_penalty", 0.0) + self.presence_context_size = self.body.get("presence_context_size", 20) + self.frequency_penalty = self.body.get("frequency_penalty", 0.0) + self.frequency_context_size = self.body.get("frequency_context_size", 20) self.xtc_probability = self.body.get("xtc_probability", 0.0) self.xtc_threshold = self.body.get("xtc_threshold", 0.0) self.logit_bias = self.body.get("logit_bias", None) @@ -1345,6 +1398,25 @@ def validate_model_parameters(self): or self.repetition_penalty < 0 ): raise ValueError("repetition_penalty must be a non-negative float") + if ( + not isinstance(self.repetition_context_size, int) + or self.repetition_context_size < 0 + ): + raise ValueError("repetition_context_size must be a non-negative integer") + if not isinstance(self.presence_penalty, (float, int)): + raise ValueError("Presence penalty must be must be a float") + if ( + not isinstance(self.presence_context_size, int) + or self.presence_context_size < 0 + ): + raise ValueError("presence_context_size must be a non-negative integer") + if not isinstance(self.frequency_penalty, (float, int)): + raise ValueError("Presence penalty must be must be a float") + if ( + not isinstance(self.frequency_context_size, int) + or self.frequency_context_size < 0 + ): + raise ValueError("frequency_context_size must be a non-negative integer") if not isinstance(self.logprobs, bool): raise ValueError("logprobs must be a boolean") @@ -1354,12 +1426,6 @@ def validate_model_parameters(self): f"top_logprobs must be between 1 and 10 but got {self.top_logprobs:,}" ) - if ( - not isinstance(self.repetition_context_size, int) - or self.repetition_context_size < 0 - ): - raise ValueError("repetition_context_size must be a non-negative integer") - if self.logit_bias is not None: if not isinstance(self.logit_bias, dict): raise ValueError("logit_bias must be a dict of int to float") @@ -1519,6 +1585,10 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): logit_bias=self.logit_bias, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, + presence_penalty=self.presence_penalty, + presence_context_size=self.presence_context_size, + frequency_penalty=self.frequency_penalty, + frequency_context_size=self.frequency_context_size, ), stop_words=stop_words, max_tokens=self.max_tokens, @@ -1936,11 +2006,7 @@ def run( handler_class=APIHandler, ): group = mx.distributed.init() - prompt_cache_bytes = model_provider.cli_args.prompt_cache_bytes - prompt_cache = LRUPromptCache( - model_provider.cli_args.prompt_cache_size, - prompt_cache_bytes if prompt_cache_bytes is not None else (1 << 63), - ) + prompt_cache = LRUPromptCache(model_provider.cli_args.prompt_cache_size) response_generator = ResponseGenerator(model_provider, prompt_cache) if group.rank() == 0: _run_http_server(host, port, response_generator) @@ -1972,6 +2038,12 @@ def main(): default=8080, help="Port for the HTTP server (default: 8080)", ) + parser.add_argument( + "--allowed-origins", + type=lambda x: x.split(","), + default="*", + help="Allowed origins (default: *)", + ) parser.add_argument( "--draft-model", type=str, @@ -2070,7 +2142,7 @@ def main(): ) parser.add_argument( "--prompt-cache-bytes", - type=parse_size, + type=_parse_size, help="Maximum size in bytes of the KV caches", ) parser.add_argument( @@ -2097,7 +2169,7 @@ def main(): ) parser.add_argument( "--max-active-kv-bytes", - type=parse_size, + type=_parse_size, help=( "Reject requests when projected active KV memory would exceed this limit " "(bytes or shorthand like 20G)" @@ -2105,7 +2177,7 @@ def main(): ) parser.add_argument( "--max-active-memory-bytes", - type=parse_size, + type=_parse_size, help=( "Abort requests when current active MLX memory exceeds this limit " "(bytes or shorthand like 30G)" diff --git a/tests/test_server.py b/tests/test_server.py index 34fdf703e..45c1e6cc2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -56,6 +56,7 @@ def __init__(self, with_draft=False): "prompt_cache_size": 10, "prompt_cache_bytes": 1 << 63, "prompt_cache_total_bytes": None, + "allowed_origins": ["*"], "max_prompt_tokens": None, "prompt_overflow_policy": "error", "prompt_keep_tokens": 512, @@ -77,20 +78,27 @@ def load(self, model, adapter=None, draft_model=None): class MockCache: - def __init__(self, value, size=None): + def __init__(self, value, is_trimmable: bool = True): self.value = value - self._size = len(value) if size is None else size + self._is_trimmable = is_trimmable @property def nbytes(self): return len(self.value) def size(self): - return self._size + return len(self.value) def __eq__(self, other): return other.value == self.value + def is_trimmable(self): + return self._is_trimmable + + def trim(self, n): + assert self._is_trimmable + return n + class TestServer(unittest.TestCase): @classmethod @@ -456,18 +464,23 @@ def get_kv(n): c[0].update_and_fetch(*get_kv(24)) cache.insert_cache(model, t, c) + # Fetching a cache that is strictly a prefix doesn't remove it from the + # lru cache tokens = tokens + [20] * 5 c, t = cache.fetch_nearest_cache(model, tokens) k, v = c[0].state self.assertTrue((k == v).all().item()) self.assertTrue((k.flatten() == mx.arange(24)).all().item()) self.assertEqual(t, [20] * 5) - self.assertEqual(len(cache._lru), 0) + self.assertEqual(len(cache), 1) + # Inserting a trimmable cache with shared prefix removes the prefixes tokens = tokens + [30] * 3 c[0].update_and_fetch(*get_kv(8)) cache.insert_cache(model, tokens, c) + self.assertEqual(len(cache), 1) + # Fetching a cache with a shared prefix doesn't remove it either tokens = tokens[:26] + [40] * 8 c, t = cache.fetch_nearest_cache(model, tokens) k, v = c[0].state @@ -476,23 +489,34 @@ def get_kv(n): (k.flatten() == mx.concatenate([mx.arange(24), mx.arange(2)])).all().item() ) self.assertEqual(t, [40] * 8) - self.assertEqual(len(cache._lru), 1) + self.assertEqual(len(cache), 1) + + # Inserting a diverged cache actually creates another entry + c[0].update_and_fetch(*get_kv(8)) + cache.insert_cache(model, tokens, c) + self.assertEqual(len(cache), 2) def test_lru(self): cache = LRUPromptCache(max_size=2) model = ("test", None, None) cache.insert_cache(model, [1, 2], [MockCache("test1")]) - cache.insert_cache(model, [1, 2], [MockCache("test1")]) + cache.insert_cache(model, [2, 3], [MockCache("test2")]) c, t = cache.fetch_nearest_cache(model, [1, 2]) self.assertEqual(c, [MockCache("test1")]) self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [1, 2]) + c, t = cache.fetch_nearest_cache(model, [1]) self.assertEqual(c, [MockCache("test1")]) - self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [1, 2]) - self.assertEqual(c, None) - self.assertEqual(t, [1, 2]) + self.assertEqual(t, [1]) + c, t = cache.fetch_nearest_cache(model, [1, 3, 4]) + self.assertEqual(c, [MockCache("test1")]) + self.assertEqual(t, [3, 4]) + c, t = cache.fetch_nearest_cache(model, [2, 3, 4]) + self.assertEqual(c, [MockCache("test2")]) + self.assertEqual(t, [4]) + c, t = cache.fetch_nearest_cache(model, [2, 4, 5]) + self.assertEqual(c, [MockCache("test2")]) + self.assertEqual(t, [4, 5]) cache.insert_cache(model, [1, 2], [MockCache("test1")]) cache.insert_cache(model, [2, 3], [MockCache("test2")]) @@ -508,6 +532,29 @@ def test_lru(self): self.assertEqual(c, [MockCache("test3")]) self.assertEqual(t, []) + cache.insert_cache(model, [4, 5], [MockCache("test4")], checkpoint=True) + c, t = cache.fetch_nearest_cache(model, [2, 3]) + self.assertEqual(c, None) + self.assertEqual(t, [2, 3]) + c, t = cache.fetch_nearest_cache(model, [3, 4]) + self.assertEqual(c, [MockCache("test3")]) + self.assertEqual(t, []) + c, t = cache.fetch_nearest_cache(model, [4, 5]) + self.assertEqual(c, [MockCache("test4")]) + self.assertEqual(t, []) + + cache.insert_cache(model, [5, 6], [MockCache("test5")]) + cache.insert_cache(model, [6, 7], [MockCache("test6")]) + c, t = cache.fetch_nearest_cache(model, [5, 6]) + self.assertEqual(c, None) + self.assertEqual(t, [5, 6]) + c, t = cache.fetch_nearest_cache(model, [6, 7]) + self.assertEqual(c, [MockCache("test6")]) + self.assertEqual(t, []) + c, t = cache.fetch_nearest_cache(model, [4, 5]) + self.assertEqual(c, [MockCache("test4")]) + self.assertEqual(t, []) + def test_lru_bytes(self): cache = LRUPromptCache(max_size=100, max_bytes=10) model = ("test", None, None) @@ -533,25 +580,19 @@ def test_lru_bytes(self): class FailingResponseGenerator: - def __init__(self, exc): + def __init__(self, exc: Exception): self.exc = exc - self.cli_args = type( - "obj", - (object,), - { - "num_draft_tokens": 3, - "max_tokens": 32, - "temp": 0.0, - "top_p": 1.0, - "top_k": 0, - "min_p": 0.0, - "model": None, - }, - ) - def generate(self, *args, **kwargs): + def stop_and_join(self): + return None + + def generate(self, request, args, progress_callback=None): raise self.exc + @property + def cli_args(self): + return type("obj", (), {"allowed_origins": ["*"]})() + class TestErrorStatusCodes(unittest.TestCase): def _run_request(self, exc): @@ -603,27 +644,12 @@ def test_is_metal_oom_error(self): class TestKVBudgeting(unittest.TestCase): def test_projected_kv_bytes_without_growth(self): - cache = [MockCache("abcd", size=0)] - self.assertEqual(projected_kv_bytes(cache, 10), 4) + cache = [MockCache("abcd")] + self.assertEqual(projected_kv_bytes(cache, 10), 14) - def test_projected_kv_bytes_with_growth(self): - cache = [MockCache("abcdef", size=3)] - # 6 bytes over 3 tokens => 2 bytes/token - self.assertEqual(projected_kv_bytes(cache, 5), 16) - - -class TestCLIValidation(unittest.TestCase): - def test_reject_bad_prompt_overflow_policy(self): - from mlx_lm import server as server_module - - argv = [ - "mlx_lm.server", - "--prompt-overflow-policy", - "invalid", - ] - with patch.object(sys, "argv", argv): - with self.assertRaises(SystemExit): - server_module.main() + def test_projected_kv_bytes_with_no_extra(self): + cache = [MockCache("abcdef")] + self.assertEqual(projected_kv_bytes(cache, 0), 6) class TestPromptTokenLimit(unittest.TestCase): @@ -659,14 +685,27 @@ def test_truncate_policy(self): ) self.assertEqual(out, [0, 1, 2, 15, 16, 17, 18, 19]) - def test_truncate_policy_keep_over_cap(self): - out = apply_prompt_token_limit( - list(range(20)), - max_prompt_tokens=8, - overflow_policy="truncate", - keep_tokens=100, - ) - self.assertEqual(out, list(range(8))) + +class TestCLIValidation(unittest.TestCase): + def test_server_parses_new_memory_flags(self): + from mlx_lm import server as server_module + + argv = [ + "mlx_lm.server", + "--max-prompt-tokens", + "4096", + "--prompt-overflow-policy", + "truncate", + "--prompt-keep-tokens", + "512", + "--max-active-kv-bytes", + "8G", + "--max-active-memory-bytes", + "30G", + ] + with patch.object(sys, "argv", argv): + with patch("mlx_lm.server.run"): + server_module.main() if __name__ == "__main__": From 220059e935e30f8cd1b72690940ff4134dd532a4 Mon Sep 17 00:00:00 2001 From: Dmitry Ryabkov Date: Mon, 23 Mar 2026 13:40:29 +0100 Subject: [PATCH 11/11] tests: fix error status code handler stub cli args --- tests/test_server.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_server.py b/tests/test_server.py index 45c1e6cc2..fd983bf53 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -582,6 +582,19 @@ def test_lru_bytes(self): class FailingResponseGenerator: def __init__(self, exc: Exception): self.exc = exc + self._cli_args = type( + "obj", + (), + { + "allowed_origins": ["*"], + "num_draft_tokens": 0, + "max_tokens": 100, + "temp": 0.0, + "top_p": 1.0, + "top_k": 0, + "min_p": 0.0, + }, + )() def stop_and_join(self): return None @@ -591,7 +604,7 @@ def generate(self, request, args, progress_callback=None): @property def cli_args(self): - return type("obj", (), {"allowed_origins": ["*"]})() + return self._cli_args class TestErrorStatusCodes(unittest.TestCase):