Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 174 additions & 8 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,58 @@ def get_system_fingerprint():
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"


def is_metal_oom_error(error: Exception) -> bool:
text = str(error).lower()
patterns = (
"out of memory",
"insufficient memory",
"kiogpucommandbuffercallbackerroroutofmemory",
"mps backend out of memory",
)
return any(pattern in text for pattern in patterns)


def projected_kv_bytes(prompt_cache: List[Any], extra_tokens: int) -> int:
cache_bytes = sum(c.nbytes for c in prompt_cache)
if cache_bytes <= 0 or extra_tokens <= 0:
return cache_bytes

cache_tokens = max(
(c.size() for c in prompt_cache if hasattr(c, "size")), default=0
)
if cache_tokens <= 0:
return cache_bytes

bytes_per_token = cache_bytes / cache_tokens
return cache_bytes + int(bytes_per_token * extra_tokens)


def apply_prompt_token_limit(
tokens: List[int],
*,
max_prompt_tokens: Optional[int],
overflow_policy: str,
keep_tokens: int,
) -> List[int]:
if max_prompt_tokens is None or len(tokens) <= max_prompt_tokens:
return tokens

if overflow_policy == "error":
raise ValueError(
"Prompt exceeds max prompt token limit: "
f"prompt_tokens={len(tokens)}, max_prompt_tokens={max_prompt_tokens}"
)

if overflow_policy != "truncate":
raise ValueError(f"Invalid prompt overflow policy: {overflow_policy}")

keep_tokens = max(0, min(keep_tokens, max_prompt_tokens))
tail_tokens = max_prompt_tokens - keep_tokens
if tail_tokens <= 0:
return tokens[:max_prompt_tokens]
return tokens[:keep_tokens] + tokens[-tail_tokens:]


class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
Expand Down Expand Up @@ -722,17 +774,35 @@ def _tokenize(self, tokenizer, request, args):
if args.chat_template_kwargs:
chat_template_args = chat_template_args.copy()
chat_template_args.update(args.chat_template_kwargs)
return tokenizer.apply_chat_template(
tokens = tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=True,
tokenize=True,
**chat_template_args,
)
return apply_prompt_token_limit(
tokens,
max_prompt_tokens=self.cli_args.max_prompt_tokens,
overflow_policy=self.cli_args.prompt_overflow_policy,
keep_tokens=self.cli_args.prompt_keep_tokens,
)
else:
return tokenizer.encode(convert_chat(messages, role_mapping))
tokens = tokenizer.encode(convert_chat(messages, role_mapping))
return apply_prompt_token_limit(
tokens,
max_prompt_tokens=self.cli_args.max_prompt_tokens,
overflow_policy=self.cli_args.prompt_overflow_policy,
keep_tokens=self.cli_args.prompt_keep_tokens,
)
else:
return tokenizer.encode(request.prompt)
tokens = tokenizer.encode(request.prompt)
return apply_prompt_token_limit(
tokens,
max_prompt_tokens=self.cli_args.max_prompt_tokens,
overflow_policy=self.cli_args.prompt_overflow_policy,
keep_tokens=self.cli_args.prompt_keep_tokens,
)

def _compute_prompt_checkpoint(self, tokenizer, request, prompt):
if request.request_type != "chat":
Expand Down Expand Up @@ -760,6 +830,44 @@ def _is_batchable(self, args):

return True

def _make_prompt_cache(self, model, draft_model=None):
cache = make_prompt_cache(
model,
max_kv_size=self.model_provider.cli_args.max_kv_size,
)
if draft_model is not None:
cache += make_prompt_cache(
draft_model,
max_kv_size=self.model_provider.cli_args.max_kv_size,
)
return cache

def _memory_admission_error(
self, prompt_cache: List[Any], extra_tokens: int, active_bytes: int = 0
) -> Optional[str]:
limit = self.model_provider.cli_args.max_active_kv_bytes
if limit is None:
return None
projected = active_bytes + projected_kv_bytes(prompt_cache, extra_tokens)
if projected <= limit:
return None
return (
"Projected KV memory usage exceeds configured active KV limit. "
f"projected={projected} bytes, limit={limit} bytes"
)

def _check_active_memory_limit(self):
limit = self.model_provider.cli_args.max_active_memory_bytes
if limit is None:
return
active = mx.get_active_memory()
if active > limit:
raise MemoryError(
"Active MLX memory exceeded configured limit: "
f"active={active} bytes, limit={limit} bytes. "
"Consider lowering prompt length or max_tokens."
)

