Skip to content

Commit a491b09

Browse files
xyang16jeejeelee
andauthored
[LoRA] Support FusedMoE LoRA Triton kernel for mxfp4 (#29708)
Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: Xin Yang <105740670+xyang16@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent b9d0504 commit a491b09

File tree

4 files changed

+439
-11
lines changed

4 files changed

+439
-11
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Test modular OAI Triton MoE
5+
"""
6+
7+
import pytest
8+
import torch
9+
10+
from vllm.utils.import_utils import has_triton_kernels
11+
12+
if not has_triton_kernels():
13+
pytest.skip(
14+
"triton_kernels not found, skipping all related tests",
15+
allow_module_level=True,
16+
)
17+
18+
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
19+
from triton_kernels.numerics import InFlexData
20+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
21+
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
22+
from triton_kernels.tensor_details import layout
23+
from triton_kernels.testing import assert_close
24+
25+
from vllm.config import VllmConfig, set_current_vllm_config
26+
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
27+
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
28+
OAITritonExperts,
29+
UnfusedOAITritonExperts,
30+
)
31+
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
32+
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
33+
MoEPrepareAndFinalizeNoEP,
34+
)
35+
from vllm.model_executor.layers.utils import shuffle_weight
36+
from vllm.platforms import current_platform
37+
38+
MNK = [
39+
(1, 512, 384),
40+
(1, 2880, 2880),
41+
(2, 512, 384),
42+
(2, 2880, 2880),
43+
(16, 2880, 2880),
44+
]
45+
46+
47+
def unshuffle_weight(w: torch.Tensor):
48+
first = w[..., ::2]
49+
second = w[..., 1::2]
50+
return torch.concat((first, second), dim=-1)
51+
52+
53+
def make_weights(dtype, k, n, e):
54+
w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda")
55+
w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda")
56+
57+
w2 = torch.randn((e, n, k), dtype=dtype, device="cuda")
58+
w2_bias = torch.randn((e, k), dtype=dtype, device="cuda")
59+
60+
w1_tri = w1.clone()
61+
w2_tri = w2.clone()
62+
63+
w1_bias_tri = w1_bias.clone()
64+
w2_bias_tri = w2_bias.clone()
65+
w1_bias_tri = w1_bias_tri.to(torch.float32)
66+
w2_bias_tri = w2_bias_tri.to(torch.float32)
67+
68+
# shuffle weights
69+
w1_tri = shuffle_weight(w1_tri)
70+
w1_bias_tri = shuffle_weight(w1_bias_tri)
71+
72+
# quant triton_weights
73+
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
74+
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1)
75+
w1 = unshuffle_weight(w1)
76+
77+
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
78+
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1)
79+
80+
num_warps = 8
81+
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
82+
w_scale_layout, w_scale_layout_opts = (
83+
layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
84+
)
85+
86+
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts)
87+
w1_scale_tri = convert_layout(
88+
wrap_torch_tensor(w1_scale_tri),
89+
w_scale_layout,
90+
**w_scale_layout_opts,
91+
)
92+
93+
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts)
94+
w2_scale_tri = convert_layout(
95+
wrap_torch_tensor(w2_scale_tri),
96+
w_scale_layout,
97+
**w_scale_layout_opts,
98+
)
99+
100+
w1_precision_config = PrecisionConfig(
101+
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
102+
)
103+
w2_precision_config = PrecisionConfig(
104+
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
105+
)
106+
107+
return (
108+
w1,
109+
w2,
110+
w1_bias,
111+
w2_bias,
112+
w1_tri,
113+
w2_tri,
114+
w1_bias_tri,
115+
w2_bias_tri,
116+
w1_precision_config,
117+
w2_precision_config,
118+
)
119+
120+
121+
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
122+
# Note we add an extra bias of 1 to the linear layer
123+
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
124+
if limit is not None:
125+
x_glu = x_glu.clamp(max=limit)
126+
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
127+
if limit is not None:
128+
x_linear = x_linear.clamp(min=-limit, max=limit)
129+
return out_glu * (x_linear + 1)
130+
131+
132+
def torch_moe_impl(
133+
hidden_states: torch.Tensor, # (M, K)
134+
w1: torch.Tensor, # (E, K, 2N)
135+
w2: torch.Tensor, # (E, N, K)
136+
w1_bias: torch.Tensor, # (E, 2N)
137+
w2_bias: torch.Tensor, # (E, K)
138+
topk_weights: torch.Tensor, # (M, topk)
139+
topk_ids: torch.Tensor, # (M, topk)
140+
):
141+
w1 = w1[topk_ids, ...]
142+
w1_bias = w1_bias[topk_ids, ...]
143+
hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias
144+
hidden_states = swiglu(hidden_states, limit=7)
145+
146+
w2 = w2[topk_ids, ...]
147+
w2_bias = w2_bias[topk_ids, ...]
148+
hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias
149+
150+
# Weighted sum of experts
151+
hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights)
152+
return hidden_states
153+
154+
155+
def oai_triton_moe_impl(
156+
x: torch.Tensor,
157+
w1: torch.Tensor,
158+
w2: torch.Tensor,
159+
w1_scale: "PrecisionConfig",
160+
w2_scale: "PrecisionConfig",
161+
w1_bias: torch.Tensor | None,
162+
w2_bias: torch.Tensor | None,
163+
num_experts: int,
164+
topk_weights: torch.Tensor,
165+
topk_ids: torch.Tensor,
166+
unfused: bool = False,
167+
) -> torch.Tensor:
168+
quant_config = mxfp4_w4a16_moe_quant_config(
169+
w1_bias=w1_bias,
170+
w2_bias=w2_bias,
171+
w1_scale=w1_scale,
172+
w2_scale=w2_scale,
173+
)
174+
175+
if unfused:
176+
fused_experts = UnfusedOAITritonExperts(quant_config)
177+
else:
178+
fused_experts = OAITritonExperts(quant_config)
179+
180+
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
181+
182+
return mk.forward(
183+
hidden_states=x,
184+
w1=w1,
185+
w2=w2,
186+
topk_weights=topk_weights,
187+
topk_ids=topk_ids,
188+
inplace=True,
189+
activation="swigluoai",
190+
global_num_experts=num_experts,
191+
expert_map=None,
192+
apply_router_weight_on_input=False,
193+
)
194+
195+
196+
@pytest.mark.skipif(
197+
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
198+
)
199+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
200+
@pytest.mark.parametrize("m,n,k", MNK)
201+
@pytest.mark.parametrize("num_experts", [32, 128])
202+
@pytest.mark.parametrize("topk", [4])
203+
@pytest.mark.parametrize("unfused", [True, False])
204+
def test_oai_triton_moe(
205+
dtype: torch.dtype,
206+
m: int,
207+
n: int,
208+
k: int,
209+
num_experts: int,
210+
topk: int,
211+
unfused: bool,
212+
):
213+
current_platform.seed_everything(0)
214+
(
215+
w1,
216+
w2,
217+
w1_bias,
218+
w2_bias,
219+
w1_tri,
220+
w2_tri,
221+
w1_bias_tri,
222+
w2_bias_tri,
223+
w1_precision_config,
224+
w2_precision_config,
225+
) = make_weights(dtype, k, n, num_experts)
226+
227+
x = torch.randn((m, k), dtype=dtype, device="cuda")
228+
router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype)
229+
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True)
230+
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
231+
232+
with set_current_vllm_config(VllmConfig()):
233+
out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids)
234+
235+
out = oai_triton_moe_impl(
236+
x,
237+
w1_tri,
238+
w2_tri,
239+
w1_precision_config,
240+
w2_precision_config,
241+
w1_bias_tri,
242+
w2_bias_tri,
243+
num_experts,
244+
topk_weights,
245+
topk_ids,
246+
unfused,
247+
)
248+
249+
assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005)

vllm/lora/layers/fused_moe.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,24 @@
2020
_get_config_dtype_str,
2121
)
2222
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
23-
modular_marlin_fused_moe,
23+
MarlinExperts,
2424
)
2525
from vllm.model_executor.layers.fused_moe.fused_moe import (
26-
modular_triton_fused_moe,
26+
TritonExperts,
2727
try_get_optimal_moe_config,
2828
)
2929
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
3030
FusedMoEModularMethod,
3131
)
32+
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
33+
UnfusedOAITritonExperts,
34+
)
35+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
36+
FusedMoEModularKernel,
37+
)
38+
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
39+
MoEPrepareAndFinalizeNoEP,
40+
)
3241

3342
from .utils import _get_lora_device
3443

@@ -114,15 +123,23 @@ def _inject_lora_into_fused_moe(self):
114123
self.base_layer.ensure_moe_quant_config_init()
115124
quant_config = self.base_layer.quant_method.moe_quant_config
116125

117-
m_fused_moe_fn = (
118-
modular_triton_fused_moe(
119-
quant_config, shared_experts=self.base_layer.shared_experts
126+
prepare_finalize = MoEPrepareAndFinalizeNoEP()
127+
m_fused_moe_fn = FusedMoEModularKernel(
128+
prepare_finalize,
129+
self.base_layer.quant_method.select_gemm_impl(
130+
prepare_finalize, self.base_layer
131+
),
132+
self.base_layer.shared_experts,
133+
getattr(self.base_layer, "shared_experts_stream", None),
134+
)
135+
if quant_config.use_mxfp4_w4a16:
136+
assert isinstance(
137+
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
120138
)
121-
if not quant_config.use_mxfp4_w4a16
122-
else modular_marlin_fused_moe(
123-
quant_config, shared_experts=self.base_layer.shared_experts
139+
else:
140+
assert isinstance(
141+
m_fused_moe_fn.fused_experts, (MarlinExperts, TritonExperts)
124142
)
125-
)
126143

127144
def fwd_decorator(layer, func):
128145
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)