Skip to content

Refactor rotary embedding dispatch in rope_utils#713

Open
huangjiyi wants to merge 2 commits intoPaddlePaddle:developfrom
huangjiyi:codex/refactor-rope-utils-dispatch
Open

Refactor rotary embedding dispatch in rope_utils#713
huangjiyi wants to merge 2 commits intoPaddlePaddle:developfrom
huangjiyi:codex/refactor-rope-utils-dispatch

Conversation

@huangjiyi
Copy link
Copy Markdown
Member

Summary

  • refactor rope_utils.py so the bshd/thd helpers take TransformerConfig directly and route fused vs unfused behavior in one place
  • keep the high-precision path working by restoring the fused vision kernel import, preserving the M-RoPE sequence-parallel transpose, and normalizing mscale=None to 1.0
  • restore total_seq_len compatibility for packed-sequence callers and raise an explicit error for unsupported fused+cu_seqlens usage instead of failing later with opaque runtime errors

Why

The refactor of apply_rotary_pos_emb introduced compatibility gaps in this shared helper:

  • attention.py still passes total_seq_len= for packed-sequence calls
  • the high-precision fused branch referenced a missing symbol/import
  • the 3D frequency transpose for sequence-parallel M-RoPE was accidentally dropped

These make the helper easy to break at runtime even though the file still compiles.

Impact

This keeps the refactor in rope_utils.py while preserving existing packed-sequence callers and making unsupported fused packed-sequence usage fail fast with a clear error.

Validation

  • python -m py_compile src/paddlefleet/models/common/embeddings/rope_utils.py src/paddlefleet/transformer/attention.py src/paddlefleet/transformer/multi_latent_attention.py
  • local pre-commit hooks triggered by git commit

@huangjiyi huangjiyi marked this pull request as ready for review April 1, 2026 09:04
Copilot AI review requested due to automatic review settings April 1, 2026 09:04
@huangjiyi huangjiyi changed the title [codex] Refactor rotary embedding dispatch in rope_utils Refactor rotary embedding dispatch in rope_utils Apr 1, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在重构 rope_utils.py 中 RoPE(rotary embedding)的分发逻辑:让 bshd/thd 路径统一接收 TransformerConfig,并把 fused/unfused 的路由集中到更少的入口处,同时补回 packed-seq(total_seq_len)兼容性与高精度 RoPE 的相关路径。

Changes:

  • 重构 _apply_rotary_pos_emb_bshd / _apply_rotary_pos_emb_thd 以接收 TransformerConfig,并在内部处理 apply_rope_fusion 分支。
  • 恢复/补强 packed-seq 场景:保留 total_seq_len 逻辑,并对 fused + cu_seqlens 进行显式报错以 fail-fast。
  • 尝试恢复高精度 RoPE fused 路径(引入 vision RoPE fused op),并统一 mscale=None 归一化为 1.0

fused_rotary_position_embedding as fused_rope,
)

from paddlefleet.ops import fused_apply_rotary_pos_emb_vision
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对 fused_apply_rotary_pos_emb_vision 进行了模块级别的硬导入;当自定义 op 未编译/在 CPU 或不满足环境要求时,from paddlefleet.ops import ... 会直接 ImportError,导致仅导入本模块就失败。建议改为在 high_precision_rope + apply_rope_fusion 分支内做惰性导入,并在不可用时抛出带提示的异常(或提供安全回退)。

Suggested change
from paddlefleet.ops import fused_apply_rotary_pos_emb_vision
try:
from paddlefleet.ops import fused_apply_rotary_pos_emb_vision
except ImportError: # 自定义 op 可能未编译或在当前设备上不可用
def fused_apply_rotary_pos_emb_vision(*args, **kwargs):
raise RuntimeError(
"fused_apply_rotary_pos_emb_vision is not available. "
"This fused RoPE vision op requires the corresponding "
"paddlefleet custom operator to be compiled and installed, "
"and may only be supported on specific hardware backends."
)

Copilot uses AI. Check for mistakes.
Comment on lines +135 to +143
rot_dim = freqs.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
if freqs.ndim == 3:
freqs_2d = freqs.reshape([-1, freqs.shape[-1]])
else:
freqs_2d = freqs
freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2]
t = fused_apply_rotary_pos_emb_vision(t, freqs_half)
return paddle.cat((t, t_pass), axis=-1)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_rope_fusionhigh_precision_rope 时直接访问 freqs.shape 并对 t 做张量切片,但当前接口文档允许 freqs=Nonet 也可能是 (query, key) 元组(如注意力里 fused 调用)。这会在高精度+融合配置下产生确定性的运行时崩溃。建议在该分支显式校验 freqs is not Noneisinstance(t, Tensor),或支持对元组中的每个 tensor 分别应用 vision RoPE。