def _generate(self):
current_model = None
current_sampling = None
Expand All @@ -778,6 +886,7 @@ def get_next_request(timeout=None):
return self._next_request(timeout)

def progress_callback(info):
self._check_active_memory_limit()
for uid, processed, total in info:
if uid in batch_results:
batch_results[uid]["rqueue"].put((min(processed, total), total))
Expand Down Expand Up @@ -848,7 +957,9 @@ def checkpoint_callback(prompts):
)
ctx.prompt_cache_count = len(prompt) - len(rest)
if cache is None:
cache = make_prompt_cache(self.model_provider.model)
cache = self._make_prompt_cache(
self.model_provider.model, self.model_provider.draft_model
)

do_checkpoint, checkpoint_position = (
self._compute_prompt_checkpoint(tokenizer, request, prompt)
Expand Down Expand Up @@ -906,6 +1017,7 @@ def checkpoint_callback(prompts):
prefill_step_size=self.cli_args.prefill_step_size,
prompt_progress_callback=progress_callback,
prompt_checkpoint_callback=checkpoint_callback,
max_kv_size=self.cli_args.max_kv_size,
)
unprocessed_requests.append((rqueue, request, args))
continue
Expand Down Expand Up @@ -984,6 +1096,7 @@ def _serve_single(self, request):

# Define the progress callback
def progress(tokens_processed, tokens_total):
self._check_active_memory_limit()
rqueue.put((tokens_processed, tokens_total))

try:
Expand Down Expand Up @@ -1030,9 +1143,16 @@ def progress(tokens_processed, tokens_total):
ctx.prompt_cache_count = len(prompt) - len(rest)
cache_key = prompt[:]
if cache is None:
cache = make_prompt_cache(self.model_provider.model)
if self.model_provider.draft_model is not None:
cache += make_prompt_cache(self.model_provider.draft_model)
cache = self._make_prompt_cache(
self.model_provider.model, self.model_provider.draft_model
)

admission_error = self._memory_admission_error(
cache,
len(rest) + args.max_tokens,
)
if admission_error is not None:
raise MemoryError(admission_error)

# Process the prompt and generate tokens
for gen in stream_generate(
Expand All @@ -1047,6 +1167,7 @@ def progress(tokens_processed, tokens_total):
num_draft_tokens=args.num_draft_tokens,
prompt_progress_callback=progress,
prefill_step_size=self.cli_args.prefill_step_size,
max_kv_size=self.cli_args.max_kv_size,
):
rqueue.put(
Response(
Expand Down Expand Up @@ -1502,7 +1623,8 @@ def keepalive_callback(processed_tokens, total_tokens):
progress_callback=keepalive_callback,
)
except Exception as e:
self._set_completion_headers(404)
status_code = 503 if is_metal_oom_error(e) else 500
self._set_completion_headers(status_code)
self.end_headers()
self.wfile.write(json.dumps({"error": f"{e}"}).encode())
return
Expand Down Expand Up @@ -2023,6 +2145,50 @@ def main():
type=_parse_size,
help="Maximum size in bytes of the KV caches",
)
parser.add_argument(
"--max-prompt-tokens",
type=int,
default=None,
help="Maximum prompt token count accepted by the server",
)
parser.add_argument(
"--prompt-overflow-policy",
type=str,
default="error",
choices=["error", "truncate"],
help="Behavior when prompt exceeds --max-prompt-tokens",
)
parser.add_argument(
"--prompt-keep-tokens",
type=int,
default=512,
help=(
"When truncating prompts, keep this many tokens from the start and fill "
"the remainder from the end"
),
)
parser.add_argument(
"--max-active-kv-bytes",
type=_parse_size,
help=(
"Reject requests when projected active KV memory would exceed this limit "
"(bytes or shorthand like 20G)"
),
)
parser.add_argument(
"--max-active-memory-bytes",
type=_parse_size,
help=(
"Abort requests when current active MLX memory exceeds this limit "
"(bytes or shorthand like 30G)"
),
)
parser.add_argument(
"--max-kv-size",
type=int,
default=None,
help="Maximum size of the active KV cache per sequence",
)
parser.add_argument(
"--pipeline",
action="store_true",
Expand Down
Loading