Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/paddlefleet/models/gpt/gpt_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ def forward(
[1, 0, 2, 3]
).contiguous()

if paddle.core._has_grad():
decoder_input.stop_gradient = False # Prevent errors in recompute_pylayer during LoRA training caused by base_weight lacking gradients.

preproc_output = {
"hidden_states": decoder_input,
"attention_mask": attention_mask,
Expand Down
117 changes: 110 additions & 7 deletions src/paddlefleet/transformer/moe/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,60 @@ def backward(self, out_grad, unzipped_probs, a2a_async_fn=None):
probs_grad = paddle.concat(probs_grad, axis=0)
return out_grad, probs_grad

def _lora_weight_grad(
self, dw, lora_A, lora_B, scaling, grad_attr="main_grad"
):
"""
Given dw (gradient w.r.t. effective weight = w + lora_A @ lora_B * scaling),
compute and accumulate gradients for lora_A and lora_B.
dw shape: [E, in_features, out_features]
lora_A: [E, in_features, r]
lora_B: [E, r, out_features]
d_lora_B = lora_A.transpose(1,2) @ dw * scaling -> [E, r, out_features]
d_lora_A = dw @ lora_B.transpose(1,2) * scaling -> [E, in_features, r]
"""
dw_f32 = dw.cast("float32")
# d_lora_B: [E, r, out] = [E, r, in] @ [E, in, out]
d_lora_B = (
paddle.bmm(lora_A.cast("float32").transpose([0, 2, 1]), dw_f32)
* scaling
)
# d_lora_A: [E, in, r] = [E, in, out] @ [E, out, r]
d_lora_A = (
paddle.bmm(dw_f32, lora_B.cast("float32").transpose([0, 2, 1]))
* scaling
)

if not hasattr(self, "_lora_grad_log_count"):
self._lora_grad_log_count = 0
if self._lora_grad_log_count < 3:
self._lora_grad_log_count += 1
import logging as _logging

_log = _logging.getLogger(__name__)
_log.info(
f"[LORA GRAD EP] step={self._lora_grad_log_count}: "
f"dw norm={float(dw_f32.norm()):.6f} "
f"d_lora_A norm={float(d_lora_A.norm()):.6f} amax={float(d_lora_A.abs().max()):.6f} "
f"d_lora_B norm={float(d_lora_B.norm()):.6f} amax={float(d_lora_B.abs().max()):.6f}"
)

def _accumulate(param, dgrad):
dgrad = dgrad.cast(param.dtype)
if hasattr(param, "main_grad"):
if param.main_grad is None:
param.main_grad = paddle.zeros(
param.shape, dtype=paddle.float32
)
param.main_grad.add_(dgrad.cast(paddle.float32))
else:
if param.grad is None:
param.grad = paddle.zeros(param.shape, dtype=paddle.float32)
param.grad.add_(dgrad.cast(paddle.float32))

_accumulate(lora_A, d_lora_A)
_accumulate(lora_B, d_lora_B)

def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None):
"""
backward_impl_bf16
Expand All @@ -1184,9 +1238,30 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None):
raise NotImplementedError(
"bf16 fuse node do not support a2a_async_fn currently"
)
# Detect LoRA on grouped_gemm_experts
_ge = (
getattr(self, "grouped_gemm_experts", None)
if self.moe_grouped_gemm
else None
)
_has_lora = (
_ge is not None
and hasattr(_ge, "get_delta_weight")
and not getattr(_ge, "disable_lora", False)
and not getattr(_ge, "merged", False)
)

if self.moe_grouped_gemm:
expert_w2 = self.grouped_gemm_experts.weight2
expert_w1 = self.grouped_gemm_experts.weight1
if _has_lora:
expert_w1 = _ge.weight1 + _ge.get_delta_weight(
_ge.weight1_lora_A, _ge.weight1_lora_B
)
expert_w2 = _ge.weight2 + _ge.get_delta_weight(
_ge.weight2_lora_A, _ge.weight2_lora_B
)
else:
expert_w1 = self.grouped_gemm_experts.weight1
expert_w2 = self.grouped_gemm_experts.weight2
else:
expert_w2 = [
x.down_proj.weight for x in self.experts if x is not None
Expand All @@ -1207,12 +1282,40 @@ def backward_impl_bf16(self, out_grad, unzipped_probs, a2a_async_fn=None):
del o1
self.o1 = None

# dw1
self.bf16_weight_grad(do1, self.input, expert_w1)
self.input = None
# dw1 / lora grads for w1
if _has_lora and self.moe_grouped_gemm:
# compute dw_eff into a temporary tensor instead of accumulating to frozen weight
if self.input is not None:
_input = self.input
elif self.dequant_input and self.input_fp8 is not None:
_input = paddle.incubate.nn.functional.fused_act_dequant(
self.input_fp8, self.input_scale
)
else:
_input = None
if _input is not None and _input.shape[0] > 0:
dw1 = paddle.incubate.nn.functional.batched_gemm(
_input, do1, self.tokens_per_expert, trans_lhs=True
)
self._lora_weight_grad(
dw1, _ge.weight1_lora_A, _ge.weight1_lora_B, _ge.scaling
)
self.input = None
else:
self.bf16_weight_grad(do1, self.input, expert_w1)
self.input = None

# dw2
self.bf16_weight_grad(out_grad, o2_s, expert_w2)
# dw2 / lora grads for w2
if _has_lora and self.moe_grouped_gemm:
if o2_s is not None and o2_s.shape[0] > 0:
dw2 = paddle.incubate.nn.functional.batched_gemm(
o2_s, out_grad, self.tokens_per_expert, trans_lhs=True
)
self._lora_weight_grad(
dw2, _ge.weight2_lora_A, _ge.weight2_lora_B, _ge.scaling
)
else:
self.bf16_weight_grad(out_grad, o2_s, expert_w2)

# dx
dx = self.bwd_gate_up_input_bf16(do1, expert_w1)
Expand Down
Loading