Skip to content

feat: add MLX backend prototype for PolarQuant, TurboQuant, and KV cache compression#52

Open
dipeshbabu wants to merge 1 commit intoTheTom:mainfrom
dipeshbabu:feat/mlx-python-backend-prototype
Open

feat: add MLX backend prototype for PolarQuant, TurboQuant, and KV cache compression#52
dipeshbabu wants to merge 1 commit intoTheTom:mainfrom
dipeshbabu:feat/mlx-python-backend-prototype

Conversation

@dipeshbabu
Copy link
Copy Markdown

This PR adds a first MLX backend for the Python TurboQuant prototype on Apple Silicon.

Scope is intentionally limited to the existing Python path:

  • add optional MLX-backed implementations of PolarQuant, QJL, TurboQuant, TurboQuantMSE, and KVCacheCompressor
  • export the MLX API from the package
  • add parity tests against the current NumPy implementation
  • add a small benchmark script for Apple Silicon KV-cache compression
  • document the optional mlx install path

Scope

The repo README and roadmap explicitly call out MLX support as a desired contribution, but a full runtime/inference integration would be too large for a first MLX PR.

This PR targets the part of the project that is already backend-shaped today:

  • vector rotation
  • norm extraction / rescaling
  • scalar centroid assignment
  • residual quantization
  • batch KV cache compression / decompression

That provides a clean phase-1 MLX contribution and a base for later Apple Silicon work.

Changes

New

  • turboquant/mlx_backend.py

    • optional MLX backend
    • MLXPolarQuant
    • MLXQJL
    • MLXTurboQuant
    • MLXTurboQuantMSE
    • MLXKVCacheCompressor
    • MLX_AVAILABLE
    • to_numpy()
  • tests/test_mlx_backend.py

    • MLX vs NumPy parity tests
    • KV cache reconstruction parity
    • attention-output sanity check
    • skip cleanly when mlx is not installed
  • benchmarks/benchmark_mlx_backend.py

    • benchmark harness for NumPy vs MLX on Qwen-shaped KV tensors

Updated

  • turboquant/__init__.py

    • exports MLX backend symbols
  • pyproject.toml

    • adds optional mlx extra
  • README.md

    • documents optional MLX install and benchmark entry point

Validation

  • benchmarks/benchmark_mlx_backend.py
    • benchmark harness for NumPy vs MLX on Qwen-shaped KV tensors

Updated

  • turboquant/__init__.py

    • exports MLX backend symbols
  • pyproject.toml

    • adds optional mlx extra
  • README.md

    • documents optional MLX install and benchmark entry point

Validation

Validated locally in this environment with:

  • import checks
  • Python bytecode compilation
  • parity checks using a NumPy-backed shim for the MLX control flow

I could not run the real MLX backend in this environment because MLX is not available here, and pytest was not installed locally. The test
file is included and will run on an MLX-capable setup.

Notes

This is a Python prototype/backend contribution only.
Not included in this PR:

  • MLX-Swift integration
  • custom Metal kernels
  • end-to-end model/runtime wiring
  • fused attention integration

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant