From 4a77e7a2bfa6d6609ee19bc614a58b10f3aed1b5 Mon Sep 17 00:00:00 2001 From: lizhenxing02 Date: Fri, 20 Mar 2026 14:01:05 +0800 Subject: [PATCH 1/2] auto_parallel_moe --- src/paddlefleet/fusions/fused_bias_swiglu.py | 4 +- src/paddlefleet/transformer/mlp.py | 18 ++ src/paddlefleet/transformer/moe/moe_layer.py | 218 ++++++++++++++++++- 3 files changed, 236 insertions(+), 4 deletions(-) diff --git a/src/paddlefleet/fusions/fused_bias_swiglu.py b/src/paddlefleet/fusions/fused_bias_swiglu.py index 09932e267..c37720d01 100644 --- a/src/paddlefleet/fusions/fused_bias_swiglu.py +++ b/src/paddlefleet/fusions/fused_bias_swiglu.py @@ -256,7 +256,7 @@ def bias_swiglu_impl( """ ori_shape = input.shape assert len(ori_shape) in [2, 3] - input = input.view(-1, ori_shape[-1]) + input = input.reshape(-1, ori_shape[-1]) if bias is not None: output = BiasSwiGLUFunction.apply( input, bias, fp8_input_store, cpu_offload_input @@ -267,7 +267,7 @@ def bias_swiglu_impl( return ( output if len(ori_shape) == 2 - else output.view(ori_shape[0], ori_shape[1], -1) + else output.reshape(ori_shape[0], ori_shape[1], -1) ) diff --git a/src/paddlefleet/transformer/mlp.py b/src/paddlefleet/transformer/mlp.py index 066f5daf5..a2d16c0ae 100644 --- a/src/paddlefleet/transformer/mlp.py +++ b/src/paddlefleet/transformer/mlp.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING import paddle +import paddle.distributed as dist import paddle.nn.functional as F # (TODO): need adapt to flex_checkpoint @@ -170,6 +171,23 @@ def __init__( tp_group=tp_group, ) + def redistribute_expert(self, mesh, placements): + self.up_gate_proj.weight = dist.shard_tensor( + self.up_gate_proj.weight, mesh, placements + ) + if self.up_gate_proj.bias is not None: + self.up_gate_proj.bias = dist.shard_tensor( + self.up_gate_proj.bias, mesh, placements + ) + + self.down_proj.weight = dist.shard_tensor( + self.down_proj.weight, mesh, placements + ) + if self.down_proj.bias is not None: + self.down_proj.bias = dist.shard_tensor( + self.down_proj.bias, mesh, placements + ) + def forward(self, hidden_states, per_token_scale=None): """Perform the forward pass through the MLP block.""" # [s, b, 4 * h/p] diff --git a/src/paddlefleet/transformer/moe/moe_layer.py b/src/paddlefleet/transformer/moe/moe_layer.py index 9a1f97214..382036fe7 100644 --- a/src/paddlefleet/transformer/moe/moe_layer.py +++ b/src/paddlefleet/transformer/moe/moe_layer.py @@ -14,14 +14,17 @@ from __future__ import annotations +import copy import logging from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING import paddle -from paddle import nn +import paddle.distributed as dist +from paddle import einsum, nn from paddle.autograd import PyLayer +from paddle.distributed.auto_parallel.local_layer import LocalLayer from paddle.distributed.fleet.utils.sequence_parallel_utils import ( GatherOp, ScatterOp, @@ -91,6 +94,64 @@ def backward(ctx, grad): return grad +class LocalGateAndDispatch(LocalLayer): + def __init__(self, gate, mesh): + out_dist_attrs = [ + (mesh, [dist.Shard(1)]), # dispatched_input [e,c,h] + (mesh, [dist.Shard(0)]), # combine_weights [s,e,c] + ] + grad_dist_attrs = [ + None, # reshaped_input + None, # gates_masked + None, # mask + ] + super().__init__(out_dist_attrs, grad_dist_attrs) + self.gate = gate + + def forward(self, reshaped_input, gates_masked, mask): + # norm_topk_prob: normalize selected expert weights + norm_topk_prob = getattr(self.gate, "norm_topk_prob", False) + if norm_topk_prob: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip( + gates_s, min=paddle.finfo(gates_masked.dtype).eps + ) + gates_masked = gates_masked / denom_s + + # drop_tokens: drop tokens exceeding capacity + drop_tokens = getattr(self.gate, "drop_tokens", False) + capacity = getattr(self.gate, "capacity", None) + + if drop_tokens and capacity is not None: + locations = paddle.cumsum(mask, axis=0) - 1 # [seq, experts] + within_capacity = (locations < capacity).astype(mask.dtype) + mask = mask * within_capacity + + dispatched_input = paddle.matmul( + mask.T, # [num_experts, batch*seq] + reshaped_input, # [batch*seq, hidden] + ) + combine_weights = mask * gates_masked + return dispatched_input, combine_weights + + +class LocalCombine(LocalLayer): + def __init__(self, mesh): + out_dist_attrs = [(mesh, [dist.Shard(0)])] + grad_dist_attrs = [None, None] + super().__init__(out_dist_attrs, grad_dist_attrs) + + def forward( + self, combine_weights, expert_output, dtype="float32", out_shape=None + ): + combined_output = einsum( + "se,ecm->sm", combine_weights.cast(dtype), expert_output + ) + if out_shape is not None: + combined_output = combined_output.reshape(out_shape) + return combined_output + + @dataclass class MoESublayers: """MoE Layer Sublayers spec""" @@ -105,6 +166,55 @@ def __init__( sublayers: MoESublayers | None = None, pg_collection: ProcessGroupCollection | None = None, ): + if config.enable_auto_parallel: + super().__init__() + self.config = config + self.mesh = dist.fleet.auto.get_mesh() + self.expert_class = StandardMLPExpert + self.num_experts = config.n_routed_experts + self.routed_expert_config = deepcopy(config) + self.moe_intermediate_size = config.moe_intermediate_size + self.sublayers = sublayers + self.moe_use_fusion_node = False + self.moe_group = ( + config.moe_group if hasattr(config, "moe_group") else "dp" + ) + self.moe_mesh_dim = 0 if self.moe_group == "dp" else 1 + self.expert_parallel_degree = self.mesh.get_dim_size(self.moe_group) + self.num_experts_per_device = ( + self.num_experts // self.expert_parallel_degree + ) + + expert_args = {} + expert_args["config"] = self.routed_expert_config + expert_args["moe_intermediate_size"] = self.moe_intermediate_size + expert_args["is_expert"] = True + expert_args["mlp_spec"] = self.sublayers.mlp_spec + + self.experts = nn.LayerList([]) + for i in range(self.num_experts): + self.experts.append(self.expert_class(**expert_args)) + + self._redistribute_experts(self.experts) + + self.gate = TopKRouter(config=config, pg_collection=pg_collection) + self.gate.group = None + self.is_dummy_moe = True + + for p in self.gate.parameters(): + p.is_gate = True + for k in self.experts: + if k is not None: + for p in k.parameters(): + p.expert = not self.is_dummy_moe + p.no_sync = not self.is_dummy_moe + + self.local_gate_and_dispatch = LocalGateAndDispatch( + self.gate, self.mesh + ) + self.local_combine = LocalCombine(self.mesh) + return + super().__init__() self.config = config self.sublayers = sublayers @@ -304,6 +414,110 @@ def __init__( if self.is_mp_moe or self.is_ep_moe: p.is_distributed = True + def _redistribute_experts(self, experts): + ep_sub_meshes = dist.auto_parallel.api.split_mesh( + self.mesh, self.moe_mesh_dim + ) + for i, expert in enumerate(experts): + ep_group_id = i // self.num_experts_per_device + experts[i].redistribute_expert( + ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()] + ) + + def expert_forward_auto(self, dispatched_input): + sub_mesh_tensors = dist.auto_parallel.api.moe_sub_mesh_tensors( + dispatched_input, + self.mesh, + self.moe_mesh_dim, + dispatched_input.placements, + ) + chunks = paddle.utils.flatten([t.unbind(1) for t in sub_mesh_tensors]) + + ep_group_outputs = [] + expert_outputs = [] + for i, (chunk, expert) in enumerate(zip(chunks, self.experts)): + chunk = chunk.contiguous() + output, output_bias = expert(chunk) + expert_outputs += [output] + if (i + 1) % self.num_experts_per_device == 0: + ep_group_outputs += [paddle.stack(expert_outputs, axis=1)] + expert_outputs = [] + + expert_output = dist.auto_parallel.api.moe_global_mesh_tensor( + ep_group_outputs, + self.mesh, + dispatched_input.placements, + self.moe_mesh_dim, + ) + return expert_output + + def forward_auto( + self, + hidden_states: paddle.Tensor, + ): + batch_size, seq_len, d_model = hidden_states.shape + ( + capacity, + topk_weights, + topk_indices, + gates_masked, + mask, + priorities, + aux_loss, + z_loss, + ) = self.gate(hidden_states) + reshaped_input = hidden_states.reshape([-1, d_model]) + dispatched_input, combine_weights = self.local_gate_and_dispatch( + reshaped_input, gates_masked, mask + ) + ori_dispatched_placements = copy.deepcopy(dispatched_input.placements) + ep_placements = copy.deepcopy(dispatched_input.placements) + ep_placements[self.moe_mesh_dim] = dist.Shard(0) + dispatched_input = dist.reshard( + dispatched_input, self.mesh, ep_placements + ) + + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape( + [ + self.expert_parallel_degree, + self.num_experts_per_device, + -1, + d_model, + ] + ) + + expert_output = self.expert_forward_auto(dispatched_input) + + # Re-shape before combine - simplified dispatch produces [e, m] directly + expert_output = expert_output.reshape( + [ + self.expert_parallel_degree * self.num_experts_per_device, + -1, + d_model, + ] + ) + expert_output = dist.reshard( + expert_output, self.mesh, ori_dispatched_placements + ) + combined_output = self.local_combine( + combine_weights, + expert_output, + dtype=hidden_states[0].dtype, + out_shape=hidden_states._local_shape, + ) + combined_output = combined_output._local_value() + return combined_output, None + + def forward( + self, + hidden_states: paddle.Tensor, + ): + if self.config.enable_auto_parallel: + return self.forward_auto(hidden_states) + else: + return self.forward_impl(hidden_states) + def _init_expert_parallel(self): def _parse_moe_expert_parallel( num_experts: int, expert_model_parallel_size: int @@ -641,7 +855,7 @@ def aux_loss_compute(self, args): output = ScatterOp.apply(output) return output - def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + def forward_impl(self, hidden_states: paddle.Tensor) -> paddle.Tensor: """ Args: hidden_states: Shape: [batch_size, seq_len, hidden_size] From 0fdabdd37d570ed4aa748d70956d96c27ebb9c08 Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Mon, 23 Mar 2026 14:36:37 +0800 Subject: [PATCH 2/2] fix enable_auto_parallel --- src/paddlefleet/transformer/moe/moe_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/paddlefleet/transformer/moe/moe_layer.py b/src/paddlefleet/transformer/moe/moe_layer.py index 382036fe7..5e264df94 100644 --- a/src/paddlefleet/transformer/moe/moe_layer.py +++ b/src/paddlefleet/transformer/moe/moe_layer.py @@ -166,7 +166,7 @@ def __init__( sublayers: MoESublayers | None = None, pg_collection: ProcessGroupCollection | None = None, ): - if config.enable_auto_parallel: + if getattr(config, "enable_auto_parallel", False): super().__init__() self.config = config self.mesh = dist.fleet.auto.get_mesh() @@ -513,7 +513,7 @@ def forward( self, hidden_states: paddle.Tensor, ): - if self.config.enable_auto_parallel: + if getattr(self.config, "enable_auto_parallel", False): return self.forward_auto(hidden_states) else: return self.forward_impl(hidden_states)