Refactor rotary embedding dispatch in rope_utils#713
Refactor rotary embedding dispatch in rope_utils#713huangjiyi wants to merge 2 commits intoPaddlePaddle:developfrom
Conversation
There was a problem hiding this comment.
Pull request overview
该 PR 旨在重构 rope_utils.py 中 RoPE(rotary embedding)的分发逻辑:让 bshd/thd 路径统一接收 TransformerConfig,并把 fused/unfused 的路由集中到更少的入口处,同时补回 packed-seq(total_seq_len)兼容性与高精度 RoPE 的相关路径。
Changes:
- 重构
_apply_rotary_pos_emb_bshd/_apply_rotary_pos_emb_thd以接收TransformerConfig,并在内部处理apply_rope_fusion分支。 - 恢复/补强 packed-seq 场景:保留
total_seq_len逻辑,并对 fused +cu_seqlens进行显式报错以 fail-fast。 - 尝试恢复高精度 RoPE fused 路径(引入 vision RoPE fused op),并统一
mscale=None归一化为1.0。
| fused_rotary_position_embedding as fused_rope, | ||
| ) | ||
|
|
||
| from paddlefleet.ops import fused_apply_rotary_pos_emb_vision |
There was a problem hiding this comment.
这里对 fused_apply_rotary_pos_emb_vision 进行了模块级别的硬导入;当自定义 op 未编译/在 CPU 或不满足环境要求时,from paddlefleet.ops import ... 会直接 ImportError,导致仅导入本模块就失败。建议改为在 high_precision_rope + apply_rope_fusion 分支内做惰性导入,并在不可用时抛出带提示的异常(或提供安全回退)。
| from paddlefleet.ops import fused_apply_rotary_pos_emb_vision | |
| try: | |
| from paddlefleet.ops import fused_apply_rotary_pos_emb_vision | |
| except ImportError: # 自定义 op 可能未编译或在当前设备上不可用 | |
| def fused_apply_rotary_pos_emb_vision(*args, **kwargs): | |
| raise RuntimeError( | |
| "fused_apply_rotary_pos_emb_vision is not available. " | |
| "This fused RoPE vision op requires the corresponding " | |
| "paddlefleet custom operator to be compiled and installed, " | |
| "and may only be supported on specific hardware backends." | |
| ) |
| rot_dim = freqs.shape[-1] | ||
| t, t_pass = t[..., :rot_dim], t[..., rot_dim:] | ||
| if freqs.ndim == 3: | ||
| freqs_2d = freqs.reshape([-1, freqs.shape[-1]]) | ||
| else: | ||
| freqs_2d = freqs | ||
| freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2] | ||
| t = fused_apply_rotary_pos_emb_vision(t, freqs_half) | ||
| return paddle.cat((t, t_pass), axis=-1) |
There was a problem hiding this comment.
apply_rope_fusion 且 high_precision_rope 时直接访问 freqs.shape 并对 t 做张量切片,但当前接口文档允许 freqs=None、t 也可能是 (query, key) 元组(如注意力里 fused 调用)。这会在高精度+融合配置下产生确定性的运行时崩溃。建议在该分支显式校验 freqs is not None 且 isinstance(t, Tensor),或支持对元组中的每个 tensor 分别应用 vision RoPE。
| rot_dim = freqs.shape[-1] | |
| t, t_pass = t[..., :rot_dim], t[..., rot_dim:] | |
| if freqs.ndim == 3: | |
| freqs_2d = freqs.reshape([-1, freqs.shape[-1]]) | |
| else: | |
| freqs_2d = freqs | |
| freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2] | |
| t = fused_apply_rotary_pos_emb_vision(t, freqs_half) | |
| return paddle.cat((t, t_pass), axis=-1) | |
| if freqs is None: | |
| raise ValueError( | |
| "freqs must not be None when using high_precision_rope " | |
| "with apply_rope_fusion enabled." | |
| ) | |
| def _apply_vision_rope_to_single_tensor( | |
| t_tensor: Tensor, freqs_tensor: Tensor | |
| ) -> Tensor: | |
| """Apply high-precision vision RoPE to a single tensor.""" | |
| rot_dim = freqs_tensor.shape[-1] | |
| t_rot, t_pass = t_tensor[..., :rot_dim], t_tensor[..., rot_dim:] | |
| if freqs_tensor.ndim == 3: | |
| freqs_2d = freqs_tensor.reshape([-1, freqs_tensor.shape[-1]]) | |
| else: | |
| freqs_2d = freqs_tensor | |
| freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2] | |
| t_rot = fused_apply_rotary_pos_emb_vision(t_rot, freqs_half) | |
| return paddle.cat((t_rot, t_pass), axis=-1) | |
| if isinstance(t, Tensor): | |
| # Single tensor input | |
| return _apply_vision_rope_to_single_tensor(t, freqs) | |
| elif isinstance(t, tuple): | |
| # Tuple of tensors input, e.g., (query, key) | |
| return tuple( | |
| _apply_vision_rope_to_single_tensor(t_elem, freqs) | |
| for t_elem in t | |
| ) | |
| else: | |
| raise TypeError( | |
| "Unsupported type for `t` in high_precision_rope fused path: " | |
| f"{type(t)!r}. Expected `Tensor` or a tuple of `Tensor`." | |
| ) |
| rot_dim = freqs.shape[-1] | ||
| t, t_pass = t[..., :rot_dim], t[..., rot_dim:] | ||
| if freqs.ndim == 3: | ||
| freqs_2d = freqs.reshape([-1, freqs.shape[-1]]) | ||
| else: | ||
| freqs_2d = freqs | ||
| freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2] | ||
| t = fused_apply_rotary_pos_emb_vision(t, freqs_half) |
There was a problem hiding this comment.
high_precision_rope 的 fused vision 路径在 freqs.ndim == 3 时直接 reshape([-1, D]),但对 M-RoPE + sequence-parallel 常见的 freqs=[S,B,D] 情况需要先转置到 [B,S,D] 才能与 t=[B,S,H,D] 对齐;否则会在 B>1 时静默算错。建议复用下方 unfused 分支的 transpose 对齐逻辑后再 reshape。
| assert not config.rotary_interleaved and mscale == 1.0, ( | ||
| "fused_apply_rotary_pos_emb_vision only supports " | ||
| "non-interleaved mode and mscale=1.0" | ||
| ) |
There was a problem hiding this comment.
这里用 assert 来约束运行时不支持的配置(interleaved/mscale);在 Python -O 下断言会被移除,可能导致进入不支持的内核路径并产生错误结果。建议改为显式抛出 ValueError/NotImplementedError,并在消息里说明受限条件。
| def _apply_rotary_pos_emb_bshd( | ||
| t: Tensor | tuple[Tensor, ...], | ||
| freqs: Tensor | None, | ||
| config: TransformerConfig, | ||
| cos: Tensor | None = None, | ||
| sin: Tensor | None = None, | ||
| mscale: float = 1.0, | ||
| ) -> Tensor: | ||
| position_ids: Tensor | None = None, | ||
| ) -> Tensor | tuple[Tensor, ...]: |
There was a problem hiding this comment.
_apply_rotary_pos_emb_bshd 的签名改为必须传入 config,但仓库内已有单测/内部调用直接以旧签名 (_apply_rotary_pos_emb_bshd(t, freqs)) 使用该私有函数,会导致 TypeError。建议要么同步更新所有调用点(含 tests),要么在此处为 config 提供兼容默认值/适配层以避免破坏现有用法。
| Returns: | ||
| Tensor: The input tensor after applying RoPE | ||
| """ | ||
| # Unfused path |
There was a problem hiding this comment.
Unfused 路径无条件使用 freqs.shape[-1];但当前类型标注允许 freqs: Tensor | None,如果上层在 apply_rope_fusion=False 时仍传了 freqs=None 会触发 AttributeError。建议在进入 unfused 路径时显式检查 freqs is not None 并抛出清晰错误,避免后续报错不直观。
| # Unfused path | |
| # Unfused path | |
| if freqs is None: | |
| raise ValueError( | |
| "freqs must be provided for the unfused RoPE path " | |
| "(config.apply_rope_fusion is False)." | |
| ) |
| def _apply_rotary_pos_emb_thd( | ||
| t: Tensor, | ||
| t: Tensor | tuple[Tensor, ...], | ||
| cu_seqlens: Tensor, | ||
| total_seq_len: int | None, | ||
| freqs: Tensor, | ||
| rotary_interleaved: bool = False, | ||
| multi_latent_attention: bool = False, | ||
| freqs: Tensor | None, | ||
| config: TransformerConfig, |
There was a problem hiding this comment.
_apply_rotary_pos_emb_thd 当前标注/文档允许 t 为 tuple、freqs 为 None,但本函数在 unfused 逻辑里会访问 t.ndim / freqs.dim();若调用方误传会直接 AttributeError。建议在函数开头增加明确的输入约束(例如要求 isinstance(t, Tensor) 且 freqs is not None),并给出可操作的错误信息。
Summary
rope_utils.pyso the bshd/thd helpers takeTransformerConfigdirectly and route fused vs unfused behavior in one placemscale=Noneto1.0total_seq_lencompatibility for packed-sequence callers and raise an explicit error for unsupported fused+cu_seqlensusage instead of failing later with opaque runtime errorsWhy
The refactor of
apply_rotary_pos_embintroduced compatibility gaps in this shared helper:attention.pystill passestotal_seq_len=for packed-sequence callsThese make the helper easy to break at runtime even though the file still compiles.
Impact
This keeps the refactor in
rope_utils.pywhile preserving existing packed-sequence callers and making unsupported fused packed-sequence usage fail fast with a clear error.Validation
python -m py_compile src/paddlefleet/models/common/embeddings/rope_utils.py src/paddlefleet/transformer/attention.py src/paddlefleet/transformer/multi_latent_attention.pygit commit