From 2983356007ad5e66a58170f30e96c8416ab79090 Mon Sep 17 00:00:00 2001 From: Lcysabcu <19556234170@163.com> Date: Wed, 1 Apr 2026 13:45:40 +0800 Subject: [PATCH 1/2] fix moe lora gemm --- src/paddlefleet/models/gpt/gpt_embedding.py | 3 + src/paddlefleet/transformer/moe/fp8_utils.py | 94 ++++++++++++++++++-- 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/src/paddlefleet/models/gpt/gpt_embedding.py b/src/paddlefleet/models/gpt/gpt_embedding.py index 11f17427c..01c0575e2 100644 --- a/src/paddlefleet/models/gpt/gpt_embedding.py +++ b/src/paddlefleet/models/gpt/gpt_embedding.py @@ -344,6 +344,9 @@ def forward( [1, 0, 2, 3] ).contiguous() + if paddle.core._has_grad(): + decoder_input.stop_gradient = False #Prevent errors in recompute_pylayer during LoRA training caused by base_weight lacking gradients. + preproc_output = { "hidden_states": decoder_input, "attention_mask": attention_mask, diff --git a/src/paddlefleet/transformer/moe/fp8_utils.py b/src/paddlefleet/transformer/moe/fp8_utils.py index 98fd35503..8398d06be 100644 --- a/src/paddlefleet/transformer/moe/fp8_utils.py +++ b/src/paddlefleet/transformer/moe/fp8_utils.py @@ -1176,6 +1176,49 @@ def backward(self, out_grad, unzipped_probs, a2a_async_fn=None): probs_grad = paddle.concat(probs_grad, axis=0) return out_grad, probs_grad + def _lora_weight_grad(self, dw, lora_A, lora_B, scaling, grad_attr="main_grad"): + """ + Given dw (gradient w.r.t. effective weight = w + lora_A @ lora_B * scaling), + compute and accumulate gradients for lora_A and lora_B. + dw shape: [E, in_features, out_features] + lora_A: [E, in_features, r] + lora_B: [E, r, out_features] + d_lora_B = lora_A.transpose(1,2) @ dw * scaling -> [E, r, out_features] + d_lora_A = dw @ lora_B.transpose(1,2) * scaling -> [E, in_features, r] + """ + dw_f32 = dw.cast("float32") + # d_lora_B: [E, r, out] = [E, r, in] @ [E, in, out] + d_lora_B = paddle.bmm(lora_A.cast("float32").transpose([0, 2, 1]), dw_f32) * scaling + # d_lora_A: [E, in, r] = [E, in, out] @ [E, out, r] + d_lora_A = paddle.bmm(dw_f32, lora_B.cast("float32").transpose([0, 2, 1])) * scaling + + if not hasattr(self, "_lora_grad_log_count"): + self._lora_grad_log_count = 0 + if self._lora_grad_log_count < 3: + self._lora_grad_log_count += 1 + import logging as _logging + _log = _logging.getLogger(__name__) + _log.info( + f"[LORA GRAD EP] step={self._lora_grad_log_count}: " + f"dw norm={float(dw_f32.norm()):.6f} " + f"d_lora_A norm={float(d_lora_A.norm()):.6f} amax={float(d_lora_A.abs().max()):.6f} " + f"d_lora_B norm={float(d_lora_B.norm()):.6f} amax={float(d_lora_B.abs().max()):.6f}" + ) + + def _accumulate(param, dgrad): + dgrad = dgrad.cast(param.dtype) + if hasattr(param, "main_grad"): + if param.main_grad is None: + param.main_grad = paddle.zeros(param.shape, dtype=paddle.float32) + param.main_grad.add_(dgrad.cast(paddle.float32)) + else: + if param.grad is None: + param.grad = paddle.zeros(param.shape, dtype=paddle.float32) + param.grad.add_(dgrad.cast(paddle.float32)) + + _accumulate(lora_A, d_lora_A) + _accumulate(lora_B, d_lora_B) + def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None): """ backward_impl_bf16 @@ -1184,9 +1227,22 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None): raise NotImplementedError( "bf16 fuse node do not support a2a_async_fn currently" ) + # Detect LoRA on grouped_gemm_experts + _ge = getattr(self, "grouped_gemm_experts", None) if self.moe_grouped_gemm else None + _has_lora = ( + _ge is not None + and hasattr(_ge, "get_delta_weight") + and not getattr(_ge, "disable_lora", False) + and not getattr(_ge, "merged", False) + ) + if self.moe_grouped_gemm: - expert_w2 = self.grouped_gemm_experts.weight2 - expert_w1 = self.grouped_gemm_experts.weight1 + if _has_lora: + expert_w1 = _ge.weight1 + _ge.get_delta_weight(_ge.weight1_lora_A, _ge.weight1_lora_B) + expert_w2 = _ge.weight2 + _ge.get_delta_weight(_ge.weight2_lora_A, _ge.weight2_lora_B) + else: + expert_w1 = self.grouped_gemm_experts.weight1 + expert_w2 = self.grouped_gemm_experts.weight2 else: expert_w2 = [ x.down_proj.weight for x in self.experts if x is not None @@ -1207,12 +1263,36 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None): del o1 self.o1 = None - # dw1 - self.bf16_weight_grad(do1, self.input, expert_w1) - self.input = None + # dw1 / lora grads for w1 + if _has_lora and self.moe_grouped_gemm: + # compute dw_eff into a temporary tensor instead of accumulating to frozen weight + if self.input is not None: + _input = self.input + elif self.dequant_input and self.input_fp8 is not None: + _input = paddle.incubate.nn.functional.fused_act_dequant( + self.input_fp8, self.input_scale + ) + else: + _input = None + if _input is not None and _input.shape[0] > 0: + dw1 = paddle.incubate.nn.functional.batched_gemm( + _input, do1, self.tokens_per_expert, trans_lhs=True + ) + self._lora_weight_grad(dw1, _ge.weight1_lora_A, _ge.weight1_lora_B, _ge.scaling) + self.input = None + else: + self.bf16_weight_grad(do1, self.input, expert_w1) + self.input = None - # dw2 - self.bf16_weight_grad(out_grad, o2_s, expert_w2) + # dw2 / lora grads for w2 + if _has_lora and self.moe_grouped_gemm: + if o2_s is not None and o2_s.shape[0] > 0: + dw2 = paddle.incubate.nn.functional.batched_gemm( + o2_s, out_grad, self.tokens_per_expert, trans_lhs=True + ) + self._lora_weight_grad(dw2, _ge.weight2_lora_A, _ge.weight2_lora_B, _ge.scaling) + else: + self.bf16_weight_grad(out_grad, o2_s, expert_w2) # dx dx = self.bwd_gate_up_input_bf16(do1, expert_w1) From aa7e913fff18c6fe3745cfe7279b84e6fd1930e6 Mon Sep 17 00:00:00 2001 From: Lcysabcu <19556234170@163.com> Date: Wed, 1 Apr 2026 17:32:07 +0800 Subject: [PATCH 2/2] fix codestyle --- src/paddlefleet/models/gpt/gpt_embedding.py | 2 +- src/paddlefleet/transformer/moe/fp8_utils.py | 41 +++++++++++++++----- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/paddlefleet/models/gpt/gpt_embedding.py b/src/paddlefleet/models/gpt/gpt_embedding.py index 01c0575e2..bc0291d34 100644 --- a/src/paddlefleet/models/gpt/gpt_embedding.py +++ b/src/paddlefleet/models/gpt/gpt_embedding.py @@ -345,7 +345,7 @@ def forward( ).contiguous() if paddle.core._has_grad(): - decoder_input.stop_gradient = False #Prevent errors in recompute_pylayer during LoRA training caused by base_weight lacking gradients. + decoder_input.stop_gradient = False # Prevent errors in recompute_pylayer during LoRA training caused by base_weight lacking gradients. preproc_output = { "hidden_states": decoder_input, diff --git a/src/paddlefleet/transformer/moe/fp8_utils.py b/src/paddlefleet/transformer/moe/fp8_utils.py index 8398d06be..d324883ac 100644 --- a/src/paddlefleet/transformer/moe/fp8_utils.py +++ b/src/paddlefleet/transformer/moe/fp8_utils.py @@ -1176,7 +1176,9 @@ def backward(self, out_grad, unzipped_probs, a2a_async_fn=None): probs_grad = paddle.concat(probs_grad, axis=0) return out_grad, probs_grad - def _lora_weight_grad(self, dw, lora_A, lora_B, scaling, grad_attr="main_grad"): + def _lora_weight_grad( + self, dw, lora_A, lora_B, scaling, grad_attr="main_grad" + ): """ Given dw (gradient w.r.t. effective weight = w + lora_A @ lora_B * scaling), compute and accumulate gradients for lora_A and lora_B. @@ -1188,15 +1190,22 @@ def _lora_weight_grad(self, dw, lora_A, lora_B, scaling, grad_attr="main_grad"): """ dw_f32 = dw.cast("float32") # d_lora_B: [E, r, out] = [E, r, in] @ [E, in, out] - d_lora_B = paddle.bmm(lora_A.cast("float32").transpose([0, 2, 1]), dw_f32) * scaling + d_lora_B = ( + paddle.bmm(lora_A.cast("float32").transpose([0, 2, 1]), dw_f32) + * scaling + ) # d_lora_A: [E, in, r] = [E, in, out] @ [E, out, r] - d_lora_A = paddle.bmm(dw_f32, lora_B.cast("float32").transpose([0, 2, 1])) * scaling + d_lora_A = ( + paddle.bmm(dw_f32, lora_B.cast("float32").transpose([0, 2, 1])) + * scaling + ) if not hasattr(self, "_lora_grad_log_count"): self._lora_grad_log_count = 0 if self._lora_grad_log_count < 3: self._lora_grad_log_count += 1 import logging as _logging + _log = _logging.getLogger(__name__) _log.info( f"[LORA GRAD EP] step={self._lora_grad_log_count}: " @@ -1209,7 +1218,9 @@ def _accumulate(param, dgrad): dgrad = dgrad.cast(param.dtype) if hasattr(param, "main_grad"): if param.main_grad is None: - param.main_grad = paddle.zeros(param.shape, dtype=paddle.float32) + param.main_grad = paddle.zeros( + param.shape, dtype=paddle.float32 + ) param.main_grad.add_(dgrad.cast(paddle.float32)) else: if param.grad is None: @@ -1228,7 +1239,11 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None): "bf16 fuse node do not support a2a_async_fn currently" ) # Detect LoRA on grouped_gemm_experts - _ge = getattr(self, "grouped_gemm_experts", None) if self.moe_grouped_gemm else None + _ge = ( + getattr(self, "grouped_gemm_experts", None) + if self.moe_grouped_gemm + else None + ) _has_lora = ( _ge is not None and hasattr(_ge, "get_delta_weight") @@ -1238,8 +1253,12 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None): if self.moe_grouped_gemm: if _has_lora: - expert_w1 = _ge.weight1 + _ge.get_delta_weight(_ge.weight1_lora_A, _ge.weight1_lora_B) - expert_w2 = _ge.weight2 + _ge.get_delta_weight(_ge.weight2_lora_A, _ge.weight2_lora_B) + expert_w1 = _ge.weight1 + _ge.get_delta_weight( + _ge.weight1_lora_A, _ge.weight1_lora_B + ) + expert_w2 = _ge.weight2 + _ge.get_delta_weight( + _ge.weight2_lora_A, _ge.weight2_lora_B + ) else: expert_w1 = self.grouped_gemm_experts.weight1 expert_w2 = self.grouped_gemm_experts.weight2 @@ -1278,7 +1297,9 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None): dw1 = paddle.incubate.nn.functional.batched_gemm( _input, do1, self.tokens_per_expert, trans_lhs=True ) - self._lora_weight_grad(dw1, _ge.weight1_lora_A, _ge.weight1_lora_B, _ge.scaling) + self._lora_weight_grad( + dw1, _ge.weight1_lora_A, _ge.weight1_lora_B, _ge.scaling + ) self.input = None else: self.bf16_weight_grad(do1, self.input, expert_w1) @@ -1290,7 +1311,9 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None): dw2 = paddle.incubate.nn.functional.batched_gemm( o2_s, out_grad, self.tokens_per_expert, trans_lhs=True ) - self._lora_weight_grad(dw2, _ge.weight2_lora_A, _ge.weight2_lora_B, _ge.scaling) + self._lora_weight_grad( + dw2, _ge.weight2_lora_A, _ge.weight2_lora_B, _ge.scaling + ) else: self.bf16_weight_grad(out_grad, o2_s, expert_w2)