Skip to content

Smilefounder/TurboMLX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TurboMLX

Drop-in KV cache compression for Gemma 4 (and friends) on Apple Silicon.

TurboMLX brings PolarQuant (Google, ICLR 2026) — a data-oblivious KV cache quantization scheme — to MLX, with first-class support for Gemma 4's MatFormer architecture: dual head_dim, hybrid sliding/global attention, and cross-layer KV sharing.

Forked from rachittshah/mlx-turboquant and substantially upgraded for Gemma 4. Maintained at Smilefounder/TurboMLX.

Gemma 4 — the flagship model

Gemma 4 (Google, 2026) is the most architecturally complex KV cache target TurboMLX supports — and the one it's most tuned for. A single make_turboquant_cache(model) call handles all of it:

Gemma 4 feature Config field What TurboMLX does
Dual head_dim (256 sliding, 512 global) head_dim, global_head_dim Per-layer head_dim in the cache plan — each layer gets its own PolarQuant rotation sized correctly
Hybrid attention (sliding + global interleaved) layer_types RotatingTurboQuantKVCache for sliding layers (circular buffer sized to sliding_window, matching mlx-lm's RotatingKVCache); TurboQuantKVCache for global layers
Cross-layer KV sharing (last N layers reuse earlier caches) num_kv_shared_layers Builds caches only for num_hidden_layers − num_kv_shared_layers — no wasted memory, no duplicated compression work
Tied Wk/Wv weights on global layers (26B / 31B) attention_k_eq_v Detected but does not collapse K/V storage. In mlx-lm, RoPE is applied to K but not V, so the K and V tensors reaching the cache are distinct (cos ≈ 0.85–0.95, not 1.0). K and V are compressed independently for correctness.
MatFormer (42 hidden layers → 24 distinct KV caches) derived End-to-end validated: 20 rotating sliding + 4 global, matches Gemma 4's native layout exactly

Gemma 4 results

Validated end-to-end on unsloth/gemma-4-E4B-it-UD-MLX-4bit with mlx-lm ≥ 0.31.2 (which ships the gemma4 / gemma4_text model classes). Averaged over 4 prompts:

Metric 3-bit 3.5-bit 4-bit
Logit cosine vs FP16 0.977 0.989 0.995
Top-1 match 3/4 4/4 4/4

Memory @ 1024 tokens: FP16 hybrid baseline 57,344 KB → TurboMLX 3-bit 12,024 KB (4.8x) / 4-bit 14,712 KB (3.9x). Measured like-for-like against a real hybrid cache (20 RotatingKVCache + 4 KVCache), not an unrealistic all-global baseline.

Decode speed: Gemma 4 E4B is the fastest relative throughput in the suite at 0.69× FP16 — the cross-layer KV sharing (num_kv_shared_layers=18) means 18 of 42 layers skip the compress/dequantize path entirely, so per-token overhead amortizes over a smaller fraction of the stack.

31B Dense status: Validated end-to-end on unsloth/gemma-4-31b-it-UD-MLX-4bit (60 layers, dual head_dim 256/512, attention_k_eq_v=True). First-token logit cosine vs FP16 = 0.989 @ 4-bit, all three smoke prompts produce coherent output matching FP16 baseline. Tied Wk/Wv weights are detected but not used to share K/V storage — see the row above for why.

Other supported models

TurboMLX works on any mlx-lm model. Tested with 4 prompts each:

Model head_dim 3-bit cosine 4-bit cosine 3-bit compression Top-1 @4-bit
Gemma 4-E4B 256/512 0.977 0.995 4.8x 4/4
Gemma 3-4B 256 0.992 0.997 4.7x 4/4
Gemma 3-1B 256 0.954 0.992 4.8x 4/4
Llama 3.2-3B 128 0.988 0.997 4.6x 4/4
Llama 3.2-1B 64 0.823 0.974 4.0x 4/4
Qwen3-4B 128 0.957 0.995 4.6x 4/4
Qwen3-1.7B 128 0.128 0.949 4.6x 4/4

Cosine = logit cosine similarity vs FP16 KV cache. See REPORT.md for full methodology and memory/speed numbers.

Install

git clone https://github.com/Smilefounder/TurboMLX.git
cd TurboMLX
uv sync --dev

Upstream: rachittshah/mlx-turboquant.

Usage

Gemma 4 (recommended)

import mlx.core as mx
from mlx_lm import load
from turbomlx.integration import make_turboquant_cache

model, tokenizer = load("unsloth/gemma-4-E4B-it-UD-MLX-4bit")

# One call handles everything: hybrid sliding/global layers, dual head_dim
# (256 sliding / 512 global), and cross-layer KV sharing. Tied Wk/Wv on
# the 26B/31B variants is detected but K and V are still compressed
# independently — RoPE makes them distinct at the cache call site.
cache = make_turboquant_cache(model, bits=3)

tokens = mx.array([tokenizer.encode("Explain MatFormer in one sentence.")])
logits = model(tokens, cache=cache)

That's the entire integration surface for Gemma 4. No manual layer-type lists, no head_dim plumbing — make_turboquant_cache inspects model.config and builds the correct per-layer mix of:

  • TurboQuantKVCache for full_attention (global) layers
  • RotatingTurboQuantKVCache for sliding_attention layers (circular buffer sized to sliding_window)

Other models (Llama, Qwen, Gemma 3)

Same entry point — it auto-detects the architecture:

model, tokenizer = load("mlx-community/Llama-3.2-3B-Instruct-4bit")
cache = make_turboquant_cache(model, bits=3)

Direct instantiation (advanced)

For plain (non-hybrid) models you can also build caches by hand:

from turbomlx import TurboQuantKVCache
cache = [TurboQuantKVCache(bits=3, head_dim=128) for _ in range(len(model.layers))]

Supported bit widths: 2, 3, 3.5, 4. Fractional bits use channel-split (half at ceil, half at floor).

How It Works

  1. Normalize each KV vector and store its norm
  2. Rotate by a fixed random orthogonal matrix (data-oblivious — same matrix for all inputs)
  3. After rotation, coordinates follow a known Gaussian distribution
  4. Quantize each coordinate using precomputed Lloyd-Max optimal codebooks
  5. Bit-pack indices into uint32 for storage (e.g., 10 3-bit values per uint32)
  6. On fetch: unpack → lookup centroids → inverse rotate → rescale

No calibration data needed. Near information-theoretic optimal (within 2.7x of lower bounds).

Benchmarks

uv run python benchmarks/bench_quality.py    # Quality: cosine sim, top-k
uv run python benchmarks/bench_memory_speed.py  # Memory + speed
uv run python benchmarks/bench_full.py       # Full suite across 4 models
uv run python tests/test_core.py             # Unit tests

Architecture

turbomlx/
├── codebooks.py        # Lloyd-Max codebook loader (precomputed for N(0,1))
├── polar_quant.py      # PolarQuant: rotate + quantize + dequantize
├── qjl.py              # QJL residual correction (for future fused attention)
├── turbo_quant.py      # Combined compressor + fractional bit support
├── packing.py          # Vectorized bit-packing into uint32
├── cache.py            # TurboQuantKVCache (global / full-attention layers)
├── rotating_cache.py   # RotatingTurboQuantKVCache (sliding-window layers)
├── attention.py        # Custom attention with QJL correction
└── integration.py      # make_turboquant_cache + SDPA monkey-patch

Upstream PR

This is a standalone proof-of-concept. The plan is to upstream into ml-explore/mlx-lm as a new cache type. See PR_PLAN.md for the integration strategy.

License

MIT

About

Drop-in KV cache compression for MLX on Apple Silicon. Brings PolarQuant (Google, ICLR 2026) to mlx-lm with first-class Gemma 4 support: MatFormer, dual head_dim, hybrid sliding/global attention, cross-layer KV sharing. 3-bit → 4.8× smaller cache, 0.995 logit cosine @ 4-bit.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages