Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ Ratio: 1.000 exactly
- **NumPy** >= 1.24, **SciPy** >= 1.10
- **cmake** + C/C++ compiler (for llama.cpp build)
- **Xcode Command Line Tools** (macOS Metal build)
- **Optional**: `mlx` for the Apple Silicon MLX backend prototype
- **Optional**: `torch`, `transformers`, `accelerate` (~4GB download, for real model validation)

### Install the Python Prototype
Expand All @@ -330,6 +331,12 @@ pip install -e ".[dev]"
python3 -m pytest tests/ -v
```

On Apple Silicon, install the optional MLX backend dependencies with:

```bash
pip install -e ".[dev,mlx]"
```

### Run the Demo

```bash
Expand Down Expand Up @@ -450,6 +457,7 @@ turboquant/
tests/ # 14 test files, 500+ tests
benchmarks/
├── demo.py # Quick compression demo
├── benchmark_mlx_backend.py # NumPy vs MLX backend benchmark
├── run_benchmark.py # Server-based benchmark runner
├── benchmark_results.md # Full benchmark report
├── benchmark_llama.sh # llama.cpp benchmark script
Expand Down
96 changes: 96 additions & 0 deletions benchmarks/benchmark_mlx_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Benchmark the optional MLX backend against the NumPy prototype.

Usage:
python3 benchmarks/benchmark_mlx_backend.py
"""

from __future__ import annotations

import time

import numpy as np

from turboquant import KVCacheCompressor, MLXKVCacheCompressor, MLX_AVAILABLE
from turboquant.mlx_backend import to_numpy


QWEN_27B = {
"name": "Qwen 3.5 27B (dense)",
"num_layers": 28,
"num_kv_heads": 8,
"head_dim": 128,
}

QWEN_MOE = {
"name": "Qwen 3.5 35B-A3B (MoE)",
"num_layers": 28,
"num_kv_heads": 8,
"head_dim": 128,
}


def simulate_kv_cache(config: dict, seq_len: int, seed: int = 42) -> tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(seed)
shape = (config["num_layers"], config["num_kv_heads"], seq_len, config["head_dim"])
scale = 1.0 / np.sqrt(config["head_dim"])
k_cache = rng.standard_normal(shape) * scale
v_cache = rng.standard_normal(shape) * scale
return k_cache, v_cache


def run_backend(name: str, compressor, k_cache: np.ndarray, v_cache: np.ndarray):
t0 = time.perf_counter()
compressed = compressor.compress(k_cache, v_cache)
k_hat, v_hat = compressor.decompress(compressed)

if MLX_AVAILABLE and name == "mlx":
import mlx.core as mx

mx.eval(k_hat, v_hat)

elapsed = time.perf_counter() - t0
k_hat_np = to_numpy(k_hat)
v_hat_np = to_numpy(v_hat)

k_mse = np.mean((k_cache - k_hat_np) ** 2)
v_mse = np.mean((v_cache - v_hat_np) ** 2)
return elapsed, k_mse, v_mse


def main():
if not MLX_AVAILABLE:
raise SystemExit("MLX is not available. Install `mlx` on an MLX-supported system to run this benchmark.")

print("=" * 70)
print("TURBOQUANT MLX BACKEND BENCHMARK")
print("=" * 70)

for config in (QWEN_27B, QWEN_MOE):
print(f"\n{config['name']}")
for seq_len in (512, 2048):
print(f" seq_len={seq_len}")
k_cache, v_cache = simulate_kv_cache(config, seq_len)

numpy_compressor = KVCacheCompressor(head_dim=config["head_dim"], k_bits=3, v_bits=3)
mlx_compressor = MLXKVCacheCompressor(
head_dim=config["head_dim"],
k_bits=3,
v_bits=3,
dtype="float32",
)

numpy_elapsed, numpy_k_mse, numpy_v_mse = run_backend("numpy", numpy_compressor, k_cache, v_cache)
mlx_elapsed, mlx_k_mse, mlx_v_mse = run_backend("mlx", mlx_compressor, k_cache, v_cache)

print(
f" NumPy: {numpy_elapsed:.2f}s, "
f"K MSE={numpy_k_mse:.8f}, V MSE={numpy_v_mse:.8f}"
)
print(
f" MLX: {mlx_elapsed:.2f}s, "
f"K MSE={mlx_k_mse:.8f}, V MSE={mlx_v_mse:.8f}"
)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ dev = [
"pytest>=7.0",
"pytest-cov>=4.0",
]
mlx = [
"mlx>=0.31",
]
bench = [
"matplotlib",
]
Expand Down
137 changes: 137 additions & 0 deletions tests/test_mlx_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Parity tests for the optional MLX backend."""

