Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 86 additions & 67 deletions paddleformers/transformers/qwen3_vl/modeling_fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def forward(
dict_args.pop("position_ids", None)
deepstack_visual_emb = dict_args.get("deepstack_visual_emb", None)
visual_pos_masks = dict_args.get("visual_pos_masks", None)
visual_update_indices = dict_args.get("visual_update_indices", None)
visual_gather_indices = dict_args.get("visual_gather_indices", None)

if self.full_recompute:
hidden_states = dict_args["hidden_states"]
Expand Down Expand Up @@ -191,8 +193,10 @@ def forward(
if deepstack_visual_emb and self.layer_number in range(len(deepstack_visual_emb)):
output = self._deepstack_process(
hidden_states=output,
visual_embeds=deepstack_visual_emb[self.layer_number],
visual_pos_masks=visual_pos_masks,
visual_embeds=deepstack_visual_emb[self.layer_number],
visual_update_indices=visual_update_indices,
visual_gather_indices=visual_gather_indices,
)

rst = OrderedDict()
Expand Down Expand Up @@ -241,7 +245,12 @@ def _forward_impl(
return hidden_states

def _deepstack_process(
self, hidden_states: paddle.Tensor, visual_pos_masks: paddle.Tensor, visual_embeds: paddle.Tensor
self,
hidden_states: paddle.Tensor,
visual_pos_masks: paddle.Tensor,
visual_embeds: paddle.Tensor,
visual_update_indices: paddle.Tensor | None = None,
visual_gather_indices: paddle.Tensor | None = None,
):
# SP layout is [S/tp, B, H] (seq-first); transpose to [B, S/tp, H] so that
# flatten(0,1) produces batch-first [B*S/tp, H], consistent with visual_pos_masks [B, S].
Expand All @@ -257,67 +266,80 @@ def _deepstack_process(

visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)

# Sequence Parallelism (SP) row slicing.
# visual_pos_masks is [B, S] (full sequence), hidden_states is [B*S/tp, H]
# (batch-major after transpose+flatten). We must slice along the S dimension
# (dim=1) to match the batch-major layout, NOT flatten-then-chunk which
# breaks when B > 1.
if visual_pos_masks.ndim > 1 and visual_pos_masks.shape[1] > hidden_states.shape[0] // max(
visual_pos_masks.shape[0], 1
):
# visual_pos_masks: [B, S], hidden_states: [B*S/tp, H]
try:
from paddle.distributed.fleet import get_hybrid_communicate_group
if visual_update_indices is not None:
# Fast path: use precomputed indices from GPTEmbedding.
# visual_gather_indices is not None iff SP is active.
if visual_gather_indices is not None:
if visual_gather_indices.shape[0] > 0:
visual_embeds = visual_embeds[visual_gather_indices]
else:
visual_embeds = visual_embeds[:0] # empty — this rank has no visual tokens

update_indices = visual_update_indices
else:
# Fallback: compute indices on-the-fly (e.g. when called without precomputation).
# Sequence Parallelism (SP) row slicing.
# visual_pos_masks is [B, S] (full sequence), hidden_states is [B*S/tp, H]
# (batch-major after transpose+flatten). We must slice along the S dimension
# (dim=1) to match the batch-major layout, NOT flatten-then-chunk which
# breaks when B > 1.
if visual_pos_masks.ndim > 1 and visual_pos_masks.shape[1] > hidden_states.shape[0] // max(
visual_pos_masks.shape[0], 1
):
# visual_pos_masks: [B, S], hidden_states: [B*S/tp, H]
try:
from paddle.distributed.fleet import get_hybrid_communicate_group

hcg = get_hybrid_communicate_group()
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()
except (ImportError, AttributeError):
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
mp_size = (batch_size * full_seq_len) // hidden_states.shape[0]
mp_rank = paddle.distributed.get_rank() % mp_size

hcg = get_hybrid_communicate_group()
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()
except (ImportError, AttributeError):
batch_size = visual_pos_masks.shape[0]
full_seq_len = visual_pos_masks.shape[1]
mp_size = (batch_size * full_seq_len) // hidden_states.shape[0]
mp_rank = paddle.distributed.get_rank() % mp_size

full_seq_len = visual_pos_masks.shape[1]
chunk_s = full_seq_len // mp_size
start_s = mp_rank * chunk_s

# Slice along S dimension: [B, S] -> [B, S/tp]
local_mask = visual_pos_masks[:, start_s : start_s + chunk_s]
batch_size = visual_pos_masks.shape[0]

# Gather per-sample visual_embeds.
# visual_embeds is ordered as [sample0_all_vis, sample1_all_vis, ...].
# Each rank only needs the visual tokens that fall within its local
# sequence chunk [start_s, start_s+chunk_s) for each sample.
per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(axis=1) # [B]
per_sample_pre = (
paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(axis=1)
if start_s > 0
else paddle.zeros([batch_size], dtype="int32")
) # [B]
per_sample_local = paddle.cast(local_mask, "int32").sum(axis=1) # [B]

gather_indices = []
cumulative_total = 0
for i in range(batch_size):
total_i = int(per_sample_total[i].item())
pre_i = int(per_sample_pre[i].item())
count_i = int(per_sample_local[i].item())
if count_i > 0:
gather_indices.append(paddle.arange(cumulative_total + pre_i, cumulative_total + pre_i + count_i))
cumulative_total += total_i

if gather_indices:
gather_indices = paddle.concat(gather_indices)
visual_embeds = visual_embeds[gather_indices]
else:
visual_embeds = visual_embeds[:0] # empty
chunk_s = full_seq_len // mp_size
start_s = mp_rank * chunk_s

# Slice along S dimension: [B, S] -> [B, S/tp]
local_mask = visual_pos_masks[:, start_s : start_s + chunk_s]
batch_size = visual_pos_masks.shape[0]

# Gather per-sample visual_embeds.
per_sample_total = paddle.cast(visual_pos_masks, "int32").sum(axis=1) # [B]
per_sample_pre = (
paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(axis=1)
if start_s > 0
else paddle.zeros([batch_size], dtype="int32")
) # [B]
per_sample_local = paddle.cast(local_mask, "int32").sum(axis=1) # [B]

gather_indices = []
cumulative_total = 0
for i in range(batch_size):
total_i = int(per_sample_total[i].item())
pre_i = int(per_sample_pre[i].item())
count_i = int(per_sample_local[i].item())
if count_i > 0:
gather_indices.append(
paddle.arange(cumulative_total + pre_i, cumulative_total + pre_i + count_i)
)
cumulative_total += total_i

if gather_indices:
gather_indices = paddle.concat(gather_indices)
visual_embeds = visual_embeds[gather_indices]
else:
visual_embeds = visual_embeds[:0] # empty

# Flatten local mask to [B*S/tp] matching hidden_states batch-major layout
visual_pos_masks = local_mask.flatten()
elif visual_pos_masks.ndim > 1:
visual_pos_masks = visual_pos_masks.flatten()

# Flatten local mask to [B*S/tp] matching hidden_states batch-major layout
visual_pos_masks = local_mask.flatten()
elif visual_pos_masks.ndim > 1:
visual_pos_masks = visual_pos_masks.flatten()
update_indices = paddle.nonzero(visual_pos_masks)

# If TP is enabled, hidden_states has shape [..., Hidden_Dim / TP_Size],
# but visual_embeds usually has full [Hidden_Dim]. We need to slice visual_embeds column-wise.
Expand All @@ -341,16 +363,13 @@ def _deepstack_process(
visual_embeds = visual_embeds[:, start_col:end_col]

hidden_states = hidden_states.clone()
update_indices = paddle.nonzero(visual_pos_masks)
# Under SP, visual tokens are unevenly distributed across ranks. After row-slicing
# visual_pos_masks and visual_embeds to the local sequence chunk, some ranks may
# have zero visual tokens (local_visual_count == 0), producing visual_embeds with
# shape [0, H]. Guard against passing an empty updates tensor to scatter_nd_add,
# whose behavior is undefined / backend-dependent in that case.
# Under SP, visual tokens are unevenly distributed across ranks. Some ranks may
# have zero visual tokens, producing visual_embeds with shape [0, H].
# Guard against passing an empty updates tensor to scatter_nd_add.
if visual_embeds.shape[0] > 0:
hidden_states = paddle.scatter_nd_add(hidden_states, update_indices, visual_embeds)

# [Supplement 3] Restore original shape [B*S, D] -> [B, S, D] if necessary
# Restore original shape [B*S, D] -> [B, S, D] if necessary
if len(original_shape) > 2:
hidden_states = hidden_states.reshape(original_shape)
if _sp_transposed:
Expand Down
Loading