Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def setup_arg_parser():
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--turbo-kv-bits",
type=int,
help="[Experimental] Number of bits for TurboQuant KV cache "
"compression (2-4). Uses PolarQuant for data-oblivious compression. "
"Default: no TurboQuant compression.",
default=None,
)
parser.add_argument(
"--draft-model",
type=str,
Expand Down Expand Up @@ -300,6 +308,15 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)


def maybe_turboquant_kv_cache(prompt_cache, turbo_kv_bits):
"""Convert KV caches to TurboQuant compression (experimental)."""
if turbo_kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if hasattr(c, "to_turboquant") and not hasattr(c, "turbo_bits"):
prompt_cache[e] = c.to_turboquant(bits=turbo_kv_bits)


def generate_step(
prompt: mx.array,
model: nn.Module,
Expand All @@ -313,6 +330,7 @@ def generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
turbo_kv_bits: Optional[int] = None,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
input_embeddings: Optional[mx.array] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
Expand All @@ -339,6 +357,9 @@ def generate_step(
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
turbo_kv_bits (int, optional): Number of bits for TurboQuant KV cache
compression (2-4). Uses PolarQuant for data-oblivious compression.
None implies no TurboQuant compression. Default: ``None``.
prompt_progress_callback (Callable[[int, int], None]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
Expand Down Expand Up @@ -379,6 +400,11 @@ def generate_step(
kv_bits=kv_bits,
)

turboquant_cache_fn = functools.partial(
maybe_turboquant_kv_cache,
turbo_kv_bits=turbo_kv_bits,
)

sampler = sampler or (lambda x: mx.argmax(x, axis=-1))

def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]):
Expand Down Expand Up @@ -412,6 +438,7 @@ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
logits = processor(tokens, logits)

quantize_cache_fn(prompt_cache)
turboquant_cache_fn(prompt_cache)

logprobs = logits - mx.logsumexp(logits, keepdims=True)
sampled = sampler(logprobs)
Expand All @@ -435,6 +462,7 @@ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
),
)
quantize_cache_fn(prompt_cache)
turboquant_cache_fn(prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_processed_tokens += n_to_process
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
Expand Down Expand Up @@ -1526,6 +1554,7 @@ def main():
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
turbo_kv_bits=args.turbo_kv_bits,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
Expand Down
31 changes: 30 additions & 1 deletion mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,18 @@ def load_prompt_cache(file_name, return_metadata=False):
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
def _resolve_cache_class(name):
if name in globals():
return globals()[name]
# Lazy-load experimental cache types
if name == "TurboQuantKVCache":
from .turboquant import TurboQuantKVCache

return TurboQuantKVCache
raise KeyError(f"Unknown cache class: {name}")

cache = [
globals()[c].from_state(state, meta_state)
_resolve_cache_class(c).from_state(state, meta_state)
for c, state, meta_state in zip(classes, arrays, info)
]
if return_metadata:
Expand Down Expand Up @@ -388,6 +398,25 @@ def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
)
return quant_cache

def to_turboquant(self, bits: int = 4):
"""Convert to TurboQuant compressed cache (experimental).

Uses PolarQuant for data-oblivious KV cache compression at 2-4 bits.
See :class:`~mlx_lm.models.turboquant.TurboQuantKVCache` for details.

Args:
bits (int): Quantization bits per coordinate (2-4). Default: ``4``.
"""
from .turboquant import TurboQuantKVCache

