Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mlx_lm/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ curl localhost:8080/v1/chat/completions \
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.

### OOM Behavior

If generation hits a Metal out-of-memory error, the server now returns a
structured error instead of crashing:

- Non-streaming requests return HTTP `503` with a JSON body containing an
`error` object.
- Streaming requests return HTTP `503`, emit a final SSE error event, and then
close the stream with `data: [DONE]`.

Useful mitigation knobs:

- `--prefill-step-size` to reduce memory pressure during prompt prefill.
- `--prompt-cache-size` to limit the number of cached prompt KV states.
- `--prompt-cache-bytes` to limit total prompt cache memory.

### List Models

Use the `v1/models` endpoint to list available models:
Expand Down
234 changes: 161 additions & 73 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def get_system_fingerprint():
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"


def is_metal_oom_error(exc: Exception) -> bool:
message = str(exc).lower()
oom_markers = (
"out of memory",
"insufficient memory",
"insufficient memory for buffer",
"attempting to allocate",
"maximum allowed buffer size",
"resource exhausted",
"failed to allocate",
"metal error: command buffer execution failed due to out of memory",
)
return any(marker in message for marker in oom_markers)


class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
Expand Down Expand Up @@ -1147,6 +1162,63 @@ def _set_stream_headers(self, status_code: int = 200):
self.send_header("Cache-Control", "no-cache")
self._set_cors_headers()

def _classify_generation_error(self, exc: Exception):
if is_metal_oom_error(exc):
logging.warning(
"Metal OOM detected while serving request_id=%s: %s",
self.request_id,
exc,
)
return (
503,
{
"message": "Metal out-of-memory during generation.",
"type": "resource_exhausted_error",
"code": "metal_out_of_memory",
},
)
return (
500,
{
"message": str(exc),
"type": "internal_server_error",
"code": "internal_generation_error",
},
)

def _completion_error_response(
self, status_code: int, error_payload: Dict[str, str]
):
self._set_completion_headers(status_code)
self.end_headers()
self.wfile.write(json.dumps({"error": error_payload}).encode())
self.wfile.flush()

def _stream_error_response(
self,
status_code: int,
error_payload: Dict[str, str],
stream_started: bool = False,
):
if not stream_started:
self._set_stream_headers(status_code)
self.end_headers()
event = {
"id": self.request_id,
"system_fingerprint": self.system_fingerprint,
"object": "error",
"model": self.requested_model,
"created": self.created,
"error": error_payload,
}
try:
self.wfile.write(f"data: {json.dumps(event)}\n\n".encode())
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
except (BrokenPipeError, ConnectionResetError, OSError):
# Client disconnected before receiving the terminal error event.
pass

def do_OPTIONS(self):
self._set_completion_headers(204)
self.end_headers()
Expand Down Expand Up @@ -1502,18 +1574,23 @@ def keepalive_callback(processed_tokens, total_tokens):
progress_callback=keepalive_callback,
)
except Exception as e:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(json.dumps({"error": f"{e}"}).encode())
status_code, error_payload = self._classify_generation_error(e)
if self.stream:
self._stream_error_response(
status_code, error_payload, stream_started=False
)
else:
self._completion_error_response(status_code, error_payload)
return

# Prepare the headers
stream_started = False
if self.stream:
self._set_stream_headers(200)
self.end_headers()
logging.debug("Starting stream:")
stream_started = True
else:
self._set_completion_headers(200)
logging.debug("Starting completion:")

# Variables to save the tool calls in as they are being generated by
Expand Down Expand Up @@ -1576,78 +1653,88 @@ def parse_tools(tool_calls):
# Well finally save the reason for stopping
finish_reason = "length"
# Process the generated tokens
for gen in response:
logging.debug(gen.text)

# Gather the text in tool calling or text variables
if in_reasoning:
if gen.text == ctx.think_end:
in_reasoning = False
else:
reasoning_text += gen.text
elif ctx.has_tool_calling and gen.text == ctx.tool_call_start:
made_tool_call = True
in_tool_call = True
elif in_tool_call:
if gen.text == ctx.tool_call_end:
tool_calls.append(tool_text)
tool_text = ""
in_tool_call = False
try:
for gen in response:
logging.debug(gen.text)

# Gather the text in tool calling or text variables
if in_reasoning:
if gen.text == ctx.think_end:
in_reasoning = False
else:
reasoning_text += gen.text
elif ctx.has_tool_calling and gen.text == ctx.tool_call_start:
made_tool_call = True
in_tool_call = True
elif in_tool_call:
if gen.text == ctx.tool_call_end:
tool_calls.append(tool_text)
tool_text = ""
in_tool_call = False
else:
tool_text += gen.text
else:
tool_text += gen.text
else:
text += gen.text
segment += gen.text

# Save the token and its logprob
tokens.append(gen.token)
if args.logprobs:
token_logprobs.append(gen.logprob)

# If requested save the k top logprobs
if args.top_logprobs > 0:
top_tokens.append(gen.top_tokens)

