Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/paddlefleet/fusions/fused_bias_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def bias_swiglu_impl(
"""
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
input = input.reshape(-1, ori_shape[-1])
if bias is not None:
output = BiasSwiGLUFunction.apply(
input, bias, fp8_input_store, cpu_offload_input
Expand All @@ -267,7 +267,7 @@ def bias_swiglu_impl(
return (
output
if len(ori_shape) == 2
else output.view(ori_shape[0], ori_shape[1], -1)
else output.reshape(ori_shape[0], ori_shape[1], -1)
)


Expand Down
18 changes: 18 additions & 0 deletions src/paddlefleet/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING

import paddle
import paddle.distributed as dist
import paddle.nn.functional as F

# (TODO): need adapt to flex_checkpoint
Expand Down Expand Up @@ -170,6 +171,23 @@ def __init__(
tp_group=tp_group,
)

def redistribute_expert(self, mesh, placements):
self.up_gate_proj.weight = dist.shard_tensor(
self.up_gate_proj.weight, mesh, placements
)
if self.up_gate_proj.bias is not None:
self.up_gate_proj.bias = dist.shard_tensor(
self.up_gate_proj.bias, mesh, placements
)

self.down_proj.weight = dist.shard_tensor(
self.down_proj.weight, mesh, placements
)
if self.down_proj.bias is not None:
self.down_proj.bias = dist.shard_tensor(
self.down_proj.bias, mesh, placements
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否可能涉及还会其他层也做shard_tensor? 比如 Linear 层

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于MoE中为MLP,这里的shard_tensor是完备的,暂未支持MoE中为其他layer的情况。

)

def forward(self, hidden_states, per_token_scale=None):
"""Perform the forward pass through the MLP block."""
# [s, b, 4 * h/p]
Expand Down
218 changes: 216 additions & 2 deletions src/paddlefleet/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

from __future__ import annotations

import copy
import logging
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING

import paddle
from paddle import nn
import paddle.distributed as dist
from paddle import einsum, nn
from paddle.autograd import PyLayer
from paddle.distributed.auto_parallel.local_layer import LocalLayer
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
GatherOp,
ScatterOp,
Expand Down Expand Up @@ -91,6 +94,64 @@ def backward(ctx, grad):
return grad


class LocalGateAndDispatch(LocalLayer):
def __init__(self, gate, mesh):
out_dist_attrs = [
(mesh, [dist.Shard(1)]), # dispatched_input [e,c,h]
(mesh, [dist.Shard(0)]), # combine_weights [s,e,c]
]
grad_dist_attrs = [
None, # reshaped_input
None, # gates_masked
None, # mask
]
super().__init__(out_dist_attrs, grad_dist_attrs)
self.gate = gate

def forward(self, reshaped_input, gates_masked, mask):
# norm_topk_prob: normalize selected expert weights
norm_topk_prob = getattr(self.gate, "norm_topk_prob", False)
if norm_topk_prob:
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
denom_s = paddle.clip(
gates_s, min=paddle.finfo(gates_masked.dtype).eps
)
gates_masked = gates_masked / denom_s

# drop_tokens: drop tokens exceeding capacity
drop_tokens = getattr(self.gate, "drop_tokens", False)
capacity = getattr(self.gate, "capacity", None)

if drop_tokens and capacity is not None:
locations = paddle.cumsum(mask, axis=0) - 1 # [seq, experts]
within_capacity = (locations < capacity).astype(mask.dtype)
mask = mask * within_capacity

dispatched_input = paddle.matmul(
mask.T, # [num_experts, batch*seq]
reshaped_input, # [batch*seq, hidden]
)
combine_weights = mask * gates_masked
return dispatched_input, combine_weights


class LocalCombine(LocalLayer):
def __init__(self, mesh):
out_dist_attrs = [(mesh, [dist.Shard(0)])]
grad_dist_attrs = [None, None]
super().__init__(out_dist_attrs, grad_dist_attrs)

def forward(
self, combine_weights, expert_output, dtype="float32", out_shape=None
):
combined_output = einsum(
"se,ecm->sm", combine_weights.cast(dtype), expert_output
)
if out_shape is not None:
combined_output = combined_output.reshape(out_shape)
return combined_output


@dataclass
class MoESublayers:
"""MoE Layer Sublayers spec"""
Expand All @@ -105,6 +166,55 @@ def __init__(
sublayers: MoESublayers | None = None,
pg_collection: ProcessGroupCollection | None = None,
):
if getattr(config, "enable_auto_parallel", False):
super().__init__()
self.config = config
self.mesh = dist.fleet.auto.get_mesh()
self.expert_class = StandardMLPExpert
self.num_experts = config.n_routed_experts
self.routed_expert_config = deepcopy(config)
self.moe_intermediate_size = config.moe_intermediate_size
self.sublayers = sublayers
self.moe_use_fusion_node = False
self.moe_group = (
config.moe_group if hasattr(config, "moe_group") else "dp"
)
self.moe_mesh_dim = 0 if self.moe_group == "dp" else 1
self.expert_parallel_degree = self.mesh.get_dim_size(self.moe_group)
self.num_experts_per_device = (
self.num_experts // self.expert_parallel_degree
)

expert_args = {}
expert_args["config"] = self.routed_expert_config
expert_args["moe_intermediate_size"] = self.moe_intermediate_size
expert_args["is_expert"] = True
expert_args["mlp_spec"] = self.sublayers.mlp_spec

self.experts = nn.LayerList([])
for i in range(self.num_experts):
self.experts.append(self.expert_class(**expert_args))

self._redistribute_experts(self.experts)

self.gate = TopKRouter(config=config, pg_collection=pg_collection)
self.gate.group = None
self.is_dummy_moe = True

for p in self.gate.parameters():
p.is_gate = True
for k in self.experts:
if k is not None:
for p in k.parameters():
p.expert = not self.is_dummy_moe
p.no_sync = not self.is_dummy_moe

self.local_gate_and_dispatch = LocalGateAndDispatch(
self.gate, self.mesh
)
self.local_combine = LocalCombine(self.mesh)
return

super().__init__()
self.config = config
self.sublayers = sublayers
Expand Down Expand Up @@ -304,6 +414,110 @@ def __init__(
if self.is_mp_moe or self.is_ep_moe:
p.is_distributed = True

def _redistribute_experts(self, experts):
ep_sub_meshes = dist.auto_parallel.api.split_mesh(
self.mesh, self.moe_mesh_dim
)
for i, expert in enumerate(experts):
ep_group_id = i // self.num_experts_per_device
experts[i].redistribute_expert(
ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()]
)

def expert_forward_auto(self, dispatched_input):
sub_mesh_tensors = dist.auto_parallel.api.moe_sub_mesh_tensors(
dispatched_input,
self.mesh,
self.moe_mesh_dim,
dispatched_input.placements,
)
chunks = paddle.utils.flatten([t.unbind(1) for t in sub_mesh_tensors])

ep_group_outputs = []
expert_outputs = []
for i, (chunk, expert) in enumerate(zip(chunks, self.experts)):
chunk = chunk.contiguous()
output, output_bias = expert(chunk)
expert_outputs += [output]
if (i + 1) % self.num_experts_per_device == 0:
ep_group_outputs += [paddle.stack(expert_outputs, axis=1)]
expert_outputs = []

expert_output = dist.auto_parallel.api.moe_global_mesh_tensor(
ep_group_outputs,
self.mesh,
dispatched_input.placements,
self.moe_mesh_dim,
)
return expert_output

def forward_auto(
self,
hidden_states: paddle.Tensor,
):
batch_size, seq_len, d_model = hidden_states.shape
(
capacity,
topk_weights,
topk_indices,
gates_masked,
mask,
priorities,
aux_loss,
z_loss,
) = self.gate(hidden_states)
reshaped_input = hidden_states.reshape([-1, d_model])
dispatched_input, combine_weights = self.local_gate_and_dispatch(
reshaped_input, gates_masked, mask
)
ori_dispatched_placements = copy.deepcopy(dispatched_input.placements)
ep_placements = copy.deepcopy(dispatched_input.placements)
ep_placements[self.moe_mesh_dim] = dist.Shard(0)
dispatched_input = dist.reshard(
dispatched_input, self.mesh, ep_placements
)

# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(
[
self.expert_parallel_degree,
self.num_experts_per_device,
-1,
d_model,
]
)

expert_output = self.expert_forward_auto(dispatched_input)

# Re-shape before combine - simplified dispatch produces [e, m] directly
expert_output = expert_output.reshape(
[
self.expert_parallel_degree * self.num_experts_per_device,
-1,
d_model,
]
)
expert_output = dist.reshard(
expert_output, self.mesh, ori_dispatched_placements
)
combined_output = self.local_combine(
combine_weights,
expert_output,
dtype=hidden_states[0].dtype,
out_shape=hidden_states._local_shape,
)
combined_output = combined_output._local_value()
return combined_output, None

def forward(
self,
hidden_states: paddle.Tensor,
):
if getattr(self.config, "enable_auto_parallel", False):
return self.forward_auto(hidden_states)
else:
return self.forward_impl(hidden_states)

def _init_expert_parallel(self):
def _parse_moe_expert_parallel(
num_experts: int, expert_model_parallel_size: int
Expand Down Expand Up @@ -641,7 +855,7 @@ def aux_loss_compute(self, args):
output = ScatterOp.apply(output)
return output

def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
def forward_impl(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
"""
Args:
hidden_states: Shape: [batch_size, seq_len, hidden_size]
Expand Down
Loading