tq_cache = TurboQuantKVCache(bits=bits)
if self.keys is not None:
tq_cache.update_and_fetch(
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
return tq_cache

def make_mask(self, *args, **kwargs):
return create_attention_mask(*args, offset=self.offset, **kwargs)

Expand Down
249 changes: 249 additions & 0 deletions mlx_lm/models/turboquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
"""
TurboQuant KV cache compression (experimental).

PolarQuant from "TurboQuant: Redefining AI Efficiency with Extreme
Compression" (Google, ICLR 2026, https://arxiv.org/abs/2504.19874).

Data-oblivious KV cache quantization at 2-4 bits per coordinate via
random orthogonal rotation followed by Lloyd-Max optimal scalar
quantization. No calibration data needed.
"""

import math

import mlx.core as mx

from .cache import _BaseCache, create_attention_mask

# fmt: off
# Lloyd-Max optimal centroids and boundaries for N(0,1).
# Scaled by 1/sqrt(head_dim) at runtime.
_CENTROIDS = {
2: [-1.5104, -0.4528, 0.4528, 1.5104],
3: [-2.1519, -1.3439, -0.7560, -0.2451, 0.2451, 0.7560, 1.3439, 2.1519],
4: [-2.7331, -2.0698, -1.6189, -1.2570, -0.9431, -0.6573,
-0.3884, -0.1285, 0.1285, 0.3884, 0.6573, 0.9431,
1.2570, 1.6189, 2.0698, 2.7331],
}
_BOUNDARIES = {
2: [-5.0, -0.9816, 0.0, 0.9816, 5.0],
3: [-5.0, -1.7479, -1.0499, -0.5005, 0.0, 0.5005, 1.0499, 1.7479, 5.0],
4: [-5.0, -2.4015, -1.8443, -1.4380, -1.1001, -0.8002,
-0.5229, -0.2585, 0.0, 0.2585, 0.5229, 0.8002,
1.1001, 1.4380, 1.8443, 2.4015, 5.0],
}
# fmt: on


def _rotation_matrix(dim, seed=42):
"""Haar-distributed random orthogonal matrix via QR of Gaussian."""
key = mx.random.key(seed)
g = mx.random.normal(shape=(dim, dim), key=key)
q, r = mx.linalg.qr(g, stream=mx.cpu)
sign = mx.sign(mx.diag(r))
sign = mx.where(sign == 0, 1, sign)
return q * sign


def _load_codebook(bits, dim):
s = 1.0 / math.sqrt(dim)
c = mx.array(_CENTROIDS[bits], dtype=mx.float32) * s
b = mx.array(_BOUNDARIES[bits], dtype=mx.float32) * s
return c, b


def _quantize(vectors, rotation_t, boundaries):
norms = mx.linalg.norm(vectors, axis=-1, keepdims=True)
rotated = (vectors / mx.maximum(norms, 1e-8)) @ rotation_t
inner = boundaries[1:-1]
indices = mx.zeros(rotated.shape, dtype=mx.uint8)
for b in range(inner.shape[0]):
indices = indices + (rotated > inner[b]).astype(mx.uint8)
return indices, norms


def _dequantize(indices, norms, rotation, centroids):
return centroids[indices] @ rotation * norms


def _pack(indices, bits):
"""Pack b-bit indices into uint32."""
shape = indices.shape
dim = shape[-1]
vpi = 32 // bits
n_packed = (dim + vpi - 1) // vpi
pad_size = n_packed * vpi - dim
if pad_size > 0:
indices = mx.concatenate(
[indices, mx.zeros((*shape[:-1], pad_size), dtype=indices.dtype)],
axis=-1,
)
reshaped = indices.reshape(*shape[:-1], n_packed, vpi).astype(mx.uint32)
shifts = mx.arange(vpi, dtype=mx.uint32) * bits
shifted = reshaped << shifts
packed = shifted[..., 0]
for i in range(1, vpi):
packed = packed | shifted[..., i]
return packed


def _unpack(packed, bits, dim):
"""Unpack uint32 back to b-bit indices."""
shape = packed.shape
vpi = 32 // bits
mask = (1 << bits) - 1
shifts = mx.arange(vpi, dtype=mx.uint32) * bits
extracted = (packed[..., None] >> shifts) & mask
return extracted.reshape(*shape[:-1], shape[-1] * vpi)[..., :dim].astype(mx.uint8)


class TurboQuantKVCache(_BaseCache):
"""KV cache compressed with PolarQuant (experimental).

Data-oblivious compression: random orthogonal rotation maps KV vectors
to coordinates with a known Gaussian distribution, then Lloyd-Max
optimal scalar quantizers compress each coordinate independently.
Bit-packed into uint32 for storage, dequantized on fetch.

Args:
bits (int): Bits per coordinate (2, 3, or 4). Default: ``4``.
"""

step = 256

def __init__(self, bits: int = 4):
if bits not in (2, 3, 4):
raise ValueError(f"bits must be 2, 3, or 4, got {bits}")
self.turbo_bits = bits
self.offset = 0
self._head_dim = None
self._k_indices = None
self._k_norms = None
self._v_indices = None
self._v_norms = None
self._centroids = None
self._boundaries = None
self._rotation = None
self._rotation_t = None

def _init_codebook(self, head_dim):
self._head_dim = head_dim
self._centroids, self._boundaries = _load_codebook(
self.turbo_bits, head_dim
)
self._rotation = _rotation_matrix(head_dim)
self._rotation_t = self._rotation.T

def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, head_dim = keys.shape
prev = self.offset

if self._centroids is None:
self._init_codebook(head_dim)

k_idx, k_norms = _quantize(keys, self._rotation_t, self._boundaries)
v_idx, v_norms = _quantize(values, self._rotation_t, self._boundaries)
pk = _pack(k_idx, self.turbo_bits)
pv = _pack(v_idx, self.turbo_bits)

if self._k_indices is None or (prev + num_steps) > self._k_indices.shape[2]:
self._expand(B, n_kv_heads, num_steps, keys.dtype, pk.shape[-1])

self._k_indices[..., prev : prev + num_steps, :] = pk
self._k_norms[..., prev : prev + num_steps, :] = k_norms
self._v_indices[..., prev : prev + num_steps, :] = pv
self._v_norms[..., prev : prev + num_steps, :] = v_norms
self.offset += num_steps

all_k = _dequantize(
_unpack(self._k_indices[..., :self.offset, :], self.turbo_bits, head_dim),
self._k_norms[..., :self.offset, :],
self._rotation,
self._centroids,
)
all_v = _dequantize(
_unpack(self._v_indices[..., :self.offset, :], self.turbo_bits, head_dim),
self._v_norms[..., :self.offset, :],
self._rotation,
self._centroids,
)
return all_k, all_v

def _expand(self, B, n_kv_heads, new_steps, dtype, packed_dim):
alloc = ((self.step + new_steps - 1) // self.step) * self.step
shape = (B, n_kv_heads, alloc)

def _new():
return (
mx.zeros((*shape, packed_dim), dtype=mx.uint32),
mx.zeros((*shape, 1), dtype=dtype),
mx.zeros((*shape, packed_dim), dtype=mx.uint32),
mx.zeros((*shape, 1), dtype=dtype),
)

if self._k_indices is not None and self.offset > 0:
old = (
self._k_indices[..., :self.offset, :],
self._k_norms[..., :self.offset, :],
self._v_indices[..., :self.offset, :],
self._v_norms[..., :self.offset, :],
)
self._k_indices, self._k_norms, self._v_indices, self._v_norms = (
mx.concatenate([o, n], axis=2) for o, n in zip(old, _new())
)
else:
self._k_indices, self._k_norms, self._v_indices, self._v_norms = _new()

def size(self):
return self.offset

@property
def state(self):
if self._k_indices is None:
return []
return [
self._k_indices[..., :self.offset, :],
self._k_norms[..., :self.offset, :],
self._v_indices[..., :self.offset, :],
self._v_norms[..., :self.offset, :],
]

@state.setter
def state(self, v):
if v is not None and v:
self._k_indices, self._k_norms, self._v_indices, self._v_norms = v
self.offset = self._k_indices.shape[2]

@property
def meta_state(self):
return tuple(map(str, (self.offset, self.turbo_bits, self._head_dim or 0)))

@meta_state.setter
def meta_state(self, v):
self.offset, self.turbo_bits = int(v[0]), int(v[1])
head_dim = int(v[2])
if head_dim > 0:
self._init_codebook(head_dim)

def is_trimmable(self):
return True

def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n

def make_mask(self, *args, **kwargs):
return create_attention_mask(*args, offset=self.offset, **kwargs)

def empty(self):
return self._k_indices is None

@property
def nbytes(self):
if self._k_indices is None:
return 0
return sum(
a[..., :self.offset, :].nbytes
for a in (self._k_indices, self._k_norms, self._v_indices, self._v_norms)
)
Loading