Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 75 additions & 59 deletions src/paddlefleet/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@
unpermute,
)

class FakeMLP(PyLayer):
@staticmethod
def forward(ctx, hidden_states):
ctx.size = hidden_states.shape[-1]
ctx.dtype = hidden_states.dtype
return paddle.empty([0, ctx.size],dtype=ctx.dtype,requires_grad=True,)
@staticmethod
def backward(ctx, grad):
return paddle.empty([0, ctx.size],dtype=ctx.dtype,requires_grad=True,)


class GradDtypeGuard(PyLayer):
"""Guard the grad's dtype if different from input's dtype."""
Expand Down Expand Up @@ -442,67 +452,73 @@ def fusion_moe_forward(

if self.using_sonic_moe:
T = dispatched_hidden_states.shape[0]
K = self.num_experts_per_tok
stream_id = paddle.device.cuda.current_stream().cuda_stream
topk_scores = filter_scores(
dispatched_probs,
dispatched_indices,
)
expert_frequency, expert_frequency_offset = count_cumsum(
dispatched_indices, self.num_experts_per_device, do_cumsum=True
)
activation_type = ActivationType("swiglu")

(
expert_frequency_offset,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
) = fused_expert_parallel_TC_topk_router_metadata(
dispatched_indices,
expert_frequency_offset,
K,
)
# no tokens need to be process
if T == 0:
hidden_states = FakeMLP.apply(dispatched_hidden_states)
else:
K = self.num_experts_per_tok
stream_id = paddle.device.cuda.current_stream().cuda_stream
topk_scores = filter_scores(
dispatched_probs,
dispatched_indices,
)
expert_frequency, expert_frequency_offset = count_cumsum(
dispatched_indices,
self.num_experts_per_device,
do_cumsum=True,
)
activation_type = ActivationType("swiglu")

(
expert_frequency_offset,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
) = fused_expert_parallel_TC_topk_router_metadata(
dispatched_indices,
expert_frequency_offset,
K,
)

TK = s_scatter_idx.shape[0]
is_varlen_K = True
w1 = self.grouped_gemm_experts.weight1
y1, z = _UpProjection.apply(
dispatched_hidden_states,
w1.permute(1, 2, 0),
None,
expert_frequency_offset,
TK,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
is_varlen_K,
activation_type,
is_inference_mode_enabled=False,
)
TK = s_scatter_idx.shape[0]
is_varlen_K = True
w1 = self.grouped_gemm_experts.weight1
y1, z = _UpProjection.apply(
dispatched_hidden_states,
w1.permute(1, 2, 0),
None,
expert_frequency_offset,
TK,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
is_varlen_K,
activation_type,
is_inference_mode_enabled=False,
)

w2 = self.grouped_gemm_experts.weight2
hidden_states = _DownProjection.apply(
y1,
z,
w2.permute(1, 2, 0),
None,
topk_scores,
expert_frequency_offset,
T,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
is_varlen_K,
activation_type,
)
w2 = self.grouped_gemm_experts.weight2
hidden_states = _DownProjection.apply(
y1,
z,
w2.permute(1, 2, 0),
None,
topk_scores,
expert_frequency_offset,
T,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
is_varlen_K,
activation_type,
)
else:
hidden_states = FusionMoePyLayer.apply(
dispatched_hidden_states,
Expand Down
Loading