Skip to content

mpsops/mps-linear-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mps-linear-attention

Metal kernels for Flash Linear Attention on Apple Silicon (MPS). Drop-in replacement for causal-conv1d and flash-linear-attention — both CUDA-only — enabling the fast path for models like Qwen3.5 on Mac.

Install

pip install mps-linear-attn

Or from source:

git clone https://github.com/mpsops/mps-linear-attention
cd mps-linear-attention
pip install -e . --no-build-isolation

Usage

Import before loading your model — that's it:

import mps_linear_attn  # patches causal_conv1d + fla on import

from transformers import Qwen3_5ForConditionalGeneration, AutoTokenizer
import torch

model = Qwen3_5ForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3.5-0.8B", dtype=torch.float16
).to("mps")

No code changes needed. The fast path activates automatically.

What it does

Transformers checks for causal_conv1d and flash-linear-attention at import time to enable the fast path. Both require CUDA and fail silently on MPS, forcing a slow fallback.

This package:

  1. Monkey-patches sys.modules with MPS-compatible implementations
  2. Overrides is_causal_conv1d_available() and is_flash_linear_attention_available() to return True
  3. Provides Metal GPU kernels for the hot paths

Metal kernels

Kernel Description
causal_conv1d_fwd Depthwise causal 1D conv + optional SiLU, fp16/fp32
causal_conv1d_update Sliding-window state update + conv + SiLU (decode)
delta_rule_recurrent_step Single-token gated delta rule state update
delta_rule_recurrent_fused Fused T-step recurrent loop (prefill), single dispatch

State is always float32 for numerical precision. Q/K/V/beta are float16.

Performance

Tested on Qwen3.5-0.8B, M1, macOS 26, PyTorch 2.7:

tok/s
Baseline (no fast path) ~16
mps-linear-attn ~30

Supported models

  • Qwen3.5 (all sizes) — uses GatedDeltaNet linear attention layers
  • Any model using fla.ops.gated_delta_rule or causal_conv1d

Requirements

  • macOS (Apple Silicon — M1 or later)
  • PyTorch >= 2.0 with MPS support
  • Xcode Command Line Tools

About

Metal kernels for Flash Linear Attention (DeltaNet/Qwen3.5) on Apple Silicon MPS

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors