From 77f0d7c32dd906b79dfd84508dd0584c64d22142 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Thu, 2 Apr 2026 15:35:28 +0800 Subject: [PATCH] fix the bug of block attention residuals in AMP --- src/paddlefleet/transformer/block_attn_res.py | 3 +++ src/paddlefleet/transformer/transformer_layer.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/src/paddlefleet/transformer/block_attn_res.py b/src/paddlefleet/transformer/block_attn_res.py index eeef9a516..e7b547192 100644 --- a/src/paddlefleet/transformer/block_attn_res.py +++ b/src/paddlefleet/transformer/block_attn_res.py @@ -132,4 +132,7 @@ def forward(self, partial_block: Tensor, blocks: list[Tensor]) -> Tensor: # Equivalent to einsum("n b s, n b s d -> b s d", weights, V) h = (weights.unsqueeze(-1) * V).sum(axis=0) + if partial_block is not None and h.dtype != partial_block.dtype: + h = h.to(partial_block.dtype) + return h diff --git a/src/paddlefleet/transformer/transformer_layer.py b/src/paddlefleet/transformer/transformer_layer.py index ca921aad8..a6a7ac259 100644 --- a/src/paddlefleet/transformer/transformer_layer.py +++ b/src/paddlefleet/transformer/transformer_layer.py @@ -568,6 +568,11 @@ def _forward_impl( ) # Accumulate attn output into partial_block + if ( + partial_block is not None + and partial_block.dtype != hidden_states.dtype + ): + partial_block = partial_block.to(hidden_states.dtype) partial_block = ( partial_block + hidden_states if partial_block is not None