Enable prefix cache reuse for hybrid models (Qwen3.5, Nemotron-H, Mamba)#1006
Enable prefix cache reuse for hybrid models (Qwen3.5, Nemotron-H, Mamba)#1006skcadri wants to merge 3 commits intoml-explore:mainfrom
Conversation
| if not self.model_provider._cache_is_trimmable: | ||
| # For hybrid models with non-trimmable caches (ArraysCache), | ||
| # checkpoint at the last-message boundary so the prefix | ||
| # can be reused via the shorter-cache path. |
There was a problem hiding this comment.
_serve_single doesn't call _compute_prompt_checkpoint at all - it only runs in the batch path. is that intentional? if a hybrid model is served without batching (_is_batchable returns False), the checkpoint logic won't kick in and you'd still get 0% cache reuse on the single-serve path.
There was a problem hiding this comment.
Good catch. This is a pre-existing limitation — _serve_single doesn't support checkpoints for trimmable models either (the think-token checkpoint only runs in the batch path). Adding checkpoint support there would require splitting the prefill in stream_generate, which is a larger change.
That said, _serve_single still benefits from checkpoints saved by prior batched requests — if a batched request already checkpointed the prefix, a subsequent single-serve request will find it via fetch_nearest_cache.
| if len(prefix_messages) > 0: | ||
| chat_template_args = self.model_provider.cli_args.chat_template_args | ||
| if args is not None and hasattr(args, 'chat_template_kwargs') and args.chat_template_kwargs: | ||
| chat_template_args = chat_template_args.copy() |
There was a problem hiding this comment.
this guard chain is a bit defensive - args is always passed from the call site now (line 905), and chat_template_kwargs is set by _parse_args. could simplify to just getattr(args, 'chat_template_kwargs', None).
There was a problem hiding this comment.
Simplified in 82574b4 — now uses getattr(args, 'chat_template_kwargs', None).
|
I tried this with parallel requests, but it seems to trigger a bug: |
Hybrid models like Qwen3.5 and Nemotron-H use both KVCache (attention) and ArraysCache (SSM/recurrent) layers. Since ArraysCache stores compressed recurrent state that cannot be trimmed, can_trim_prompt_cache() returns False, blocking the "longer cache trim" path in fetch_nearest_cache() and resulting in 0% cache reuse on the divergent-suffix pattern (same system prompt, different user messages). PR ml-explore#999 fixed this for sliding-window models (RotatingKVCache). This commit addresses the remaining non-sliding-window hybrid models by extending the existing prompt checkpoint mechanism to save the cache at the last-message boundary during prefill, enabling subsequent requests to find it via the "shorter cache" path — no trimming needed. The boundary is found using a sentinel substitution technique: the last user message is replaced with a short dummy, both versions are tokenized, and the common prefix identifies the exact split point. This avoids issues where apply_chat_template(messages[:-1]) produces tokens that don't form a prefix of apply_chat_template(messages). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Left-pad last_inputs in _process_prompts to handle variable-length checkpoint offsets across batched requests (fixes ValueError with non-uniform length when parallel requests have different user message lengths) - Simplify chat_template_kwargs guard to use getattr Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ec0434b to
82574b4
Compare
|
@dae Thanks for the repro — this was a real bug. When parallel requests with different user message lengths get batched, Fixed in 82574b4 by left-padding Could you try again with the latest push? |
|
I don't get an error any more, but I still get strange results with concurrency. My use case is a browser ext sending a 'translate this' system prompt and then about 300 tokens. The first 'translate page' fires off about 6 parallel requests, and things perform fine as there's no cache at that point. But when I repeat it, I get output that implies some of the system prompt has been lost or corrupted. Adding --prompt-concurrency 1 to startup params makes the odd behaviour go away. Edit: I've observed (less common) strange behaviour on subsequent runs even with --prompt-concurrency 1. |
|
Hey ! thanks for the PR; pretty interested in these fixes! |
The checkpoint_callback used the batch-global prompt_end (max checkpoint offset across all batched requests) to compute the checkpoint cache key. For requests with shorter user messages, this cut into the shared prefix, storing a key that didn't match the cache state — causing cache corruption on subsequent requests. Fix: store per-request checkpoint_position and only save checkpoints for requests whose individual offset matches the batch-global prompt_end. Requests with shorter offsets safely skip the checkpoint (they'll get their own checkpoint when they're the longest in a future batch, or via sequential processing). Tested with parallel requests of varying user message lengths on Qwen3.5-9B: all requests return correct answers with cache hits, no corruption after repeated bursts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@dae Great catch — found the root cause. The Fixed in c2e01bc: each request now tracks its own Stress-tested with parallel bursts of varying user message lengths on Qwen3.5-9B — all requests return correct answers with cache hits across repeated bursts, no corruption. The Could you try the latest push and see if the corruption is gone for your translation use case? |
|
@celestial-rose Thanks for the interest! This PR specifically targets the For parallel queries in |
|
This solves a real problem. We run Nemotron-3-Super-120B-A12B (40 Mamba + 8 Attention + 40 MoE layers) on M2 Ultra and get 0% cache reuse on the current The batch concurrency fix for Note: |
|
@Thump604 Great to hear this helps Nemotron-3-Super on M2 Ultra — that's a perfect test case with the mixed Mamba/Attention/MoE layer composition. Good catch on the #1030 conflict. Will keep an eye on it and rebase if #1030 lands first (the overlap is just the |
|
@angeloskath Thanks for the pointer — tested #1072 with Qwen3.5-9B (hybrid: 75% ArraysCache + 25% KVCache) and it fully handles prefix caching for non-trimmable models. Both sequential cache reuse and parallel requests with varying user message lengths work correctly, including the translation scenario that @dae reported corruption on. Results on
Our PR is redundant with #1072. Closing this — the refactored batch generation with arbitrary segment checkpointing is the cleaner solution. Thanks for the great work on #1072! |
Summary
Fixes prefix cache reuse for non-sliding-window hybrid models — the remaining case from #980 that #999 did not address.
KVCache(attention) andArraysCache(SSM/recurrent) per layerArraysCachestores compressed recurrent state that cannot be trimmed — you can't remove N tokens from a hidden statecan_trim_prompt_cache()returnsFalse→ the "longer cache trim" path infetch_nearest_cache()is blocked → 0% cache reuse on divergent-suffix requests (same system prompt, different user messages)This PR extends the existing prompt checkpoint mechanism to save the cache at the last-message boundary during prefill. Subsequent requests with the same prefix find it via the "shorter cache" path — no trimming needed.
Approach
For non-trimmable models,
_compute_prompt_checkpointnow computes where the last user message begins using a sentinel substitution technique:"x")This avoids a subtle issue where
apply_chat_template(messages[:-1])produces tokens that don't form a prefix ofapply_chat_template(messages)due to template-specific formatting (discovered during testing with Qwen3.5's template which requires a user message and inserts role tokens differently).For trimmable models (pure attention), the existing
<think>token checkpoint logic is completely unchanged.Changes
mlx_lm/server.py— 1 file, +60/-10 lines:ModelProvider.load(): Cache_cache_is_trimmableflag at model load time_compute_prompt_checkpoint(): Branch on trimmability:<think>token checkpoint (unchanged)argsto_compute_prompt_checkpointforchat_template_kwargsaccessTest Results
Unit tests
All 20 existing tests pass (
python -m pytest tests/test_prompt_cache.py -v).Qwen3.5-9B (attention + linear recurrence hybrid)
Cache composition:
KVCache+ArraysCache,can_trim=FalseAgentic workload (~2.7K token system prompt, 5 divergent user messages):
Summary: 98% cache hit, 2.86x speedup, 26.78s saved across 4 warm requests.
Multi-turn test (shared
[sys, user, assistant]prefix, different follow-up):NVIDIA Nemotron-H 30B-A3B (attention + Mamba SSM hybrid)
Architecture:
nemotron_h— 52 layers,MEMEM*EMEMEM*...pattern. Only 6/29 cache layers trimmable (KVCache), 23/29 areArraysCache(Mamba SSM state).can_trim=False.Agentic workload (~2.7K token system prompt):
Summary: 98% cache hit, 3.1x speedup. Stock LM Studio (unpatched) shows 0% cache hit on the same workload.
Regression: Qwen1.5-0.5B (pure attention, trimmable)
Verified the trimmable model path is unaffected — cache trim still works correctly (69% hit via existing trim path).
Edge cases verified
How It Works
First request
[sys, user_1]:_compute_prompt_checkpointfinds the message boundary via sentinel comparisonSecond request
[sys, user_2]:fetch_nearest_cachefinds cached[sys]via the shorter path (existing code, line ~303)rest = user_2_tokensRelationship to #999
PR #999 by @Giustino98 fixed cache reuse for sliding-window models (
RotatingKVCache) by enabling trim on wrapped ring buffers. This PR complements it by handling the non-sliding-window hybrid models (ArraysCache) where trimming is fundamentally impossible, using the checkpoint approach instead.Together they resolve #980 completely.
🤖 Generated with Claude Code