diff --git a/tests/test_kv_cache.py b/tests/test_kv_cache.py index 00e13a7d8..914667fb4 100644 --- a/tests/test_kv_cache.py +++ b/tests/test_kv_cache.py @@ -102,11 +102,25 @@ def test_memory_stats(self): compressor = KVCacheCompressor(head_dim=128, k_bits=3, v_bits=3) stats = compressor.memory_stats(seq_len=1024, num_layers=32, num_heads=32) - # K: 3 bits/val + norm overhead, V: 3 bits/val - # Ratio vs fp16 (16 bits): 16 / ((3+3)/2 + overhead) ≈ 2.5-3x - assert stats["compression_ratio"] > 2.0 + # Combined K/V fp16 baseline: 32 bits/value pair. + # For 3-bit K and V, actual stored metadata gives ~4.74x compression. + assert stats["compression_ratio"] > 4.0 assert stats["compressed_mb"] < stats["original_mb"] + def test_memory_stats_exact_accounting(self): + """Memory stats should match the actual K/V storage layout.""" + compressor = KVCacheCompressor(head_dim=128, k_bits=3, v_bits=3) + stats = compressor.memory_stats(seq_len=1, num_layers=1, num_heads=1) + + # One K vector and one V vector at head_dim=128: + # - Original fp16 K/V pair: 128 * 2 bytes * 2 tensors = 512 bytes + # - K compressed: 128 * 3 bits + 64 bits of norms = 448 bits = 56 bytes + # - V compressed: 128 * 3 bits + 32-bit norm = 416 bits = 52 bytes + # - Total compressed = 108 bytes + assert stats["original_mb"] == pytest.approx(512 / 1024 / 1024) + assert stats["compressed_mb"] == pytest.approx(108 / 1024 / 1024) + assert stats["compression_ratio"] == pytest.approx(512 / 108) + def test_metadata_stored(self): """Compressed cache should store correct metadata.""" compressor = KVCacheCompressor(head_dim=64, k_bits=3, v_bits=3) diff --git a/turboquant/kv_cache.py b/turboquant/kv_cache.py index 80c61f9cf..ce222666a 100644 --- a/turboquant/kv_cache.py +++ b/turboquant/kv_cache.py @@ -156,12 +156,18 @@ def memory_stats(self, seq_len: int, num_layers: int, num_heads: int) -> dict: Returns dict with original_mb, compressed_mb, ratio. """ n_vectors = num_layers * num_heads * seq_len - original_bytes = n_vectors * self.head_dim * 2 # fp16 - - # K: b bits per coord + 32-bit norm - k_bits_total = n_vectors * (self.head_dim * self.k_bits + 32) - # V: b bits per coord (no norm needed for MSE-only) - v_bits_total = n_vectors * self.head_dim * self.v_bits + # Original KV cache stores both K and V in fp16. + original_bytes = n_vectors * self.head_dim * 2 * 2 + + # K uses full TurboQuant: + # - d * k_bits total quantized bits + # - 32-bit vector norm + # - 32-bit residual norm + k_bits_total = n_vectors * (self.head_dim * self.k_bits + 64) + # V uses MSE-only PolarQuant: + # - d * v_bits quantized bits + # - 32-bit vector norm + v_bits_total = n_vectors * (self.head_dim * self.v_bits + 32) compressed_bytes = (k_bits_total + v_bits_total) / 8