-
Notifications
You must be signed in to change notification settings - Fork 516
Description
KV cache is the memory bottleneck for long context inference on Apple Silicon. The existing QuantizedKVCache uses affine quantization at 4/8 bits but doesn't go lower. PolarQuant (from Google's TurboQuant paper, ICLR 2026) offers a path to 3 bit KV cache with near lossless quality.
The algorithm is simple: apply a fixed random orthogonal rotation to each KV vector, then quantize each coordinate with a precomputed Lloyd-Max optimal scalar quantizer. The rotation makes the coordinate distribution predictable (Gaussian), so the quantizer is data oblivious. No calibration data needed.
I built a standalone implementation and benchmarked it on Apple Silicon with mlx-lm models:
Quality (logit cosine similarity vs FP16 KV cache):
| Model | 3-bit | 4-bit |
|---|---|---|
| Llama 3.2-3B (head_dim=128) | 0.988 | 0.997 |
| Qwen3-4B (head_dim=128) | 0.957 | 0.995 |
| Llama 3.2-1B (head_dim=64) | 0.823 | 0.974 |
Top-1 token accuracy is 4/4 at 4-bit across all models and prompts tested.
Memory: 4.6x compression at 3-bit, 3.8x at 4-bit (head_dim=128). Indices are bit-packed into uint32.
Speed: Decode is currently 0.5x FP16 because of the dequantize-on-fetch overhead. A fused Metal kernel for dequant-matmul would fix this, similar to how mx.quantized_matmul works for the existing QuantizedKVCache.
The implementation is a single file (~200 lines, pure mlx.core, no numpy) that subclasses _BaseCache. I have a draft PR at #1059 with 7 tests that pass alongside the existing 20.
Full benchmark methodology, results across 4 models, and the standalone implementation: https://github.com/rachittshah/mlx-turboquant
Paper: https://arxiv.org/abs/2504.19874
Would there be interest in adding this as an experimental cache type? Happy to adjust the approach based on feedback.