From 607a99f87a2e493f46115b7219830e298f6e03e3 Mon Sep 17 00:00:00 2001 From: sevenan2 <65183308+sevenan2@users.noreply.github.com> Date: Wed, 18 Mar 2026 21:39:46 +0800 Subject: [PATCH] fix auto-parallel (#587) --- src/paddlefleet/models/common/embeddings/rope_utils.py | 10 ++++++++-- src/paddlefleet/tensor_parallel/layers.py | 9 +++++++-- src/paddlefleet/transformer/attention.py | 2 +- src/paddlefleet/utils.py | 4 ++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/paddlefleet/models/common/embeddings/rope_utils.py b/src/paddlefleet/models/common/embeddings/rope_utils.py index 1dce7d39e..23046db94 100644 --- a/src/paddlefleet/models/common/embeddings/rope_utils.py +++ b/src/paddlefleet/models/common/embeddings/rope_utils.py @@ -117,7 +117,6 @@ def _apply_rotary_pos_emb_bshd_fp32( with paddle.amp.auto_cast(False): orig_t_dtype = t.dtype t = t.astype(dtype="float32") - t_pass = t_pass.astype(dtype="float32") rotate_t = _rotate_half(t, rotary_interleaved) cos_ = (paddle.cos(freqs) * mscale).to(t.dtype) sin_ = (paddle.sin(freqs) * mscale).to(t.dtype) @@ -131,7 +130,14 @@ def _apply_rotary_pos_emb_bshd_fp32( rotate_t.reshape_(t.shape) t = (t * cos_) + (rotate_t * sin_) - return paddle.cat((t, t_pass), axis=-1).astype(orig_t_dtype) + skip_t_pass = t_pass.shape[-1] == 0 + if not skip_t_pass: + t_pass = t_pass.astype(dtype="float32") + res = paddle.cat((t, t_pass), axis=-1).astype(orig_t_dtype) + else: + res = t.astype(orig_t_dtype) + + return res def _apply_rotary_pos_emb_bshd( diff --git a/src/paddlefleet/tensor_parallel/layers.py b/src/paddlefleet/tensor_parallel/layers.py index f932ea111..99ad3e5c6 100644 --- a/src/paddlefleet/tensor_parallel/layers.py +++ b/src/paddlefleet/tensor_parallel/layers.py @@ -1066,7 +1066,12 @@ def forward( ) else: output = output_parallel - output_bias = self.bias if self.skip_bias_add else None + output_bias = ( + self.bias.clone() + if (self.skip_bias_add and self.bias is not None) + else None + ) + return output, output_bias def sharded_state_dict( @@ -1320,7 +1325,7 @@ def forward(self, input_): output_bias = None else: output = output_ - output_bias = self.bias + output_bias = self.bias.clone() if self.bias is not None else None return output, output_bias def sharded_state_dict( diff --git a/src/paddlefleet/transformer/attention.py b/src/paddlefleet/transformer/attention.py index 7d6e526da..aa5ffccc1 100644 --- a/src/paddlefleet/transformer/attention.py +++ b/src/paddlefleet/transformer/attention.py @@ -495,7 +495,7 @@ def get_query_key_value_tensors( * self.hidden_size_per_attention_head ), ) - mixed_qkv = mixed_qkv.view(*new_tensor_shape) + mixed_qkv = mixed_qkv.reshape(*new_tensor_shape) split_arg_list = [ ( diff --git a/src/paddlefleet/utils.py b/src/paddlefleet/utils.py index 617086270..a41def207 100644 --- a/src/paddlefleet/utils.py +++ b/src/paddlefleet/utils.py @@ -231,10 +231,10 @@ def prepare_input_tensors_for_wgrad_compute(grad_output, all_gathered_input): all_gathered_input = all_gathered_input.contiguous() # Convert the tensor shapes to 2D for execution compatibility if grad_output.dim() == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( [grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]] ) - all_gathered_input = all_gathered_input.view( + all_gathered_input = all_gathered_input.reshape( [ all_gathered_input.shape[0] * all_gathered_input.shape[1], all_gathered_input.shape[2],