From 70904ce22fc5b55365dc637d9b2dc2187978bab9 Mon Sep 17 00:00:00 2001 From: huangjiyi <947613776@qq.com> Date: Wed, 1 Apr 2026 11:33:54 +0800 Subject: [PATCH] [Qwen3_VL] use precomputed visual scatter indices in deepstack process --- .../transformers/qwen3_vl/modeling_fleet.py | 153 ++++++++++-------- 1 file changed, 86 insertions(+), 67 deletions(-) diff --git a/paddleformers/transformers/qwen3_vl/modeling_fleet.py b/paddleformers/transformers/qwen3_vl/modeling_fleet.py index 689315470ee..02f154fbf6e 100644 --- a/paddleformers/transformers/qwen3_vl/modeling_fleet.py +++ b/paddleformers/transformers/qwen3_vl/modeling_fleet.py @@ -148,6 +148,8 @@ def forward( dict_args.pop("position_ids", None) deepstack_visual_emb = dict_args.get("deepstack_visual_emb", None) visual_pos_masks = dict_args.get("visual_pos_masks", None) + visual_update_indices = dict_args.get("visual_update_indices", None) + visual_gather_indices = dict_args.get("visual_gather_indices", None) if self.full_recompute: hidden_states = dict_args["hidden_states"] @@ -191,8 +193,10 @@ def forward( if deepstack_visual_emb and self.layer_number in range(len(deepstack_visual_emb)): output = self._deepstack_process( hidden_states=output, - visual_embeds=deepstack_visual_emb[self.layer_number], visual_pos_masks=visual_pos_masks, + visual_embeds=deepstack_visual_emb[self.layer_number], + visual_update_indices=visual_update_indices, + visual_gather_indices=visual_gather_indices, ) rst = OrderedDict() @@ -241,7 +245,12 @@ def _forward_impl( return hidden_states def _deepstack_process( - self, hidden_states: paddle.Tensor, visual_pos_masks: paddle.Tensor, visual_embeds: paddle.Tensor + self, + hidden_states: paddle.Tensor, + visual_pos_masks: paddle.Tensor, + visual_embeds: paddle.Tensor, + visual_update_indices: paddle.Tensor | None = None, + visual_gather_indices: paddle.Tensor | None = None, ): # SP layout is [S/tp, B, H] (seq-first); transpose to [B, S/tp, H] so that # flatten(0,1) produces batch-first [B*S/tp, H], consistent with visual_pos_masks [B, S]. @@ -257,67 +266,80 @@ def _deepstack_process( visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) - # Sequence Parallelism (SP) row slicing. - # visual_pos_masks is [B, S] (full sequence), hidden_states is [B*S/tp, H] - # (batch-major after transpose+flatten). We must slice along the S dimension - # (dim=1) to match the batch-major layout, NOT flatten-then-chunk which - # breaks when B > 1. - if visual_pos_masks.ndim > 1 and visual_pos_masks.shape[1] > hidden_states.shape[0] // max( - visual_pos_masks.shape[0], 1 - ): - # visual_pos_masks: [B, S], hidden_states: [B*S/tp, H] - try: - from paddle.distributed.fleet import get_hybrid_communicate_group + if visual_update_indices is not None: + # Fast path: use precomputed indices from GPTEmbedding. + # visual_gather_indices is not None iff SP is active. + if visual_gather_indices is not None: + if visual_gather_indices.shape[0] > 0: + visual_embeds = visual_embeds[visual_gather_indices] + else: + visual_embeds = visual_embeds[:0] # empty — this rank has no visual tokens + + update_indices = visual_update_indices + else: + # Fallback: compute indices on-the-fly (e.g. when called without precomputation). + # Sequence Parallelism (SP) row slicing. + # visual_pos_masks is [B, S] (full sequence), hidden_states is [B*S/tp, H] + # (batch-major after transpose+flatten). We must slice along the S dimension + # (dim=1) to match the batch-major layout, NOT flatten-then-chunk which + # breaks when B > 1. + if visual_pos_masks.ndim > 1 and visual_pos_masks.shape[1] > hidden_states.shape[0] // max( + visual_pos_masks.shape[0], 1 + ): + # visual_pos_masks: [B, S], hidden_states: [B*S/tp, H] + try: + from paddle.distributed.fleet import get_hybrid_communicate_group + + hcg = get_hybrid_communicate_group() + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + except (ImportError, AttributeError): + batch_size = visual_pos_masks.shape[0] + full_seq_len = visual_pos_masks.shape[1] + mp_size = (batch_size * full_seq_len) // hidden_states.shape[0] + mp_rank = paddle.distributed.get_rank() % mp_size - hcg = get_hybrid_communicate_group() - mp_rank = hcg.get_model_parallel_rank() - mp_size = hcg.get_model_parallel_world_size() - except (ImportError, AttributeError): - batch_size = visual_pos_masks.shape[0] full_seq_len = visual_pos_masks.shape[1] - mp_size = (batch_size * full_seq_len) // hidden_states.shape[0] - mp_rank = paddle.distributed.get_rank() % mp_size - - full_seq_len = visual_pos_masks.shape[1] - chunk_s = full_seq_len // mp_size - start_s = mp_rank * chunk_s - - # Slice along S dimension: [B, S] -> [B, S/tp] - local_mask = visual_pos_masks[:, start_s : start_s + chunk_s] - batch_size = visual_pos_masks.shape[0] - - # Gather per-sample visual_embeds. - # visual_embeds is ordered as [sample0_all_vis, sample1_all_vis, ...]. - # Each rank only needs the visual tokens that fall within its local - # sequence chunk [start_s, start_s+chunk_s) for each sample. - per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(axis=1) # [B] - per_sample_pre = ( - paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(axis=1) - if start_s > 0 - else paddle.zeros([batch_size], dtype="int32") - ) # [B] - per_sample_local = paddle.cast(local_mask, "int32").sum(axis=1) # [B] - - gather_indices = [] - cumulative_total = 0 - for i in range(batch_size): - total_i = int(per_sample_total[i].item()) - pre_i = int(per_sample_pre[i].item()) - count_i = int(per_sample_local[i].item()) - if count_i > 0: - gather_indices.append(paddle.arange(cumulative_total + pre_i, cumulative_total + pre_i + count_i)) - cumulative_total += total_i - - if gather_indices: - gather_indices = paddle.concat(gather_indices) - visual_embeds = visual_embeds[gather_indices] - else: - visual_embeds = visual_embeds[:0] # empty + chunk_s = full_seq_len // mp_size + start_s = mp_rank * chunk_s + + # Slice along S dimension: [B, S] -> [B, S/tp] + local_mask = visual_pos_masks[:, start_s : start_s + chunk_s] + batch_size = visual_pos_masks.shape[0] + + # Gather per-sample visual_embeds. + per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(axis=1) # [B] + per_sample_pre = ( + paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(axis=1) + if start_s > 0 + else paddle.zeros([batch_size], dtype="int32") + ) # [B] + per_sample_local = paddle.cast(local_mask, "int32").sum(axis=1) # [B] + + gather_indices = [] + cumulative_total = 0 + for i in range(batch_size): + total_i = int(per_sample_total[i].item()) + pre_i = int(per_sample_pre[i].item()) + count_i = int(per_sample_local[i].item()) + if count_i > 0: + gather_indices.append( + paddle.arange(cumulative_total + pre_i, cumulative_total + pre_i + count_i) + ) + cumulative_total += total_i + + if gather_indices: + gather_indices = paddle.concat(gather_indices) + visual_embeds = visual_embeds[gather_indices] + else: + visual_embeds = visual_embeds[:0] # empty + + # Flatten local mask to [B*S/tp] matching hidden_states batch-major layout + visual_pos_masks = local_mask.flatten() + elif visual_pos_masks.ndim > 1: + visual_pos_masks = visual_pos_masks.flatten() - # Flatten local mask to [B*S/tp] matching hidden_states batch-major layout - visual_pos_masks = local_mask.flatten() - elif visual_pos_masks.ndim > 1: - visual_pos_masks = visual_pos_masks.flatten() + update_indices = paddle.nonzero(visual_pos_masks) # If TP is enabled, hidden_states has shape [..., Hidden_Dim / TP_Size], # but visual_embeds usually has full [Hidden_Dim]. We need to slice visual_embeds column-wise. @@ -341,16 +363,13 @@ def _deepstack_process( visual_embeds = visual_embeds[:, start_col:end_col] hidden_states = hidden_states.clone() - update_indices = paddle.nonzero(visual_pos_masks) - # Under SP, visual tokens are unevenly distributed across ranks. After row-slicing - # visual_pos_masks and visual_embeds to the local sequence chunk, some ranks may - # have zero visual tokens (local_visual_count == 0), producing visual_embeds with - # shape [0, H]. Guard against passing an empty updates tensor to scatter_nd_add, - # whose behavior is undefined / backend-dependent in that case. + # Under SP, visual tokens are unevenly distributed across ranks. Some ranks may + # have zero visual tokens, producing visual_embeds with shape [0, H]. + # Guard against passing an empty updates tensor to scatter_nd_add. if visual_embeds.shape[0] > 0: hidden_states = paddle.scatter_nd_add(hidden_states, update_indices, visual_embeds) - # [Supplement 3] Restore original shape [B*S, D] -> [B, S, D] if necessary + # Restore original shape [B*S, D] -> [B, S, D] if necessary if len(original_shape) > 2: hidden_states = hidden_states.reshape(original_shape) if _sp_transposed: