diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88fa4ad32..c5b360fa5 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1361,8 +1361,11 @@ def merge(cls, caches): for i, (p, c) in enumerate(zip(padding, caches)): if c.keys is None: continue - keys[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.keys) - values[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.values) + n = c.size() + ordered_k = c._temporal_order(c.keys) + ordered_v = c._temporal_order(c.values) + keys[i : i + 1, :, p : p + n] = ordered_k[..., -n:, :] + values[i : i + 1, :, p : p + n] = ordered_v[..., -n:, :] cache = cls(caches[0].max_size, padding) cache.keys = keys diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index 05dcd7dc4..20d3ab88d 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -673,5 +673,57 @@ def test_window_mask_with_full_kv_cache(self): self.assertTrue(mx.array_equal(mask, expected)) +class TestBatchRotatingKVCacheMerge(unittest.TestCase): + def test_merge_output_shape_is_bounded_by_max_size_when_prompt_length_exceeds_max_size( + self, + ): + """merge() output sequence dimension must equal max_size, not prompt length.""" + cache = RotatingKVCache(max_size=70) + cache.update_and_fetch(mx.ones((1, 8, 128, 64)), mx.ones((1, 8, 128, 64))) + + merged = BatchRotatingKVCache.merge([cache]) + + self.assertEqual(merged.keys.shape[2], 70) + self.assertEqual(merged.values.shape[2], 70) + + def test_merge_contains_most_recent_tokens_when_prompt_length_exceeds_max_size( + self, + ): + """merge() must place the most-recent max_size tokens into the merged cache.""" + max_size, seq_len, H, D = 4, 8, 2, 16 + vals = mx.arange(1, seq_len + 1, dtype=mx.float32).reshape(1, 1, seq_len, 1) + keys = mx.broadcast_to(vals, (1, H, seq_len, D)) + cache = RotatingKVCache(max_size=max_size) + cache.update_and_fetch(keys, keys) + expected_first_slot_value = float(seq_len - max_size + 1) + + merged = BatchRotatingKVCache.merge([cache]) + mx.eval(merged.keys) + + self.assertAlmostEqual( + merged.keys[0, 0, 0, 0].item(), expected_first_slot_value + ) + + def test_merge_after_rotation_preserves_temporal_order(self): + """merge() must roll the ring buffer into temporal order after autoregressive wrap-around.""" + max_size, H, D = 4, 2, 8 + cache = RotatingKVCache(max_size=max_size) + cache.update_and_fetch( + mx.ones((1, H, max_size, D)), mx.ones((1, H, max_size, D)) + ) + for _ in range(2): + cache.update_and_fetch( + mx.full((1, H, 1, D), 2.0), mx.full((1, H, 1, D), 2.0) + ) + expected_first_slot_value = 1.0 + + merged = BatchRotatingKVCache.merge([cache]) + mx.eval(merged.keys) + + self.assertAlmostEqual( + merged.keys[0, 0, 0, 0].item(), expected_first_slot_value + ) + + if __name__ == "__main__": unittest.main()