diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 0b8077e9..8bef8253 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -93,6 +93,10 @@ def _get_num_tokens_in_cache(self) -> int | None: for c in self.cache: if hasattr(c, "offset"): return c.offset + # Fallback: use the length of tracked tokens if cache offset is unavailable + # This handles models where cache layers don't expose offset (e.g., GPT-OSS) + if self.tokens is not None: + return len(self.tokens) return None @staticmethod diff --git a/mlx_engine/model_kit/batched_model_kit.py b/mlx_engine/model_kit/batched_model_kit.py index 0753e389..fbed7413 100644 --- a/mlx_engine/model_kit/batched_model_kit.py +++ b/mlx_engine/model_kit/batched_model_kit.py @@ -318,8 +318,11 @@ def get_next_request(timeout=None): ) # Track this request + # Use separate tracking for cross-prompt cache key (original prompt only) + # and live cache key (updated during generation for intra-request caching) self._batch_results[uid] = { - "cache_key": request.prompt_tokens[:], + "cross_prompt_cache_key": request.prompt_tokens[:], + "live_cache_key": request.prompt_tokens[:], "rqueue": request.rqueue, "detokenizer": self.tokenizer.detokenizer, "top_logprobs": request.top_logprobs, @@ -346,7 +349,7 @@ def get_next_request(timeout=None): for r in responses: # Create response object result = self._batch_results[r.uid] - result["cache_key"].append(r.token) + result["live_cache_key"].append(r.token) if r.finish_reason != "stop": result["detokenizer"].add_token(r.token) token_logprob = r.logprobs[r.token].item() @@ -386,8 +389,13 @@ def get_next_request(timeout=None): # Clean up if necessary if r.finish_reason is not None: result["rqueue"].put(None) + # Use cross_prompt_cache_key for cross-prompt caching + # This ensures the cache is keyed by the original prompt, + # not by the prompt + generated tokens self._prompt_cache.insert_cache( - current_model_key, result["cache_key"], r.prompt_cache + current_model_key, + result["cross_prompt_cache_key"], + r.prompt_cache, ) del self._batch_results[r.uid] diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index 001eaeb3..ccbab0df 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -93,6 +93,21 @@ def test_prompt_processing_cancellation(self): # Verify that the second attempt completed successfully self.assertIsNotNone(result_tokens) + def test_get_num_tokens_in_cache_without_offset(self): + """Test that _get_num_tokens_in_cache falls back to len(self.tokens) when offset is unavailable""" + mock_cache = [object() for _ in range(10)] + + wrapper = object.__new__(CacheWrapper) + wrapper.cache = mock_cache + wrapper.tokens = mx.array([1, 2, 3, 4, 5]) + + result = wrapper._get_num_tokens_in_cache() + self.assertEqual(result, 5) + + wrapper.tokens = None + result = wrapper._get_num_tokens_in_cache() + self.assertIsNone(result) + if __name__ == "__main__": unittest.main(verbosity=2)