diff --git a/src/paddlefleet/transformer/moe/moe_layer.py b/src/paddlefleet/transformer/moe/moe_layer.py index 4d49fd96b..b2d6ccfc4 100644 --- a/src/paddlefleet/transformer/moe/moe_layer.py +++ b/src/paddlefleet/transformer/moe/moe_layer.py @@ -60,6 +60,16 @@ unpermute, ) +class FakeMLP(PyLayer): + @staticmethod + def forward(ctx, hidden_states): + ctx.size = hidden_states.shape[-1] + ctx.dtype = hidden_states.dtype + return paddle.empty([0, ctx.size],dtype=ctx.dtype,requires_grad=True,) + @staticmethod + def backward(ctx, grad): + return paddle.empty([0, ctx.size],dtype=ctx.dtype,requires_grad=True,) + class GradDtypeGuard(PyLayer): """Guard the grad's dtype if different from input's dtype.""" @@ -442,67 +452,73 @@ def fusion_moe_forward( if self.using_sonic_moe: T = dispatched_hidden_states.shape[0] - K = self.num_experts_per_tok - stream_id = paddle.device.cuda.current_stream().cuda_stream - topk_scores = filter_scores( - dispatched_probs, - dispatched_indices, - ) - expert_frequency, expert_frequency_offset = count_cumsum( - dispatched_indices, self.num_experts_per_device, do_cumsum=True - ) - activation_type = ActivationType("swiglu") - - ( - expert_frequency_offset, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - ) = fused_expert_parallel_TC_topk_router_metadata( - dispatched_indices, - expert_frequency_offset, - K, - ) + # no tokens need to be process + if T == 0: + hidden_states = FakeMLP.apply(dispatched_hidden_states) + else: + K = self.num_experts_per_tok + stream_id = paddle.device.cuda.current_stream().cuda_stream + topk_scores = filter_scores( + dispatched_probs, + dispatched_indices, + ) + expert_frequency, expert_frequency_offset = count_cumsum( + dispatched_indices, + self.num_experts_per_device, + do_cumsum=True, + ) + activation_type = ActivationType("swiglu") + + ( + expert_frequency_offset, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + ) = fused_expert_parallel_TC_topk_router_metadata( + dispatched_indices, + expert_frequency_offset, + K, + ) - TK = s_scatter_idx.shape[0] - is_varlen_K = True - w1 = self.grouped_gemm_experts.weight1 - y1, z = _UpProjection.apply( - dispatched_hidden_states, - w1.permute(1, 2, 0), - None, - expert_frequency_offset, - TK, - K, - stream_id, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - is_varlen_K, - activation_type, - is_inference_mode_enabled=False, - ) + TK = s_scatter_idx.shape[0] + is_varlen_K = True + w1 = self.grouped_gemm_experts.weight1 + y1, z = _UpProjection.apply( + dispatched_hidden_states, + w1.permute(1, 2, 0), + None, + expert_frequency_offset, + TK, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + is_varlen_K, + activation_type, + is_inference_mode_enabled=False, + ) - w2 = self.grouped_gemm_experts.weight2 - hidden_states = _DownProjection.apply( - y1, - z, - w2.permute(1, 2, 0), - None, - topk_scores, - expert_frequency_offset, - T, - K, - stream_id, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - is_varlen_K, - activation_type, - ) + w2 = self.grouped_gemm_experts.weight2 + hidden_states = _DownProjection.apply( + y1, + z, + w2.permute(1, 2, 0), + None, + topk_scores, + expert_frequency_offset, + T, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + is_varlen_K, + activation_type, + ) else: hidden_states = FusionMoePyLayer.apply( dispatched_hidden_states,