from __future__ import annotations

import numpy as np
import pytest

from turboquant.kv_cache import KVCacheCompressor
from turboquant.mlx_backend import (
MLXKVCacheCompressor,
MLXPolarQuant,
MLXTurboQuant,
MLX_AVAILABLE,
to_numpy,
)
from turboquant.polar_quant import PolarQuant
from turboquant.turboquant import TurboQuant


mlx = pytest.importorskip("mlx.core", reason="MLX backend tests require the optional mlx package")


@pytest.mark.skipif(not MLX_AVAILABLE, reason="MLX is not available")
class TestMLXPolarQuant:
def test_quantize_matches_numpy(self):
d = 128
rng = np.random.default_rng(7)
x = rng.standard_normal((8, d))

numpy_pq = PolarQuant(d=d, bit_width=3, seed=42, norm_correction=True)
mlx_pq = MLXPolarQuant(d=d, bit_width=3, seed=42, norm_correction=True, dtype="float64")

numpy_idx, numpy_norms = numpy_pq.quantize(x)
mlx_idx, mlx_norms = mlx_pq.quantize(x)
mlx_idx_np = to_numpy(mlx_idx)
mlx_norms_np = to_numpy(mlx_norms)

np.testing.assert_array_equal(mlx_idx_np, numpy_idx)
np.testing.assert_allclose(mlx_norms_np, numpy_norms, atol=1e-10)

numpy_recon = numpy_pq.dequantize(numpy_idx, numpy_norms)
mlx_recon = to_numpy(mlx_pq.dequantize(mlx_idx, mlx_norms))
np.testing.assert_allclose(mlx_recon, numpy_recon, atol=1e-8)

def test_single_vector_round_trip(self):
d = 64
rng = np.random.default_rng(11)
x = rng.standard_normal(d)

mlx_pq = MLXPolarQuant(d=d, bit_width=2, seed=99, dtype="float64")
idx, norms = mlx_pq.quantize(x)
x_hat = to_numpy(mlx_pq.dequantize(idx, norms))

assert x_hat.shape == x.shape
assert np.mean((x - x_hat) ** 2) < np.mean(x ** 2)


@pytest.mark.skipif(not MLX_AVAILABLE, reason="MLX is not available")
class TestMLXTurboQuant:
def test_turboquant_matches_numpy(self):
d = 128
rng = np.random.default_rng(21)
x = rng.standard_normal((6, d))

numpy_tq = TurboQuant(d=d, bit_width=3, seed=42, norm_correction=True)
mlx_tq = MLXTurboQuant(d=d, bit_width=3, seed=42, norm_correction=True, dtype="float64")

numpy_compressed = numpy_tq.quantize(x)
mlx_compressed = mlx_tq.quantize(x)

np.testing.assert_array_equal(to_numpy(mlx_compressed.mse_indices), numpy_compressed.mse_indices)
np.testing.assert_allclose(to_numpy(mlx_compressed.vector_norms), numpy_compressed.vector_norms, atol=1e-10)
np.testing.assert_array_equal(to_numpy(mlx_compressed.qjl_signs), numpy_compressed.qjl_signs)
np.testing.assert_allclose(
to_numpy(mlx_compressed.residual_norms),
numpy_compressed.residual_norms,
atol=1e-10,
)

