Skip to content

[VL] precompute visual scatter indices in GPTEmbedding#708

Open
huangjiyi wants to merge 1 commit intoPaddlePaddle:developfrom
huangjiyi:precompute-visual-scatter-indices
Open

[VL] precompute visual scatter indices in GPTEmbedding#708
huangjiyi wants to merge 1 commit intoPaddlePaddle:developfrom
huangjiyi:precompute-visual-scatter-indices

Conversation

@huangjiyi
Copy link
Copy Markdown
Member

Summary

  • Precompute visual_update_indices and visual_gather_indices once in GPTEmbedding, instead of redundantly computing them in every deepstack layer.
  • Supports both sequence-parallel (SP) and non-SP paths, with proper per-rank slicing and gather index construction.

Test plan

  • Verify VL pretrain/SFT training produces identical results
  • Confirm no regression in training loss/performance

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在将 VL 场景下 deepstack 相关的视觉 token 索引(visual_update_indices / visual_gather_indices)从“每层重复计算”提前到 GPTEmbedding.forward() 中一次性预计算,以减少 deepstack layer 的冗余开销,并同时支持 SP 与非 SP 路径。

Changes:

  • GPTEmbedding.forward() 中新增对视觉位置 mask 的索引预计算逻辑(含 SP per-rank slice 与 gather 索引构造)。
  • 将预计算得到的 visual_update_indices / visual_gather_indices 写入 preproc_output,供后续 deepstack 流程复用。

Comment on lines +379 to +410
per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(
axis=1
)
per_sample_pre = (
paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(
axis=1
)
if start_s > 0
else paddle.zeros([batch_size], dtype="int32")
)

gather_indices_list = []
cumulative_total = 0
for i in range(batch_size):
total_i = int(per_sample_total[i].item())
pre_i = int(per_sample_pre[i].item())
count_i = int(
paddle.cast(local_mask[i], "int32").sum().item()
)
if count_i > 0:
gather_indices_list.append(
paddle.arange(
cumulative_total + pre_i,
cumulative_total + pre_i + count_i,
)
)
cumulative_total += total_i

if gather_indices_list:
visual_gather_indices = paddle.concat(gather_indices_list)
else:
visual_gather_indices = paddle.zeros([0], dtype="int64")
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 per-sample for-loop 里多次调用 .item() 把 Tensor 拉回 Python 标量,会触发 CPU/GPU 同步并破坏静态图/导出(同时也会在每个 step 带来显著性能开销)。建议用纯 Tensor 方式构造 gather 索引(例如基于 local_mask 的 cumsum/offsets 向量化计算),避免 Python 循环与 .item()

Suggested change
per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(
axis=1
)
per_sample_pre = (
paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(
axis=1
)
if start_s > 0
else paddle.zeros([batch_size], dtype="int32")
)
gather_indices_list = []
cumulative_total = 0
for i in range(batch_size):
total_i = int(per_sample_total[i].item())
pre_i = int(per_sample_pre[i].item())
count_i = int(
paddle.cast(local_mask[i], "int32").sum().item()
)
if count_i > 0:
gather_indices_list.append(
paddle.arange(
cumulative_total + pre_i,
cumulative_total + pre_i + count_i,
)
)
cumulative_total += total_i
if gather_indices_list:
visual_gather_indices = paddle.concat(gather_indices_list)
else:
visual_gather_indices = paddle.zeros([0], dtype="int64")
visual_mask_int = paddle.cast(visual_pos_masks, "int32")
per_sample_total = visual_mask_int.sum(axis=1)
# 计算每个样本在全局打平 visual 序列中的起始偏移:
# sample_offsets[i] = sum_{j < i} per_sample_total[j]
sample_offsets = paddle.concat(
[
paddle.zeros([1], dtype="int32"),
paddle.cumsum(per_sample_total[:-1]),
],
axis=0,
)
# prefix[b, s] = 样本 b 在位置 s 及之前的 visual 数量 - 1(得到 0-based index)
prefix = paddle.cumsum(visual_mask_int, axis=1) - 1
global_idx = prefix + sample_offsets.unsqueeze(1)
# 仅保留当前 rank 负责的序列片段,并在 local_mask 为 1 的位置选出对应全局索引
local_global_idx = global_idx[:, start_s : start_s + chunk_s]
visual_gather_indices = paddle.masked_select(
local_global_idx,
paddle.cast(local_mask, "bool"),
).astype("int64")