Suggested change
rot_dim = freqs.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
if freqs.ndim == 3:
freqs_2d = freqs.reshape([-1, freqs.shape[-1]])
else:
freqs_2d = freqs
freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2]
t = fused_apply_rotary_pos_emb_vision(t, freqs_half)
return paddle.cat((t, t_pass), axis=-1)
if freqs is None:
raise ValueError(
"freqs must not be None when using high_precision_rope "
"with apply_rope_fusion enabled."
)
def _apply_vision_rope_to_single_tensor(
t_tensor: Tensor, freqs_tensor: Tensor
) -> Tensor:
"""Apply high-precision vision RoPE to a single tensor."""
rot_dim = freqs_tensor.shape[-1]
t_rot, t_pass = t_tensor[..., :rot_dim], t_tensor[..., rot_dim:]
if freqs_tensor.ndim == 3:
freqs_2d = freqs_tensor.reshape([-1, freqs_tensor.shape[-1]])
else:
freqs_2d = freqs_tensor
freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2]
t_rot = fused_apply_rotary_pos_emb_vision(t_rot, freqs_half)
return paddle.cat((t_rot, t_pass), axis=-1)
if isinstance(t, Tensor):
# Single tensor input
return _apply_vision_rope_to_single_tensor(t, freqs)
elif isinstance(t, tuple):
# Tuple of tensors input, e.g., (query, key)
return tuple(
_apply_vision_rope_to_single_tensor(t_elem, freqs)
for t_elem in t
)
else:
raise TypeError(
"Unsupported type for `t` in high_precision_rope fused path: "
f"{type(t)!r}. Expected `Tensor` or a tuple of `Tensor`."
)

Copilot uses AI. Check for mistakes.
Comment on lines +135 to +142
rot_dim = freqs.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
if freqs.ndim == 3:
freqs_2d = freqs.reshape([-1, freqs.shape[-1]])
else:
freqs_2d = freqs
freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2]
t = fused_apply_rotary_pos_emb_vision(t, freqs_half)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high_precision_rope 的 fused vision 路径在 freqs.ndim == 3 时直接 reshape([-1, D]),但对 M-RoPE + sequence-parallel 常见的 freqs=[S,B,D] 情况需要先转置到 [B,S,D] 才能与 t=[B,S,H,D] 对齐;否则会在 B>1 时静默算错。建议复用下方 unfused 分支的 transpose 对齐逻辑后再 reshape。

Copilot uses AI. Check for mistakes.
Comment on lines +131 to +134
assert not config.rotary_interleaved and mscale == 1.0, (
"fused_apply_rotary_pos_emb_vision only supports "
"non-interleaved mode and mscale=1.0"
)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 assert 来约束运行时不支持的配置(interleaved/mscale);在 Python -O 下断言会被移除,可能导致进入不支持的内核路径并产生错误结果。建议改为显式抛出 ValueError/NotImplementedError,并在消息里说明受限条件。

Copilot uses AI. Check for mistakes.
Comment on lines +98 to +106
def _apply_rotary_pos_emb_bshd(
t: Tensor | tuple[Tensor, ...],
freqs: Tensor | None,
config: TransformerConfig,
cos: Tensor | None = None,
sin: Tensor | None = None,
mscale: float = 1.0,
) -> Tensor:
position_ids: Tensor | None = None,
) -> Tensor | tuple[Tensor, ...]:
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_apply_rotary_pos_emb_bshd 的签名改为必须传入 config,但仓库内已有单测/内部调用直接以旧签名 (_apply_rotary_pos_emb_bshd(t, freqs)) 使用该私有函数,会导致 TypeError。建议要么同步更新所有调用点(含 tests),要么在此处为 config 提供兼容默认值/适配层以避免破坏现有用法。

Copilot uses AI. Check for mistakes.
Returns:
Tensor: The input tensor after applying RoPE
"""
# Unfused path
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfused 路径无条件使用 freqs.shape[-1];但当前类型标注允许 freqs: Tensor | None,如果上层在 apply_rope_fusion=False 时仍传了 freqs=None 会触发 AttributeError。建议在进入 unfused 路径时显式检查 freqs is not None 并抛出清晰错误,避免后续报错不直观。

Suggested change
# Unfused path
# Unfused path
if freqs is None:
raise ValueError(
"freqs must be provided for the unfused RoPE path "
"(config.apply_rope_fusion is False)."
)

Copilot uses AI. Check for mistakes.
Comment on lines 264 to +269
def _apply_rotary_pos_emb_thd(
t: Tensor,
t: Tensor | tuple[Tensor, ...],
cu_seqlens: Tensor,
total_seq_len: int | None,
freqs: Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
freqs: Tensor | None,
config: TransformerConfig,
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_apply_rotary_pos_emb_thd 当前标注/文档允许 t 为 tuple、freqs 为 None,但本函数在 unfused 逻辑里会访问 t.ndim / freqs.dim();若调用方误传会直接 AttributeError。建议在函数开头增加明确的输入约束(例如要求 isinstance(t, Tensor)freqs is not None),并给出可操作的错误信息。

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants