|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +""" |
| 4 | +Test modular OAI Triton MoE |
| 5 | +""" |
| 6 | + |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | + |
| 10 | +from vllm.utils.import_utils import has_triton_kernels |
| 11 | + |
| 12 | +if not has_triton_kernels(): |
| 13 | + pytest.skip( |
| 14 | + "triton_kernels not found, skipping all related tests", |
| 15 | + allow_module_level=True, |
| 16 | + ) |
| 17 | + |
| 18 | +from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig |
| 19 | +from triton_kernels.numerics import InFlexData |
| 20 | +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp |
| 21 | +from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor |
| 22 | +from triton_kernels.tensor_details import layout |
| 23 | +from triton_kernels.testing import assert_close |
| 24 | + |
| 25 | +from vllm.config import VllmConfig, set_current_vllm_config |
| 26 | +from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config |
| 27 | +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( |
| 28 | + OAITritonExperts, |
| 29 | + UnfusedOAITritonExperts, |
| 30 | +) |
| 31 | +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel |
| 32 | +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( |
| 33 | + MoEPrepareAndFinalizeNoEP, |
| 34 | +) |
| 35 | +from vllm.model_executor.layers.utils import shuffle_weight |
| 36 | +from vllm.platforms import current_platform |
| 37 | + |
| 38 | +MNK = [ |
| 39 | + (1, 512, 384), |
| 40 | + (1, 2880, 2880), |
| 41 | + (2, 512, 384), |
| 42 | + (2, 2880, 2880), |
| 43 | + (16, 2880, 2880), |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +def unshuffle_weight(w: torch.Tensor): |
| 48 | + first = w[..., ::2] |
| 49 | + second = w[..., 1::2] |
| 50 | + return torch.concat((first, second), dim=-1) |
| 51 | + |
| 52 | + |
| 53 | +def make_weights(dtype, k, n, e): |
| 54 | + w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda") |
| 55 | + w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda") |
| 56 | + |
| 57 | + w2 = torch.randn((e, n, k), dtype=dtype, device="cuda") |
| 58 | + w2_bias = torch.randn((e, k), dtype=dtype, device="cuda") |
| 59 | + |
| 60 | + w1_tri = w1.clone() |
| 61 | + w2_tri = w2.clone() |
| 62 | + |
| 63 | + w1_bias_tri = w1_bias.clone() |
| 64 | + w2_bias_tri = w2_bias.clone() |
| 65 | + w1_bias_tri = w1_bias_tri.to(torch.float32) |
| 66 | + w2_bias_tri = w2_bias_tri.to(torch.float32) |
| 67 | + |
| 68 | + # shuffle weights |
| 69 | + w1_tri = shuffle_weight(w1_tri) |
| 70 | + w1_bias_tri = shuffle_weight(w1_bias_tri) |
| 71 | + |
| 72 | + # quant triton_weights |
| 73 | + w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) |
| 74 | + w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1) |
| 75 | + w1 = unshuffle_weight(w1) |
| 76 | + |
| 77 | + w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) |
| 78 | + w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1) |
| 79 | + |
| 80 | + num_warps = 8 |
| 81 | + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) |
| 82 | + w_scale_layout, w_scale_layout_opts = ( |
| 83 | + layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps) |
| 84 | + ) |
| 85 | + |
| 86 | + w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts) |
| 87 | + w1_scale_tri = convert_layout( |
| 88 | + wrap_torch_tensor(w1_scale_tri), |
| 89 | + w_scale_layout, |
| 90 | + **w_scale_layout_opts, |
| 91 | + ) |
| 92 | + |
| 93 | + w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts) |
| 94 | + w2_scale_tri = convert_layout( |
| 95 | + wrap_torch_tensor(w2_scale_tri), |
| 96 | + w_scale_layout, |
| 97 | + **w_scale_layout_opts, |
| 98 | + ) |
| 99 | + |
| 100 | + w1_precision_config = PrecisionConfig( |
| 101 | + weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) |
| 102 | + ) |
| 103 | + w2_precision_config = PrecisionConfig( |
| 104 | + weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) |
| 105 | + ) |
| 106 | + |
| 107 | + return ( |
| 108 | + w1, |
| 109 | + w2, |
| 110 | + w1_bias, |
| 111 | + w2_bias, |
| 112 | + w1_tri, |
| 113 | + w2_tri, |
| 114 | + w1_bias_tri, |
| 115 | + w2_bias_tri, |
| 116 | + w1_precision_config, |
| 117 | + w2_precision_config, |
| 118 | + ) |
| 119 | + |
| 120 | + |
| 121 | +def swiglu(x, alpha: float = 1.702, limit: float = 1.0): |
| 122 | + # Note we add an extra bias of 1 to the linear layer |
| 123 | + x_glu, x_linear = torch.chunk(x, 2, dim=-1) |
| 124 | + if limit is not None: |
| 125 | + x_glu = x_glu.clamp(max=limit) |
| 126 | + out_glu = x_glu * torch.sigmoid(alpha * x_glu) |
| 127 | + if limit is not None: |
| 128 | + x_linear = x_linear.clamp(min=-limit, max=limit) |
| 129 | + return out_glu * (x_linear + 1) |
| 130 | + |
| 131 | + |
| 132 | +def torch_moe_impl( |
| 133 | + hidden_states: torch.Tensor, # (M, K) |
| 134 | + w1: torch.Tensor, # (E, K, 2N) |
| 135 | + w2: torch.Tensor, # (E, N, K) |
| 136 | + w1_bias: torch.Tensor, # (E, 2N) |
| 137 | + w2_bias: torch.Tensor, # (E, K) |
| 138 | + topk_weights: torch.Tensor, # (M, topk) |
| 139 | + topk_ids: torch.Tensor, # (M, topk) |
| 140 | +): |
| 141 | + w1 = w1[topk_ids, ...] |
| 142 | + w1_bias = w1_bias[topk_ids, ...] |
| 143 | + hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias |
| 144 | + hidden_states = swiglu(hidden_states, limit=7) |
| 145 | + |
| 146 | + w2 = w2[topk_ids, ...] |
| 147 | + w2_bias = w2_bias[topk_ids, ...] |
| 148 | + hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias |
| 149 | + |
| 150 | + # Weighted sum of experts |
| 151 | + hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights) |
| 152 | + return hidden_states |
| 153 | + |
| 154 | + |
| 155 | +def oai_triton_moe_impl( |
| 156 | + x: torch.Tensor, |
| 157 | + w1: torch.Tensor, |
| 158 | + w2: torch.Tensor, |
| 159 | + w1_scale: "PrecisionConfig", |
| 160 | + w2_scale: "PrecisionConfig", |
| 161 | + w1_bias: torch.Tensor | None, |
| 162 | + w2_bias: torch.Tensor | None, |
| 163 | + num_experts: int, |
| 164 | + topk_weights: torch.Tensor, |
| 165 | + topk_ids: torch.Tensor, |
| 166 | + unfused: bool = False, |
| 167 | +) -> torch.Tensor: |
| 168 | + quant_config = mxfp4_w4a16_moe_quant_config( |
| 169 | + w1_bias=w1_bias, |
| 170 | + w2_bias=w2_bias, |
| 171 | + w1_scale=w1_scale, |
| 172 | + w2_scale=w2_scale, |
| 173 | + ) |
| 174 | + |
| 175 | + if unfused: |
| 176 | + fused_experts = UnfusedOAITritonExperts(quant_config) |
| 177 | + else: |
| 178 | + fused_experts = OAITritonExperts(quant_config) |
| 179 | + |
| 180 | + mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts) |
| 181 | + |
| 182 | + return mk.forward( |
| 183 | + hidden_states=x, |
| 184 | + w1=w1, |
| 185 | + w2=w2, |
| 186 | + topk_weights=topk_weights, |
| 187 | + topk_ids=topk_ids, |
| 188 | + inplace=True, |
| 189 | + activation="swigluoai", |
| 190 | + global_num_experts=num_experts, |
| 191 | + expert_map=None, |
| 192 | + apply_router_weight_on_input=False, |
| 193 | + ) |
| 194 | + |
| 195 | + |
| 196 | +@pytest.mark.skipif( |
| 197 | + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." |
| 198 | +) |
| 199 | +@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
| 200 | +@pytest.mark.parametrize("m,n,k", MNK) |
| 201 | +@pytest.mark.parametrize("num_experts", [32, 128]) |
| 202 | +@pytest.mark.parametrize("topk", [4]) |
| 203 | +@pytest.mark.parametrize("unfused", [True, False]) |
| 204 | +def test_oai_triton_moe( |
| 205 | + dtype: torch.dtype, |
| 206 | + m: int, |
| 207 | + n: int, |
| 208 | + k: int, |
| 209 | + num_experts: int, |
| 210 | + topk: int, |
| 211 | + unfused: bool, |
| 212 | +): |
| 213 | + current_platform.seed_everything(0) |
| 214 | + ( |
| 215 | + w1, |
| 216 | + w2, |
| 217 | + w1_bias, |
| 218 | + w2_bias, |
| 219 | + w1_tri, |
| 220 | + w2_tri, |
| 221 | + w1_bias_tri, |
| 222 | + w2_bias_tri, |
| 223 | + w1_precision_config, |
| 224 | + w2_precision_config, |
| 225 | + ) = make_weights(dtype, k, n, num_experts) |
| 226 | + |
| 227 | + x = torch.randn((m, k), dtype=dtype, device="cuda") |
| 228 | + router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype) |
| 229 | + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True) |
| 230 | + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) |
| 231 | + |
| 232 | + with set_current_vllm_config(VllmConfig()): |
| 233 | + out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids) |
| 234 | + |
| 235 | + out = oai_triton_moe_impl( |
| 236 | + x, |
| 237 | + w1_tri, |
| 238 | + w2_tri, |
| 239 | + w1_precision_config, |
| 240 | + w2_precision_config, |
| 241 | + w1_bias_tri, |
| 242 | + w2_bias_tri, |
| 243 | + num_experts, |
| 244 | + topk_weights, |
| 245 | + topk_ids, |
| 246 | + unfused, |
| 247 | + ) |
| 248 | + |
| 249 | + assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005) |
0 commit comments