-
Notifications
You must be signed in to change notification settings - Fork 58
[VL] precompute visual scatter indices in GPTEmbedding #708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -344,6 +344,77 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [1, 0, 2, 3] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ).contiguous() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Precompute visual scatter indices for _deepstack_process. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # These indices are layer-independent and would otherwise be | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # redundantly computed in every deepstack layer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| visual_update_indices = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| visual_gather_indices = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if visual_pos_masks is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.sequence_parallel: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from paddle.distributed.fleet import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_hybrid_communicate_group, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hcg = get_hybrid_communicate_group() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mp_rank = hcg.get_model_parallel_rank() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mp_size = hcg.get_model_parallel_world_size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except (ImportError, AttributeError): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size = visual_pos_masks.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| full_seq_len = visual_pos_masks.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # decoder_input is already [S/tp, B, H] after SP scatter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_seq = decoder_input.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mp_size = (batch_size * full_seq_len) // ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_seq * batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mp_rank = paddle.distributed.get_rank() % mp_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+363
to
+371
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size = visual_pos_masks.shape[0] | |
| full_seq_len = visual_pos_masks.shape[1] | |
| # decoder_input is already [S/tp, B, H] after SP scatter | |
| local_seq = decoder_input.shape[0] | |
| mp_size = (batch_size * full_seq_len) // ( | |
| local_seq * batch_size | |
| ) | |
| mp_rank = paddle.distributed.get_rank() % mp_size | |
| # Fallback: try to get tensor model parallel rank/world_size | |
| try: | |
| from paddlefleet import parallel_state | |
| mp_rank = parallel_state.get_tensor_model_parallel_rank() | |
| mp_size = ( | |
| parallel_state.get_tensor_model_parallel_world_size() | |
| ) | |
| except (ImportError, AttributeError): | |
| # Final fallback: infer mp_size from shapes and assume rank 0 | |
| batch_size = visual_pos_masks.shape[0] | |
| full_seq_len = visual_pos_masks.shape[1] | |
| # decoder_input is already [S/tp, B, H] after SP scatter | |
| local_seq = decoder_input.shape[0] | |
| mp_size = (batch_size * full_seq_len) // ( | |
| local_seq * batch_size | |
| ) | |
| mp_rank = 0 |
Copilot
AI
Apr 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SP 分支里 chunk_s = full_seq_len // mp_size 会在 full_seq_len 不能被 mp_size 整除时静默丢弃尾部 token,导致 local_mask / indices 与真实的 SP 切分不一致。建议在这里显式 assert 可整除,或实现“最后一段带 remainder”的切分逻辑以覆盖非整除场景。
| chunk_s = full_seq_len // mp_size | |
| start_s = mp_rank * chunk_s | |
| # Support non-divisible sequence length by distributing the remainder | |
| base_chunk = full_seq_len // mp_size | |
| remainder = full_seq_len % mp_size | |
| if mp_rank < remainder: | |
| # Ranks before `remainder` get one extra token | |
| chunk_s = base_chunk + 1 | |
| start_s = mp_rank * (base_chunk + 1) | |
| else: | |
| # Remaining ranks get the base_chunk size | |
| chunk_s = base_chunk | |
| start_s = remainder * (base_chunk + 1) + (mp_rank - remainder) * base_chunk |
Copilot
AI
Apr 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 per-sample for-loop 里多次调用 .item() 把 Tensor 拉回 Python 标量,会触发 CPU/GPU 同步并破坏静态图/导出(同时也会在每个 step 带来显著性能开销)。建议用纯 Tensor 方式构造 gather 索引(例如基于 local_mask 的 cumsum/offsets 向量化计算),避免 Python 循环与 .item()。
| per_sample_total = paddle.cast(visual_pos_masks, "int32").sum( | |
| axis=1 | |
| ) | |
| per_sample_pre = ( | |
| paddle.cast(visual_pos_masks[:, :start_s], "int32").sum( | |
| axis=1 | |
| ) | |
| if start_s > 0 | |
| else paddle.zeros([batch_size], dtype="int32") | |
| ) | |
| gather_indices_list = [] | |
| cumulative_total = 0 | |
| for i in range(batch_size): | |
| total_i = int(per_sample_total[i].item()) | |
| pre_i = int(per_sample_pre[i].item()) | |
| count_i = int( | |
| paddle.cast(local_mask[i], "int32").sum().item() | |
| ) | |
| if count_i > 0: | |
| gather_indices_list.append( | |
| paddle.arange( | |
| cumulative_total + pre_i, | |
| cumulative_total + pre_i + count_i, | |
| ) | |
| ) | |
| cumulative_total += total_i | |
| if gather_indices_list: | |
| visual_gather_indices = paddle.concat(gather_indices_list) | |
| else: | |
| visual_gather_indices = paddle.zeros([0], dtype="int64") | |
| visual_mask_int = paddle.cast(visual_pos_masks, "int32") | |
| per_sample_total = visual_mask_int.sum(axis=1) | |
| # 计算每个样本在全局打平 visual 序列中的起始偏移: | |
| # sample_offsets[i] = sum_{j < i} per_sample_total[j] | |
| sample_offsets = paddle.concat( | |
| [ | |
| paddle.zeros([1], dtype="int32"), | |
| paddle.cumsum(per_sample_total[:-1]), | |
| ], | |
| axis=0, | |
| ) | |
| # prefix[b, s] = 样本 b 在位置 s 及之前的 visual 数量 - 1(得到 0-based index) | |
| prefix = paddle.cumsum(visual_mask_int, axis=1) - 1 | |
| global_idx = prefix + sample_offsets.unsqueeze(1) | |
| # 仅保留当前 rank 负责的序列片段,并在 local_mask 为 1 的位置选出对应全局索引 | |
| local_global_idx = global_idx[:, start_s : start_s + chunk_s] | |
| visual_gather_indices = paddle.masked_select( | |
| local_global_idx, | |
| paddle.cast(local_mask, "bool"), | |
| ).astype("int64") |
Copilot
AI
Apr 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.nonzero(local_mask.flatten())/paddle.nonzero(visual_pos_masks.flatten()) 对 1D 输入会返回形状为 [N, 1] 的索引张量;当前未做 squeeze(-1),很容易在下游当作 1D index 使用(例如 scatter/gather)时触发 shape 不匹配。建议显式 squeeze(-1) 并在注释/变量名中明确期望的 index 形状。
| visual_update_indices = paddle.nonzero(local_mask.flatten()) | |
| else: | |
| visual_update_indices = paddle.nonzero( | |
| visual_pos_masks.flatten() | |
| ) | |
| # nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index | |
| visual_update_indices = paddle.nonzero( | |
| local_mask.flatten() | |
| ).squeeze(-1) | |
| else: | |
| # nonzero on 1D input returns indices of shape [N, 1]; squeeze to get 1D [N] index | |
| visual_update_indices = paddle.nonzero( | |
| visual_pos_masks.flatten() | |
| ).squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增的
visual_update_indices/visual_gather_indices逻辑(尤其是 SP 分支的 per-rank 切分与 gather index 构造)目前没有单测覆盖。建议在现有的tests/single_card_tests/transformer/test_qwen_vl_sp.py中补充断言:非 SP 场景下 indices 与 mask 对齐;SP 场景下在 mock 的 mp_rank/mp_size 下能生成预期的 local indices/gather indices,并覆盖空 visual 的边界情况。