From 6a9f54ce68b184affeb302130d20654cc1df2915 Mon Sep 17 00:00:00 2001 From: zhimding Date: Sun, 28 Sep 2025 11:15:38 +0000 Subject: [PATCH 1/2] update aiter fused_moe interface --- .../layers/quantization/mxfp4.py | 64 +++++++------------ 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 9d09c46245aa..4ee0ad68dfd8 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -63,7 +63,7 @@ def should_use_flashinfer_mxfp4(): if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: import aiter - from aiter.fused_moe import fused_topk, moe_sorting + from aiter.fused_moe import fused_moe, fused_topk, moe_sorting from aiter.ops.shuffle import shuffle_mxfp4_weight, shuffle_mxfp4_scale class Mxfp4Config(QuantizationConfig): @@ -690,51 +690,31 @@ def apply( token_num = x.shape[0] BLOCKM = 16 if token_num < 2048 else 32 topk_weights, topk_ids = fused_topk(x, router_logits, top_k, True) - sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_out = moe_sorting( - topk_ids, - topk_weights, - self.num_experts, - x.shape[1], - torch.bfloat16, - BLOCKM - ) - _, n1, k1 = self.w13_weight_aiter_tensor.shape - _, k2, n2 = self.w2_weight_aiter_tensor.shape - D = n2 if k2 == k1 else n2*2 - cktile_moe_out1 = torch.empty((token_num, top_k, D), dtype=torch.bfloat16, device=x.device) - aiter.moe_cktile2stages_gemm1( + return fused_moe( x, self.w13_weight_aiter_tensor, - cktile_moe_out1, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - top_k, - self.intermediate_pad // 64 * 64 * 2, - self.hidden_pad // 128 * 128, # k_pad_zeros - None, # sorted_weights - None, - self.w13_scale_aiter_tensor, - self.w13_bias_aiter_tensor, - BLOCKM, # block_size - ) - aiter.moe_cktile2stages_gemm2( - cktile_moe_out1, self.w2_weight_aiter_tensor, - moe_out, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - top_k, - self.hidden_pad // 64 * 64, # n_pad_zeros - self.intermediate_pad // 128 * 128, - sorted_weights, # sorted_weights - None, - self.w2_scale_aiter_tensor, - layer.w2_bias, - BLOCKM, # block_size + topk_weights, + topk_ids, + expert_mask=None, + activation=aiter.ActivationType.Swiglu, + quant_type=aiter.QuantType.per_1x32, + doweight_stage1=False, + w1_scale=self.w13_scale_aiter_tensor, + w2_scale=self.w2_scale_aiter_tensor, + a1_scale=None, + a2_scale=None, + block_size_M=BLOCKM, + num_local_tokens=None, + moe_sorting_dispatch_policy=0, + dtype=None, + n_pad_zeros=self.intermediate_pad // 64 * 64 * 2, + k_pad_zeros=self.hidden_pad // 128 * 128, + n_pad_zeros2=self.hidden_pad // 64 * 64, + k_pad_zeros2=self.intermediate_pad // 128 * 128, + bias1=self.w13_bias_aiter_tensor, + bias2=layer.w2_bias, ) - return moe_out from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward) From a5b8510887fea92782635ac92784134f4fa2ba7b Mon Sep 17 00:00:00 2001 From: zhimding Date: Mon, 29 Sep 2025 02:58:15 +0000 Subject: [PATCH 2/2] update --- vllm/model_executor/layers/quantization/mxfp4.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4ee0ad68dfd8..b2e34c286773 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -708,10 +708,8 @@ def apply( num_local_tokens=None, moe_sorting_dispatch_policy=0, dtype=None, - n_pad_zeros=self.intermediate_pad // 64 * 64 * 2, - k_pad_zeros=self.hidden_pad // 128 * 128, - n_pad_zeros2=self.hidden_pad // 64 * 64, - k_pad_zeros2=self.intermediate_pad // 128 * 128, + hidden_pad=self.hidden_pad, + intermediate_pad=self.intermediate_pad, bias1=self.w13_bias_aiter_tensor, bias2=layer.w2_bias, )