Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/paddlefleet/models/common/language_loss/language_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,20 @@ def forward(self, logits: Tensor | list, labels: Tensor) -> Tensor:
)
assert len(logits) == self.config.num_nextn_predict_layers + 1
labels_ori = labels
lm_labels = labels[:, : -self.config.num_nextn_predict_layers]
seq_length = lm_labels.shape[1]
if self.config.context_parallel_size > 1:
label_list = paddle.split(labels, self.config.num_nextn_predict_layers + 1)
lm_labels = label_list[0]
mtp_labels = label_list[1:]
seq_length = lm_labels.shape[1]
else:
lm_labels = labels[:, : -self.config.num_nextn_predict_layers]
mtp_labels = []
seq_length = lm_labels.shape[1]
for depth in range(self.config.num_nextn_predict_layers):
labels_cur_depth = labels_ori[
:, (depth + 1) : (depth + 1 + seq_length)
]
mtp_labels.append(labels_cur_depth)

mtp_loss = []
mtp_logits = logits[1:]
Expand All @@ -262,9 +274,7 @@ def forward(self, logits: Tensor | list, labels: Tensor) -> Tensor:

for depth in range(self.config.num_nextn_predict_layers):
logits_cur_depth = mtp_logits[depth]
labels_cur_depth = labels_ori[
:, (depth + 1) : (depth + 1 + seq_length)
]
labels_cur_depth = mtp_labels[depth]
loss_cur_depth = self._forward(
logits_cur_depth,
labels_cur_depth,
Expand Down Expand Up @@ -298,9 +308,7 @@ def padding(tensor, left=False, pad_len=1):
):
for depth in range(len(mtp_logits)):
prediction_scores_cur_depth = mtp_logits[depth]
labels_cur_depth = labels_ori[
:, (depth + 1) : (depth + 1 + seq_length)
]
labels_cur_depth = mtp_labels[depth]
lossmask = (
labels_cur_depth != self.ignored_index
).cast(paddle.float32)
Expand Down
42 changes: 28 additions & 14 deletions src/paddlefleet/models/gpt/gpt_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,24 @@ def forward(
assert not self.multimodal_embedding, (
"MTP not support mm for now."
)
inputs_embeds_extra = decoder_input[
:, -self.config.num_nextn_predict_layers :, :
] # [B, S, H]
inputs_embeds = decoder_input[
:, : -self.config.num_nextn_predict_layers, :
]
inputs_embeds_ori = inputs_embeds
if self.config.context_parallel_size > 1:
# when mtp and cp are opened at the same time,
# shape of decoder_input is [(K + 1)*B, S, H]
# K is the number of num_nextn_predict_layers
tensor_list = paddle.split(decoder_input, self.config.num_nextn_predict_layers+1)
inputs_embeds = tensor_list[0] # [B, S, H]
inputs_embeds_extra = tensor_list[1:] # K * [B, S, H]
else:
# when just mtp is opened,
# shape of decoder_input is [B, S + K, H]
# K is the number of num_nextn_predict_layers
inputs_embeds_extra = decoder_input[
:, -self.config.num_nextn_predict_layers :, :
] # [B, k, H]
inputs_embeds = decoder_input[
:, : -self.config.num_nextn_predict_layers, :
] # [B, S, H]
inputs_embeds_ori = inputs_embeds
batch_size, seq_length, hidden_size = inputs_embeds.shape

if self.sequence_parallel:
Expand All @@ -163,13 +174,16 @@ def forward(
) # change to [S, B, H]
mtp_emb_res = [inputs_embeds]
for depth in range(self.config.num_nextn_predict_layers):
inputs_embeds_mtp = paddle.concat(
[
inputs_embeds_ori[:, (depth + 1) :, :],
inputs_embeds_extra[:, : (depth + 1), :],
],
axis=1,
)
if self.config.context_parallel_size > 1:
inputs_embeds_mtp = inputs_embeds_extra[depth]
else:
inputs_embeds_mtp = paddle.concat(
[
inputs_embeds_ori[:, (depth + 1) :, :],
inputs_embeds_extra[:, : (depth + 1), :],
],
axis=1,
)
if self.sequence_parallel:
inputs_embeds_mtp = inputs_embeds_mtp.reshape(
[-1, inputs_embeds_mtp.shape[-1]]
Expand Down
28 changes: 19 additions & 9 deletions src/paddlefleet/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,17 @@ def forward(
# process position_ids
if "position_ids" in dict_args.keys():
position_ids = dict_args["position_ids"]
decoder_ids = position_ids[
:, : -self.config.num_nextn_predict_layers
]
mtp_ids = position_ids[
:, -self.config.num_nextn_predict_layers :
]
if self.config.context_parallel_size > 1:
tensor_list = paddle.split(position_ids, self.config.num_nextn_predict_layers + 1)
decoder_ids = tensor_list[0]
mtp_ids = tensor_list[1:]
else:
decoder_ids = position_ids[
:, : -self.config.num_nextn_predict_layers
]
mtp_ids = position_ids[
:, -self.config.num_nextn_predict_layers :
]
dict_args["position_ids"] = decoder_ids

# #process attn_mask_startend_row_indices
Expand Down Expand Up @@ -500,9 +505,14 @@ def forward(
rst["hidden_states"] = hidden_states_concat

if "position_ids" in dict_args.keys():
position_ids = paddle.concat(
[dict_args["position_ids"], mtp_ids], axis=1
)
if self.config.context_parallel_size > 1:
position_ids = paddle.concat(
[dict_args["position_ids"], *mtp_ids], axis=0
)
else:
position_ids = paddle.concat(
[dict_args["position_ids"], mtp_ids], axis=1
)
dict_args["position_ids"] = position_ids

if "attn_mask_startend_row_indices" in dict_args.keys():
Expand Down
11 changes: 9 additions & 2 deletions src/paddlefleet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,22 @@ def is_paddle_min_version(version, check_equality=True):
########################


def get_batch_on_this_cp_rank(inputs):
def get_batch_on_this_cp_rank(inputs, num_nextn_predict_layers=0):
if isinstance(inputs, paddle.Tensor):
return ContextParallelScatterOp.apply(inputs, axis=-1)
elif isinstance(inputs, dict):
res = {}
keys = ["input_ids", "position_ids", "labels"]
for k, tensor in inputs.items():
if k in keys:
res[k] = ContextParallelScatterOp.apply(tensor, axis=-1)
seq_len = tensor.shape[-1]
chunk_size = seq_len - num_nextn_predict_layers
res[k] = []
for i in range(num_nextn_predict_layers+1):
tensor_chunk = tensor[:, i : i+chunk_size]
res[k].append(ContextParallelScatterOp.apply(tensor_chunk, axis=-1))
# tensor shape = [(k+1)*b, s]
res[k] = paddle.concat(res[k])
else:
res[k] = tensor
elif isinstance(inputs, list):
Expand Down
Loading