From 103a4acd1ad890e646a60b8e8336021f6c7fa4e0 Mon Sep 17 00:00:00 2001 From: Yuan Lik Xun Date: Wed, 25 Mar 2026 15:23:39 +0800 Subject: [PATCH 1/3] Add failing reproduction for merge() OOB when prompt exceeds max_size BatchRotatingKVCache.merge() crashes when a constituent RotatingKVCache received a prompt longer than max_size on its first fill. _update_concat stores every token without trimming on first fill (trimming is deferred to the next call), leaving _idx == prompt_len > max_size. merge() used _idx as the output-slice width, writing prompt_len tokens into a max_size-wide buffer -> out-of-bounds write. Reproducer: RotatingKVCache(max_size=70) fed a 128-token prefill raises an index error when merge() is called. Signed-off-by: Yuan Lik Xun --- tests/test_prompt_cache.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index 05dcd7dc4..6eebb9319 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -673,5 +673,19 @@ def test_window_mask_with_full_kv_cache(self): self.assertTrue(mx.array_equal(mask, expected)) +class TestBatchRotatingKVCacheMerge(unittest.TestCase): + def test_merge_output_shape_is_bounded_by_max_size_when_prompt_length_exceeds_max_size( + self, + ): + """merge() output sequence dimension must equal max_size, not prompt length.""" + cache = RotatingKVCache(max_size=70) + cache.update_and_fetch(mx.ones((1, 8, 128, 64)), mx.ones((1, 8, 128, 64))) + + merged = BatchRotatingKVCache.merge([cache]) + + self.assertEqual(merged.keys.shape[2], 70) + self.assertEqual(merged.values.shape[2], 70) + + if __name__ == "__main__": unittest.main() From c45130ee374899b49639a8b438e53cfb4fe6aa52 Mon Sep 17 00:00:00 2001 From: Yuan Lik Xun Date: Wed, 25 Mar 2026 15:24:04 +0800 Subject: [PATCH 2/3] Fix BatchRotatingKVCache.merge() OOB write when prompt exceeds max_size _update_concat defers trimming on first fill, so after a prompt longer than max_size, _idx == prompt_len > max_size. Using _idx as the output-slice width writes past the end of the max_size-wide buffer. c.size() (= min(offset, max_size)) is the correct width. The slice is taken from the tail because _temporal_order returns tokens oldest-first; the sliding window must retain the most-recent n, not the oldest. Signed-off-by: Yuan Lik Xun --- mlx_lm/models/cache.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88fa4ad32..c5b360fa5 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1361,8 +1361,11 @@ def merge(cls, caches): for i, (p, c) in enumerate(zip(padding, caches)): if c.keys is None: continue - keys[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.keys) - values[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.values) + n = c.size() + ordered_k = c._temporal_order(c.keys) + ordered_v = c._temporal_order(c.values) + keys[i : i + 1, :, p : p + n] = ordered_k[..., -n:, :] + values[i : i + 1, :, p : p + n] = ordered_v[..., -n:, :] cache = cls(caches[0].max_size, padding) cache.keys = keys From cc56f4cbbe70902e0fc0fea041a5d20ee5137bea Mon Sep 17 00:00:00 2001 From: Yuan Lik Xun Date: Wed, 25 Mar 2026 15:24:35 +0800 Subject: [PATCH 3/3] Add regression tests for merge() large-prefill and rotation paths Three tests cover the fix: - Shape is capped at max_size when prompt exceeds max_size - Most-recent tokens land in the merged cache, not the oldest - Ring buffer is rolled into temporal order after autoregressive wrap-around Signed-off-by: Yuan Lik Xun --- tests/test_prompt_cache.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index 6eebb9319..20d3ab88d 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -686,6 +686,44 @@ def test_merge_output_shape_is_bounded_by_max_size_when_prompt_length_exceeds_ma self.assertEqual(merged.keys.shape[2], 70) self.assertEqual(merged.values.shape[2], 70) + def test_merge_contains_most_recent_tokens_when_prompt_length_exceeds_max_size( + self, + ): + """merge() must place the most-recent max_size tokens into the merged cache.""" + max_size, seq_len, H, D = 4, 8, 2, 16 + vals = mx.arange(1, seq_len + 1, dtype=mx.float32).reshape(1, 1, seq_len, 1) + keys = mx.broadcast_to(vals, (1, H, seq_len, D)) + cache = RotatingKVCache(max_size=max_size) + cache.update_and_fetch(keys, keys) + expected_first_slot_value = float(seq_len - max_size + 1) + + merged = BatchRotatingKVCache.merge([cache]) + mx.eval(merged.keys) + + self.assertAlmostEqual( + merged.keys[0, 0, 0, 0].item(), expected_first_slot_value + ) + + def test_merge_after_rotation_preserves_temporal_order(self): + """merge() must roll the ring buffer into temporal order after autoregressive wrap-around.""" + max_size, H, D = 4, 2, 8 + cache = RotatingKVCache(max_size=max_size) + cache.update_and_fetch( + mx.ones((1, H, max_size, D)), mx.ones((1, H, max_size, D)) + ) + for _ in range(2): + cache.update_and_fetch( + mx.full((1, H, 1, D), 2.0), mx.full((1, H, 1, D), 2.0) + ) + expected_first_slot_value = 1.0 + + merged = BatchRotatingKVCache.merge([cache]) + mx.eval(merged.keys) + + self.assertAlmostEqual( + merged.keys[0, 0, 0, 0].item(), expected_first_slot_value + ) + if __name__ == "__main__": unittest.main()