[VL] precompute visual scatter indices in GPTEmbedding#708
[VL] precompute visual scatter indices in GPTEmbedding#708huangjiyi wants to merge 1 commit intoPaddlePaddle:developfrom
Conversation
There was a problem hiding this comment.
Pull request overview
该 PR 旨在将 VL 场景下 deepstack 相关的视觉 token 索引(visual_update_indices / visual_gather_indices)从“每层重复计算”提前到 GPTEmbedding.forward() 中一次性预计算,以减少 deepstack layer 的冗余开销,并同时支持 SP 与非 SP 路径。
Changes:
- 在
GPTEmbedding.forward()中新增对视觉位置 mask 的索引预计算逻辑(含 SP per-rank slice 与 gather 索引构造)。 - 将预计算得到的
visual_update_indices/visual_gather_indices写入preproc_output,供后续 deepstack 流程复用。
| per_sample_total = paddle.cast(visual_pos_masks, "int32").sum( | ||
| axis=1 | ||
| ) | ||
| per_sample_pre = ( | ||
| paddle.cast(visual_pos_masks[:, :start_s], "int32").sum( | ||
| axis=1 | ||
| ) | ||
| if start_s > 0 | ||
| else paddle.zeros([batch_size], dtype="int32") | ||
| ) | ||
|
|
||
| gather_indices_list = [] | ||
| cumulative_total = 0 | ||
| for i in range(batch_size): | ||
| total_i = int(per_sample_total[i].item()) | ||
| pre_i = int(per_sample_pre[i].item()) | ||
| count_i = int( | ||
| paddle.cast(local_mask[i], "int32").sum().item() | ||
| ) | ||
| if count_i > 0: | ||
| gather_indices_list.append( | ||
| paddle.arange( | ||
| cumulative_total + pre_i, | ||
| cumulative_total + pre_i + count_i, | ||
| ) | ||
| ) | ||
| cumulative_total += total_i | ||
|
|
||
| if gather_indices_list: | ||
| visual_gather_indices = paddle.concat(gather_indices_list) | ||
| else: | ||
| visual_gather_indices = paddle.zeros([0], dtype="int64") |
There was a problem hiding this comment.
这里的 per-sample for-loop 里多次调用 .item() 把 Tensor 拉回 Python 标量,会触发 CPU/GPU 同步并破坏静态图/导出(同时也会在每个 step 带来显著性能开销)。建议用纯 Tensor 方式构造 gather 索引(例如基于 local_mask 的 cumsum/offsets 向量化计算),避免 Python 循环与 .item()。
| per_sample_total = paddle.cast(visual_pos_masks, "int32").sum( | |
| axis=1 | |
| ) | |
| per_sample_pre = ( | |
| paddle.cast(visual_pos_masks[:, :start_s], "int32").sum( | |
| axis=1 | |
| ) | |
| if start_s > 0 | |
| else paddle.zeros([batch_size], dtype="int32") | |
| ) | |
| gather_indices_list = [] | |
| cumulative_total = 0 | |
| for i in range(batch_size): | |
| total_i = int(per_sample_total[i].item()) | |
| pre_i = int(per_sample_pre[i].item()) | |
| count_i = int( | |
| paddle.cast(local_mask[i], "int32").sum().item() | |
| ) | |
| if count_i > 0: | |
| gather_indices_list.append( | |
| paddle.arange( | |
| cumulative_total + pre_i, | |
| cumulative_total + pre_i + count_i, | |
| ) | |
| ) | |
| cumulative_total += total_i | |
| if gather_indices_list: | |
| visual_gather_indices = paddle.concat(gather_indices_list) | |
| else: | |
| visual_gather_indices = paddle.zeros([0], dtype="int64") | |
| visual_mask_int = paddle.cast(visual_pos_masks, "int32") | |
| per_sample_total = visual_mask_int.sum(axis=1) | |
| # 计算每个样本在全局打平 visual 序列中的起始偏移: | |
| # sample_offsets[i] = sum_{j < i} per_sample_total[j] | |
| sample_offsets = paddle.concat( | |
| [ | |
| paddle.zeros([1], dtype="int32"), | |
| paddle.cumsum(per_sample_total[:-1]), | |
| ], | |
| axis=0, | |
| ) | |
| # prefix[b, s] = 样本 b 在位置 s 及之前的 visual 数量 - 1(得到 0-based index) | |
| prefix = paddle.cumsum(visual_mask_int, axis=1) - 1 | |
| global_idx = prefix + sample_offsets.unsqueeze(1) | |
| # 仅保留当前 rank 负责的序列片段,并在 local_mask 为 1 的位置选出对应全局索引 | |
| local_global_idx = global_idx[:, start_s : start_s + chunk_s] | |
| visual_gather_indices = paddle.masked_select( | |
| local_global_idx, | |
| paddle.cast(local_mask, "bool"), | |
| ).astype("int64") |
| visual_update_indices = paddle.nonzero(local_mask.flatten()) | ||
| else: | ||
| visual_update_indices = paddle.nonzero( | ||
| visual_pos_masks.flatten() | ||
| ) |
There was a problem hiding this comment.
paddle.nonzero(local_mask.flatten())/paddle.nonzero(visual_pos_masks.flatten()) 对 1D 输入会返回形状为 [N, 1] 的索引张量;当前未做 squeeze(-1),很容易在下游当作 1D index 使用(例如 scatter/gather)时触发 shape 不匹配。建议显式 squeeze(-1) 并在注释/变量名中明确期望的 index 形状。
| visual_update_indices = paddle.nonzero(local_mask.flatten()) | |
| else: | |
| visual_update_indices = paddle.nonzero( | |
| visual_pos_masks.flatten() | |
| ) | |
| # nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index | |
| visual_update_indices = paddle.nonzero( | |
| local_mask.flatten() | |
| ).squeeze(-1) | |
| else: | |
| # nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index | |
| visual_update_indices = paddle.nonzero( | |
| visual_pos_masks.flatten() | |
| ).squeeze(-1) |
| chunk_s = full_seq_len // mp_size | ||
| start_s = mp_rank * chunk_s |
There was a problem hiding this comment.
SP 分支里 chunk_s = full_seq_len // mp_size 会在 full_seq_len 不能被 mp_size 整除时静默丢弃尾部 token,导致 local_mask / indices 与真实的 SP 切分不一致。建议在这里显式 assert 可整除,或实现“最后一段带 remainder”的切分逻辑以覆盖非整除场景。
| chunk_s = full_seq_len // mp_size | |
| start_s = mp_rank * chunk_s | |
| # Support non-divisible sequence length by distributing the remainder | |
| base_chunk = full_seq_len // mp_size | |
| remainder = full_seq_len % mp_size | |
| if mp_rank < remainder: | |
| # Ranks before `remainder` get one extra token | |
| chunk_s = base_chunk + 1 | |
| start_s = mp_rank * (base_chunk + 1) | |
| else: | |
| # Remaining ranks get the base_chunk size | |
| chunk_s = base_chunk | |
| start_s = remainder * (base_chunk + 1) + (mp_rank - remainder) * base_chunk |
| batch_size = visual_pos_masks.shape[0] | ||
| full_seq_len = visual_pos_masks.shape[1] | ||
| # decoder_input is already [S/tp, B, H] after SP scatter | ||
| local_seq = decoder_input.shape[0] | ||
| mp_size = (batch_size * full_seq_len) // ( | ||
| local_seq * batch_size | ||
| ) | ||
| mp_rank = paddle.distributed.get_rank() % mp_size | ||
|
|
There was a problem hiding this comment.
except 回退分支里用 mp_rank = dist.get_rank() % mp_size 来推断 model-parallel rank 在存在 DP/PP/其它拓扑时通常是不正确的(global rank 的取模不等于 MP rank),会导致 local slice / gather indices 计算错误。建议直接从已知的 TP/MP group 获取 rank/world_size(例如 self.embedding.tp_group.rank/nranks 或 paddlefleet.parallel_state 的 get_tensor_model_parallel_rank/world_size),而不是基于 global rank 推断。
| batch_size = visual_pos_masks.shape[0] | |
| full_seq_len = visual_pos_masks.shape[1] | |
| # decoder_input is already [S/tp, B, H] after SP scatter | |
| local_seq = decoder_input.shape[0] | |
| mp_size = (batch_size * full_seq_len) // ( | |
| local_seq * batch_size | |
| ) | |
| mp_rank = paddle.distributed.get_rank() % mp_size | |
| # Fallback: try to get tensor model parallel rank/world_size | |
| try: | |
| from paddlefleet import parallel_state | |
| mp_rank = parallel_state.get_tensor_model_parallel_rank() | |
| mp_size = ( | |
| parallel_state.get_tensor_model_parallel_world_size() | |
| ) | |
| except (ImportError, AttributeError): | |
| # Final fallback: infer mp_size from shapes and assume rank 0 | |
| batch_size = visual_pos_masks.shape[0] | |
| full_seq_len = visual_pos_masks.shape[1] | |
| # decoder_input is already [S/tp, B, H] after SP scatter | |
| local_seq = decoder_input.shape[0] | |
| mp_size = (batch_size * full_seq_len) // ( | |
| local_seq * batch_size | |
| ) | |
| mp_rank = 0 |
| # Precompute visual scatter indices for _deepstack_process. | ||
| # These indices are layer-independent and would otherwise be | ||
| # redundantly computed in every deepstack layer. | ||
| visual_update_indices = None | ||
| visual_gather_indices = None | ||
| if visual_pos_masks is not None: |
There was a problem hiding this comment.
新增的 visual_update_indices/visual_gather_indices 逻辑(尤其是 SP 分支的 per-rank 切分与 gather index 构造)目前没有单测覆盖。建议在现有的 tests/single_card_tests/transformer/test_qwen_vl_sp.py 中补充断言:非 SP 场景下 indices 与 mask 对齐;SP 场景下在 mock 的 mp_rank/mp_size 下能生成预期的 local indices/gather indices,并覆盖空 visual 的边界情况。
Summary
visual_update_indicesandvisual_gather_indicesonce inGPTEmbedding, instead of redundantly computing them in every deepstack layer.Test plan