Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/paddlefleet/models/gpt/gpt_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,77 @@ def forward(
[1, 0, 2, 3]
).contiguous()

# Precompute visual scatter indices for _deepstack_process.
# These indices are layer-independent and would otherwise be
# redundantly computed in every deepstack layer.
visual_update_indices = None
visual_gather_indices = None
if visual_pos_masks is not None:
Comment on lines +347 to +352
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的 visual_update_indices/visual_gather_indices 逻辑(尤其是 SP 分支的 per-rank 切分与 gather index 构造)目前没有单测覆盖。建议在现有的 tests/single_card_tests/transformer/test_qwen_vl_sp.py 中补充断言:非 SP 场景下 indices 与 mask 对齐;SP 场景下在 mock 的 mp_rank/mp_size 下能生成预期的 local indices/gather indices,并覆盖空 visual 的边界情况。

Copilot generated this review using guidance from repository custom instructions.
if self.sequence_parallel:
try:
from paddle.distributed.fleet import (
get_hybrid_communicate_group,
)

hcg = get_hybrid_communicate_group()
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()
except (ImportError, AttributeError):
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
# decoder_input is already [S/tp, B, H] after SP scatter
local_seq = decoder_input.shape[0]
mp_size = (batch_size * full_seq_len) // (
local_seq * batch_size
)
mp_rank = paddle.distributed.get_rank() % mp_size

Comment on lines +363 to +371
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except 回退分支里用 mp_rank = dist.get_rank() % mp_size 来推断 model-parallel rank 在存在 DP/PP/其它拓扑时通常是不正确的(global rank 的取模不等于 MP rank),会导致 local slice / gather indices 计算错误。建议直接从已知的 TP/MP group 获取 rank/world_size(例如 self.embedding.tp_group.rank/nranks 或 paddlefleet.parallel_state 的 get_tensor_model_parallel_rank/world_size),而不是基于 global rank 推断。

Suggested change
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
# decoder_input is already [S/tp, B, H] after SP scatter
local_seq = decoder_input.shape[0]
mp_size = (batch_size * full_seq_len) // (
local_seq * batch_size
)
mp_rank = paddle.distributed.get_rank() % mp_size
# Fallback: try to get tensor model parallel rank/world_size
try:
from paddlefleet import parallel_state
mp_rank = parallel_state.get_tensor_model_parallel_rank()
mp_size = (
parallel_state.get_tensor_model_parallel_world_size()
)
except (ImportError, AttributeError):
# Final fallback: infer mp_size from shapes and assume rank 0
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
# decoder_input is already [S/tp, B, H] after SP scatter
local_seq = decoder_input.shape[0]
mp_size = (batch_size * full_seq_len) // (
local_seq * batch_size
)
mp_rank = 0

Copilot uses AI. Check for mistakes.
full_seq_len = visual_pos_masks.shape[1]
chunk_s = full_seq_len // mp_size
start_s = mp_rank * chunk_s
Comment on lines +373 to +374
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SP 分支里 chunk_s = full_seq_len // mp_size 会在 full_seq_len 不能被 mp_size 整除时静默丢弃尾部 token,导致 local_mask / indices 与真实的 SP 切分不一致。建议在这里显式 assert 可整除,或实现“最后一段带 remainder”的切分逻辑以覆盖非整除场景。

Suggested change
chunk_s = full_seq_len // mp_size
start_s = mp_rank * chunk_s
# Support non-divisible sequence length by distributing the remainder
base_chunk = full_seq_len // mp_size
remainder = full_seq_len % mp_size
if mp_rank < remainder:
# Ranks before `remainder` get one extra token
chunk_s = base_chunk + 1
start_s = mp_rank * (base_chunk + 1)
else:
# Remaining ranks get the base_chunk size
chunk_s = base_chunk
start_s = remainder * (base_chunk + 1) + (mp_rank - remainder) * base_chunk

Copilot uses AI. Check for mistakes.

local_mask = visual_pos_masks[:, start_s : start_s + chunk_s]
batch_size = visual_pos_masks.shape[0]

per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(
axis=1
)
per_sample_pre = (
paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(
axis=1
)
if start_s > 0
else paddle.zeros([batch_size], dtype="int32")
)

