Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,26 +1282,28 @@ def forward(
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
if (prefill_compiled_stage or prefill_noncompiled_stage) or rope_deltas is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it enough to move above the line self.rope_deltas = rope_deltas?

position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None:
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
delta = (cache_position[0] + rope_deltas).to(inputs_embeds.device)
else:
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
position_ids = position_ids + delta.to(position_ids.device)

if rope_deltas is not None:
self.rope_deltas = rope_deltas

outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
Expand All @@ -1321,7 +1323,7 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
rope_deltas=rope_deltas,
)
return output if return_dict else output.to_tuple()

Expand Down Expand Up @@ -1479,6 +1481,7 @@ def forward(
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1561,7 +1564,7 @@ def prepare_inputs_for_generation(
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.model.rope_deltas = rope_deltas
model_inputs["rope_deltas"] = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
elif "position_ids" in model_inputs:
batch_size, seq_length = model_inputs["position_ids"].shape
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,26 +606,28 @@ def forward(
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
if (prefill_compiled_stage or prefill_noncompiled_stage) or rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None:
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
delta = (cache_position[0] + rope_deltas).to(inputs_embeds.device)
else:
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
position_ids = position_ids + delta.to(position_ids.device)

if rope_deltas is not None:
self.rope_deltas = rope_deltas

outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
Expand All @@ -645,7 +647,7 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
rope_deltas=rope_deltas,
)
return output if return_dict else output.to_tuple()

Expand Down Expand Up @@ -735,6 +737,7 @@ def forward(
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -817,7 +820,7 @@ def prepare_inputs_for_generation(
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.model.rope_deltas = rope_deltas
model_inputs["rope_deltas"] = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
elif "position_ids" in model_inputs:
batch_size, seq_length = model_inputs["position_ids"].shape
Expand Down