numpy_recon = numpy_tq.dequantize(numpy_compressed)
mlx_recon = to_numpy(mlx_tq.dequantize(mlx_compressed))
np.testing.assert_allclose(mlx_recon, numpy_recon, atol=1e-8)


@pytest.mark.skipif(not MLX_AVAILABLE, reason="MLX is not available")
class TestMLXKVCache:
def test_kv_cache_matches_numpy(self):
rng = np.random.default_rng(42)
k = rng.standard_normal((2, 3, 8, 64))
v = rng.standard_normal((2, 3, 8, 64))

numpy_compressor = KVCacheCompressor(head_dim=64, k_bits=3, v_bits=3)
mlx_compressor = MLXKVCacheCompressor(head_dim=64, k_bits=3, v_bits=3, dtype="float64")

numpy_compressed = numpy_compressor.compress(k, v)
mlx_compressed = mlx_compressor.compress(k, v)

numpy_k, numpy_v = numpy_compressor.decompress(numpy_compressed)
mlx_k, mlx_v = mlx_compressor.decompress(mlx_compressed)
mlx_k = to_numpy(mlx_k)
mlx_v = to_numpy(mlx_v)

np.testing.assert_allclose(mlx_k, numpy_k, atol=1e-8)
np.testing.assert_allclose(mlx_v, numpy_v, atol=1e-8)

def test_attention_output_remains_reasonable(self):
head_dim = 64
seq_len = 16
rng = np.random.default_rng(123)

q = rng.standard_normal((1, head_dim))
k = rng.standard_normal((seq_len, head_dim))
v = rng.standard_normal((seq_len, head_dim))

compressor = MLXKVCacheCompressor(head_dim=head_dim, k_bits=3, v_bits=3, dtype="float32")
compressed = compressor.compress(k[np.newaxis, np.newaxis, :, :], v[np.newaxis, np.newaxis, :, :])
k_hat, v_hat = compressor.decompress(compressed)
k_hat = to_numpy(k_hat)[0, 0]
v_hat = to_numpy(v_hat)[0, 0]

scores_orig = q @ k.T / np.sqrt(head_dim)
attn_orig = _softmax(scores_orig)
out_orig = attn_orig @ v

scores_comp = q @ k_hat.T / np.sqrt(head_dim)
attn_comp = _softmax(scores_comp)
out_comp = attn_comp @ v_hat

cosine = np.dot(out_orig.ravel(), out_comp.ravel()) / (
np.linalg.norm(out_orig) * np.linalg.norm(out_comp)
)
assert cosine > 0.5


def _softmax(x):
e = np.exp(x - np.max(x, axis=-1, keepdims=True))
return e / np.sum(e, axis=-1, keepdims=True)
23 changes: 22 additions & 1 deletion turboquant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,26 @@
from turboquant.qjl import QJL
from turboquant.turboquant import TurboQuant, TurboQuantMSE, CompressedVector
from turboquant.kv_cache import KVCacheCompressor
from turboquant.mlx_backend import (
MLX_AVAILABLE,
MLXKVCacheCompressor,
MLXPolarQuant,
MLXQJL,
MLXTurboQuant,
MLXTurboQuantMSE,
)

__all__ = ["PolarQuant", "QJL", "TurboQuant", "TurboQuantMSE", "CompressedVector", "KVCacheCompressor"]
__all__ = [
"PolarQuant",
"QJL",
"TurboQuant",
"TurboQuantMSE",
"CompressedVector",
"KVCacheCompressor",
"MLX_AVAILABLE",
"MLXPolarQuant",
"MLXQJL",
"MLXTurboQuant",
"MLXTurboQuantMSE",
"MLXKVCacheCompressor",
]
Loading