Drop-in KV cache compression for Gemma 4 (and friends) on Apple Silicon.
TurboMLX brings PolarQuant (Google, ICLR 2026) — a data-oblivious KV cache quantization scheme — to MLX, with first-class support for Gemma 4's MatFormer architecture: dual head_dim, hybrid sliding/global attention, and cross-layer KV sharing.
Forked from rachittshah/mlx-turboquant and substantially upgraded for Gemma 4. Maintained at Smilefounder/TurboMLX.
Gemma 4 (Google, 2026) is the most architecturally complex KV cache target TurboMLX supports — and the one it's most tuned for. A single make_turboquant_cache(model) call handles all of it:
| Gemma 4 feature | Config field | What TurboMLX does |
|---|---|---|
| Dual head_dim (256 sliding, 512 global) | head_dim, global_head_dim |
Per-layer head_dim in the cache plan — each layer gets its own PolarQuant rotation sized correctly |
| Hybrid attention (sliding + global interleaved) | layer_types |
RotatingTurboQuantKVCache for sliding layers (circular buffer sized to sliding_window, matching mlx-lm's RotatingKVCache); TurboQuantKVCache for global layers |
| Cross-layer KV sharing (last N layers reuse earlier caches) | num_kv_shared_layers |
Builds caches only for num_hidden_layers − num_kv_shared_layers — no wasted memory, no duplicated compression work |
| Tied Wk/Wv weights on global layers (26B / 31B) | attention_k_eq_v |
Detected but does not collapse K/V storage. In mlx-lm, RoPE is applied to K but not V, so the K and V tensors reaching the cache are distinct (cos ≈ 0.85–0.95, not 1.0). K and V are compressed independently for correctness. |
| MatFormer (42 hidden layers → 24 distinct KV caches) | derived | End-to-end validated: 20 rotating sliding + 4 global, matches Gemma 4's native layout exactly |
Validated end-to-end on unsloth/gemma-4-E4B-it-UD-MLX-4bit with mlx-lm ≥ 0.31.2 (which ships the gemma4 / gemma4_text model classes). Averaged over 4 prompts:
| Metric | 3-bit | 3.5-bit | 4-bit |
|---|---|---|---|
| Logit cosine vs FP16 | 0.977 | 0.989 | 0.995 |
| Top-1 match | 3/4 | 4/4 | 4/4 |
Memory @ 1024 tokens: FP16 hybrid baseline 57,344 KB → TurboMLX 3-bit 12,024 KB (4.8x) / 4-bit 14,712 KB (3.9x). Measured like-for-like against a real hybrid cache (20 RotatingKVCache + 4 KVCache), not an unrealistic all-global baseline.
Decode speed: Gemma 4 E4B is the fastest relative throughput in the suite at 0.69× FP16 — the cross-layer KV sharing (num_kv_shared_layers=18) means 18 of 42 layers skip the compress/dequantize path entirely, so per-token overhead amortizes over a smaller fraction of the stack.
31B Dense status: Validated end-to-end on unsloth/gemma-4-31b-it-UD-MLX-4bit (60 layers, dual head_dim 256/512, attention_k_eq_v=True). First-token logit cosine vs FP16 = 0.989 @ 4-bit, all three smoke prompts produce coherent output matching FP16 baseline. Tied Wk/Wv weights are detected but not used to share K/V storage — see the row above for why.
TurboMLX works on any mlx-lm model. Tested with 4 prompts each:
| Model | head_dim | 3-bit cosine | 4-bit cosine | 3-bit compression | Top-1 @4-bit |
|---|---|---|---|---|---|
| Gemma 4-E4B | 256/512 | 0.977 | 0.995 | 4.8x | 4/4 |
| Gemma 3-4B | 256 | 0.992 | 0.997 | 4.7x | 4/4 |
| Gemma 3-1B | 256 | 0.954 | 0.992 | 4.8x | 4/4 |
| Llama 3.2-3B | 128 | 0.988 | 0.997 | 4.6x | 4/4 |
| Llama 3.2-1B | 64 | 0.823 | 0.974 | 4.0x | 4/4 |
| Qwen3-4B | 128 | 0.957 | 0.995 | 4.6x | 4/4 |
| Qwen3-1.7B | 128 | 0.128 | 0.949 | 4.6x | 4/4 |
Cosine = logit cosine similarity vs FP16 KV cache. See REPORT.md for full methodology and memory/speed numbers.
git clone https://github.com/Smilefounder/TurboMLX.git
cd TurboMLX
uv sync --devUpstream: rachittshah/mlx-turboquant.
import mlx.core as mx
from mlx_lm import load
from turbomlx.integration import make_turboquant_cache
model, tokenizer = load("unsloth/gemma-4-E4B-it-UD-MLX-4bit")
# One call handles everything: hybrid sliding/global layers, dual head_dim
# (256 sliding / 512 global), and cross-layer KV sharing. Tied Wk/Wv on
# the 26B/31B variants is detected but K and V are still compressed
# independently — RoPE makes them distinct at the cache call site.
cache = make_turboquant_cache(model, bits=3)
tokens = mx.array([tokenizer.encode("Explain MatFormer in one sentence.")])
logits = model(tokens, cache=cache)That's the entire integration surface for Gemma 4. No manual layer-type lists, no head_dim plumbing — make_turboquant_cache inspects model.config and builds the correct per-layer mix of:
TurboQuantKVCacheforfull_attention(global) layersRotatingTurboQuantKVCacheforsliding_attentionlayers (circular buffer sized tosliding_window)
Same entry point — it auto-detects the architecture:
model, tokenizer = load("mlx-community/Llama-3.2-3B-Instruct-4bit")
cache = make_turboquant_cache(model, bits=3)For plain (non-hybrid) models you can also build caches by hand:
from turbomlx import TurboQuantKVCache
cache = [TurboQuantKVCache(bits=3, head_dim=128) for _ in range(len(model.layers))]Supported bit widths: 2, 3, 3.5, 4. Fractional bits use channel-split (half at ceil, half at floor).
- Normalize each KV vector and store its norm
- Rotate by a fixed random orthogonal matrix (data-oblivious — same matrix for all inputs)
- After rotation, coordinates follow a known Gaussian distribution
- Quantize each coordinate using precomputed Lloyd-Max optimal codebooks
- Bit-pack indices into uint32 for storage (e.g., 10 3-bit values per uint32)
- On fetch: unpack → lookup centroids → inverse rotate → rescale
No calibration data needed. Near information-theoretic optimal (within 2.7x of lower bounds).
uv run python benchmarks/bench_quality.py # Quality: cosine sim, top-k
uv run python benchmarks/bench_memory_speed.py # Memory + speed
uv run python benchmarks/bench_full.py # Full suite across 4 models
uv run python tests/test_core.py # Unit teststurbomlx/
├── codebooks.py # Lloyd-Max codebook loader (precomputed for N(0,1))
├── polar_quant.py # PolarQuant: rotate + quantize + dequantize
├── qjl.py # QJL residual correction (for future fused attention)
├── turbo_quant.py # Combined compressor + fractional bit support
├── packing.py # Vectorized bit-packing into uint32
├── cache.py # TurboQuantKVCache (global / full-attention layers)
├── rotating_cache.py # RotatingTurboQuantKVCache (sliding-window layers)
├── attention.py # Custom attention with QJL correction
└── integration.py # make_turboquant_cache + SDPA monkey-patch
This is a standalone proof-of-concept. The plan is to upstream into ml-explore/mlx-lm as a new cache type. See PR_PLAN.md for the integration strategy.
MIT