Metal kernels for Flash Linear Attention on Apple Silicon (MPS). Drop-in replacement for causal-conv1d and flash-linear-attention — both CUDA-only — enabling the fast path for models like Qwen3.5 on Mac.
pip install mps-linear-attnOr from source:
git clone https://github.com/mpsops/mps-linear-attention
cd mps-linear-attention
pip install -e . --no-build-isolationImport before loading your model — that's it:
import mps_linear_attn # patches causal_conv1d + fla on import
from transformers import Qwen3_5ForConditionalGeneration, AutoTokenizer
import torch
model = Qwen3_5ForConditionalGeneration.from_pretrained(
"Qwen/Qwen3.5-0.8B", dtype=torch.float16
).to("mps")No code changes needed. The fast path activates automatically.
Transformers checks for causal_conv1d and flash-linear-attention at import time to enable the fast path. Both require CUDA and fail silently on MPS, forcing a slow fallback.
This package:
- Monkey-patches
sys.moduleswith MPS-compatible implementations - Overrides
is_causal_conv1d_available()andis_flash_linear_attention_available()to returnTrue - Provides Metal GPU kernels for the hot paths
| Kernel | Description |
|---|---|
causal_conv1d_fwd |
Depthwise causal 1D conv + optional SiLU, fp16/fp32 |
causal_conv1d_update |
Sliding-window state update + conv + SiLU (decode) |
delta_rule_recurrent_step |
Single-token gated delta rule state update |
delta_rule_recurrent_fused |
Fused T-step recurrent loop (prefill), single dispatch |
State is always float32 for numerical precision. Q/K/V/beta are float16.
Tested on Qwen3.5-0.8B, M1, macOS 26, PyTorch 2.7:
| tok/s | |
|---|---|
| Baseline (no fast path) | ~16 |
| mps-linear-attn | ~30 |
- Qwen3.5 (all sizes) — uses GatedDeltaNet linear attention layers
- Any model using
fla.ops.gated_delta_ruleorcausal_conv1d
- macOS (Apple Silicon — M1 or later)
- PyTorch >= 2.0 with MPS support
- Xcode Command Line Tools