From 31410002d37592d55d8fdc91090f80555f46e1da Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Tue, 31 Mar 2026 02:55:02 -0400 Subject: [PATCH] feat: add MLX backend prototype for TurboQuant Python path --- README.md | 8 + benchmarks/benchmark_mlx_backend.py | 96 +++++++ pyproject.toml | 3 + tests/test_mlx_backend.py | 137 ++++++++++ turboquant/__init__.py | 23 +- turboquant/mlx_backend.py | 388 ++++++++++++++++++++++++++++ 6 files changed, 654 insertions(+), 1 deletion(-) create mode 100644 benchmarks/benchmark_mlx_backend.py create mode 100644 tests/test_mlx_backend.py create mode 100644 turboquant/mlx_backend.py diff --git a/README.md b/README.md index 2f799246e..480b8705f 100644 --- a/README.md +++ b/README.md @@ -316,6 +316,7 @@ Ratio: 1.000 exactly - **NumPy** >= 1.24, **SciPy** >= 1.10 - **cmake** + C/C++ compiler (for llama.cpp build) - **Xcode Command Line Tools** (macOS Metal build) +- **Optional**: `mlx` for the Apple Silicon MLX backend prototype - **Optional**: `torch`, `transformers`, `accelerate` (~4GB download, for real model validation) ### Install the Python Prototype @@ -330,6 +331,12 @@ pip install -e ".[dev]" python3 -m pytest tests/ -v ``` +On Apple Silicon, install the optional MLX backend dependencies with: + +```bash +pip install -e ".[dev,mlx]" +``` + ### Run the Demo ```bash @@ -450,6 +457,7 @@ turboquant/ tests/ # 14 test files, 500+ tests benchmarks/ ├── demo.py # Quick compression demo +├── benchmark_mlx_backend.py # NumPy vs MLX backend benchmark ├── run_benchmark.py # Server-based benchmark runner ├── benchmark_results.md # Full benchmark report ├── benchmark_llama.sh # llama.cpp benchmark script diff --git a/benchmarks/benchmark_mlx_backend.py b/benchmarks/benchmark_mlx_backend.py new file mode 100644 index 000000000..93a66664a --- /dev/null +++ b/benchmarks/benchmark_mlx_backend.py @@ -0,0 +1,96 @@ +"""Benchmark the optional MLX backend against the NumPy prototype. + +Usage: + python3 benchmarks/benchmark_mlx_backend.py +""" + +from __future__ import annotations + +import time + +import numpy as np + +from turboquant import KVCacheCompressor, MLXKVCacheCompressor, MLX_AVAILABLE +from turboquant.mlx_backend import to_numpy + + +QWEN_27B = { + "name": "Qwen 3.5 27B (dense)", + "num_layers": 28, + "num_kv_heads": 8, + "head_dim": 128, +} + +QWEN_MOE = { + "name": "Qwen 3.5 35B-A3B (MoE)", + "num_layers": 28, + "num_kv_heads": 8, + "head_dim": 128, +} + + +def simulate_kv_cache(config: dict, seq_len: int, seed: int = 42) -> tuple[np.ndarray, np.ndarray]: + rng = np.random.default_rng(seed) + shape = (config["num_layers"], config["num_kv_heads"], seq_len, config["head_dim"]) + scale = 1.0 / np.sqrt(config["head_dim"]) + k_cache = rng.standard_normal(shape) * scale + v_cache = rng.standard_normal(shape) * scale + return k_cache, v_cache + + +def run_backend(name: str, compressor, k_cache: np.ndarray, v_cache: np.ndarray): + t0 = time.perf_counter() + compressed = compressor.compress(k_cache, v_cache) + k_hat, v_hat = compressor.decompress(compressed) + + if MLX_AVAILABLE and name == "mlx": + import mlx.core as mx + + mx.eval(k_hat, v_hat) + + elapsed = time.perf_counter() - t0 + k_hat_np = to_numpy(k_hat) + v_hat_np = to_numpy(v_hat) + + k_mse = np.mean((k_cache - k_hat_np) ** 2) + v_mse = np.mean((v_cache - v_hat_np) ** 2) + return elapsed, k_mse, v_mse + + +def main(): + if not MLX_AVAILABLE: + raise SystemExit("MLX is not available. Install `mlx` on an MLX-supported system to run this benchmark.") + + print("=" * 70) + print("TURBOQUANT MLX BACKEND BENCHMARK") + print("=" * 70) + + for config in (QWEN_27B, QWEN_MOE): + print(f"\n{config['name']}") + for seq_len in (512, 2048): + print(f" seq_len={seq_len}") + k_cache, v_cache = simulate_kv_cache(config, seq_len) + + numpy_compressor = KVCacheCompressor(head_dim=config["head_dim"], k_bits=3, v_bits=3) + mlx_compressor = MLXKVCacheCompressor( + head_dim=config["head_dim"], + k_bits=3, + v_bits=3, + dtype="float32", + ) + + numpy_elapsed, numpy_k_mse, numpy_v_mse = run_backend("numpy", numpy_compressor, k_cache, v_cache) + mlx_elapsed, mlx_k_mse, mlx_v_mse = run_backend("mlx", mlx_compressor, k_cache, v_cache) + + print( + f" NumPy: {numpy_elapsed:.2f}s, " + f"K MSE={numpy_k_mse:.8f}, V MSE={numpy_v_mse:.8f}" + ) + print( + f" MLX: {mlx_elapsed:.2f}s, " + f"K MSE={mlx_k_mse:.8f}, V MSE={mlx_v_mse:.8f}" + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 930a81363..e25253a04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,9 @@ dev = [ "pytest>=7.0", "pytest-cov>=4.0", ] +mlx = [ + "mlx>=0.31", +] bench = [ "matplotlib", ] diff --git a/tests/test_mlx_backend.py b/tests/test_mlx_backend.py new file mode 100644 index 000000000..c1148f26f --- /dev/null +++ b/tests/test_mlx_backend.py @@ -0,0 +1,137 @@ +"""Parity tests for the optional MLX backend.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from turboquant.kv_cache import KVCacheCompressor +from turboquant.mlx_backend import ( + MLXKVCacheCompressor, + MLXPolarQuant, + MLXTurboQuant, + MLX_AVAILABLE, + to_numpy, +) +from turboquant.polar_quant import PolarQuant +from turboquant.turboquant import TurboQuant + + +mlx = pytest.importorskip("mlx.core", reason="MLX backend tests require the optional mlx package") + + +@pytest.mark.skipif(not MLX_AVAILABLE, reason="MLX is not available") +class TestMLXPolarQuant: + def test_quantize_matches_numpy(self): + d = 128 + rng = np.random.default_rng(7) + x = rng.standard_normal((8, d)) + + numpy_pq = PolarQuant(d=d, bit_width=3, seed=42, norm_correction=True) + mlx_pq = MLXPolarQuant(d=d, bit_width=3, seed=42, norm_correction=True, dtype="float64") + + numpy_idx, numpy_norms = numpy_pq.quantize(x) + mlx_idx, mlx_norms = mlx_pq.quantize(x) + mlx_idx_np = to_numpy(mlx_idx) + mlx_norms_np = to_numpy(mlx_norms) + + np.testing.assert_array_equal(mlx_idx_np, numpy_idx) + np.testing.assert_allclose(mlx_norms_np, numpy_norms, atol=1e-10) + + numpy_recon = numpy_pq.dequantize(numpy_idx, numpy_norms) + mlx_recon = to_numpy(mlx_pq.dequantize(mlx_idx, mlx_norms)) + np.testing.assert_allclose(mlx_recon, numpy_recon, atol=1e-8) + + def test_single_vector_round_trip(self): + d = 64 + rng = np.random.default_rng(11) + x = rng.standard_normal(d) + + mlx_pq = MLXPolarQuant(d=d, bit_width=2, seed=99, dtype="float64") + idx, norms = mlx_pq.quantize(x) + x_hat = to_numpy(mlx_pq.dequantize(idx, norms)) + + assert x_hat.shape == x.shape + assert np.mean((x - x_hat) ** 2) < np.mean(x ** 2) + + +@pytest.mark.skipif(not MLX_AVAILABLE, reason="MLX is not available") +class TestMLXTurboQuant: + def test_turboquant_matches_numpy(self): + d = 128 + rng = np.random.default_rng(21) + x = rng.standard_normal((6, d)) + + numpy_tq = TurboQuant(d=d, bit_width=3, seed=42, norm_correction=True) + mlx_tq = MLXTurboQuant(d=d, bit_width=3, seed=42, norm_correction=True, dtype="float64") + + numpy_compressed = numpy_tq.quantize(x) + mlx_compressed = mlx_tq.quantize(x) + + np.testing.assert_array_equal(to_numpy(mlx_compressed.mse_indices), numpy_compressed.mse_indices) + np.testing.assert_allclose(to_numpy(mlx_compressed.vector_norms), numpy_compressed.vector_norms, atol=1e-10) + np.testing.assert_array_equal(to_numpy(mlx_compressed.qjl_signs), numpy_compressed.qjl_signs) + np.testing.assert_allclose( + to_numpy(mlx_compressed.residual_norms), + numpy_compressed.residual_norms, + atol=1e-10, + ) + + numpy_recon = numpy_tq.dequantize(numpy_compressed) + mlx_recon = to_numpy(mlx_tq.dequantize(mlx_compressed)) + np.testing.assert_allclose(mlx_recon, numpy_recon, atol=1e-8) + + +@pytest.mark.skipif(not MLX_AVAILABLE, reason="MLX is not available") +class TestMLXKVCache: + def test_kv_cache_matches_numpy(self): + rng = np.random.default_rng(42) + k = rng.standard_normal((2, 3, 8, 64)) + v = rng.standard_normal((2, 3, 8, 64)) + + numpy_compressor = KVCacheCompressor(head_dim=64, k_bits=3, v_bits=3) + mlx_compressor = MLXKVCacheCompressor(head_dim=64, k_bits=3, v_bits=3, dtype="float64") + + numpy_compressed = numpy_compressor.compress(k, v) + mlx_compressed = mlx_compressor.compress(k, v) + + numpy_k, numpy_v = numpy_compressor.decompress(numpy_compressed) + mlx_k, mlx_v = mlx_compressor.decompress(mlx_compressed) + mlx_k = to_numpy(mlx_k) + mlx_v = to_numpy(mlx_v) + + np.testing.assert_allclose(mlx_k, numpy_k, atol=1e-8) + np.testing.assert_allclose(mlx_v, numpy_v, atol=1e-8) + + def test_attention_output_remains_reasonable(self): + head_dim = 64 + seq_len = 16 + rng = np.random.default_rng(123) + + q = rng.standard_normal((1, head_dim)) + k = rng.standard_normal((seq_len, head_dim)) + v = rng.standard_normal((seq_len, head_dim)) + + compressor = MLXKVCacheCompressor(head_dim=head_dim, k_bits=3, v_bits=3, dtype="float32") + compressed = compressor.compress(k[np.newaxis, np.newaxis, :, :], v[np.newaxis, np.newaxis, :, :]) + k_hat, v_hat = compressor.decompress(compressed) + k_hat = to_numpy(k_hat)[0, 0] + v_hat = to_numpy(v_hat)[0, 0] + + scores_orig = q @ k.T / np.sqrt(head_dim) + attn_orig = _softmax(scores_orig) + out_orig = attn_orig @ v + + scores_comp = q @ k_hat.T / np.sqrt(head_dim) + attn_comp = _softmax(scores_comp) + out_comp = attn_comp @ v_hat + + cosine = np.dot(out_orig.ravel(), out_comp.ravel()) / ( + np.linalg.norm(out_orig) * np.linalg.norm(out_comp) + ) + assert cosine > 0.5 + + +def _softmax(x): + e = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return e / np.sum(e, axis=-1, keepdims=True) diff --git a/turboquant/__init__.py b/turboquant/__init__.py index a7ebc4961..798c93088 100644 --- a/turboquant/__init__.py +++ b/turboquant/__init__.py @@ -4,5 +4,26 @@ from turboquant.qjl import QJL from turboquant.turboquant import TurboQuant, TurboQuantMSE, CompressedVector from turboquant.kv_cache import KVCacheCompressor +from turboquant.mlx_backend import ( + MLX_AVAILABLE, + MLXKVCacheCompressor, + MLXPolarQuant, + MLXQJL, + MLXTurboQuant, + MLXTurboQuantMSE, +) -__all__ = ["PolarQuant", "QJL", "TurboQuant", "TurboQuantMSE", "CompressedVector", "KVCacheCompressor"] +__all__ = [ + "PolarQuant", + "QJL", + "TurboQuant", + "TurboQuantMSE", + "CompressedVector", + "KVCacheCompressor", + "MLX_AVAILABLE", + "MLXPolarQuant", + "MLXQJL", + "MLXTurboQuant", + "MLXTurboQuantMSE", + "MLXKVCacheCompressor", +] diff --git a/turboquant/mlx_backend.py b/turboquant/mlx_backend.py new file mode 100644 index 000000000..26580c88a --- /dev/null +++ b/turboquant/mlx_backend.py @@ -0,0 +1,388 @@ +"""Optional MLX backend for the Python TurboQuant prototype. + +This module mirrors the existing NumPy prototype with MLX arrays for the +compute-heavy paths. It is intentionally self-contained and import-safe on +systems where MLX is unavailable. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from turboquant.codebook import optimal_centroids +from turboquant.kv_cache import CompressedKVCache +from turboquant.qjl import QJL_CONST +from turboquant.rotation import random_rotation_dense +from turboquant.turboquant import CompressedVector + +try: # pragma: no cover - exercised indirectly on MLX-capable systems + import mlx.core as mx +except ImportError: # pragma: no cover - expected on non-Apple CI/dev boxes + mx = None + + +MLX_AVAILABLE = mx is not None + + +def _require_mlx() -> None: + if not MLX_AVAILABLE: + raise ImportError( + "MLX backend requires the optional 'mlx' package. " + "Install it on an MLX-supported system with `pip install mlx`." + ) + + +def _resolve_dtype(dtype: str): + _require_mlx() + try: + return getattr(mx, dtype) + except AttributeError as exc: + raise ValueError(f"Unsupported MLX dtype: {dtype}") from exc + + +def _is_mlx_array(value: Any) -> bool: + if not MLX_AVAILABLE: + return False + + array_type = getattr(mx, "array", None) + if isinstance(array_type, type): + return isinstance(value, array_type) + + # Fallback for binding implementations where `mx.array` is callable but not a + # direct Python type object. + value_type = type(value) + return value_type.__module__.startswith("mlx.") and value_type.__name__ == "array" + + +def _to_mx_array(value: Any, dtype=None): + _require_mlx() + if _is_mlx_array(value): + return value.astype(dtype) if dtype is not None and value.dtype != dtype else value + return mx.array(value, dtype=dtype) + + +def to_numpy(value: Any) -> np.ndarray: + """Convert MLX arrays to NumPy arrays without requiring callers to branch.""" + if _is_mlx_array(value): + return np.array(value) + return np.asarray(value) + + +class MLXPolarQuant: + """MLX implementation of PolarQuant using the NumPy-calibrated codebook.""" + + def __init__( + self, + d: int, + bit_width: int, + seed: int = 42, + norm_correction: bool = True, + dtype: str = "float32", + ): + _require_mlx() + self.d = d + self.bit_width = bit_width + self.n_centroids = 1 << bit_width + self.norm_correction = norm_correction + self.dtype = _resolve_dtype(dtype) + + rng = np.random.default_rng(seed) + rotation = random_rotation_dense(d, rng) + centroids = optimal_centroids(bit_width, d) + boundaries = (centroids[:-1] + centroids[1:]) / 2.0 + + self.rotation = mx.array(rotation, dtype=self.dtype) + self.rotation_t = mx.transpose(self.rotation) + self.centroids = mx.array(centroids, dtype=self.dtype) + self.boundaries = mx.array(boundaries, dtype=self.dtype) + + def quantize(self, x: Any): + x_arr = _to_mx_array(x, self.dtype) + single = x_arr.ndim == 1 + if single: + x_arr = x_arr.reshape(1, self.d) + if x_arr.ndim != 2 or x_arr.shape[1] != self.d: + raise ValueError(f"Expected shape ({self.d},) or (batch, {self.d}), got {x_arr.shape}") + + norms = mx.linalg.norm(x_arr, axis=1) + safe_norms = mx.where(norms > 0, norms, 1.0) + x_normalized = x_arr / safe_norms.reshape((-1, 1)) + + y = mx.matmul(x_normalized, self.rotation_t) + gt = (y[..., None] > self.boundaries.reshape((1, 1, self.boundaries.shape[0]))).astype(mx.int32) + indices = gt.sum(axis=-1).astype(mx.uint32) + + if single: + return mx.squeeze(indices, axis=0), mx.squeeze(norms, axis=0) + return indices, norms + + def dequantize(self, indices: Any, norms: Any): + indices_arr = _to_mx_array(indices) + norms_arr = _to_mx_array(norms, self.dtype) + + single = indices_arr.ndim == 1 + if single: + indices_arr = indices_arr.reshape(1, self.d) + norms_arr = norms_arr.reshape((1,)) + if indices_arr.ndim != 2 or indices_arr.shape[1] != self.d: + raise ValueError( + f"Expected index shape ({self.d},) or (batch, {self.d}), got {indices_arr.shape}" + ) + + y_hat = mx.take(self.centroids, indices_arr) + + if self.norm_correction: + y_hat_norms = mx.linalg.norm(y_hat, axis=1, keepdims=True) + y_hat_norms = mx.where(y_hat_norms > 1e-10, y_hat_norms, 1.0) + y_hat = y_hat / y_hat_norms + + x_hat_unit = mx.matmul(y_hat, self.rotation) + x_hat = x_hat_unit * norms_arr.reshape((-1, 1)) + + return mx.squeeze(x_hat, axis=0) if single else x_hat + + def quantize_and_residual(self, x: Any): + indices, norms = self.quantize(x) + x_hat = self.dequantize(indices, norms) + residual = _to_mx_array(x, self.dtype) - x_hat + return indices, norms, residual + + +class MLXQJL: + """MLX implementation of the 1-bit QJL residual stage.""" + + def __init__(self, d: int, seed: int = 123, dtype: str = "float32"): + _require_mlx() + self.d = d + self.dtype = _resolve_dtype(dtype) + rng = np.random.default_rng(seed) + self.S = mx.array(rng.standard_normal((d, d)), dtype=self.dtype) + self.S_t = mx.transpose(self.S) + + def quantize(self, r: Any): + r_arr = _to_mx_array(r, self.dtype) + single = r_arr.ndim == 1 + if single: + r_arr = r_arr.reshape(1, self.d) + if r_arr.ndim != 2 or r_arr.shape[1] != self.d: + raise ValueError(f"Expected shape ({self.d},) or (batch, {self.d}), got {r_arr.shape}") + + norms = mx.linalg.norm(r_arr, axis=1) + projected = mx.matmul(r_arr, self.S_t) + signs = mx.where(projected >= 0, 1, -1).astype(mx.int8) + + if single: + return mx.squeeze(signs, axis=0), mx.squeeze(norms, axis=0) + return signs, norms + + def dequantize(self, signs: Any, norms: Any): + signs_arr = _to_mx_array(signs) + norms_arr = _to_mx_array(norms, self.dtype) + + single = signs_arr.ndim == 1 + if single: + signs_arr = signs_arr.reshape(1, self.d) + norms_arr = norms_arr.reshape((1,)) + if signs_arr.ndim != 2 or signs_arr.shape[1] != self.d: + raise ValueError( + f"Expected sign shape ({self.d},) or (batch, {self.d}), got {signs_arr.shape}" + ) + + reconstructed = mx.matmul(signs_arr.astype(self.dtype), self.S) + scale = norms_arr * (QJL_CONST / self.d) + reconstructed = reconstructed * scale.reshape((-1, 1)) + + return mx.squeeze(reconstructed, axis=0) if single else reconstructed + + +class MLXTurboQuant: + """Full MLX TurboQuant path for K-cache parity with the NumPy prototype.""" + + def __init__( + self, + d: int, + bit_width: int, + seed: int = 42, + norm_correction: bool = True, + dtype: str = "float32", + ): + if bit_width < 2: + raise ValueError( + "TurboQuant requires bit_width >= 2 (1 bit PolarQuant + 1 bit QJL). " + "For 1-bit, use QJL directly." + ) + + self.d = d + self.bit_width = bit_width + self.polar_quant = MLXPolarQuant( + d, + bit_width=bit_width - 1, + seed=seed, + norm_correction=norm_correction, + dtype=dtype, + ) + self.qjl = MLXQJL(d, seed=seed + 1000, dtype=dtype) + + def quantize(self, x: Any) -> CompressedVector: + mse_indices, vector_norms, residual = self.polar_quant.quantize_and_residual(x) + qjl_signs, residual_norms = self.qjl.quantize(residual) + return CompressedVector( + mse_indices=mse_indices, + vector_norms=vector_norms, + qjl_signs=qjl_signs, + residual_norms=residual_norms, + bit_width=self.bit_width, + ) + + def dequantize(self, compressed: CompressedVector): + x_mse = self.polar_quant.dequantize(compressed.mse_indices, compressed.vector_norms) + x_qjl = self.qjl.dequantize(compressed.qjl_signs, compressed.residual_norms) + return x_mse + x_qjl + + def compressed_size_bits(self, n_vectors: int) -> int: + per_vector = self.d * self.bit_width + norms = 32 + return n_vectors * (per_vector + norms) + + def compression_ratio(self, original_bits_per_value: int = 16) -> float: + original_per_vector = self.d * original_bits_per_value + compressed_per_vector = self.d * self.bit_width + 32 + return original_per_vector / compressed_per_vector + + +class MLXTurboQuantMSE: + """MSE-only MLX TurboQuant path for V-cache compression.""" + + def __init__( + self, + d: int, + bit_width: int, + seed: int = 42, + norm_correction: bool = True, + dtype: str = "float32", + ): + self.d = d + self.bit_width = bit_width + self.polar_quant = MLXPolarQuant( + d, + bit_width=bit_width, + seed=seed, + norm_correction=norm_correction, + dtype=dtype, + ) + + def quantize(self, x: Any): + return self.polar_quant.quantize(x) + + def dequantize(self, indices: Any, norms: Any): + return self.polar_quant.dequantize(indices, norms) + + +class MLXKVCacheCompressor: + """MLX equivalent of the NumPy KV cache compressor.""" + + def __init__( + self, + head_dim: int, + k_bits: int = 3, + v_bits: int = 3, + seed: int = 42, + norm_correction: bool = True, + dtype: str = "float32", + ): + self.head_dim = head_dim + self.k_bits = k_bits + self.v_bits = v_bits + self.dtype = dtype + + self.k_quantizer = MLXTurboQuant( + head_dim, + bit_width=k_bits, + seed=seed, + norm_correction=norm_correction, + dtype=dtype, + ) + self.v_quantizer = MLXTurboQuantMSE( + head_dim, + bit_width=v_bits, + seed=seed + 500, + norm_correction=norm_correction, + dtype=dtype, + ) + + def compress(self, k_cache: Any, v_cache: Any) -> CompressedKVCache: + k_cache_arr = _to_mx_array(k_cache, _resolve_dtype(self.dtype)) + v_cache_arr = _to_mx_array(v_cache, _resolve_dtype(self.dtype)) + + if k_cache_arr.shape != v_cache_arr.shape: + raise ValueError(f"K/V cache shapes must match, got {k_cache_arr.shape} and {v_cache_arr.shape}") + if k_cache_arr.ndim != 4: + raise ValueError(f"Expected KV cache shape (layers, heads, seq, dim), got {k_cache_arr.shape}") + + num_layers, num_heads, seq_len, head_dim = k_cache_arr.shape + if head_dim != self.head_dim: + raise ValueError(f"Expected head_dim={self.head_dim}, got {head_dim}") + + result = CompressedKVCache( + num_layers=num_layers, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + k_bit_width=self.k_bits, + v_bit_width=self.v_bits, + ) + + for layer in range(num_layers): + k_layer = [] + v_layer_idx = [] + v_layer_norms = [] + for head in range(num_heads): + k_compressed = self.k_quantizer.quantize(k_cache_arr[layer, head]) + v_indices, v_norms = self.v_quantizer.quantize(v_cache_arr[layer, head]) + k_layer.append(k_compressed) + v_layer_idx.append(v_indices) + v_layer_norms.append(v_norms) + + result.k_compressed.append(k_layer) + result.v_indices.append(v_layer_idx) + result.v_norms.append(v_layer_norms) + + return result + + def decompress(self, compressed: CompressedKVCache): + k_layers = [] + v_layers = [] + + for layer in range(compressed.num_layers): + k_heads = [] + v_heads = [] + for head in range(compressed.num_heads): + k_heads.append(self.k_quantizer.dequantize(compressed.k_compressed[layer][head])) + v_heads.append( + self.v_quantizer.dequantize( + compressed.v_indices[layer][head], + compressed.v_norms[layer][head], + ) + ) + + k_layers.append(mx.stack(k_heads, axis=0)) + v_layers.append(mx.stack(v_heads, axis=0)) + + return mx.stack(k_layers, axis=0), mx.stack(v_layers, axis=0) + + def memory_stats(self, seq_len: int, num_layers: int, num_heads: int) -> dict[str, float]: + n_vectors = num_layers * num_heads * seq_len + original_bytes = n_vectors * self.head_dim * 2 + k_bits_total = n_vectors * (self.head_dim * self.k_bits + 32) + v_bits_total = n_vectors * self.head_dim * self.v_bits + compressed_bytes = (k_bits_total + v_bits_total) / 8 + + return { + "original_mb": original_bytes / 1024 / 1024, + "compressed_mb": compressed_bytes / 1024 / 1024, + "compression_ratio": original_bytes / compressed_bytes, + "k_bits_per_value": self.k_bits, + "v_bits_per_value": self.v_bits, + }