Copilot uses AI. Check for mistakes.
Comment on lines +412 to +416
visual_update_indices = paddle.nonzero(local_mask.flatten())
else:
visual_update_indices = paddle.nonzero(
visual_pos_masks.flatten()
)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.nonzero(local_mask.flatten())/paddle.nonzero(visual_pos_masks.flatten()) 对 1D 输入会返回形状为 [N, 1] 的索引张量;当前未做 squeeze(-1),很容易在下游当作 1D index 使用(例如 scatter/gather)时触发 shape 不匹配。建议显式 squeeze(-1) 并在注释/变量名中明确期望的 index 形状。

Suggested change
visual_update_indices = paddle.nonzero(local_mask.flatten())
else:
visual_update_indices = paddle.nonzero(
visual_pos_masks.flatten()
)
# nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index
visual_update_indices = paddle.nonzero(
local_mask.flatten()
).squeeze(-1)
else:
# nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index
visual_update_indices = paddle.nonzero(
visual_pos_masks.flatten()
).squeeze(-1)

Copilot uses AI. Check for mistakes.
Comment on lines +373 to +374
chunk_s = full_seq_len // mp_size
start_s = mp_rank * chunk_s
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SP 分支里 chunk_s = full_seq_len // mp_size 会在 full_seq_len 不能被 mp_size 整除时静默丢弃尾部 token,导致 local_mask / indices 与真实的 SP 切分不一致。建议在这里显式 assert 可整除,或实现“最后一段带 remainder”的切分逻辑以覆盖非整除场景。

Suggested change
chunk_s = full_seq_len // mp_size
start_s = mp_rank * chunk_s
# Support non-divisible sequence length by distributing the remainder
base_chunk = full_seq_len // mp_size
remainder = full_seq_len % mp_size
if mp_rank < remainder:
# Ranks before `remainder` get one extra token
chunk_s = base_chunk + 1
start_s = mp_rank * (base_chunk + 1)
else:
# Remaining ranks get the base_chunk size
chunk_s = base_chunk
start_s = remainder * (base_chunk + 1) + (mp_rank - remainder) * base_chunk

Copilot uses AI. Check for mistakes.
Comment on lines +363 to +371
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
# decoder_input is already [S/tp, B, H] after SP scatter
local_seq = decoder_input.shape[0]
mp_size = (batch_size * full_seq_len) // (
local_seq * batch_size
)
mp_rank = paddle.distributed.get_rank() % mp_size

Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except 回退分支里用 mp_rank = dist.get_rank() % mp_size 来推断 model-parallel rank 在存在 DP/PP/其它拓扑时通常是不正确的(global rank 的取模不等于 MP rank),会导致 local slice / gather indices 计算错误。建议直接从已知的 TP/MP group 获取 rank/world_size(例如 self.embedding.tp_group.rank/nranks 或 paddlefleet.parallel_state 的 get_tensor_model_parallel_rank/world_size),而不是基于 global rank 推断。

Suggested change
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
# decoder_input is already [S/tp, B, H] after SP scatter
local_seq = decoder_input.shape[0]
mp_size = (batch_size * full_seq_len) // (
local_seq * batch_size
)
mp_rank = paddle.distributed.get_rank() % mp_size
# Fallback: try to get tensor model parallel rank/world_size
try:
from paddlefleet import parallel_state
mp_rank = parallel_state.get_tensor_model_parallel_rank()
mp_size = (
parallel_state.get_tensor_model_parallel_world_size()
)
except (ImportError, AttributeError):
# Final fallback: infer mp_size from shapes and assume rank 0
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
# decoder_input is already [S/tp, B, H] after SP scatter
local_seq = decoder_input.shape[0]
mp_size = (batch_size * full_seq_len) // (
local_seq * batch_size
)
mp_rank = 0

Copilot uses AI. Check for mistakes.
Comment on lines +347 to +352
# Precompute visual scatter indices for _deepstack_process.
# These indices are layer-independent and would otherwise be
# redundantly computed in every deepstack layer.
visual_update_indices = None
visual_gather_indices = None
if visual_pos_masks is not None:
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的 visual_update_indices/visual_gather_indices 逻辑(尤其是 SP 分支的 per-rank 切分与 gather index 构造)目前没有单测覆盖。建议在现有的 tests/single_card_tests/transformer/test_qwen_vl_sp.py 中补充断言:非 SP 场景下 indices 与 mask 对齐;SP 场景下在 mock 的 mp_rank/mp_size 下能生成预期的 local indices/gather indices,并覆盖空 visual 的边界情况。

Copilot generated this review using guidance from repository custom instructions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants