From c216ca82d21efe247eb6ae260f50ae2ccb6841b4 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Sat, 28 Mar 2026 22:22:18 +0100 Subject: [PATCH 1/2] Fix prompt caching for GPT-OSS 20B MLX models Two fixes: 1. cache_wrapper.py: Added fallback when cache layers don't expose `offset` attribute 2. batched_model_kit.py: Separated cross-prompt cache key from live cache key The main issue was that batched models (like GPT-OSS) were tracking generated tokens in the cross-prompt cache key, preventing cache hits for new prompts with overlapping content. --- mlx_engine/cache_wrapper.py | 4 ++++ mlx_engine/model_kit/batched_model_kit.py | 14 +++++++++++--- tests/test_cache_wrapper.py | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 0b8077e9..8bef8253 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -93,6 +93,10 @@ def _get_num_tokens_in_cache(self) -> int | None: for c in self.cache: if hasattr(c, "offset"): return c.offset + # Fallback: use the length of tracked tokens if cache offset is unavailable + # This handles models where cache layers don't expose offset (e.g., GPT-OSS) + if self.tokens is not None: + return len(self.tokens) return None @staticmethod diff --git a/mlx_engine/model_kit/batched_model_kit.py b/mlx_engine/model_kit/batched_model_kit.py index 0753e389..fbed7413 100644 --- a/mlx_engine/model_kit/batched_model_kit.py +++ b/mlx_engine/model_kit/batched_model_kit.py @@ -318,8 +318,11 @@ def get_next_request(timeout=None): ) # Track this request + # Use separate tracking for cross-prompt cache key (original prompt only) + # and live cache key (updated during generation for intra-request caching) self._batch_results[uid] = { - "cache_key": request.prompt_tokens[:], + "cross_prompt_cache_key": request.prompt_tokens[:], + "live_cache_key": request.prompt_tokens[:], "rqueue": request.rqueue, "detokenizer": self.tokenizer.detokenizer, "top_logprobs": request.top_logprobs, @@ -346,7 +349,7 @@ def get_next_request(timeout=None): for r in responses: # Create response object result = self._batch_results[r.uid] - result["cache_key"].append(r.token) + result["live_cache_key"].append(r.token) if r.finish_reason != "stop": result["detokenizer"].add_token(r.token) token_logprob = r.logprobs[r.token].item() @@ -386,8 +389,13 @@ def get_next_request(timeout=None): # Clean up if necessary if r.finish_reason is not None: result["rqueue"].put(None) + # Use cross_prompt_cache_key for cross-prompt caching + # This ensures the cache is keyed by the original prompt, + # not by the prompt + generated tokens self._prompt_cache.insert_cache( - current_model_key, result["cache_key"], r.prompt_cache + current_model_key, + result["cross_prompt_cache_key"], + r.prompt_cache, ) del self._batch_results[r.uid] diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 001eaeb3..d48494e8 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -42,6 +42,7 @@ def test_find_common_prefix_all_match(self): def test_prompt_processing_cancellation(self): """Test that progress is saved when processing is cancelled and cache is reused on retry""" + self.skipTest("Requires model download") model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") model_kit = load_model(model_path=model_path, max_kv_size=4096) @@ -93,6 +94,21 @@ def test_prompt_processing_cancellation(self): # Verify that the second attempt completed successfully self.assertIsNotNone(result_tokens) + def test_get_num_tokens_in_cache_without_offset(self): + """Test that _get_num_tokens_in_cache falls back to len(self.tokens) when offset is unavailable""" + mock_cache = [object() for _ in range(10)] + + wrapper = object.__new__(CacheWrapper) + wrapper.cache = mock_cache + wrapper.tokens = mx.array([1, 2, 3, 4, 5]) + + result = wrapper._get_num_tokens_in_cache() + self.assertEqual(result, 5) + + wrapper.tokens = None + result = wrapper._get_num_tokens_in_cache() + self.assertIsNone(result) + if __name__ == "__main__": unittest.main(verbosity=2) From fce2ffd6d44c802d18d8d9236118f8891d304d5c Mon Sep 17 00:00:00 2001 From: Dmitry Date: Sun, 29 Mar 2026 00:43:28 +0100 Subject: [PATCH 2/2] Revert unintended changes --- tests/test_cache_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index d48494e8..ccbab0df 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -42,7 +42,6 @@ def test_find_common_prefix_all_match(self): def test_prompt_processing_cancellation(self): """Test that progress is saved when processing is cancelled and cache is reused on retry""" - self.skipTest("Requires model download") model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") model_kit = load_model(model_path=model_path, max_kv_size=4096)