gather_indices_list = []
cumulative_total = 0
for i in range(batch_size):
total_i = int(per_sample_total[i].item())
pre_i = int(per_sample_pre[i].item())
count_i = int(
paddle.cast(local_mask[i], "int32").sum().item()
)
if count_i > 0:
gather_indices_list.append(
paddle.arange(
cumulative_total + pre_i,
cumulative_total + pre_i + count_i,
)
)
cumulative_total += total_i

if gather_indices_list:
visual_gather_indices = paddle.concat(gather_indices_list)
else:
visual_gather_indices = paddle.zeros([0], dtype="int64")
Comment on lines +379 to +410
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 per-sample for-loop 里多次调用 .item() 把 Tensor 拉回 Python 标量,会触发 CPU/GPU 同步并破坏静态图/导出(同时也会在每个 step 带来显著性能开销)。建议用纯 Tensor 方式构造 gather 索引(例如基于 local_mask 的 cumsum/offsets 向量化计算),避免 Python 循环与 .item()

Suggested change
per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(
axis=1
)
per_sample_pre = (
paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(
axis=1
)
if start_s > 0
else paddle.zeros([batch_size], dtype="int32")
)
gather_indices_list = []
cumulative_total = 0
for i in range(batch_size):
total_i = int(per_sample_total[i].item())
pre_i = int(per_sample_pre[i].item())
count_i = int(
paddle.cast(local_mask[i], "int32").sum().item()
)
if count_i > 0:
gather_indices_list.append(
paddle.arange(
cumulative_total + pre_i,
cumulative_total + pre_i + count_i,
)
)
cumulative_total += total_i
if gather_indices_list:
visual_gather_indices = paddle.concat(gather_indices_list)
else:
visual_gather_indices = paddle.zeros([0], dtype="int64")
visual_mask_int = paddle.cast(visual_pos_masks, "int32")
per_sample_total = visual_mask_int.sum(axis=1)
# 计算每个样本在全局打平 visual 序列中的起始偏移:
# sample_offsets[i] = sum_{j < i} per_sample_total[j]
sample_offsets = paddle.concat(
[
paddle.zeros([1], dtype="int32"),
paddle.cumsum(per_sample_total[:-1]),
],
axis=0,
)
# prefix[b, s] = 样本 b 在位置 s 及之前的 visual 数量 - 1(得到 0-based index)
prefix = paddle.cumsum(visual_mask_int, axis=1) - 1
global_idx = prefix + sample_offsets.unsqueeze(1)
# 仅保留当前 rank 负责的序列片段,并在 local_mask 为 1 的位置选出对应全局索引
local_global_idx = global_idx[:, start_s : start_s + chunk_s]
visual_gather_indices = paddle.masked_select(
local_global_idx,
paddle.cast(local_mask, "bool"),
).astype("int64")

Copilot uses AI. Check for mistakes.

visual_update_indices = paddle.nonzero(local_mask.flatten())
else:
visual_update_indices = paddle.nonzero(
visual_pos_masks.flatten()
)
Comment on lines +412 to +416
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.nonzero(local_mask.flatten())/paddle.nonzero(visual_pos_masks.flatten()) 对 1D 输入会返回形状为 [N, 1] 的索引张量;当前未做 squeeze(-1),很容易在下游当作 1D index 使用(例如 scatter/gather)时触发 shape 不匹配。建议显式 squeeze(-1) 并在注释/变量名中明确期望的 index 形状。

Suggested change
visual_update_indices = paddle.nonzero(local_mask.flatten())
else:
visual_update_indices = paddle.nonzero(
visual_pos_masks.flatten()
)
# nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index
visual_update_indices = paddle.nonzero(
local_mask.flatten()
).squeeze(-1)
else:
# nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index
visual_update_indices = paddle.nonzero(
visual_pos_masks.flatten()
).squeeze(-1)

Copilot uses AI. Check for mistakes.

preproc_output = {
"hidden_states": decoder_input,
"attention_mask": attention_mask,
Expand All @@ -354,6 +425,8 @@ def forward(
"position_ids": position_ids,
"deepstack_visual_emb": deepstack_visual_embeds,
"visual_pos_masks": visual_pos_masks,
"visual_update_indices": visual_update_indices,
"visual_gather_indices": visual_gather_indices,
}
if mtp_emb_res is not None:
assert (
Expand Down
Loading