From 7f3175e953ddc7056c3d4b19585411789dba1af6 Mon Sep 17 00:00:00 2001 From: huangjiyi <947613776@qq.com> Date: Wed, 1 Apr 2026 17:02:08 +0800 Subject: [PATCH 1/2] Refactor rotary embedding dispatch in rope_utils --- .../models/common/embeddings/rope_utils.py | 298 +++++++++--------- 1 file changed, 157 insertions(+), 141 deletions(-) diff --git a/src/paddlefleet/models/common/embeddings/rope_utils.py b/src/paddlefleet/models/common/embeddings/rope_utils.py index 089471ea5..e59899945 100644 --- a/src/paddlefleet/models/common/embeddings/rope_utils.py +++ b/src/paddlefleet/models/common/embeddings/rope_utils.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from contextlib import nullcontext from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -27,6 +28,7 @@ fused_rotary_position_embedding as fused_rope, ) +from paddlefleet.ops import fused_apply_rotary_pos_emb_vision from paddlefleet.utils import get_pg_rank, get_pg_size logger = logging.getLogger(__name__) @@ -93,78 +95,72 @@ def get_unsqueeze_dim(t, freqs): return 2 if t.shape[1] == seq_len else 1 -def _apply_rotary_pos_emb_bshd_fp32( - t: Tensor, - t_pass: Tensor, - freqs: Tensor, - rotary_interleaved: bool = False, +def _apply_rotary_pos_emb_bshd( + t: Tensor | tuple[Tensor, ...], + freqs: Tensor | None, + config: TransformerConfig, + cos: Tensor | None = None, + sin: Tensor | None = None, mscale: float = 1.0, -) -> Tensor: + position_ids: Tensor | None = None, +) -> Tensor | tuple[Tensor, ...]: """Apply rotary positional embedding to input tensor T. check https://kexue.fm/archives/8265 for detailed formulas Args: - t (Tensor): Input tensor T is of shape [seq_length, ... , dim] - t_pass (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + t (Tensor | tuple[Tensor, ...]): Input tensor T is of shape [seq_length, ... , dim], + or a tuple of tensors (e.g. (query, key)) for the fused path. + freqs (Tensor | None): Rotary Positional embedding tensor freq is of shape + [seq_length, ..., dim]. Can be None when using fused path. + config (TransformerConfig): Transformer configuration providing + apply_rope_fusion, rotary_interleaved, multi_latent_attention, + high_precision_rope, rope_theta, sequence_parallel. + cos (Tensor | None): Pre-computed cosine values (for fused path). + sin (Tensor | None): Pre-computed sine values (for fused path). + mscale (float): Scaling factor for rotary embedding. + position_ids (Tensor | None): Position indices (for fused path). Returns: - Tensor: The input tensor after applying RoPE + Tensor | tuple[Tensor, ...]: The input tensor(s) after applying RoPE. """ + mscale = 1.0 if mscale is None else mscale - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - with paddle.amp.auto_cast(False): - orig_t_dtype = t.dtype - t = t.astype(dtype="float32") - rotate_t = _rotate_half(t, rotary_interleaved) - cos_ = (paddle.cos(freqs) * mscale).to(t.dtype) - sin_ = (paddle.sin(freqs) * mscale).to(t.dtype) - - if len(cos_.shape) < len(t.shape): - # [b,s,h]->[b,s,1,h] - unsqueeze_dim = get_unsqueeze_dim(t, cos_) - cos_.unsqueeze_(unsqueeze_dim) - sin_.unsqueeze_(unsqueeze_dim) - if len(rotate_t.shape) < len(t.shape): - rotate_t.reshape_(t.shape) - - t = (t * cos_) + (rotate_t * sin_) - skip_t_pass = t_pass.shape[-1] == 0 - if not skip_t_pass: - t_pass = t_pass.astype(dtype="float32") - res = paddle.cat((t, t_pass), axis=-1).astype(orig_t_dtype) + if config.apply_rope_fusion: + if config.high_precision_rope: + assert not config.rotary_interleaved and mscale == 1.0, ( + "fused_apply_rotary_pos_emb_vision only supports " + "non-interleaved mode and mscale=1.0" + ) + rot_dim = freqs.shape[-1] + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + if freqs.ndim == 3: + freqs_2d = freqs.reshape([-1, freqs.shape[-1]]) + else: + freqs_2d = freqs + freqs_half = freqs_2d[..., : freqs_2d.shape[-1] // 2] + t = fused_apply_rotary_pos_emb_vision(t, freqs_half) + return paddle.cat((t, t_pass), axis=-1) else: - res = t.astype(orig_t_dtype) - - return res - - -def _apply_rotary_pos_emb_bshd( - t: Tensor, - freqs: Tensor, - rotary_interleaved: bool = False, - multi_latent_attention: bool = False, - mscale: float = 1.0, - high_precision_rope: bool = False, -) -> Tensor: - """Apply rotary positional embedding to input tensor T. - - check https://kexue.fm/archives/8265 for detailed formulas - - Args: - t (Tensor): Input tensor T is of shape [seq_length, ... , dim] - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + assert isinstance(t, tuple), ( + "The input for fused_rope should be a tuple of tensors" + ) + return fused_rope( + *t, + sin=sin, + cos=cos, + rotary_emb_base=config.rope_theta, + position_ids=position_ids, + use_neox_rotary_style=config.rotary_interleaved, + time_major=config.sequence_parallel, + ) - Returns: - Tensor: The input tensor after applying RoPE - """ + # Unfused path rot_dim = freqs.shape[-1] - # For M-RoPE with sequence parallel, freqs may be [S, B, D] while t is [B, S, H, D]. - # When the first two dims are swapped (same product but different order), transpose - # freqs to align with t's [batch, seq] layout. A plain reshape would silently - # reinterpret the memory without reordering data, giving wrong results for B > 1. + # For M-RoPE with sequence parallel, freqs may be [S, B, D] while t is + # [B, S, H, D]. When the first two dims are swapped (same product but + # different order), transpose freqs to align with t's [batch, seq] layout. if freqs.ndim == 3: t_d0, t_d1 = t.shape[0], t.shape[1] f_d0, f_d1 = freqs.shape[0], freqs.shape[1] @@ -174,26 +170,44 @@ def _apply_rotary_pos_emb_bshd( # ideally t_pass is empty so rotary pos embedding is applied to all tensor t t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - if high_precision_rope: - return _apply_rotary_pos_emb_bshd_fp32( - t, - t_pass=t_pass, - freqs=freqs, - rotary_interleaved=rotary_interleaved, - mscale=mscale, - ) - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - cos_ = (paddle.cos(freqs) * mscale).to(t.dtype) - sin_ = (paddle.sin(freqs) * mscale).to(t.dtype) - if len(cos_.shape) < len(t.shape): - # [b,s,h]->[b,s,1,h] - unsqueeze_dim = get_unsqueeze_dim(t, cos_) - cos_.unsqueeze_(unsqueeze_dim) - sin_.unsqueeze_(unsqueeze_dim) + if config.multi_latent_attention: + x1 = t[..., 0::2] + x2 = t[..., 1::2] + t = paddle.cat((x1, x2), axis=-1) + + # For high_precision_rope, cast to float32 and disable auto_cast to ensure + # numerical stability in the rotary computation. + orig_t_dtype = t.dtype + ctx = ( + paddle.amp.auto_cast(False) + if config.high_precision_rope + else nullcontext() + ) + with ctx: + if config.high_precision_rope: + t = t.astype(dtype="float32") + t_pass = t_pass.astype(dtype="float32") - t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_) - return paddle.cat((t, t_pass), axis=-1) + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + cos_ = (paddle.cos(freqs) * mscale).to(t.dtype) + sin_ = (paddle.sin(freqs) * mscale).to(t.dtype) + if len(cos_.shape) < len(t.shape): + # [b,s,h]->[b,s,1,h] + unsqueeze_dim = get_unsqueeze_dim(t, cos_) + cos_.unsqueeze_(unsqueeze_dim) + sin_.unsqueeze_(unsqueeze_dim) + + rotate_t = _rotate_half(t, config.rotary_interleaved) + if len(rotate_t.shape) < len(t.shape): + rotate_t.reshape_(t.shape) + + t = (t * cos_) + (rotate_t * sin_) + result = paddle.cat((t, t_pass), axis=-1) + + if config.high_precision_rope: + result = result.astype(orig_t_dtype) + return result def _get_thd_freqs_on_this_cp_rank( @@ -248,31 +262,48 @@ def _get_thd_freqs_on_this_cp_rank( def _apply_rotary_pos_emb_thd( - t: Tensor, + t: Tensor | tuple[Tensor, ...], cu_seqlens: Tensor, total_seq_len: int | None, - freqs: Tensor, - rotary_interleaved: bool = False, - multi_latent_attention: bool = False, + freqs: Tensor | None, + config: TransformerConfig, + cos: Tensor | None = None, + sin: Tensor | None = None, mscale: float = 1.0, cp_group: Group = None, - high_precision_rope: bool = False, -) -> Tensor: + position_ids: Tensor | None = None, +) -> Tensor | tuple[Tensor, ...]: """A baseline implementation of applying RoPE for `thd` format. Args: - t (Tensor): Input tensor T is of shape [t, h, d] + t (Tensor | tuple[Tensor, ...]): Input tensor T is of shape [t, h, d], + or a tuple of tensors (e.g. (query, key)) for the fused path. cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, - with shape [b + 1] and dtype paddle.int32. - total_seq_len (int | None): The actual total sequence length before padding. - When cu_seqlens uses a padded version, this provides the true total length - for correct frequency tensor selection. If None, falls back to cu_seqlens[-1]. - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] - cp_group (Group): The context parallel group + with shape [b + 1] and dtype paddle.int32. + total_seq_len (int | None): The actual total sequence length before + padding. When cu_seqlens uses a padded version, this provides the + true total length for correct frequency tensor selection. If None, + falls back to cu_seqlens[-1]. + freqs (Tensor | None): Rotary Positional embedding tensor freq is of shape + [max_s, 1, 1, d]. Can be None when using fused path. + config (TransformerConfig): Transformer configuration providing + apply_rope_fusion, rotary_interleaved, multi_latent_attention, + high_precision_rope, rope_theta, sequence_parallel. + cos (Tensor | None): Pre-computed cosine values (for fused path). + sin (Tensor | None): Pre-computed sine values (for fused path). + mscale (float): Scaling factor for rotary embedding. + cp_group (Group): The context parallel group. + position_ids (Tensor | None): Position indices (for fused path). Returns: - Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + Tensor | tuple[Tensor, ...]: Shape [t, h, d]. The input tensor(s) after + applying RoPE. """ + if config.apply_rope_fusion: + raise NotImplementedError( + "cu_seqlens is not supported when using fused_rope" + ) + cp_size = get_pg_size(cp_group) cp_rank = get_pg_rank(cp_group) @@ -293,10 +324,8 @@ def _apply_rotary_pos_emb_thd( return _apply_rotary_pos_emb_bshd( t, freqs, - rotary_interleaved=rotary_interleaved, - multi_latent_attention=multi_latent_attention, + config, mscale=mscale, - high_precision_rope=high_precision_rope, ) seqlens = ((cu_seqlens[1:] - cu_seqlens[:-1]) // cp_size).tolist() # Build packed freqs in one pass, then apply once to the whole packed tensor @@ -313,14 +342,12 @@ def _apply_rotary_pos_emb_thd( ) freqs_packed = paddle.cat(freq_slices, axis=1) - # [b,seq,num_heads,head_dim] + # [seq,bs,num_heads,head_dim] return _apply_rotary_pos_emb_bshd( t, freqs_packed, - rotary_interleaved=rotary_interleaved, - multi_latent_attention=multi_latent_attention, + config, mscale=mscale, - high_precision_rope=high_precision_rope, ) else: # CASE 2: Traditional mapping without offsets @@ -338,16 +365,14 @@ def _apply_rotary_pos_emb_thd( return _apply_rotary_pos_emb_bshd( t, freqs_packed, - rotary_interleaved=rotary_interleaved, - multi_latent_attention=multi_latent_attention, + config, mscale=mscale, - high_precision_rope=high_precision_rope, ) def apply_rotary_pos_emb( - t: Tensor, - freqs: Tensor, + t: Tensor | tuple[Tensor, ...], + freqs: Tensor | None, cos: Tensor | None, sin: Tensor | None, config: TransformerConfig, @@ -359,50 +384,40 @@ def apply_rotary_pos_emb( ): """ Reroute to the appropriate apply_rotary_pos_emb function depending on - fused/unfused kernels, or bshd (conventional) / thd (packed seq) format + bshd (conventional) / thd (packed seq) format. + + The fused/unfused decision is handled internally by each format-specific + function based on config.apply_rope_fusion. Args: - t (Tensor): Input tensor - freqs (Tensor): Rotary positional embedding frequencies - cos (Tensor | None): Pre-computed cosine values of freqs (used for fused implementation) - sin (Tensor | None): Pre-computed sine values of freqs (used for fused implementation) - config (TransformerConfig): Transformer configuration - cu_seqlens (Tensor | None): Cumulative sequence lengths - total_seq_len (int | None): The actual total sequence length before padding. - Used in thd format to correctly select frequency tensor when cu_seqlens - is padded. If None, falls back to cu_seqlens[-1]. - mscale (float): Scaling factor - cp_group (Group): Context parallel group + t (Tensor | tuple[Tensor, ...]): Input tensor, or a tuple of tensors + (e.g. (query, key)) for the fused path. + freqs (Tensor | None): Rotary positional embedding frequencies. + Can be None when using fused path. + cos (Tensor | None): Pre-computed cosine values of freqs (used for + fused implementation). + sin (Tensor | None): Pre-computed sine values of freqs (used for + fused implementation). + config (TransformerConfig): Transformer configuration. + cu_seqlens (Tensor | None): Cumulative sequence lengths. + total_seq_len (int | None): The actual total sequence length before + padding. Used in thd format to correctly select frequency tensor + when cu_seqlens is padded. If None, falls back to cu_seqlens[-1]. + mscale (float): Scaling factor. + cp_group (Group): Context parallel group. + position_ids (Tensor | None): Position indices. """ - if config.apply_rope_fusion: - # Paddle fused_rope not support cu_seqlens or cp_group - if cu_seqlens: - raise NotImplementedError( - "cu_seqlens not be supported when using fused_rope" - ) - else: - assert isinstance(t, tuple), ( - "The input for fused_rope should be a tuple of tensors" - ) - return fused_rope( - *t, - sin=sin, - cos=cos, - rotary_emb_base=config.rope_theta, - position_ids=position_ids, - use_neox_rotary_style=config.rotary_interleaved, - time_major=config.sequence_parallel, - ) + mscale = 1.0 if mscale is None else mscale - # use unfused implementation if cu_seqlens is None: return _apply_rotary_pos_emb_bshd( t, freqs, - rotary_interleaved=config.rotary_interleaved, - multi_latent_attention=config.multi_latent_attention, + config, + cos=cos, + sin=sin, mscale=mscale, - high_precision_rope=config.high_precision_rope, + position_ids=position_ids, ) else: return _apply_rotary_pos_emb_thd( @@ -410,9 +425,10 @@ def apply_rotary_pos_emb( cu_seqlens, total_seq_len, freqs, - rotary_interleaved=config.rotary_interleaved, - multi_latent_attention=config.multi_latent_attention, + config, + cos=cos, + sin=sin, mscale=mscale, cp_group=cp_group, - high_precision_rope=config.high_precision_rope, + position_ids=position_ids, ) From 7166f7b400ad6acfb91a19ad5c9a782dae351699 Mon Sep 17 00:00:00 2001 From: huangjiyi <947613776@qq.com> Date: Wed, 1 Apr 2026 17:17:09 +0800 Subject: [PATCH 2/2] Adjust rotary embedding dispatch guards --- .../models/common/embeddings/rope_utils.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/paddlefleet/models/common/embeddings/rope_utils.py b/src/paddlefleet/models/common/embeddings/rope_utils.py index e59899945..988623122 100644 --- a/src/paddlefleet/models/common/embeddings/rope_utils.py +++ b/src/paddlefleet/models/common/embeddings/rope_utils.py @@ -128,12 +128,24 @@ def _apply_rotary_pos_emb_bshd( if config.apply_rope_fusion: if config.high_precision_rope: + rot_dim = freqs.shape[-1] + + # For M-RoPE with sequence parallel, freqs may be [S, B, D] while + # t is [B, S, H, D]. Align freqs before splitting rotary/pass dims. + if freqs.ndim == 3: + t_d0, t_d1 = t.shape[0], t.shape[1] + f_d0, f_d1 = freqs.shape[0], freqs.shape[1] + if ( + t_d0 != f_d0 or t_d1 != f_d1 + ) and t_d0 * t_d1 == f_d0 * f_d1: + freqs = freqs.transpose([1, 0, 2]).contiguous() + + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] assert not config.rotary_interleaved and mscale == 1.0, ( "fused_apply_rotary_pos_emb_vision only supports " "non-interleaved mode and mscale=1.0" ) - rot_dim = freqs.shape[-1] - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] if freqs.ndim == 3: freqs_2d = freqs.reshape([-1, freqs.shape[-1]]) else: @@ -155,7 +167,6 @@ def _apply_rotary_pos_emb_bshd( time_major=config.sequence_parallel, ) - # Unfused path rot_dim = freqs.shape[-1] # For M-RoPE with sequence parallel, freqs may be [S, B, D] while t is