Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/paddlefleet/models/common/embeddings/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def _apply_rotary_pos_emb_bshd_fp32(
with paddle.amp.auto_cast(False):
orig_t_dtype = t.dtype
t = t.astype(dtype="float32")
t_pass = t_pass.astype(dtype="float32")
rotate_t = _rotate_half(t, rotary_interleaved)
cos_ = (paddle.cos(freqs) * mscale).to(t.dtype)
sin_ = (paddle.sin(freqs) * mscale).to(t.dtype)
Expand All @@ -131,7 +130,14 @@ def _apply_rotary_pos_emb_bshd_fp32(
rotate_t.reshape_(t.shape)

t = (t * cos_) + (rotate_t * sin_)
return paddle.cat((t, t_pass), axis=-1).astype(orig_t_dtype)
skip_t_pass = t_pass.shape[-1] == 0
if not skip_t_pass:
t_pass = t_pass.astype(dtype="float32")
res = paddle.cat((t, t_pass), axis=-1).astype(orig_t_dtype)
else:
res = t.astype(orig_t_dtype)

return res


def _apply_rotary_pos_emb_bshd(
Expand Down
9 changes: 7 additions & 2 deletions src/paddlefleet/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,12 @@ def forward(
)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
output_bias = (
self.bias.clone()
if (self.skip_bias_add and self.bias is not None)
else None
)

return output, output_bias

def sharded_state_dict(
Expand Down Expand Up @@ -1320,7 +1325,7 @@ def forward(self, input_):
output_bias = None
else:
output = output_
output_bias = self.bias
output_bias = self.bias.clone() if self.bias is not None else None
return output, output_bias

def sharded_state_dict(
Expand Down
2 changes: 1 addition & 1 deletion src/paddlefleet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def get_query_key_value_tensors(
* self.hidden_size_per_attention_head
),
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)
mixed_qkv = mixed_qkv.reshape(*new_tensor_shape)

split_arg_list = [
(
Expand Down
4 changes: 2 additions & 2 deletions src/paddlefleet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ def prepare_input_tensors_for_wgrad_compute(grad_output, all_gathered_input):
all_gathered_input = all_gathered_input.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if grad_output.dim() == 3:
grad_output = grad_output.view(
grad_output = grad_output.reshape(
[grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]]
)
all_gathered_input = all_gathered_input.view(
all_gathered_input = all_gathered_input.reshape(
[
all_gathered_input.shape[0] * all_gathered_input.shape[1],
all_gathered_input.shape[2],
Expand Down
Loading