# Check if we should stop early
stop_condition = stopping_criteria(
tokens,
ctx.eos_token_ids,
ctx.stop_token_sequences,
stop_words,
)
if stop_condition.stop_met:
finish_reason = "tool_calls" if made_tool_call else "stop"
ctx.stop()
tokens = tokens[: len(tokens) - stop_condition.trim_length]
text = text[: len(text) - stop_condition.trim_text_length]
segment = ""
break

if self.stream and not in_tool_call:
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(
sequence_overlap(tokens, sequence)
for sequence in ctx.stop_token_sequences
)
):
continue
elif segment or tool_calls or reasoning_text:
response = self.generate_response(
segment,
None,
tool_calls=parse_tools(tool_calls),
reasoning_text=reasoning_text,
)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
reasoning_text = ""
text += gen.text
segment += gen.text

# Save the token and its logprob
tokens.append(gen.token)
if args.logprobs:
token_logprobs.append(gen.logprob)

# If requested save the k top logprobs
if args.top_logprobs > 0:
top_tokens.append(gen.top_tokens)

# Check if we should stop early
stop_condition = stopping_criteria(
tokens,
ctx.eos_token_ids,
ctx.stop_token_sequences,
stop_words,
)
if stop_condition.stop_met:
finish_reason = "tool_calls" if made_tool_call else "stop"
ctx.stop()
tokens = tokens[: len(tokens) - stop_condition.trim_length]
text = text[: len(text) - stop_condition.trim_text_length]
segment = ""
tool_calls = []
break

if gen.finish_reason is not None:
finish_reason = gen.finish_reason
if self.stream and not in_tool_call:
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(
sequence_overlap(tokens, sequence)
for sequence in ctx.stop_token_sequences
)
):
continue
elif segment or tool_calls or reasoning_text:
response = self.generate_response(
segment,
None,
tool_calls=parse_tools(tool_calls),
reasoning_text=reasoning_text,
)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
reasoning_text = ""
segment = ""
tool_calls = []

if gen.finish_reason is not None:
finish_reason = gen.finish_reason
except Exception as e:
status_code, error_payload = self._classify_generation_error(e)
if self.stream:
self._stream_error_response(
status_code, error_payload, stream_started=stream_started
)
else:
self._completion_error_response(status_code, error_payload)
return

# Flush any remaining tool text (e.g. when tool_call_end is empty)
if in_tool_call and tool_text:
Expand Down Expand Up @@ -1689,6 +1776,7 @@ def parse_tools(tool_calls):
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")

self._set_completion_headers(200)
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
Expand Down
72 changes: 72 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import threading
import unittest
from unittest import mock

import mlx.core as mx
import requests
Expand Down Expand Up @@ -230,6 +231,77 @@ def test_sequence_overlap(self):
self.assertFalse(sequence_overlap([1, 2], [3, 4]))
self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3]))

def test_oom_returns_503_without_crashing_server(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
post_data = {
"model": "chat_model",
"max_tokens": 4,
"messages": [{"role": "user", "content": "hello"}],
}

with mock.patch.object(
self.response_generator,
"generate",
side_effect=RuntimeError(
"Metal error: command buffer execution failed due to out of memory"
),
):
response = requests.post(url, json=post_data)

self.assertEqual(response.status_code, 503)
response_body = json.loads(response.text)
self.assertIn("error", response_body)
self.assertEqual(response_body["error"]["code"], "metal_out_of_memory")
self.assertEqual(response_body["error"]["type"], "resource_exhausted_error")

# Server should continue serving requests after an OOM.
next_response = requests.post(url, json=post_data)
self.assertEqual(next_response.status_code, 200)

def test_streaming_oom_returns_terminal_error_event(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
post_data = {
"model": "chat_model",
"max_tokens": 4,
"stream": True,
"messages": [{"role": "user", "content": "hello"}],
}

with mock.patch.object(
self.response_generator,
"generate",
side_effect=RuntimeError("out of memory"),
):
response = requests.post(url, json=post_data, stream=True)
lines = [line.decode("utf-8") for line in response.iter_lines() if line]

self.assertEqual(response.status_code, 503)
self.assertGreaterEqual(len(lines), 2)
self.assertEqual(lines[-1], "data: [DONE]")

error_event = json.loads(lines[0][6:])
self.assertIn("error", error_event)
self.assertEqual(error_event["error"]["code"], "metal_out_of_memory")

def test_oom_marker_attempting_to_allocate_maps_to_503(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
post_data = {
"model": "chat_model",
"max_tokens": 4,
"messages": [{"role": "user", "content": "hello"}],
}

with mock.patch.object(
self.response_generator,
"generate",
side_effect=RuntimeError(
"Error: attempting to allocate 12.3 GB, maximum allowed buffer size reached"
),
):
response = requests.post(url, json=post_data)

self.assertEqual(response.status_code, 503)


class TestServerWithDraftModel(unittest.TestCase):
@classmethod
Expand Down