From 34dc14bcc55f0ec3c06de66feb303a43764f4be7 Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Sat, 21 Mar 2026 11:07:56 -0400 Subject: [PATCH 1/4] Handle Metal OOM in server without crashing process Classify generation failures in mlx_lm.server and return structured errors instead of crashing or misreporting as 404. - Detect Metal/MLX OOM errors and map them to HTTP 503 - Map other generation exceptions to HTTP 500 - Return structured JSON error payloads for non-stream responses - Emit terminal SSE error event + [DONE] for stream responses - Keep server alive after OOM - Defer non-stream 200 headers until success response is ready - Add OOM regression tests (stream + non-stream) in test_server.py - Document OOM behavior and mitigation knobs in SERVER.md --- mlx_lm/SERVER.md | 16 ++++ mlx_lm/server.py | 222 +++++++++++++++++++++++++++++-------------- tests/test_server.py | 53 +++++++++++ 3 files changed, 218 insertions(+), 73 deletions(-) diff --git a/mlx_lm/SERVER.md b/mlx_lm/SERVER.md index f38ad3dd4..61072917d 100644 --- a/mlx_lm/SERVER.md +++ b/mlx_lm/SERVER.md @@ -140,6 +140,22 @@ curl localhost:8080/v1/chat/completions \ - `completion_tokens`: The number of tokens generated. - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. +### OOM Behavior + +If generation hits a Metal out-of-memory error, the server now returns a +structured error instead of crashing: + +- Non-streaming requests return HTTP `503` with a JSON body containing an + `error` object. +- Streaming requests return HTTP `503`, emit a final SSE error event, and then + close the stream with `data: [DONE]`. + +Useful mitigation knobs: + +- `--prefill-step-size` to reduce memory pressure during prompt prefill. +- `--prompt-cache-size` to limit the number of cached prompt KV states. +- `--prompt-cache-bytes` to limit total prompt cache memory. + ### List Models Use the `v1/models` endpoint to list available models: diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..3e6cb9bca 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -50,6 +50,18 @@ def get_system_fingerprint(): return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" +def is_metal_oom_error(exc: Exception) -> bool: + message = str(exc).lower() + oom_markers = ( + "out of memory", + "insufficient memory", + "resource exhausted", + "failed to allocate", + "metal error: command buffer execution failed due to out of memory", + ) + return any(marker in message for marker in oom_markers) + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -1147,6 +1159,56 @@ def _set_stream_headers(self, status_code: int = 200): self.send_header("Cache-Control", "no-cache") self._set_cors_headers() + def _classify_generation_error(self, exc: Exception): + if is_metal_oom_error(exc): + return ( + 503, + { + "message": "Metal out-of-memory during generation.", + "type": "resource_exhausted_error", + "code": "metal_out_of_memory", + }, + ) + return ( + 500, + { + "message": str(exc), + "type": "internal_server_error", + "code": "internal_generation_error", + }, + ) + + def _completion_error_response(self, status_code: int, error_payload: Dict[str, str]): + self._set_completion_headers(status_code) + self.end_headers() + self.wfile.write(json.dumps({"error": error_payload}).encode()) + self.wfile.flush() + + def _stream_error_response( + self, + status_code: int, + error_payload: Dict[str, str], + stream_started: bool = False, + ): + if not stream_started: + self._set_stream_headers(status_code) + self.end_headers() + event = { + "id": self.request_id, + "system_fingerprint": self.system_fingerprint, + "object": "error", + "model": self.requested_model, + "created": self.created, + "error": error_payload, + } + try: + self.wfile.write(f"data: {json.dumps(event)}\n\n".encode()) + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + except (BrokenPipeError, ConnectionResetError, OSError): + # Client disconnected before receiving the terminal error event. + pass + def do_OPTIONS(self): self._set_completion_headers(204) self.end_headers() @@ -1502,18 +1564,21 @@ def keepalive_callback(processed_tokens, total_tokens): progress_callback=keepalive_callback, ) except Exception as e: - self._set_completion_headers(404) - self.end_headers() - self.wfile.write(json.dumps({"error": f"{e}"}).encode()) + status_code, error_payload = self._classify_generation_error(e) + if self.stream: + self._stream_error_response(status_code, error_payload, stream_started=False) + else: + self._completion_error_response(status_code, error_payload) return # Prepare the headers + stream_started = False if self.stream: self._set_stream_headers(200) self.end_headers() logging.debug("Starting stream:") + stream_started = True else: - self._set_completion_headers(200) logging.debug("Starting completion:") # Variables to save the tool calls in as they are being generated by @@ -1576,78 +1641,88 @@ def parse_tools(tool_calls): # Well finally save the reason for stopping finish_reason = "length" # Process the generated tokens - for gen in response: - logging.debug(gen.text) - - # Gather the text in tool calling or text variables - if in_reasoning: - if gen.text == ctx.think_end: - in_reasoning = False - else: - reasoning_text += gen.text - elif ctx.has_tool_calling and gen.text == ctx.tool_call_start: - made_tool_call = True - in_tool_call = True - elif in_tool_call: - if gen.text == ctx.tool_call_end: - tool_calls.append(tool_text) - tool_text = "" - in_tool_call = False + try: + for gen in response: + logging.debug(gen.text) + + # Gather the text in tool calling or text variables + if in_reasoning: + if gen.text == ctx.think_end: + in_reasoning = False + else: + reasoning_text += gen.text + elif ctx.has_tool_calling and gen.text == ctx.tool_call_start: + made_tool_call = True + in_tool_call = True + elif in_tool_call: + if gen.text == ctx.tool_call_end: + tool_calls.append(tool_text) + tool_text = "" + in_tool_call = False + else: + tool_text += gen.text else: - tool_text += gen.text - else: - text += gen.text - segment += gen.text - - # Save the token and its logprob - tokens.append(gen.token) - if args.logprobs: - token_logprobs.append(gen.logprob) - - # If requested save the k top logprobs - if args.top_logprobs > 0: - top_tokens.append(gen.top_tokens) - - # Check if we should stop early - stop_condition = stopping_criteria( - tokens, - ctx.eos_token_ids, - ctx.stop_token_sequences, - stop_words, - ) - if stop_condition.stop_met: - finish_reason = "tool_calls" if made_tool_call else "stop" - ctx.stop() - tokens = tokens[: len(tokens) - stop_condition.trim_length] - text = text[: len(text) - stop_condition.trim_text_length] - segment = "" - break - - if self.stream and not in_tool_call: - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - ( - sequence_overlap(tokens, sequence) - for sequence in ctx.stop_token_sequences - ) - ): - continue - elif segment or tool_calls or reasoning_text: - response = self.generate_response( - segment, - None, - tool_calls=parse_tools(tool_calls), - reasoning_text=reasoning_text, - ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - reasoning_text = "" + text += gen.text + segment += gen.text + + # Save the token and its logprob + tokens.append(gen.token) + if args.logprobs: + token_logprobs.append(gen.logprob) + + # If requested save the k top logprobs + if args.top_logprobs > 0: + top_tokens.append(gen.top_tokens) + + # Check if we should stop early + stop_condition = stopping_criteria( + tokens, + ctx.eos_token_ids, + ctx.stop_token_sequences, + stop_words, + ) + if stop_condition.stop_met: + finish_reason = "tool_calls" if made_tool_call else "stop" + ctx.stop() + tokens = tokens[: len(tokens) - stop_condition.trim_length] + text = text[: len(text) - stop_condition.trim_text_length] segment = "" - tool_calls = [] + break - if gen.finish_reason is not None: - finish_reason = gen.finish_reason + if self.stream and not in_tool_call: + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + ( + sequence_overlap(tokens, sequence) + for sequence in ctx.stop_token_sequences + ) + ): + continue + elif segment or tool_calls or reasoning_text: + response = self.generate_response( + segment, + None, + tool_calls=parse_tools(tool_calls), + reasoning_text=reasoning_text, + ) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + reasoning_text = "" + segment = "" + tool_calls = [] + + if gen.finish_reason is not None: + finish_reason = gen.finish_reason + except Exception as e: + status_code, error_payload = self._classify_generation_error(e) + if self.stream: + self._stream_error_response( + status_code, error_payload, stream_started=stream_started + ) + else: + self._completion_error_response(status_code, error_payload) + return # Flush any remaining tool text (e.g. when tool_call_end is empty) if in_tool_call and tool_text: @@ -1689,6 +1764,7 @@ def parse_tools(tool_calls): indent = "\t" # Backslashes can't be inside of f-strings logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") + self._set_completion_headers(200) # Send an additional Content-Length header when it is known self.send_header("Content-Length", str(len(response_json))) self.end_headers() diff --git a/tests/test_server.py b/tests/test_server.py index c5a815e4f..787dc1d6c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,6 +5,7 @@ import json import threading import unittest +from unittest import mock import mlx.core as mx import requests @@ -230,6 +231,58 @@ def test_sequence_overlap(self): self.assertFalse(sequence_overlap([1, 2], [3, 4])) self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3])) + def test_oom_returns_503_without_crashing_server(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + post_data = { + "model": "chat_model", + "max_tokens": 4, + "messages": [{"role": "user", "content": "hello"}], + } + + with mock.patch.object( + self.response_generator, + "generate", + side_effect=RuntimeError( + "Metal error: command buffer execution failed due to out of memory" + ), + ): + response = requests.post(url, json=post_data) + + self.assertEqual(response.status_code, 503) + response_body = json.loads(response.text) + self.assertIn("error", response_body) + self.assertEqual(response_body["error"]["code"], "metal_out_of_memory") + self.assertEqual(response_body["error"]["type"], "resource_exhausted_error") + + # Server should continue serving requests after an OOM. + next_response = requests.post(url, json=post_data) + self.assertEqual(next_response.status_code, 200) + + def test_streaming_oom_returns_terminal_error_event(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + post_data = { + "model": "chat_model", + "max_tokens": 4, + "stream": True, + "messages": [{"role": "user", "content": "hello"}], + } + + with mock.patch.object( + self.response_generator, + "generate", + side_effect=RuntimeError("out of memory"), + ): + response = requests.post(url, json=post_data, stream=True) + lines = [line.decode("utf-8") for line in response.iter_lines() if line] + + self.assertEqual(response.status_code, 503) + self.assertGreaterEqual(len(lines), 2) + self.assertEqual(lines[-1], "data: [DONE]") + + error_event = json.loads(lines[0][6:]) + self.assertIn("error", error_event) + self.assertEqual(error_event["error"]["code"], "metal_out_of_memory") + class TestServerWithDraftModel(unittest.TestCase): @classmethod From 06ce874c6c20b22bd15a394ed9394a0832bfd844 Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Sat, 21 Mar 2026 11:12:50 -0400 Subject: [PATCH 2/4] Format server OOM handling with pre-commit black --- mlx_lm/server.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 3e6cb9bca..3df741265 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -1178,7 +1178,9 @@ def _classify_generation_error(self, exc: Exception): }, ) - def _completion_error_response(self, status_code: int, error_payload: Dict[str, str]): + def _completion_error_response( + self, status_code: int, error_payload: Dict[str, str] + ): self._set_completion_headers(status_code) self.end_headers() self.wfile.write(json.dumps({"error": error_payload}).encode()) @@ -1566,7 +1568,9 @@ def keepalive_callback(processed_tokens, total_tokens): except Exception as e: status_code, error_payload = self._classify_generation_error(e) if self.stream: - self._stream_error_response(status_code, error_payload, stream_started=False) + self._stream_error_response( + status_code, error_payload, stream_started=False + ) else: self._completion_error_response(status_code, error_payload) return From 4713cfb14387d49a6edb38ab998d0b2000c86640 Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Sat, 21 Mar 2026 11:14:23 -0400 Subject: [PATCH 3/4] Broaden OOM markers and log classified Metal OOMs --- mlx_lm/server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 3df741265..acd7f307c 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -55,6 +55,7 @@ def is_metal_oom_error(exc: Exception) -> bool: oom_markers = ( "out of memory", "insufficient memory", + "insufficient memory for buffer", "resource exhausted", "failed to allocate", "metal error: command buffer execution failed due to out of memory", @@ -1161,6 +1162,11 @@ def _set_stream_headers(self, status_code: int = 200): def _classify_generation_error(self, exc: Exception): if is_metal_oom_error(exc): + logging.warning( + "Metal OOM detected while serving request_id=%s: %s", + self.request_id, + exc, + ) return ( 503, { From 53dc99ef86c2e73673594b4d7a06d4995443b6fc Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Wed, 25 Mar 2026 11:17:28 -0400 Subject: [PATCH 4/4] Catch additional Metal OOM error strings - Add marker coverage for 'attempting to allocate' and 'maximum allowed buffer size' - Add regression test to ensure these variants map to HTTP 503 --- mlx_lm/server.py | 2 ++ tests/test_server.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index acd7f307c..c71810be6 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -56,6 +56,8 @@ def is_metal_oom_error(exc: Exception) -> bool: "out of memory", "insufficient memory", "insufficient memory for buffer", + "attempting to allocate", + "maximum allowed buffer size", "resource exhausted", "failed to allocate", "metal error: command buffer execution failed due to out of memory", diff --git a/tests/test_server.py b/tests/test_server.py index 787dc1d6c..3c2c6103b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -283,6 +283,25 @@ def test_streaming_oom_returns_terminal_error_event(self): self.assertIn("error", error_event) self.assertEqual(error_event["error"]["code"], "metal_out_of_memory") + def test_oom_marker_attempting_to_allocate_maps_to_503(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + post_data = { + "model": "chat_model", + "max_tokens": 4, + "messages": [{"role": "user", "content": "hello"}], + } + + with mock.patch.object( + self.response_generator, + "generate", + side_effect=RuntimeError( + "Error: attempting to allocate 12.3 GB, maximum allowed buffer size reached" + ), + ): + response = requests.post(url, json=post_data) + + self.assertEqual(response.status_code, 503) + class TestServerWithDraftModel(unittest.TestCase): @classmethod