Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,20 @@ def setup_arg_parser():
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--turbo-kv-bits",
type=int,
help="TurboQuant KV cache compression bits (1-4). "
"3-bit gives 4.6x compression. Default: no compression.",
default=None,
)
parser.add_argument(
"--turbo-fp16-layers",
type=int,
help="Number of first/last layers to keep in FP16 "
"when using --turbo-kv-bits. Default: 1.",
default=1,
)
parser.add_argument(
"--draft-model",
type=str,
Expand Down Expand Up @@ -300,6 +314,7 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)



def generate_step(
prompt: mx.array,
model: nn.Module,
Expand All @@ -313,6 +328,8 @@ def generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
turbo_kv_bits: Optional[int] = None,
turbo_fp16_layers: int = 1,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
input_embeddings: Optional[mx.array] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
Expand All @@ -339,6 +356,11 @@ def generate_step(
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
turbo_kv_bits (int, optional): TurboQuant KV cache compression bits (1-4).
Uses PolarQuant with Hadamard rotation. 3-bit gives 4.6x compression.
None implies no TurboQuant. Default: ``None``.
turbo_fp16_layers (int): Number of first/last layers to keep in FP16 when
using TurboQuant. Default: ``1``.
prompt_progress_callback (Callable[[int, int], None]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
Expand Down Expand Up @@ -368,6 +390,8 @@ def generate_step(
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
turbo_kv_bits=turbo_kv_bits,
turbo_fp16_layers=turbo_fp16_layers,
)

prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
Expand Down Expand Up @@ -1526,6 +1550,8 @@ def main():
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
turbo_kv_bits=args.turbo_kv_bits,
turbo_fp16_layers=args.turbo_fp16_layers,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
Expand Down
50 changes: 49 additions & 1 deletion mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
turbo_kv_bits: Optional[int] = None,
turbo_fp16_layers: int = 1,
) -> List[Any]:
"""
Construct the model's cache for use in generation.
Expand All @@ -27,11 +29,39 @@ def make_prompt_cache(
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
turbo_kv_bits (Optional[int]): If provided, use TurboQuant KV cache
compression at the given bit width (1-4). 3-bit gives 4.6x
compression. Default: ``None`` (no compression).
turbo_fp16_layers (int): Number of first/last layers to keep in FP16
when using TurboQuant. Default: ``1``.
"""
if hasattr(model, "make_cache"):
return model.make_cache()
default_cache = model.make_cache()
if turbo_kv_bits is not None:
# Check compatibility
if not isinstance(default_cache[0], KVCache):
raise ValueError(
f"[TurboQuant] Incompatible cache type: "
f"{type(default_cache[0]).__name__}. "
f"TurboQuant only works with standard multi-head "
f"attention (KVCache)."
)
else:
return default_cache

num_layers = len(model.layers)

if turbo_kv_bits is not None:
from mlx_lm.models.turboquant_cache import TurboQuantKVCache

caches = []
for i in range(num_layers):
if i < turbo_fp16_layers or i >= num_layers - turbo_fp16_layers:
caches.append(KVCache())
else:
caches.append(TurboQuantKVCache(bits=turbo_kv_bits))
return caches

if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
Expand Down Expand Up @@ -76,6 +106,13 @@ def load_prompt_cache(file_name, return_metadata=False):
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata

# Ensure TurboQuantKVCache is in globals for deserialization
if "TurboQuantKVCache" in classes and "TurboQuantKVCache" not in globals():
from mlx_lm.models.turboquant_cache import TurboQuantKVCache

globals()["TurboQuantKVCache"] = TurboQuantKVCache

cache = [
globals()[c].from_state(state, meta_state)
for c, state, meta_state in zip(classes, arrays, info)
Expand Down Expand Up @@ -390,6 +427,17 @@ def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
)
return quant_cache

def to_turbo_quantized(self, bits: int = 3):
from mlx_lm.models.turboquant_cache import TurboQuantKVCache

tq_cache = TurboQuantKVCache(bits=bits)
if self.keys is not None:
tq_cache.update_and_fetch(
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
return tq_cache

def make_mask(self, *args, **kwargs):
return create_attention_mask(*args, offset=self.offset, **kwargs)

Expand Down
229 changes: 229 additions & 0 deletions mlx_lm/models/turboquant_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""TurboQuantKVCache: PolarQuant KV cache compression with fused Metal kernels.

Implements TurboQuant (arXiv 2504.19874, ICLR 2026) for MLX KV cache compression.
4.6x compression via randomized Hadamard rotation + Lloyd-Max quantization.
Bit-packed uint32 storage with fused Metal quantize/dequantize kernels.
"""

import mlx.core as mx
import math
from mlx_lm.models.turboquant_rotation import random_diagonal_sign
from mlx_lm.models.turboquant_packing import pack_indices, unpack_indices, packed_dim, VALS_PER_WORD
from mlx_lm.models.turboquant_metal import fused_quantize, dequant_fp16
from mlx_lm.models.turboquant_kernels import packed_dequantize


def _compute_gaussian_codebook(bits):
codebooks = {
1: [-0.7979, 0.7979],
2: [-1.5104, -0.4528, 0.4528, 1.5104],
3: [-2.1520, -1.3440, -0.7560, -0.2451,
0.2451, 0.7560, 1.3440, 2.1520],
4: [-2.7326, -2.0690, -1.6180, -1.2562,
-0.9423, -0.6568, -0.3881, -0.1284,
0.1284, 0.3881, 0.6568, 0.9423,
1.2562, 1.6180, 2.0690, 2.7326],
}
return mx.array(codebooks[bits], dtype=mx.float32)


def _compute_boundaries(centroids):
return (centroids[:-1] + centroids[1:]) / 2.0


class _Quantizer:
def __init__(self, dim, bits, seed):
self.dim = dim
self.bits = bits
self.signs = random_diagonal_sign(dim, seed=seed)
self.centroids = _compute_gaussian_codebook(bits)
self.boundaries = _compute_boundaries(self.centroids)


class TurboQuantKVCache:
"""TurboQuant KV cache — drop-in replacement for KVCache.

Compresses KV vectors using PolarQuant (Hadamard rotation + Lloyd-Max
codebook quantization). Stores bit-packed indices in uint32 + float32 norms.

Uses fused Metal kernels for quantize and dequantize operations.
Maintains an incremental decode buffer for O(1) per-step dequantization.
"""

step = 256

def __init__(self, bits: int = 3, seed: int = 42):
self.quant_bits = bits
self.seed = seed
self.offset = 0

self.k_packed = None
self.k_norms = None
self.v_packed = None
self.v_norms = None

self._k_deq_buf = None
self._v_deq_buf = None
self._deq_offset = 0
self._deq_alloc = 0

self._k_q = None
self._v_q = None
self._k_dim = None
self._v_dim = None
self._k_pdim = None
self._v_pdim = None

def _ensure_quantizer(self, k_dim, v_dim):
if self._k_q is None:
self._k_q = _Quantizer(k_dim, self.quant_bits, self.seed)
self._k_dim = k_dim
self._k_pdim = packed_dim(k_dim, self.quant_bits)
if self._v_q is None:
self._v_q = _Quantizer(v_dim, self.quant_bits, self.seed + 1)
self._v_dim = v_dim
self._v_pdim = packed_dim(v_dim, self.quant_bits)

def _ensure_storage(self, B, H, num_new):
prev = self.offset
needed = prev + num_new
if self.k_packed is None or needed > self.k_packed.shape[2]:
n = ((needed + self.step - 1) // self.step) * self.step
new_kp = mx.zeros((B, H, n, self._k_pdim), dtype=mx.uint32)
new_kn = mx.zeros((B, H, n), dtype=mx.float32)
new_vp = mx.zeros((B, H, n, self._v_pdim), dtype=mx.uint32)
new_vn = mx.zeros((B, H, n), dtype=mx.float32)
if self.k_packed is not None:
self.k_packed = mx.concatenate([self.k_packed[..., :prev, :], new_kp], axis=2)
self.k_norms = mx.concatenate([self.k_norms[..., :prev], new_kn], axis=2)
self.v_packed = mx.concatenate([self.v_packed[..., :prev, :], new_vp], axis=2)
self.v_norms = mx.concatenate([self.v_norms[..., :prev], new_vn], axis=2)
else:
self.k_packed, self.k_norms = new_kp, new_kn
self.v_packed, self.v_norms = new_vp, new_vn

def _full_dequant(self, packed, norms, q, dim, B, H, total, dtype):
flat_p = packed[..., :total, :].reshape(-1, packed.shape[-1])
flat_n = norms[..., :total].reshape(-1)
out = packed_dequantize(flat_p, flat_n, q.centroids, q.signs, dim, self.quant_bits)
return out.reshape(B, H, total, dim).astype(dtype)

def update_and_fetch(self, keys, values):
B, H, S, k_dim = keys.shape
v_dim = values.shape[3]
self._ensure_quantizer(k_dim, v_dim)
self._ensure_storage(B, H, S)
prev = self.offset

# Fused Metal quantize
k_pk, k_nrm = fused_quantize(keys.reshape(-1, k_dim), self._k_q.signs, self._k_q.boundaries, k_dim, self.quant_bits)
k_pk = k_pk.reshape(B, H, S, self._k_pdim)
v_pk, v_nrm = fused_quantize(values.reshape(-1, v_dim), self._v_q.signs, self._v_q.boundaries, v_dim, self.quant_bits)
v_pk = v_pk.reshape(B, H, S, self._v_pdim)

self.k_packed[..., prev:prev+S, :] = k_pk
self.k_norms[..., prev:prev+S] = k_nrm.reshape(B, H, S)
self.v_packed[..., prev:prev+S, :] = v_pk
self.v_norms[..., prev:prev+S] = v_nrm.reshape(B, H, S)
self.offset += S
total = self.offset

# Incremental decode
if S <= 4 and self._v_deq_buf is not None and self._deq_offset == prev:
if total > self._deq_alloc:
na = ((total + self.step - 1) // self.step) * self.step
self._k_deq_buf = mx.concatenate([self._k_deq_buf[..., :self._deq_offset, :],
mx.zeros((B, H, na - self._deq_alloc, k_dim), dtype=keys.dtype)], axis=2)
self._v_deq_buf = mx.concatenate([self._v_deq_buf[..., :self._deq_offset, :],
mx.zeros((B, H, na - self._deq_alloc, v_dim), dtype=values.dtype)], axis=2)
self._deq_alloc = na

nk = dequant_fp16(k_pk.reshape(-1, self._k_pdim), k_nrm, self._k_q.centroids, self._k_q.signs, k_dim, self.quant_bits).reshape(B, H, S, k_dim)
nv = dequant_fp16(v_pk.reshape(-1, self._v_pdim), v_nrm, self._v_q.centroids, self._v_q.signs, v_dim, self.quant_bits).reshape(B, H, S, v_dim)
self._k_deq_buf[..., prev:total, :] = nk
self._v_deq_buf[..., prev:total, :] = nv
self._deq_offset = total
return self._k_deq_buf[..., :total, :], self._v_deq_buf[..., :total, :]

# Full dequant (prefill)
all_k = self._full_dequant(self.k_packed, self.k_norms, self._k_q, k_dim, B, H, total, keys.dtype)
all_v = self._full_dequant(self.v_packed, self.v_norms, self._v_q, v_dim, B, H, total, values.dtype)
alloc = ((total + self.step - 1) // self.step) * self.step
self._k_deq_buf = mx.zeros((B, H, alloc, k_dim), dtype=keys.dtype)
self._v_deq_buf = mx.zeros((B, H, alloc, v_dim), dtype=values.dtype)
self._k_deq_buf[..., :total, :] = all_k
self._v_deq_buf[..., :total, :] = all_v
self._deq_offset = total
self._deq_alloc = alloc
return all_k, all_v

def empty(self):
return self.k_packed is None

@property
def nbytes(self):
if self.k_packed is None:
return 0
return (self.k_packed[..., :self.offset, :].nbytes + self.v_packed[..., :self.offset, :].nbytes +
self.k_norms[..., :self.offset].nbytes + self.v_norms[..., :self.offset].nbytes)

@property
def state(self):
if self.k_packed is None:
return []
return [self.k_packed[..., :self.offset, :], self.k_norms[..., :self.offset],
self.v_packed[..., :self.offset, :], self.v_norms[..., :self.offset]]

@state.setter
def state(self, v):
if not v:
return
self.k_packed, self.k_norms, self.v_packed, self.v_norms = v
self.offset = self.k_packed.shape[2]

@property
def meta_state(self):
return f"{self.offset},{self.quant_bits},{self.seed},{self._k_dim or 0},{self._v_dim or 0}"

@meta_state.setter
def meta_state(self, v):
parts = v.split(",")
self.offset, self.quant_bits, self.seed = int(parts[0]), int(parts[1]), int(parts[2])
self._k_dim = int(parts[3]) or None
self._v_dim = int(parts[4]) or None

def is_trimmable(self):
return True

def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n

def size(self):
return self.offset

def make_mask(self, *args, **kwargs):
from mlx_lm.models.cache import create_attention_mask
return create_attention_mask(*args, offset=self.offset, **kwargs)

@classmethod
def from_state(cls, state, meta_state):
obj = cls.__new__(cls)
obj.k_packed = None
obj.k_norms = None
obj.v_packed = None
obj.v_norms = None
obj._k_deq_buf = None
obj._v_deq_buf = None
obj._deq_offset = 0
obj._deq_alloc = 0
obj._k_q = None
obj._v_q = None
obj._k_dim = None
obj._v_dim = None
obj._k_pdim = None
obj._v_pdim = None
obj.meta_state = meta_state
obj.state = state
return obj
Loading