From 4fd9bef87bff152e82f6f88ed43603c9453b62e6 Mon Sep 17 00:00:00 2001 From: Waynezee Date: Thu, 2 Apr 2026 16:36:02 +0800 Subject: [PATCH 1/2] skip compute when no tokens need to process --- src/paddlefleet/transformer/moe/moe_layer.py | 124 ++++++++++--------- 1 file changed, 65 insertions(+), 59 deletions(-) diff --git a/src/paddlefleet/transformer/moe/moe_layer.py b/src/paddlefleet/transformer/moe/moe_layer.py index 4d49fd96b..8b83e6a34 100644 --- a/src/paddlefleet/transformer/moe/moe_layer.py +++ b/src/paddlefleet/transformer/moe/moe_layer.py @@ -442,67 +442,73 @@ def fusion_moe_forward( if self.using_sonic_moe: T = dispatched_hidden_states.shape[0] - K = self.num_experts_per_tok - stream_id = paddle.device.cuda.current_stream().cuda_stream - topk_scores = filter_scores( - dispatched_probs, - dispatched_indices, - ) - expert_frequency, expert_frequency_offset = count_cumsum( - dispatched_indices, self.num_experts_per_device, do_cumsum=True - ) - activation_type = ActivationType("swiglu") - - ( - expert_frequency_offset, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - ) = fused_expert_parallel_TC_topk_router_metadata( - dispatched_indices, - expert_frequency_offset, - K, - ) + # no tokens need to be process + if T == 0: + hidden_states = dispatched_hidden_states + else: + K = self.num_experts_per_tok + stream_id = paddle.device.cuda.current_stream().cuda_stream + topk_scores = filter_scores( + dispatched_probs, + dispatched_indices, + ) + expert_frequency, expert_frequency_offset = count_cumsum( + dispatched_indices, + self.num_experts_per_device, + do_cumsum=True, + ) + activation_type = ActivationType("swiglu") + + ( + expert_frequency_offset, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + ) = fused_expert_parallel_TC_topk_router_metadata( + dispatched_indices, + expert_frequency_offset, + K, + ) - TK = s_scatter_idx.shape[0] - is_varlen_K = True - w1 = self.grouped_gemm_experts.weight1 - y1, z = _UpProjection.apply( - dispatched_hidden_states, - w1.permute(1, 2, 0), - None, - expert_frequency_offset, - TK, - K, - stream_id, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - is_varlen_K, - activation_type, - is_inference_mode_enabled=False, - ) + TK = s_scatter_idx.shape[0] + is_varlen_K = True + w1 = self.grouped_gemm_experts.weight1 + y1, z = _UpProjection.apply( + dispatched_hidden_states, + w1.permute(1, 2, 0), + None, + expert_frequency_offset, + TK, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + is_varlen_K, + activation_type, + is_inference_mode_enabled=False, + ) - w2 = self.grouped_gemm_experts.weight2 - hidden_states = _DownProjection.apply( - y1, - z, - w2.permute(1, 2, 0), - None, - topk_scores, - expert_frequency_offset, - T, - K, - stream_id, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - is_varlen_K, - activation_type, - ) + w2 = self.grouped_gemm_experts.weight2 + hidden_states = _DownProjection.apply( + y1, + z, + w2.permute(1, 2, 0), + None, + topk_scores, + expert_frequency_offset, + T, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + is_varlen_K, + activation_type, + ) else: hidden_states = FusionMoePyLayer.apply( dispatched_hidden_states, From 5d6a27cd838bac3069919a8a4ad0dcdea212bb66 Mon Sep 17 00:00:00 2001 From: Waynezee Date: Thu, 2 Apr 2026 21:23:56 +0800 Subject: [PATCH 2/2] skip compute when no tokens need to process --- src/paddlefleet/transformer/moe/moe_layer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/paddlefleet/transformer/moe/moe_layer.py b/src/paddlefleet/transformer/moe/moe_layer.py index 8b83e6a34..b2d6ccfc4 100644 --- a/src/paddlefleet/transformer/moe/moe_layer.py +++ b/src/paddlefleet/transformer/moe/moe_layer.py @@ -60,6 +60,16 @@ unpermute, ) +class FakeMLP(PyLayer): + @staticmethod + def forward(ctx, hidden_states): + ctx.size = hidden_states.shape[-1] + ctx.dtype = hidden_states.dtype + return paddle.empty([0, ctx.size],dtype=ctx.dtype,requires_grad=True,) + @staticmethod + def backward(ctx, grad): + return paddle.empty([0, ctx.size],dtype=ctx.dtype,requires_grad=True,) + class GradDtypeGuard(PyLayer): """Guard the grad's dtype if different from input's dtype.""" @@ -444,7 +454,7 @@ def fusion_moe_forward( T = dispatched_hidden_states.shape[0] # no tokens need to be process if T == 0: - hidden_states = dispatched_hidden_states + hidden_states = FakeMLP.apply(dispatched_hidden_states) else: K = self.num_experts_per_tok stream_id = paddle.device.cuda.current_stream().cuda_stream