Fuse gate/up expert projections in SwitchGLU#1032
Fuse gate/up expert projections in SwitchGLU#1032Thump604 wants to merge 3 commits intoml-explore:mainfrom
Conversation
Add fuse_gate_up option to SwitchGLU that uses a single gather_qmm call with 2x hidden_dims instead of two separate calls for gate_proj and up_proj. Eliminates one kernel dispatch per MoE layer per token. Measured +5% tok/s on Qwen3.5-122B, +8.6% on Qwen3-30B, +5.1% on MiniMax M2.5, +3.8% on OLMoE (ref: ml-explore#956). Models updated: Qwen3, Qwen3.5 (all variants), Llama 4, Mixtral, MiniMax, OLMoE.
4ccf76f to
81270f9
Compare
|
Update: Metal OOM on memory-constrained setup Testing on M2 Ultra 128GB with Qwen3.5-122B-A10B (5-bit, ~82GB weights), the fused The 122B model leaves only ~46GB headroom for KV cache + Metal scratch + OS. The single larger The original benchmarks in #956 were on M3 Ultra 512GB where scratch memory isn't a constraint. This suggests the optimization needs a memory-aware fallback — perhaps fuse only when headroom is sufficient, or cap the fused dimension. Will investigate whether |
Read fuse_gate_up from model config (default True) instead of hardcoding. On memory-constrained systems (e.g. 122B on 128GB), set "fuse_gate_up": false in config.json to use separate gate/up projections that allow Metal to deallocate scratch between kernels. Also fixes sanitize() in qwen3_next and qwen3_moe which produced unfused weight names while the model expected fused — a mismatch that would break loading from HuggingFace format.
|
Update: OOM fix implemented Added config-driven
All 10 model files now support both paths, controlled by a single config field. |
|
I tried this on some of my existing MiniMax DWQ quants (e.g. catalystsec/MiniMax-M2.5-4bit-DWQ) but unfortunately they no longer work with these changes, so you may want to look at the impact on existing (quantized) models. |
Existing quantized models (e.g. MiniMax DWQ quants) have separate gate_proj/up_proj weights. Defaulting to True broke them because the model init creates gate_up_proj while weights have unfused names, and quantized weights cannot be concatenated (scales/biases mismatch). Default to False so existing models work unchanged. New conversions can opt in by setting "fuse_gate_up": true in config.json.
|
Thanks @kernelpool for the report! Fixed in e7de34b. The issue was that Fix: Default is now |
|
@angeloskath @awni — this PR has been open since March 20 with no maintainer review. A community member (kernelpool) found a breaking-change issue with existing quantized models, which we fixed the same day — default is now Is there a concern with the approach? Happy to rework if needed. The fused path gives ~5% speedup on SwitchGLU MoE models for users who opt in via config. |
Summary
Add
fuse_gate_upoption toSwitchGLUthat performs a singlegather_qmmcall with concatenated gate+up weights instead of two separate calls. Eliminates one kernel dispatch per MoE layer per token.Follows the approach proposed in #956 by @BurntToastGPT: handle fusion at the model layer via
sanitize.Measured Results (from #956)
Changes (10 files, 6 model families)
Core:
switch_layers.py:SwitchGLUgetsfuse_gate_up=Falseparameter. When True, creates singlegate_up_projSwitchLinear (2x hidden_dims). Forward pass auto-detects via"gate_up_proj" in self.Models with already-fused checkpoint weights (stop splitting):
qwen3_5_moe.py,qwen3_vl_moe.py: Sanitize keeps fused weights asgate_up_proj.weightllama4.py: Same pattern (contiguous split + swapaxes)Models with per-expert weights (stack + concatenate):
olmoe.py: Stack per-expert gate/up, concatenate into gate_up_proj (handles quantized weights/scales/biases)mixtral.py: Same pattern with w1/w2/w3 namingminimax.py: Same pattern (FP8 dequant then fuse)Constructor updates:
qwen3_next.py,qwen3_moe.py: Passfuse_gate_up=TrueSharding:
qwen3_5.py: Handles both fused and unfused pathsBackward compatible —
fuse_gate_up=False(default) preserves existing behavior for all other models using SwitchGLU. GPT-OSS excluded (interleaved weights need separate handling).Test plan