diff --git a/src/paddlefleet/models/gpt/gpt_embedding.py b/src/paddlefleet/models/gpt/gpt_embedding.py index 11f17427c..ae8a6e4eb 100644 --- a/src/paddlefleet/models/gpt/gpt_embedding.py +++ b/src/paddlefleet/models/gpt/gpt_embedding.py @@ -344,6 +344,77 @@ def forward( [1, 0, 2, 3] ).contiguous() + # Precompute visual scatter indices for _deepstack_process. + # These indices are layer-independent and would otherwise be + # redundantly computed in every deepstack layer. + visual_update_indices = None + visual_gather_indices = None + if visual_pos_masks is not None: + if self.sequence_parallel: + try: + from paddle.distributed.fleet import ( + get_hybrid_communicate_group, + ) + + hcg = get_hybrid_communicate_group() + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + except (ImportError, AttributeError): + batch_size = visual_pos_masks.shape[0] + full_seq_len = visual_pos_masks.shape[1] + # decoder_input is already [S/tp, B, H] after SP scatter + local_seq = decoder_input.shape[0] + mp_size = (batch_size * full_seq_len) // ( + local_seq * batch_size + ) + mp_rank = paddle.distributed.get_rank() % mp_size + + full_seq_len = visual_pos_masks.shape[1] + chunk_s = full_seq_len // mp_size + start_s = mp_rank * chunk_s + + local_mask = visual_pos_masks[:, start_s : start_s + chunk_s] + batch_size = visual_pos_masks.shape[0] + + per_sample_total = paddle.cast(visual_pos_masks, "int32").sum( + axis=1 + ) + per_sample_pre = ( + paddle.cast(visual_pos_masks[:, :start_s], "int32").sum( + axis=1 + ) + if start_s > 0 + else paddle.zeros([batch_size], dtype="int32") + ) + + gather_indices_list = [] + cumulative_total = 0 + for i in range(batch_size): + total_i = int(per_sample_total[i].item()) + pre_i = int(per_sample_pre[i].item()) + count_i = int( + paddle.cast(local_mask[i], "int32").sum().item() + ) + if count_i > 0: + gather_indices_list.append( + paddle.arange( + cumulative_total + pre_i, + cumulative_total + pre_i + count_i, + ) + ) + cumulative_total += total_i + + if gather_indices_list: + visual_gather_indices = paddle.concat(gather_indices_list) + else: + visual_gather_indices = paddle.zeros([0], dtype="int64") + + visual_update_indices = paddle.nonzero(local_mask.flatten()) + else: + visual_update_indices = paddle.nonzero( + visual_pos_masks.flatten() + ) + preproc_output = { "hidden_states": decoder_input, "attention_mask": attention_mask, @@ -354,6 +425,8 @@ def forward( "position_ids": position_ids, "deepstack_visual_emb": deepstack_visual_embeds, "visual_pos_masks": visual_pos_masks, + "visual_update_indices": visual_update_indices, + "visual_gather_indices": visual_gather_indices, } if mtp_emb_res is not None: assert (