diff --git a/src/paddlefleet/models/common/language_loss/language_loss.py b/src/paddlefleet/models/common/language_loss/language_loss.py index 70aeb7ee2..bf360a393 100644 --- a/src/paddlefleet/models/common/language_loss/language_loss.py +++ b/src/paddlefleet/models/common/language_loss/language_loss.py @@ -248,8 +248,20 @@ def forward(self, logits: Tensor | list, labels: Tensor) -> Tensor: ) assert len(logits) == self.config.num_nextn_predict_layers + 1 labels_ori = labels - lm_labels = labels[:, : -self.config.num_nextn_predict_layers] - seq_length = lm_labels.shape[1] + if self.config.context_parallel_size > 1: + label_list = paddle.split(labels, self.config.num_nextn_predict_layers + 1) + lm_labels = label_list[0] + mtp_labels = label_list[1:] + seq_length = lm_labels.shape[1] + else: + lm_labels = labels[:, : -self.config.num_nextn_predict_layers] + mtp_labels = [] + seq_length = lm_labels.shape[1] + for depth in range(self.config.num_nextn_predict_layers): + labels_cur_depth = labels_ori[ + :, (depth + 1) : (depth + 1 + seq_length) + ] + mtp_labels.append(labels_cur_depth) mtp_loss = [] mtp_logits = logits[1:] @@ -262,9 +274,7 @@ def forward(self, logits: Tensor | list, labels: Tensor) -> Tensor: for depth in range(self.config.num_nextn_predict_layers): logits_cur_depth = mtp_logits[depth] - labels_cur_depth = labels_ori[ - :, (depth + 1) : (depth + 1 + seq_length) - ] + labels_cur_depth = mtp_labels[depth] loss_cur_depth = self._forward( logits_cur_depth, labels_cur_depth, @@ -298,9 +308,7 @@ def padding(tensor, left=False, pad_len=1): ): for depth in range(len(mtp_logits)): prediction_scores_cur_depth = mtp_logits[depth] - labels_cur_depth = labels_ori[ - :, (depth + 1) : (depth + 1 + seq_length) - ] + labels_cur_depth = mtp_labels[depth] lossmask = ( labels_cur_depth != self.ignored_index ).cast(paddle.float32) diff --git a/src/paddlefleet/models/gpt/gpt_embedding.py b/src/paddlefleet/models/gpt/gpt_embedding.py index 11f17427c..4bea04da2 100644 --- a/src/paddlefleet/models/gpt/gpt_embedding.py +++ b/src/paddlefleet/models/gpt/gpt_embedding.py @@ -142,13 +142,24 @@ def forward( assert not self.multimodal_embedding, ( "MTP not support mm for now." ) - inputs_embeds_extra = decoder_input[ - :, -self.config.num_nextn_predict_layers :, : - ] # [B, S, H] - inputs_embeds = decoder_input[ - :, : -self.config.num_nextn_predict_layers, : - ] - inputs_embeds_ori = inputs_embeds + if self.config.context_parallel_size > 1: + # when mtp and cp are opened at the same time, + # shape of decoder_input is [(K + 1)*B, S, H] + # K is the number of num_nextn_predict_layers + tensor_list = paddle.split(decoder_input, self.config.num_nextn_predict_layers+1) + inputs_embeds = tensor_list[0] # [B, S, H] + inputs_embeds_extra = tensor_list[1:] # K * [B, S, H] + else: + # when just mtp is opened, + # shape of decoder_input is [B, S + K, H] + # K is the number of num_nextn_predict_layers + inputs_embeds_extra = decoder_input[ + :, -self.config.num_nextn_predict_layers :, : + ] # [B, k, H] + inputs_embeds = decoder_input[ + :, : -self.config.num_nextn_predict_layers, : + ] # [B, S, H] + inputs_embeds_ori = inputs_embeds batch_size, seq_length, hidden_size = inputs_embeds.shape if self.sequence_parallel: @@ -163,13 +174,16 @@ def forward( ) # change to [S, B, H] mtp_emb_res = [inputs_embeds] for depth in range(self.config.num_nextn_predict_layers): - inputs_embeds_mtp = paddle.concat( - [ - inputs_embeds_ori[:, (depth + 1) :, :], - inputs_embeds_extra[:, : (depth + 1), :], - ], - axis=1, - ) + if self.config.context_parallel_size > 1: + inputs_embeds_mtp = inputs_embeds_extra[depth] + else: + inputs_embeds_mtp = paddle.concat( + [ + inputs_embeds_ori[:, (depth + 1) :, :], + inputs_embeds_extra[:, : (depth + 1), :], + ], + axis=1, + ) if self.sequence_parallel: inputs_embeds_mtp = inputs_embeds_mtp.reshape( [-1, inputs_embeds_mtp.shape[-1]] diff --git a/src/paddlefleet/transformer/transformer_layer.py b/src/paddlefleet/transformer/transformer_layer.py index ca921aad8..dd12a8df8 100644 --- a/src/paddlefleet/transformer/transformer_layer.py +++ b/src/paddlefleet/transformer/transformer_layer.py @@ -412,12 +412,17 @@ def forward( # process position_ids if "position_ids" in dict_args.keys(): position_ids = dict_args["position_ids"] - decoder_ids = position_ids[ - :, : -self.config.num_nextn_predict_layers - ] - mtp_ids = position_ids[ - :, -self.config.num_nextn_predict_layers : - ] + if self.config.context_parallel_size > 1: + tensor_list = paddle.split(position_ids, self.config.num_nextn_predict_layers + 1) + decoder_ids = tensor_list[0] + mtp_ids = tensor_list[1:] + else: + decoder_ids = position_ids[ + :, : -self.config.num_nextn_predict_layers + ] + mtp_ids = position_ids[ + :, -self.config.num_nextn_predict_layers : + ] dict_args["position_ids"] = decoder_ids # #process attn_mask_startend_row_indices @@ -500,9 +505,14 @@ def forward( rst["hidden_states"] = hidden_states_concat if "position_ids" in dict_args.keys(): - position_ids = paddle.concat( - [dict_args["position_ids"], mtp_ids], axis=1 - ) + if self.config.context_parallel_size > 1: + position_ids = paddle.concat( + [dict_args["position_ids"], *mtp_ids], axis=0 + ) + else: + position_ids = paddle.concat( + [dict_args["position_ids"], mtp_ids], axis=1 + ) dict_args["position_ids"] = position_ids if "attn_mask_startend_row_indices" in dict_args.keys(): diff --git a/src/paddlefleet/utils.py b/src/paddlefleet/utils.py index a41def207..110a979e4 100644 --- a/src/paddlefleet/utils.py +++ b/src/paddlefleet/utils.py @@ -267,7 +267,7 @@ def is_paddle_min_version(version, check_equality=True): ######################## -def get_batch_on_this_cp_rank(inputs): +def get_batch_on_this_cp_rank(inputs, num_nextn_predict_layers=0): if isinstance(inputs, paddle.Tensor): return ContextParallelScatterOp.apply(inputs, axis=-1) elif isinstance(inputs, dict): @@ -275,7 +275,14 @@ def get_batch_on_this_cp_rank(inputs): keys = ["input_ids", "position_ids", "labels"] for k, tensor in inputs.items(): if k in keys: - res[k] = ContextParallelScatterOp.apply(tensor, axis=-1) + seq_len = tensor.shape[-1] + chunk_size = seq_len - num_nextn_predict_layers + res[k] = [] + for i in range(num_nextn_predict_layers+1): + tensor_chunk = tensor[:, i : i+chunk_size] + res[k].append(ContextParallelScatterOp.apply(tensor_chunk, axis=-1)) + # tensor shape = [(k+1)*b, s] + res[k] = paddle.concat(res[k]) else: res[k] = tensor elif isinstance(inputs, list):