diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..564d23575 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -50,6 +50,58 @@ def get_system_fingerprint(): return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" +def is_metal_oom_error(error: Exception) -> bool: + text = str(error).lower() + patterns = ( + "out of memory", + "insufficient memory", + "kiogpucommandbuffercallbackerroroutofmemory", + "mps backend out of memory", + ) + return any(pattern in text for pattern in patterns) + + +def projected_kv_bytes(prompt_cache: List[Any], extra_tokens: int) -> int: + cache_bytes = sum(c.nbytes for c in prompt_cache) + if cache_bytes <= 0 or extra_tokens <= 0: + return cache_bytes + + cache_tokens = max( + (c.size() for c in prompt_cache if hasattr(c, "size")), default=0 + ) + if cache_tokens <= 0: + return cache_bytes + + bytes_per_token = cache_bytes / cache_tokens + return cache_bytes + int(bytes_per_token * extra_tokens) + + +def apply_prompt_token_limit( + tokens: List[int], + *, + max_prompt_tokens: Optional[int], + overflow_policy: str, + keep_tokens: int, +) -> List[int]: + if max_prompt_tokens is None or len(tokens) <= max_prompt_tokens: + return tokens + + if overflow_policy == "error": + raise ValueError( + "Prompt exceeds max prompt token limit: " + f"prompt_tokens={len(tokens)}, max_prompt_tokens={max_prompt_tokens}" + ) + + if overflow_policy != "truncate": + raise ValueError(f"Invalid prompt overflow policy: {overflow_policy}") + + keep_tokens = max(0, min(keep_tokens, max_prompt_tokens)) + tail_tokens = max_prompt_tokens - keep_tokens + if tail_tokens <= 0: + return tokens[:max_prompt_tokens] + return tokens[:keep_tokens] + tokens[-tail_tokens:] + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -722,17 +774,35 @@ def _tokenize(self, tokenizer, request, args): if args.chat_template_kwargs: chat_template_args = chat_template_args.copy() chat_template_args.update(args.chat_template_kwargs) - return tokenizer.apply_chat_template( + tokens = tokenizer.apply_chat_template( messages, tools=tools, add_generation_prompt=True, tokenize=True, **chat_template_args, ) + return apply_prompt_token_limit( + tokens, + max_prompt_tokens=self.cli_args.max_prompt_tokens, + overflow_policy=self.cli_args.prompt_overflow_policy, + keep_tokens=self.cli_args.prompt_keep_tokens, + ) else: - return tokenizer.encode(convert_chat(messages, role_mapping)) + tokens = tokenizer.encode(convert_chat(messages, role_mapping)) + return apply_prompt_token_limit( + tokens, + max_prompt_tokens=self.cli_args.max_prompt_tokens, + overflow_policy=self.cli_args.prompt_overflow_policy, + keep_tokens=self.cli_args.prompt_keep_tokens, + ) else: - return tokenizer.encode(request.prompt) + tokens = tokenizer.encode(request.prompt) + return apply_prompt_token_limit( + tokens, + max_prompt_tokens=self.cli_args.max_prompt_tokens, + overflow_policy=self.cli_args.prompt_overflow_policy, + keep_tokens=self.cli_args.prompt_keep_tokens, + ) def _compute_prompt_checkpoint(self, tokenizer, request, prompt): if request.request_type != "chat": @@ -760,6 +830,44 @@ def _is_batchable(self, args): return True + def _make_prompt_cache(self, model, draft_model=None): + cache = make_prompt_cache( + model, + max_kv_size=self.model_provider.cli_args.max_kv_size, + ) + if draft_model is not None: + cache += make_prompt_cache( + draft_model, + max_kv_size=self.model_provider.cli_args.max_kv_size, + ) + return cache + + def _memory_admission_error( + self, prompt_cache: List[Any], extra_tokens: int, active_bytes: int = 0 + ) -> Optional[str]: + limit = self.model_provider.cli_args.max_active_kv_bytes + if limit is None: + return None + projected = active_bytes + projected_kv_bytes(prompt_cache, extra_tokens) + if projected <= limit: + return None + return ( + "Projected KV memory usage exceeds configured active KV limit. " + f"projected={projected} bytes, limit={limit} bytes" + ) + + def _check_active_memory_limit(self): + limit = self.model_provider.cli_args.max_active_memory_bytes + if limit is None: + return + active = mx.get_active_memory() + if active > limit: + raise MemoryError( + "Active MLX memory exceeded configured limit: " + f"active={active} bytes, limit={limit} bytes. " + "Consider lowering prompt length or max_tokens." + ) + def _generate(self): current_model = None current_sampling = None @@ -778,6 +886,7 @@ def get_next_request(timeout=None): return self._next_request(timeout) def progress_callback(info): + self._check_active_memory_limit() for uid, processed, total in info: if uid in batch_results: batch_results[uid]["rqueue"].put((min(processed, total), total)) @@ -848,7 +957,9 @@ def checkpoint_callback(prompts): ) ctx.prompt_cache_count = len(prompt) - len(rest) if cache is None: - cache = make_prompt_cache(self.model_provider.model) + cache = self._make_prompt_cache( + self.model_provider.model, self.model_provider.draft_model + ) do_checkpoint, checkpoint_position = ( self._compute_prompt_checkpoint(tokenizer, request, prompt) @@ -906,6 +1017,7 @@ def checkpoint_callback(prompts): prefill_step_size=self.cli_args.prefill_step_size, prompt_progress_callback=progress_callback, prompt_checkpoint_callback=checkpoint_callback, + max_kv_size=self.cli_args.max_kv_size, ) unprocessed_requests.append((rqueue, request, args)) continue @@ -984,6 +1096,7 @@ def _serve_single(self, request): # Define the progress callback def progress(tokens_processed, tokens_total): + self._check_active_memory_limit() rqueue.put((tokens_processed, tokens_total)) try: @@ -1030,9 +1143,16 @@ def progress(tokens_processed, tokens_total): ctx.prompt_cache_count = len(prompt) - len(rest) cache_key = prompt[:] if cache is None: - cache = make_prompt_cache(self.model_provider.model) - if self.model_provider.draft_model is not None: - cache += make_prompt_cache(self.model_provider.draft_model) + cache = self._make_prompt_cache( + self.model_provider.model, self.model_provider.draft_model + ) + + admission_error = self._memory_admission_error( + cache, + len(rest) + args.max_tokens, + ) + if admission_error is not None: + raise MemoryError(admission_error) # Process the prompt and generate tokens for gen in stream_generate( @@ -1047,6 +1167,7 @@ def progress(tokens_processed, tokens_total): num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, + max_kv_size=self.cli_args.max_kv_size, ): rqueue.put( Response( @@ -1502,7 +1623,8 @@ def keepalive_callback(processed_tokens, total_tokens): progress_callback=keepalive_callback, ) except Exception as e: - self._set_completion_headers(404) + status_code = 503 if is_metal_oom_error(e) else 500 + self._set_completion_headers(status_code) self.end_headers() self.wfile.write(json.dumps({"error": f"{e}"}).encode()) return @@ -2023,6 +2145,50 @@ def main(): type=_parse_size, help="Maximum size in bytes of the KV caches", ) + parser.add_argument( + "--max-prompt-tokens", + type=int, + default=None, + help="Maximum prompt token count accepted by the server", + ) + parser.add_argument( + "--prompt-overflow-policy", + type=str, + default="error", + choices=["error", "truncate"], + help="Behavior when prompt exceeds --max-prompt-tokens", + ) + parser.add_argument( + "--prompt-keep-tokens", + type=int, + default=512, + help=( + "When truncating prompts, keep this many tokens from the start and fill " + "the remainder from the end" + ), + ) + parser.add_argument( + "--max-active-kv-bytes", + type=_parse_size, + help=( + "Reject requests when projected active KV memory would exceed this limit " + "(bytes or shorthand like 20G)" + ), + ) + parser.add_argument( + "--max-active-memory-bytes", + type=_parse_size, + help=( + "Abort requests when current active MLX memory exceeds this limit " + "(bytes or shorthand like 30G)" + ), + ) + parser.add_argument( + "--max-kv-size", + type=int, + default=None, + help="Maximum size of the active KV cache per sequence", + ) parser.add_argument( "--pipeline", action="store_true", diff --git a/tests/test_server.py b/tests/test_server.py index c5a815e4f..fd983bf53 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,14 +3,23 @@ import http import io import json +import sys import threading import unittest +from unittest.mock import patch import mlx.core as mx import requests from mlx_lm.models.cache import KVCache -from mlx_lm.server import APIHandler, LRUPromptCache, ResponseGenerator +from mlx_lm.server import ( + APIHandler, + LRUPromptCache, + ResponseGenerator, + apply_prompt_token_limit, + is_metal_oom_error, + projected_kv_bytes, +) from mlx_lm.utils import load @@ -48,6 +57,12 @@ def __init__(self, with_draft=False): "prompt_cache_bytes": 1 << 63, "prompt_cache_total_bytes": None, "allowed_origins": ["*"], + "max_prompt_tokens": None, + "prompt_overflow_policy": "error", + "prompt_keep_tokens": 512, + "max_active_kv_bytes": None, + "max_active_memory_bytes": None, + "max_kv_size": None, }, ) @@ -71,6 +86,9 @@ def __init__(self, value, is_trimmable: bool = True): def nbytes(self): return len(self.value) + def size(self): + return len(self.value) + def __eq__(self, other): return other.value == self.value @@ -561,5 +579,147 @@ def test_lru_bytes(self): self.assertEqual(t, [3, 4]) +class FailingResponseGenerator: + def __init__(self, exc: Exception): + self.exc = exc + self._cli_args = type( + "obj", + (), + { + "allowed_origins": ["*"], + "num_draft_tokens": 0, + "max_tokens": 100, + "temp": 0.0, + "top_p": 1.0, + "top_k": 0, + "min_p": 0.0, + }, + )() + + def stop_and_join(self): + return None + + def generate(self, request, args, progress_callback=None): + raise self.exc + + @property + def cli_args(self): + return self._cli_args + + +class TestErrorStatusCodes(unittest.TestCase): + def _run_request(self, exc): + response_generator = FailingResponseGenerator(exc) + httpd = http.server.HTTPServer( + ("localhost", 0), + lambda *args, **kwargs: APIHandler(response_generator, *args, **kwargs), + ) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.daemon = True + server_thread.start() + try: + url = f"http://localhost:{httpd.server_port}/v1/completions" + return requests.post( + url, + json={ + "model": "default_model", + "prompt": "test", + "max_tokens": 2, + }, + ) + finally: + httpd.shutdown() + httpd.server_close() + server_thread.join() + + def test_oom_maps_to_service_unavailable(self): + response = self._run_request( + RuntimeError( + "[METAL] Command buffer execution failed: Insufficient Memory " + "(00000008:kIOGPUCommandBufferCallbackErrorOutOfMemory)" + ) + ) + self.assertEqual(response.status_code, 503) + + def test_non_oom_maps_to_internal_server_error(self): + response = self._run_request(RuntimeError("arbitrary failure")) + self.assertEqual(response.status_code, 500) + + def test_is_metal_oom_error(self): + self.assertTrue(is_metal_oom_error(RuntimeError("out of memory"))) + self.assertTrue( + is_metal_oom_error( + RuntimeError("kIOGPUCommandBufferCallbackErrorOutOfMemory") + ) + ) + self.assertFalse(is_metal_oom_error(RuntimeError("other runtime failure"))) + + +class TestKVBudgeting(unittest.TestCase): + def test_projected_kv_bytes_without_growth(self): + cache = [MockCache("abcd")] + self.assertEqual(projected_kv_bytes(cache, 10), 14) + + def test_projected_kv_bytes_with_no_extra(self): + cache = [MockCache("abcdef")] + self.assertEqual(projected_kv_bytes(cache, 0), 6) + + +class TestPromptTokenLimit(unittest.TestCase): + def test_no_limit(self): + tokens = list(range(10)) + self.assertEqual( + apply_prompt_token_limit( + tokens, + max_prompt_tokens=None, + overflow_policy="error", + keep_tokens=0, + ), + tokens, + ) + + def test_error_policy(self): + with self.assertRaisesRegex( + ValueError, "Prompt exceeds max prompt token limit" + ): + apply_prompt_token_limit( + list(range(20)), + max_prompt_tokens=8, + overflow_policy="error", + keep_tokens=0, + ) + + def test_truncate_policy(self): + out = apply_prompt_token_limit( + list(range(20)), + max_prompt_tokens=8, + overflow_policy="truncate", + keep_tokens=3, + ) + self.assertEqual(out, [0, 1, 2, 15, 16, 17, 18, 19]) + + +class TestCLIValidation(unittest.TestCase): + def test_server_parses_new_memory_flags(self): + from mlx_lm import server as server_module + + argv = [ + "mlx_lm.server", + "--max-prompt-tokens", + "4096", + "--prompt-overflow-policy", + "truncate", + "--prompt-keep-tokens", + "512", + "--max-active-kv-bytes", + "8G", + "--max-active-memory-bytes", + "30G", + ] + with patch.object(sys, "argv", argv): + with patch("mlx_lm.server.run"): + server_module.main() + + if __name__ == "__main__": unittest.main()