diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ef8dbf7bf..c47e36de7 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -203,6 +203,14 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) + parser.add_argument( + "--turbo-kv-bits", + type=int, + help="[Experimental] Number of bits for TurboQuant KV cache " + "compression (2-4). Uses PolarQuant for data-oblivious compression. " + "Default: no TurboQuant compression.", + default=None, + ) parser.add_argument( "--draft-model", type=str, @@ -300,6 +308,15 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) +def maybe_turboquant_kv_cache(prompt_cache, turbo_kv_bits): + """Convert KV caches to TurboQuant compression (experimental).""" + if turbo_kv_bits is None: + return + for e, c in enumerate(prompt_cache): + if hasattr(c, "to_turboquant") and not hasattr(c, "turbo_bits"): + prompt_cache[e] = c.to_turboquant(bits=turbo_kv_bits) + + def generate_step( prompt: mx.array, model: nn.Module, @@ -313,6 +330,7 @@ def generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + turbo_kv_bits: Optional[int] = None, prompt_progress_callback: Optional[Callable[[int, int], None]] = None, input_embeddings: Optional[mx.array] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -339,6 +357,9 @@ def generate_step( kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. when ``kv_bits`` is non-None. Default: ``0``. + turbo_kv_bits (int, optional): Number of bits for TurboQuant KV cache + compression (2-4). Uses PolarQuant for data-oblivious compression. + None implies no TurboQuant compression. Default: ``None``. prompt_progress_callback (Callable[[int, int], None]): A call-back which takes the prompt tokens processed so far and the total number of prompt tokens. input_embeddings (mx.array, optional): Input embeddings to use instead of or in @@ -379,6 +400,11 @@ def generate_step( kv_bits=kv_bits, ) + turboquant_cache_fn = functools.partial( + maybe_turboquant_kv_cache, + turbo_kv_bits=turbo_kv_bits, + ) + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]): @@ -412,6 +438,7 @@ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None): logits = processor(tokens, logits) quantize_cache_fn(prompt_cache) + turboquant_cache_fn(prompt_cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) sampled = sampler(logprobs) @@ -435,6 +462,7 @@ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None): ), ) quantize_cache_fn(prompt_cache) + turboquant_cache_fn(prompt_cache) mx.eval([c.state for c in prompt_cache]) prompt_processed_tokens += n_to_process prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) @@ -1526,6 +1554,7 @@ def main(): kv_bits=args.kv_bits, kv_group_size=args.kv_group_size, quantized_kv_start=args.quantized_kv_start, + turbo_kv_bits=args.turbo_kv_bits, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, ) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88fa4ad32..60151d673 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -74,8 +74,18 @@ def load_prompt_cache(file_name, return_metadata=False): arrays = tree_unflatten(list(arrays.items())) cache_metadata = tree_unflatten(list(cache_metadata.items())) info, metadata, classes = cache_metadata + def _resolve_cache_class(name): + if name in globals(): + return globals()[name] + # Lazy-load experimental cache types + if name == "TurboQuantKVCache": + from .turboquant import TurboQuantKVCache + + return TurboQuantKVCache + raise KeyError(f"Unknown cache class: {name}") + cache = [ - globals()[c].from_state(state, meta_state) + _resolve_cache_class(c).from_state(state, meta_state) for c, state, meta_state in zip(classes, arrays, info) ] if return_metadata: @@ -388,6 +398,25 @@ def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: ) return quant_cache + def to_turboquant(self, bits: int = 4): + """Convert to TurboQuant compressed cache (experimental). + + Uses PolarQuant for data-oblivious KV cache compression at 2-4 bits. + See :class:`~mlx_lm.models.turboquant.TurboQuantKVCache` for details. + + Args: + bits (int): Quantization bits per coordinate (2-4). Default: ``4``. + """ + from .turboquant import TurboQuantKVCache + + tq_cache = TurboQuantKVCache(bits=bits) + if self.keys is not None: + tq_cache.update_and_fetch( + self.keys[..., : self.offset, :], + self.values[..., : self.offset, :], + ) + return tq_cache + def make_mask(self, *args, **kwargs): return create_attention_mask(*args, offset=self.offset, **kwargs) diff --git a/mlx_lm/models/turboquant.py b/mlx_lm/models/turboquant.py new file mode 100644 index 000000000..1844318ef --- /dev/null +++ b/mlx_lm/models/turboquant.py @@ -0,0 +1,249 @@ +""" +TurboQuant KV cache compression (experimental). + +PolarQuant from "TurboQuant: Redefining AI Efficiency with Extreme +Compression" (Google, ICLR 2026, https://arxiv.org/abs/2504.19874). + +Data-oblivious KV cache quantization at 2-4 bits per coordinate via +random orthogonal rotation followed by Lloyd-Max optimal scalar +quantization. No calibration data needed. +""" + +import math + +import mlx.core as mx + +from .cache import _BaseCache, create_attention_mask + +# fmt: off +# Lloyd-Max optimal centroids and boundaries for N(0,1). +# Scaled by 1/sqrt(head_dim) at runtime. +_CENTROIDS = { + 2: [-1.5104, -0.4528, 0.4528, 1.5104], + 3: [-2.1519, -1.3439, -0.7560, -0.2451, 0.2451, 0.7560, 1.3439, 2.1519], + 4: [-2.7331, -2.0698, -1.6189, -1.2570, -0.9431, -0.6573, + -0.3884, -0.1285, 0.1285, 0.3884, 0.6573, 0.9431, + 1.2570, 1.6189, 2.0698, 2.7331], +} +_BOUNDARIES = { + 2: [-5.0, -0.9816, 0.0, 0.9816, 5.0], + 3: [-5.0, -1.7479, -1.0499, -0.5005, 0.0, 0.5005, 1.0499, 1.7479, 5.0], + 4: [-5.0, -2.4015, -1.8443, -1.4380, -1.1001, -0.8002, + -0.5229, -0.2585, 0.0, 0.2585, 0.5229, 0.8002, + 1.1001, 1.4380, 1.8443, 2.4015, 5.0], +} +# fmt: on + + +def _rotation_matrix(dim, seed=42): + """Haar-distributed random orthogonal matrix via QR of Gaussian.""" + key = mx.random.key(seed) + g = mx.random.normal(shape=(dim, dim), key=key) + q, r = mx.linalg.qr(g, stream=mx.cpu) + sign = mx.sign(mx.diag(r)) + sign = mx.where(sign == 0, 1, sign) + return q * sign + + +def _load_codebook(bits, dim): + s = 1.0 / math.sqrt(dim) + c = mx.array(_CENTROIDS[bits], dtype=mx.float32) * s + b = mx.array(_BOUNDARIES[bits], dtype=mx.float32) * s + return c, b + + +def _quantize(vectors, rotation_t, boundaries): + norms = mx.linalg.norm(vectors, axis=-1, keepdims=True) + rotated = (vectors / mx.maximum(norms, 1e-8)) @ rotation_t + inner = boundaries[1:-1] + indices = mx.zeros(rotated.shape, dtype=mx.uint8) + for b in range(inner.shape[0]): + indices = indices + (rotated > inner[b]).astype(mx.uint8) + return indices, norms + + +def _dequantize(indices, norms, rotation, centroids): + return centroids[indices] @ rotation * norms + + +def _pack(indices, bits): + """Pack b-bit indices into uint32.""" + shape = indices.shape + dim = shape[-1] + vpi = 32 // bits + n_packed = (dim + vpi - 1) // vpi + pad_size = n_packed * vpi - dim + if pad_size > 0: + indices = mx.concatenate( + [indices, mx.zeros((*shape[:-1], pad_size), dtype=indices.dtype)], + axis=-1, + ) + reshaped = indices.reshape(*shape[:-1], n_packed, vpi).astype(mx.uint32) + shifts = mx.arange(vpi, dtype=mx.uint32) * bits + shifted = reshaped << shifts + packed = shifted[..., 0] + for i in range(1, vpi): + packed = packed | shifted[..., i] + return packed + + +def _unpack(packed, bits, dim): + """Unpack uint32 back to b-bit indices.""" + shape = packed.shape + vpi = 32 // bits + mask = (1 << bits) - 1 + shifts = mx.arange(vpi, dtype=mx.uint32) * bits + extracted = (packed[..., None] >> shifts) & mask + return extracted.reshape(*shape[:-1], shape[-1] * vpi)[..., :dim].astype(mx.uint8) + + +class TurboQuantKVCache(_BaseCache): + """KV cache compressed with PolarQuant (experimental). + + Data-oblivious compression: random orthogonal rotation maps KV vectors + to coordinates with a known Gaussian distribution, then Lloyd-Max + optimal scalar quantizers compress each coordinate independently. + Bit-packed into uint32 for storage, dequantized on fetch. + + Args: + bits (int): Bits per coordinate (2, 3, or 4). Default: ``4``. + """ + + step = 256 + + def __init__(self, bits: int = 4): + if bits not in (2, 3, 4): + raise ValueError(f"bits must be 2, 3, or 4, got {bits}") + self.turbo_bits = bits + self.offset = 0 + self._head_dim = None + self._k_indices = None + self._k_norms = None + self._v_indices = None + self._v_norms = None + self._centroids = None + self._boundaries = None + self._rotation = None + self._rotation_t = None + + def _init_codebook(self, head_dim): + self._head_dim = head_dim + self._centroids, self._boundaries = _load_codebook( + self.turbo_bits, head_dim + ) + self._rotation = _rotation_matrix(head_dim) + self._rotation_t = self._rotation.T + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, head_dim = keys.shape + prev = self.offset + + if self._centroids is None: + self._init_codebook(head_dim) + + k_idx, k_norms = _quantize(keys, self._rotation_t, self._boundaries) + v_idx, v_norms = _quantize(values, self._rotation_t, self._boundaries) + pk = _pack(k_idx, self.turbo_bits) + pv = _pack(v_idx, self.turbo_bits) + + if self._k_indices is None or (prev + num_steps) > self._k_indices.shape[2]: + self._expand(B, n_kv_heads, num_steps, keys.dtype, pk.shape[-1]) + + self._k_indices[..., prev : prev + num_steps, :] = pk + self._k_norms[..., prev : prev + num_steps, :] = k_norms + self._v_indices[..., prev : prev + num_steps, :] = pv + self._v_norms[..., prev : prev + num_steps, :] = v_norms + self.offset += num_steps + + all_k = _dequantize( + _unpack(self._k_indices[..., :self.offset, :], self.turbo_bits, head_dim), + self._k_norms[..., :self.offset, :], + self._rotation, + self._centroids, + ) + all_v = _dequantize( + _unpack(self._v_indices[..., :self.offset, :], self.turbo_bits, head_dim), + self._v_norms[..., :self.offset, :], + self._rotation, + self._centroids, + ) + return all_k, all_v + + def _expand(self, B, n_kv_heads, new_steps, dtype, packed_dim): + alloc = ((self.step + new_steps - 1) // self.step) * self.step + shape = (B, n_kv_heads, alloc) + + def _new(): + return ( + mx.zeros((*shape, packed_dim), dtype=mx.uint32), + mx.zeros((*shape, 1), dtype=dtype), + mx.zeros((*shape, packed_dim), dtype=mx.uint32), + mx.zeros((*shape, 1), dtype=dtype), + ) + + if self._k_indices is not None and self.offset > 0: + old = ( + self._k_indices[..., :self.offset, :], + self._k_norms[..., :self.offset, :], + self._v_indices[..., :self.offset, :], + self._v_norms[..., :self.offset, :], + ) + self._k_indices, self._k_norms, self._v_indices, self._v_norms = ( + mx.concatenate([o, n], axis=2) for o, n in zip(old, _new()) + ) + else: + self._k_indices, self._k_norms, self._v_indices, self._v_norms = _new() + + def size(self): + return self.offset + + @property + def state(self): + if self._k_indices is None: + return [] + return [ + self._k_indices[..., :self.offset, :], + self._k_norms[..., :self.offset, :], + self._v_indices[..., :self.offset, :], + self._v_norms[..., :self.offset, :], + ] + + @state.setter + def state(self, v): + if v is not None and v: + self._k_indices, self._k_norms, self._v_indices, self._v_norms = v + self.offset = self._k_indices.shape[2] + + @property + def meta_state(self): + return tuple(map(str, (self.offset, self.turbo_bits, self._head_dim or 0))) + + @meta_state.setter + def meta_state(self, v): + self.offset, self.turbo_bits = int(v[0]), int(v[1]) + head_dim = int(v[2]) + if head_dim > 0: + self._init_codebook(head_dim) + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + def make_mask(self, *args, **kwargs): + return create_attention_mask(*args, offset=self.offset, **kwargs) + + def empty(self): + return self._k_indices is None + + @property + def nbytes(self): + if self._k_indices is None: + return 0 + return sum( + a[..., :self.offset, :].nbytes + for a in (self._k_indices, self._k_norms, self._v_indices, self._v_norms) + ) diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index 05dcd7dc4..aadae56b0 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -23,6 +23,7 @@ save_prompt_cache, trim_prompt_cache, ) +from mlx_lm.models.turboquant import TurboQuantKVCache from mlx_lm.utils import load HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" @@ -672,6 +673,132 @@ def test_window_mask_with_full_kv_cache(self): expected = create_causal_mask(1, offset=32, window_size=4) self.assertTrue(mx.array_equal(mask, expected)) + def test_turboquant_cache_basic(self): + """Test TurboQuantKVCache update_and_fetch, offset, shapes.""" + for bits in [2, 3, 4]: + c = TurboQuantKVCache(bits=bits) + self.assertTrue(c.empty()) + self.assertEqual(c.size(), 0) + + # Prefill + k = mx.random.uniform(shape=(1, 8, 10, 64)) + v = mx.random.uniform(shape=(1, 8, 10, 64)) + dk, dv = c.update_and_fetch(k, v) + mx.eval(dk, dv) + + self.assertEqual(c.size(), 10) + self.assertFalse(c.empty()) + self.assertEqual(dk.shape, (1, 8, 10, 64)) + self.assertEqual(dv.shape, (1, 8, 10, 64)) + + # Decode step + k2 = mx.random.uniform(shape=(1, 8, 1, 64)) + v2 = mx.random.uniform(shape=(1, 8, 1, 64)) + dk2, dv2 = c.update_and_fetch(k2, v2) + mx.eval(dk2, dv2) + + self.assertEqual(c.size(), 11) + self.assertEqual(dk2.shape, (1, 8, 11, 64)) + + def test_turboquant_cache_quality(self): + """Test that TurboQuant dequantized output is close to input.""" + c = TurboQuantKVCache(bits=4) + k = mx.random.normal(shape=(1, 8, 32, 128)) + v = mx.random.normal(shape=(1, 8, 32, 128)) + dk, dv = c.update_and_fetch(k, v) + mx.eval(dk, dv, k, v) + + # Cosine similarity per vector should be high at 4-bit + cos_k = mx.mean( + mx.sum(k * dk, axis=-1) + / (mx.linalg.norm(k, axis=-1) * mx.linalg.norm(dk, axis=-1) + 1e-8) + ) + mx.eval(cos_k) + self.assertGreater(float(cos_k), 0.95) + + def test_turboquant_cache_trim(self): + """Test TurboQuantKVCache trim.""" + c = TurboQuantKVCache(bits=3) + k = mx.random.uniform(shape=(1, 4, 8, 64)) + c.update_and_fetch(k, k) + self.assertEqual(c.size(), 8) + self.assertTrue(c.is_trimmable()) + + trimmed = c.trim(3) + self.assertEqual(trimmed, 3) + self.assertEqual(c.size(), 5) + + def test_turboquant_save_load(self): + """Test TurboQuantKVCache save/load roundtrip.""" + cache_file = os.path.join(self.test_dir, "turbo_cache.safetensors") + + cache = [TurboQuantKVCache(bits=3) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 64)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertEqual(len(cache), len(loaded_cache)) + + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertEqual(c.turbo_bits, lc.turbo_bits) + + def test_turboquant_to_turboquant(self): + """Test KVCache.to_turboquant conversion.""" + kv_cache = KVCache() + k = mx.random.normal(shape=(1, 8, 16, 128)) + v = mx.random.normal(shape=(1, 8, 16, 128)) + kv_cache.update_and_fetch(k, v) + + tq_cache = kv_cache.to_turboquant(bits=4) + self.assertEqual(tq_cache.size(), 16) + self.assertFalse(tq_cache.empty()) + + def test_turboquant_with_model(self): + """Test TurboQuantKVCache with actual model generation.""" + num_layers = len(self.model.layers) + args = self.model.args + head_dim = getattr( + args, "head_dim", args.hidden_size // args.num_attention_heads + ) + + # FP16 baseline + fp16_cache = [KVCache() for _ in range(num_layers)] + prompt = mx.array([[1, 2, 3, 4, 5]]) + logits_fp16 = self.model(prompt, cache=fp16_cache) + mx.eval(logits_fp16) + + # TurboQuant + tq_cache = [TurboQuantKVCache(bits=4) for _ in range(num_layers)] + logits_tq = self.model(prompt, cache=tq_cache) + mx.eval(logits_tq) + + self.assertEqual(logits_fp16.shape, logits_tq.shape) + + # Logit cosine similarity should be reasonable + l1 = logits_fp16[0, -1].astype(mx.float32) + l2 = logits_tq[0, -1].astype(mx.float32) + cos = float( + mx.sum(l1 * l2) / (mx.linalg.norm(l1) * mx.linalg.norm(l2) + 1e-8) + ) + self.assertGreater(cos, 0.9) + + def test_turboquant_nbytes(self): + """Test TurboQuantKVCache memory is less than FP16.""" + kv = KVCache() + tq = TurboQuantKVCache(bits=3) + + k = mx.random.uniform(shape=(1, 8, 256, 128)) + v = mx.random.uniform(shape=(1, 8, 256, 128)) + + kv.update_and_fetch(k, v) + tq.update_and_fetch(k, v) + mx.eval(kv.keys, tq._k_indices) + + self.assertLess(tq.nbytes, kv.nbytes) + if __name__ == "__main__": unittest.main()