From bc968f18711d3b3774c52ee757dfde27c2a2938c Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 11 Mar 2026 11:33:38 +0800 Subject: [PATCH 01/12] support dsa --- src/paddlefleet/models/gpt/gpt_layer_specs.py | 13 +- src/paddlefleet/transformer/dsa_attention.py | 725 ++++++++++++++++++ .../transformer/transformer_config.py | 20 + 3 files changed, 757 insertions(+), 1 deletion(-) create mode 100644 src/paddlefleet/transformer/dsa_attention.py diff --git a/src/paddlefleet/models/gpt/gpt_layer_specs.py b/src/paddlefleet/models/gpt/gpt_layer_specs.py index 21ab0aef8..7d4a8ba1c 100644 --- a/src/paddlefleet/models/gpt/gpt_layer_specs.py +++ b/src/paddlefleet/models/gpt/gpt_layer_specs.py @@ -45,6 +45,9 @@ SelfAttention, SelfAttentionSublayersSpec, ) +from paddlefleet.transformer.dsa_attention import ( + MLASelfAttentionWithDSA, +) from paddlefleet.transformer.enums import AttnMaskType from paddlefleet.transformer.identity_op import IdentityOp from paddlefleet.transformer.mlp import MLP, MLPSublayersSpec @@ -124,12 +127,20 @@ def get_gpt_layer_local_spec( if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." + + # Decide attention class: DSA variant if index_n_heads is configured + use_dsa = ( + config is not None + and getattr(config, "index_n_heads", None) is not None + ) + attn_cls = MLASelfAttentionWithDSA if use_dsa else MLASelfAttention + return LayerSpec( layer=transformer_cls, sublayers_spec=TransformerLayerSublayersSpec( input_layernorm=layer_norm, self_attn=LayerSpec( - layer=MLASelfAttention, + layer=attn_cls, extra_kwargs={"attn_mask_type": AttnMaskType.causal}, sublayers_spec=MLASelfAttentionSublayersSpec( q_proj=backend.column_parallel_linear(), diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py new file mode 100644 index 000000000..e0496154a --- /dev/null +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -0,0 +1,725 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DeepSeek Sparse Attention (DSA) extension for Multi-Latent Attention. + +This module extends the upstream MLASelfAttention with DSA Indexer support +(DeepSeek V3.2 architecture): + - Indexer: Token scoring module that selects top-k relevant positions + - DSAIndexerLoss: KL-divergence loss for Indexer training + - DSAIndexerLossAutoScaler: Loss scaling helper + - MLASelfAttentionWithDSA: Subclass of MLASelfAttention with DSA integration + +Reference: Megatron-LM/megatron/core/transformer/experimental_attention_variant/dsa.py +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +import paddle.nn.functional as F +from paddle import Tensor +from paddle.distributed.fleet.utils import recompute + +from paddlefleet.models.common.embeddings.rope_utils import ( + _apply_rotary_pos_emb_bshd, +) +from paddlefleet.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, +) +from paddlefleet.transformer.enums import AttnMaskType +from paddlefleet.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSublayersSpec, +) + +if TYPE_CHECKING: + from paddlefleet.process_groups_config import ProcessGroupCollection + from paddlefleet.transformer.transformer_config import TransformerConfig + + +# --------------------------------------------------------------------------- +# Unfused DSA attention (explicit bmm, supports asymmetric Q/K vs V dims) +# --------------------------------------------------------------------------- + + +def _unfused_dsa_attention( + query: Tensor, + key: Tensor, + value: Tensor, + combined_mask: Tensor | None, + softmax_scale: float, +) -> Tensor: + """Unfused DSA sparse attention (matches Megatron-Core unfused_dsa_fn). + + Uses explicit bmm instead of flash attention to support: + - Different Q/K head_dim vs V head_dim (MLA architecture) + - Arbitrary per-token sparse masks from DSA Indexer + + Args: + query: [b, s, nhpp, qk_head_dim] + key: [b, s, nhpp, qk_head_dim] + value: [b, s, nhpp, v_head_dim] (v_head_dim may differ from qk_head_dim) + combined_mask: [b, 1, s, s] (causal + sparse index mask, -inf for masked) + softmax_scale: 1/sqrt(qk_head_dim) + + Returns: + output: [b, s, nhpp * v_head_dim] + """ + b, s, nhpp, qk_hd = query.shape + v_hd = value.shape[-1] + + # Reshape for bmm: [b*nhpp, s, hd] + q = query.transpose([0, 2, 1, 3]).reshape([b * nhpp, s, qk_hd]) + k = key.transpose([0, 2, 1, 3]).reshape([b * nhpp, s, qk_hd]) + v = value.transpose([0, 2, 1, 3]).reshape([b * nhpp, s, v_hd]) + + # Q * K^T with scale: [b*nhpp, s, s] + attn_scores = ( + paddle.bmm(q.cast("float32"), k.cast("float32").transpose([0, 2, 1])) + * softmax_scale + ) + + # Apply combined mask (causal + sparse index mask) + if combined_mask is not None: + mask = ( + combined_mask.expand([b, nhpp, s, s]) + .contiguous() + .reshape([b * nhpp, s, s]) + ) + attn_scores = attn_scores + mask.cast("float32") + + attn_weights = F.softmax(attn_scores, axis=-1) + + # Attention_weights * V: [b*nhpp, s, v_hd] + output = paddle.bmm(attn_weights.cast(v.dtype), v) + + # [b*nhpp, s, v_hd] -> [b, s, nhpp*v_hd] + output = ( + output.reshape([b, nhpp, s, v_hd]) + .transpose([0, 2, 1, 3]) + .reshape([b, s, nhpp * v_hd]) + ) + + return output + + +# --------------------------------------------------------------------------- +# DSA Indexer +# --------------------------------------------------------------------------- + + +class Indexer(paddle.nn.Layer): + """DSA Indexer: DeepSeek Sparse Attention token selection module. + + For each query token, scores all cached key positions using a lightweight + n_heads-head attention mechanism, then selects the top-k most relevant + positions for the full MLA attention computation. + + Key design notes: + - Uses non-interleaved RoPE (unlike MLA which uses interleaved) + - Uses LayerNorm (not RMSNorm) on K + - nope/pe split order: [nope | pe] + - Uses ReLU-aggregated scoring across heads + - Per-head learned importance weights via weights_proj + - weights absorbs softmax_scale + + Reference: Megatron-LM dsa.py DSAIndexer + """ + + def __init__(self, config: TransformerConfig, layer_number: int): + super().__init__() + self.config = config + + self.n_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = self.head_dim - self.rope_head_dim + self.index_topk = config.index_topk + self.softmax_scale = self.head_dim**-0.5 + self.layer_number = layer_number + + # wq_b: q_lora_rank -> n_heads * head_dim (duplicated) + self.wq_b = paddle.nn.Linear( + config.q_lora_rank, + self.n_heads * self.head_dim, + bias_attr=False, + ) + + # wk: hidden_size -> head_dim (single shared K, duplicated) + self.wk = paddle.nn.Linear( + config.hidden_size, + self.head_dim, + bias_attr=False, + ) + + # k_norm: LayerNorm (NOT RMSNorm) per reference + self.k_norm = paddle.nn.LayerNorm(self.head_dim, epsilon=1e-6) + + # weights_proj: learned per-head importance [hidden -> n_heads] + self.weights_proj = paddle.nn.Linear( + config.hidden_size, + self.n_heads, + bias_attr=False, + ) + + def _apply_rope( + self, x: Tensor, freqs: Tensor, mscale: float = 1.0 + ) -> Tensor: + """Apply non-interleaved RoPE to the pe portion of x. + + Split order: [nope | pe], matching Megatron-Core dsa.py _apply_rope. + + Args: + x: [..., head_dim] (nope_dim + rope_dim) + freqs: RoPE frequencies + mscale: YaRN concentration factor (1.0 for plain RoPE, ~1.37 for YaRN) + """ + x_nope = x[..., : self.nope_head_dim] + x_pe = x[..., self.nope_head_dim :] + x_pe = _apply_rotary_pos_emb_bshd( + x_pe, + freqs, + rotary_interleaved=False, + multi_latent_attention=self.config.multi_latent_attention, + mscale=mscale, + ) + return paddle.concat([x_nope, x_pe], axis=-1) + + def forward_before_topk( + self, + hidden_states: Tensor, # [b, s, hidden_size] + q_latent: Tensor, # [b, s, q_lora_rank] + freqs: Tensor, + mscale: float = 1.0, + ): + """Compute q, k, weights before top-k selection.""" + bsz, seqlen, _ = hidden_states.shape + + q = self.wq_b(q_latent) # [b, s, n_heads * head_dim] + q = q.reshape([bsz, seqlen, self.n_heads, self.head_dim]) + if freqs is not None: + q = self._apply_rope(q, freqs, mscale) + + k = self.wk(hidden_states) # [b, s, head_dim] + k = self.k_norm(k) + if freqs is not None: + k = self._apply_rope(k.unsqueeze(2), freqs, mscale).squeeze(2) + + weights = ( + self.weights_proj(hidden_states.cast("float32")) + * (self.n_heads**-0.5) + * self.softmax_scale + ) + + return q, k, weights + + def compute_index_scores( + self, + q: Tensor, # [b, s, n_heads, head_dim] + k: Tensor, # [b, t, head_dim] + weights: Tensor, # [b, s, n_heads] + mask: Tensor | None = None, + ): + """Compute index scores and select top-k.""" + q_fp32 = q.cast("float32") + k_fp32 = k.cast("float32") + + scores = paddle.einsum("bshd,btd->bsht", q_fp32, k_fp32) + index_scores = (weights.unsqueeze(-1) * F.relu(scores)).sum(axis=2) + + if mask is not None: + index_scores = index_scores + mask.squeeze(1) + + topk_k = min(self.index_topk, index_scores.shape[-1]) + topk_indices = paddle.topk(index_scores, k=topk_k, axis=-1)[1] + + return index_scores, topk_indices + + def forward( + self, + hidden_states: Tensor, + q_latent: Tensor, + freqs: Tensor, + attention_mask: Tensor, + mscale: float = 1.0, + ) -> tuple[Tensor, Tensor]: + """Compute DSA token importance scores and return scores + top-k indices.""" + q, k, weights = self.forward_before_topk( + hidden_states, q_latent, freqs, mscale + ) + index_scores, topk_indices = self.compute_index_scores( + q, k, weights, attention_mask + ) + return index_scores, topk_indices + + +# --------------------------------------------------------------------------- +# DSA Indexer Loss (PyLayer) +# --------------------------------------------------------------------------- + + +class DSAIndexerLoss(paddle.autograd.PyLayer): + """Fused DSA Indexer KL-divergence loss. + + Trains the Indexer to predict which tokens receive high attention weights. + Reference: Megatron-Core dsa.py FusedDSAIndexerLoss + """ + + @staticmethod + def forward( + ctx, + index_scores: Tensor, # [b, sq, sk] + topk_indices: Tensor, # [b, sq, topk] + query: Tensor, # [b, sq, nhpp, qk_head_dim] (DETACHED) + key: Tensor, # [b, sk, nhpp, qk_head_dim] (DETACHED) + mla_softmax_scale: float, + loss_coeff: float, + sparse_loss: bool, + tp_group, + ) -> Tensor: + b, sq, sk = index_scores.shape + nhpp = query.shape[2] + + q_f = query.cast("float32") + k_f = key.cast("float32") + attention_scores = ( + paddle.einsum("bshd,bthd->bhst", q_f, k_f) * mla_softmax_scale + ) + + causal_mask = paddle.triu( + paddle.full([sq, sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + attention_scores = attention_scores + causal_mask.unsqueeze([0, 1]) + + index_mask = paddle.full([b, sq, sk], float("-inf"), dtype="float32") + index_mask = paddle.put_along_axis( + index_mask, + topk_indices, + paddle.zeros_like(topk_indices, dtype="float32"), + axis=-1, + ) + + masked_index_scores = index_scores.cast( + "float32" + ) + causal_mask.unsqueeze(0) + if sparse_loss: + attention_scores = attention_scores + index_mask.unsqueeze(1) + masked_index_scores = masked_index_scores + index_mask + + attn_probs = F.softmax(attention_scores, axis=-1) + idx_probs = F.softmax(masked_index_scores, axis=-1) + + attn_probs_sum = attn_probs.sum(axis=1) + if tp_group is not None and tp_group.nranks > 1: + paddle.distributed.all_reduce(attn_probs_sum, group=tp_group) + + target = attn_probs_sum / ( + attn_probs_sum.sum(axis=-1, keepdim=True) + 1e-10 + ) + + kl = target * ( + paddle.log(target + 1e-10) - paddle.log(idx_probs + 1e-10) + ) + kl_div = kl.sum(axis=-1).mean() + indexer_loss = kl_div * loss_coeff + + ctx.save_for_backward( + target, idx_probs, index_mask if sparse_loss else None + ) + ctx.b = b + ctx.sq = sq + ctx.sparse_loss = sparse_loss + ctx.loss_coeff = loss_coeff + + return indexer_loss + + @staticmethod + def backward(ctx, grad_loss: Tensor): + target, idx_probs, index_mask = ctx.saved_tensor() + b, sq = ctx.b, ctx.sq + sparse_loss = ctx.sparse_loss + loss_coeff = ctx.loss_coeff + sk = target.shape[-1] + + grad_idx_probs = ( + -target + / (idx_probs + 1e-10) + * (grad_loss.cast("float32") * loss_coeff / (b * sq)) + ) + sum_grad = (grad_idx_probs * idx_probs).sum(axis=-1, keepdim=True) + grad_index_scores = idx_probs * (grad_idx_probs - sum_grad) + + causal_valid = paddle.tril(paddle.ones([sq, sk], dtype="bool")) + if sparse_loss and index_mask is not None: + valid_mask = causal_valid.unsqueeze(0) & (index_mask == 0) + else: + valid_mask = causal_valid.unsqueeze(0).expand([b, sq, sk]) + + grad_index_scores = grad_index_scores * valid_mask.cast("float32") + + # Gradients for Tensor inputs only (Paddle PyLayer convention): + # index_scores, topk_indices(None), query(None), key(None) + return grad_index_scores.cast(idx_probs.dtype), None, None, None + + +class DSAIndexerLossAutoScaler(paddle.autograd.PyLayer): + """Attaches indexer_loss to the backward graph without changing output value.""" + + _main_loss_backward_scale: Tensor | None = None + + @staticmethod + def forward(ctx, output: Tensor, indexer_loss: Tensor) -> Tensor: + ctx.save_for_backward(indexer_loss) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + (indexer_loss,) = ctx.saved_tensor() + scale = DSAIndexerLossAutoScaler._main_loss_backward_scale + if scale is None: + scale = paddle.ones([1], dtype=indexer_loss.dtype) + scaled_grad = paddle.ones_like(indexer_loss) * scale + return grad_output, scaled_grad + + @staticmethod + def set_loss_scale(scale: Tensor): + DSAIndexerLossAutoScaler._main_loss_backward_scale = scale + + +# --------------------------------------------------------------------------- +# MLASelfAttentionWithDSA — extends upstream MLASelfAttention +# --------------------------------------------------------------------------- + + +class MLASelfAttentionWithDSA(MLASelfAttention): + """MLA Self-attention with DeepSeek Sparse Attention (DSA) Indexer. + + Extends the upstream MLASelfAttention by: + 1. Reusing parent's get_query_key_value_tensors() for Q/K/V computation. + 2. Overriding forward() to compute Indexer inputs separately, run the DSA + Indexer, build a sparse mask, and use unfused bmm attention. + + The Indexer needs q_latent (normed q_compressed with full q_lora_rank) and + hidden_states in [b, s, h] format. Since the parent's get_query_key_value_tensors + doesn't expose these intermediates, we re-compute q_a_proj + norm for the + Indexer path. This is cheap because: + - q_a_proj is a small down-projection (hidden -> q_lora_rank) + - Indexer inputs are detached, so no extra backward cost + """ + + def __init__( + self, + config: TransformerConfig, + sublayers_spec: MLASelfAttentionSublayersSpec, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + cp_comm_type: str | None = None, + pg_collection: ProcessGroupCollection | None = None, + ): + super().__init__( + config=config, + sublayers_spec=sublayers_spec, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + cp_comm_type=cp_comm_type, + pg_collection=pg_collection, + ) + + # DSA Indexer + self.indexer = Indexer(config, layer_number) + + # DSA loss config + self.dsa_indexer_loss_coeff = getattr( + config, "indexer_loss_coeff", None + ) + self.dsa_indexer_use_sparse_loss = getattr( + config, "indexer_use_sparse_loss", False + ) + + def _compute_indexer_inputs(self, hidden_states): + """Compute Indexer's q_latent by re-running q_a_proj + norm. + + This is a lightweight re-computation of the q down-projection path. + The result is detached, so no gradient flows back through this path. + + Unlike the parent's get_query_key_value_tensors (which scatters back to + [s/tp, b, ...] for SP), the Indexer needs full-sequence tensors [s, b, ...]. + So we follow the parent's q_a_proj path but skip the scatter step. + + Args: + hidden_states: [s/tp, b, h] (SP) or [s, b, h] (non-SP) + + Returns: + indexer_hidden: [b, s, h] — detached hidden_states (full seq) for Indexer + indexer_q_latent: [b, s, q_lora_rank] — detached normed q_compressed (full seq) + """ + # Re-compute q_a_proj on original hidden_states (let it handle SP internally) + with paddle.no_grad(): + if self.config.q_lora_rank is not None: + # q_a_proj (ColumnParallelLinear) handles SP all-gather internally: + # SP: [s/tp, b, h] -> [s, b, q_lora_rank/tp] + # non-SP: [s, b, h] -> [s, b, q_lora_rank/tp] + q_compressed, _ = self.q_a_proj(hidden_states) + # Gather feature dim to full q_lora_rank if sharded + if q_compressed.shape[-1] != self.config.q_lora_rank: + q_compressed = gather_from_tensor_model_parallel_region( + q_compressed + ) + # Parent would scatter_to_sequence_parallel_region here, + # but Indexer needs full sequence — skip scatter. + q_latent = self.q_a_layernorm(q_compressed) + # q_latent: [s, b, q_lora_rank] (full seq, full feature) + else: + q_latent = hidden_states + # Need full seq for the else branch too + if self.config.sequence_parallel: + q_latent = gather_from_sequence_parallel_region( + q_latent, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ) + + # Gather hidden_states separately for indexer_hidden + if self.config.sequence_parallel: + indexer_hidden = gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ).detach() + else: + indexer_hidden = hidden_states.detach() + + # Convert [s, b, ...] -> [b, s, ...] + indexer_hidden = indexer_hidden.transpose([1, 0, 2]) + indexer_q_latent = q_latent.detach().transpose([1, 0, 2]) + + return indexer_hidden, indexer_q_latent + + def forward( + self, + hidden_states, + attention_mask, + attn_mask_startend_row_indices: Tensor | None = None, + key_value_states=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + packed_seq_params=None, + in_recompute: bool = False, + position_ids=None, + ): + """Forward: MLA projections + DSA Indexer + sparse attention. + + Overrides the parent forward to: + 1. Reuse parent's get_query_key_value_tensors for Q/K/V + 2. Compute Indexer inputs separately (lightweight re-computation) + 3. Run DSA Indexer to get top-k indices + 4. Build sparse mask and use unfused bmm attention + 5. Compute DSA Indexer KL loss if enabled + """ + assert rotary_pos_emb is None, ( + "Rotary position embeddings should not be passed into MLA." + ) + assert attention_bias is None + assert rotary_pos_cos is None and rotary_pos_sin is None + + # ===================== + # Query, Key, and Value (reuse parent's MLA implementation) + # ===================== + query, key, value = self.get_query_key_value_tensors( + hidden_states, + key_value_states, + position_ids, + packed_seq_params, + ) + + # ===================== + # DSA Indexer + # ===================== + # Compute Indexer inputs: re-runs q_a_proj + norm (cheap, detached) + indexer_hidden, indexer_q_latent = self._compute_indexer_inputs( + hidden_states + ) + + # Get RoPE freqs for Indexer (non-interleaved, computed from rotary_pos_emb) + # Re-compute from self.rotary_pos_emb since parent doesn't expose it + with paddle.no_grad(): + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + hidden_states, self.config, packed_seq_params + ) + packed_seq = ( + packed_seq_params is not None + and packed_seq_params.qkv_format == "thd" + ) + assert not self.config.apply_rope_fusion, ( + "DSA Indexer requires unfused RoPE (apply_rope_fusion=False). " + "Fused RoPE returns cos/sin instead of freqs, which the Indexer cannot use." + ) + indexer_mscale = 1.0 + if self.config.rope_type == "rope": + indexer_freqs = self.rotary_pos_emb( + rotary_seq_len, packed_seq=packed_seq + ) + else: + indexer_freqs, indexer_mscale = self.rotary_pos_emb( + rotary_seq_len, packed_seq=packed_seq + ) + # rotary_pos_emb returns [1, seq_len, 1, dim]. + # Indexer tensors are batch-first [b, s, heads, head_dim], so keep + # freqs as 4D [1, seq_len, 1, dim] for correct broadcasting. + # Do NOT squeeze(0) — that would make it [seq_len, 1, dim] (seq-first) + # which causes _apply_rotary_pos_emb_bshd to mis-broadcast. + if indexer_freqs is not None and ( + packed_seq_params is None + or self.config.context_parallel_size == 1 + ): + # Use the actual full sequence length from gathered indexer input, + # not the potentially-sharded hidden_states.shape[0] (s/tp in SP mode). + actual_seq_len = indexer_hidden.shape[ + 1 + ] # [b, s, h] — always full seq + indexer_freqs = indexer_freqs[ + :, 0:actual_seq_len + ] # slice seq dim (dim 1) + + # Build causal float_mask for Indexer scoring, matching MG DSAttention.forward: + # MG always passes a causal-aware mask to the Indexer so that topk selection + # never picks future positions. Without this, Indexer could select future + # tokens which are then killed by the causal mask in attention → all-inf rows + # → softmax NaN. + indexer_seq_len = indexer_hidden.shape[1] # [b, s, h] — always full seq + indexer_causal_mask = paddle.triu( + paddle.full( + [indexer_seq_len, indexer_seq_len], + float("-inf"), + dtype="float32", + ), + diagonal=1, + ) # [s, s] + if attention_mask is not None: + # attention_mask: [b, 1, sq, sk] — may contain padding info beyond causal + indexer_float_mask = attention_mask.cast( + "float32" + ) + indexer_causal_mask.unsqueeze(0).unsqueeze(0) # [b, 1, s, s] + else: + indexer_float_mask = indexer_causal_mask.unsqueeze(0).unsqueeze( + 0 + ) # [1, 1, s, s] + + index_scores, topk_indices = self.indexer( + indexer_hidden, + indexer_q_latent, + indexer_freqs, + indexer_float_mask, + mscale=indexer_mscale, + ) + # Detach topk_indices: int64 index tensor, no meaningful gradients. + topk_indices = topk_indices.detach() + + # ===================== + # Build sparse mask + # ===================== + if self.config.sequence_parallel: + seqlen = query.shape[0] # [s, b, nhpp, hd] + bsz = query.shape[1] + else: + bsz = query.shape[0] # [b, s, nhpp, hd] + seqlen = query.shape[1] + + index_mask = paddle.full( + [bsz, seqlen, seqlen], + fill_value=float("-inf"), + dtype="float32", + ) + zeros = paddle.zeros( + [ + topk_indices.shape[0], + topk_indices.shape[1], + topk_indices.shape[2], + ], + dtype="float32", + ) + index_mask = paddle.put_along_axis( + index_mask, topk_indices, zeros, axis=-1 + ) + # Merge causal + index into [b, s, s], then unsqueeze to [b, 1, s, s] + # causal_mask is [s, s], reuse the one built for indexer (same seqlen) + causal_mask = paddle.triu( + paddle.full([seqlen, seqlen], float("-inf"), dtype="float32"), + diagonal=1, + ) + index_mask = index_mask + causal_mask.unsqueeze(0) + combined_mask = index_mask.unsqueeze(1) # [b, 1, s, s] + + if attention_mask is not None: + combined_mask = attention_mask.cast("float32") + combined_mask + + # ===================== + # Core attention (unfused bmm for DSA) + # ===================== + if self.config.sequence_parallel: + query = query.transpose([1, 0, 2, 3]).contiguous() + key = key.transpose([1, 0, 2, 3]).contiguous() + value = value.transpose([1, 0, 2, 3]).contiguous() + + if self.recompute_core_attention and self.training: + core_attn_out = recompute( + _unfused_dsa_attention, + query, + key, + value, + combined_mask.clone() if combined_mask is not None else None, + self.softmax_scale, + ) + else: + core_attn_out = _unfused_dsa_attention( + query, + key, + value, + combined_mask, + self.softmax_scale, + ) + + # ===================== + # Output projection + # ===================== + if self.config.sequence_parallel: + core_attn_out = core_attn_out.transpose([1, 0, 2]).contiguous() + output, bias = self.o_proj(core_attn_out) + + # ===================== + # DSA Indexer KL loss + # ===================== + if self.training and self.dsa_indexer_loss_coeff is not None: + indexer_loss = DSAIndexerLoss.apply( + index_scores, + topk_indices, + query.detach(), + key.detach(), + self.softmax_scale, + float(self.dsa_indexer_loss_coeff), + bool(self.dsa_indexer_use_sparse_loss), + self.pg_collection.tp + if self.pg_collection.tp.nranks > 1 + else None, + ) + output = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + + return output, bias diff --git a/src/paddlefleet/transformer/transformer_config.py b/src/paddlefleet/transformer/transformer_config.py index 8ffedcfac..72a0c9a1c 100644 --- a/src/paddlefleet/transformer/transformer_config.py +++ b/src/paddlefleet/transformer/transformer_config.py @@ -529,6 +529,26 @@ class TransformerConfig(ModelParallelConfig): # cache_mla_latents: bool = False + #################### + # DSA (DeepSeek Sparse Attention) + #################### + + index_n_heads: int | None = None + """Number of DSA Indexer heads. None disables DSA; non-None activates + DeepSeek V3.2 sparse attention path.""" + + index_head_dim: int = 128 + """Per-head dimension for Indexer Q/K vectors.""" + + index_topk: int = 2048 + """Number of token positions selected by Indexer per query token.""" + + indexer_loss_coeff: float | None = None + """KL loss coefficient for DSA Indexer training. None disables the KL loss.""" + + indexer_use_sparse_loss: bool = False + """Whether to restrict DSA KL loss to top-k positions only.""" + @classmethod def from_config(cls, config_dict): # note(zhangweilong): if cls(),will call __post_init__ directly,but __new__ will skip some attr init .please check provider attr From d77032afdfc549ce783be1e1a381318162380d66 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 11 Mar 2026 17:27:02 +0800 Subject: [PATCH 02/12] support hadamard transform --- src/paddlefleet/transformer/dsa_attention.py | 68 ++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py index e0496154a..f80073bee 100644 --- a/src/paddlefleet/transformer/dsa_attention.py +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -52,6 +52,69 @@ from paddlefleet.transformer.transformer_config import TransformerConfig +def hadamard_transform(x: Tensor, scale: float = 1.0) -> Tensor: + """Fast Walsh-Hadamard Transform using the butterfly algorithm. + + Pure Paddle implementation, equivalent to: + F.linear(x, hadamard_matrix(dim)) * scale + + Uses O(N log N) butterfly operations instead of O(N^2) matrix multiply. + The Hadamard matrix is symmetric and orthogonal, so backward is the same + transform applied to grad_output (handled automatically by Paddle autograd). + + Reference: + - fast-hadamard-transform (Tri Dao): csrc/fast_hadamard_transform_cuda.cu + - PaddleFormers/paddleformers/quantization/hadamard_utils.py (matmul_hadU) + + Args: + x: Input tensor of shape (..., dim). dim must be a power of 2. + scale: Scaling factor applied to the output. + + Returns: + Hadamard-transformed tensor of the same shape. + """ + original_shape = x.shape + dim = original_shape[-1] + assert dim > 0 and (dim & (dim - 1)) == 0, ( + f"hadamard_transform requires dim to be a power of 2, got {dim}" + ) + + # Flatten batch dims: (..., dim) -> (batch, dim) + x = x.reshape([-1, dim]) + + # Butterfly: iteratively halve and compute sum/diff pairs. + # Uses paddle.stack (not in-place index assignment) to keep autograd intact. + h = 1 + while h < dim: + x = x.reshape([-1, dim // (2 * h), 2, h]) + a = x[:, :, 0, :] + b = x[:, :, 1, :] + x = paddle.stack([a + b, a - b], axis=2) + x = x.reshape([-1, dim]) + h *= 2 + + return x.reshape(original_shape) * scale + + +def rotate_activation(x: Tensor) -> Tensor: + """Apply Hadamard rotation activation. + + Reference: + https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/inference/model.py#L424-L428 + + Args: + x: Input tensor (must be bfloat16). + + Returns: + Rotated tensor. + """ + assert x.dtype == paddle.bfloat16, ( + f"rotate_activation only support bf16 input, but got {x.dtype}" + ) + hidden_size = x.shape[-1] + return hadamard_transform(x, scale=hidden_size**-0.5) + + # --------------------------------------------------------------------------- # Unfused DSA attention (explicit bmm, supports asymmetric Q/K vs V dims) # --------------------------------------------------------------------------- @@ -220,6 +283,10 @@ def forward_before_topk( if freqs is not None: k = self._apply_rope(k.unsqueeze(2), freqs, mscale).squeeze(2) + # Rotate activation (Hadamard transform) + q = rotate_activation(q) + k = rotate_activation(k) + weights = ( self.weights_proj(hidden_states.cast("float32")) * (self.n_heads**-0.5) @@ -385,6 +452,7 @@ class DSAIndexerLossAutoScaler(paddle.autograd.PyLayer): @staticmethod def forward(ctx, output: Tensor, indexer_loss: Tensor) -> Tensor: + print(f"===========> indexer_loss: {indexer_loss}") ctx.save_for_backward(indexer_loss) return output From d15e7d6a025c0522a3a952b9f82c078f33c7d749 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 11 Mar 2026 20:47:14 +0800 Subject: [PATCH 03/12] return get_query_key_value_tensors tmp values --- src/paddlefleet/transformer/dsa_attention.py | 123 ++++++------------ .../transformer/multi_latent_attention.py | 4 +- 2 files changed, 42 insertions(+), 85 deletions(-) diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py index f80073bee..20286c3c8 100644 --- a/src/paddlefleet/transformer/dsa_attention.py +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -39,7 +39,6 @@ ) from paddlefleet.tensor_parallel.mappings import ( gather_from_sequence_parallel_region, - gather_from_tensor_model_parallel_region, ) from paddlefleet.transformer.enums import AttnMaskType from paddlefleet.transformer.multi_latent_attention import ( @@ -479,16 +478,14 @@ class MLASelfAttentionWithDSA(MLASelfAttention): """MLA Self-attention with DeepSeek Sparse Attention (DSA) Indexer. Extends the upstream MLASelfAttention by: - 1. Reusing parent's get_query_key_value_tensors() for Q/K/V computation. - 2. Overriding forward() to compute Indexer inputs separately, run the DSA - Indexer, build a sparse mask, and use unfused bmm attention. - - The Indexer needs q_latent (normed q_compressed with full q_lora_rank) and - hidden_states in [b, s, h] format. Since the parent's get_query_key_value_tensors - doesn't expose these intermediates, we re-compute q_a_proj + norm for the - Indexer path. This is cheap because: - - q_a_proj is a small down-projection (hidden -> q_lora_rank) - - Indexer inputs are detached, so no extra backward cost + 1. Reusing parent's get_query_key_value_tensors() for Q/K/V + q_compressed. + 2. Overriding forward() to run the DSA Indexer, build a sparse mask, + and use unfused bmm attention. + + The Indexer needs q_compressed (normed, from MLA down-projection) and + hidden_states in [b, s, h] format. The parent's get_query_key_value_tensors + now returns (query, key, value, q_compressed, kv_compressed) — aligned with + Megatron-LM — so we directly reuse q_compressed for the Indexer path. """ def __init__( @@ -520,65 +517,6 @@ def __init__( config, "indexer_use_sparse_loss", False ) - def _compute_indexer_inputs(self, hidden_states): - """Compute Indexer's q_latent by re-running q_a_proj + norm. - - This is a lightweight re-computation of the q down-projection path. - The result is detached, so no gradient flows back through this path. - - Unlike the parent's get_query_key_value_tensors (which scatters back to - [s/tp, b, ...] for SP), the Indexer needs full-sequence tensors [s, b, ...]. - So we follow the parent's q_a_proj path but skip the scatter step. - - Args: - hidden_states: [s/tp, b, h] (SP) or [s, b, h] (non-SP) - - Returns: - indexer_hidden: [b, s, h] — detached hidden_states (full seq) for Indexer - indexer_q_latent: [b, s, q_lora_rank] — detached normed q_compressed (full seq) - """ - # Re-compute q_a_proj on original hidden_states (let it handle SP internally) - with paddle.no_grad(): - if self.config.q_lora_rank is not None: - # q_a_proj (ColumnParallelLinear) handles SP all-gather internally: - # SP: [s/tp, b, h] -> [s, b, q_lora_rank/tp] - # non-SP: [s, b, h] -> [s, b, q_lora_rank/tp] - q_compressed, _ = self.q_a_proj(hidden_states) - # Gather feature dim to full q_lora_rank if sharded - if q_compressed.shape[-1] != self.config.q_lora_rank: - q_compressed = gather_from_tensor_model_parallel_region( - q_compressed - ) - # Parent would scatter_to_sequence_parallel_region here, - # but Indexer needs full sequence — skip scatter. - q_latent = self.q_a_layernorm(q_compressed) - # q_latent: [s, b, q_lora_rank] (full seq, full feature) - else: - q_latent = hidden_states - # Need full seq for the else branch too - if self.config.sequence_parallel: - q_latent = gather_from_sequence_parallel_region( - q_latent, - tensor_parallel_output_grad=True, - group=self.pg_collection.tp, - ) - - # Gather hidden_states separately for indexer_hidden - if self.config.sequence_parallel: - indexer_hidden = gather_from_sequence_parallel_region( - hidden_states, - tensor_parallel_output_grad=True, - group=self.pg_collection.tp, - ).detach() - else: - indexer_hidden = hidden_states.detach() - - # Convert [s, b, ...] -> [b, s, ...] - indexer_hidden = indexer_hidden.transpose([1, 0, 2]) - indexer_q_latent = q_latent.detach().transpose([1, 0, 2]) - - return indexer_hidden, indexer_q_latent - def forward( self, hidden_states, @@ -596,8 +534,8 @@ def forward( """Forward: MLA projections + DSA Indexer + sparse attention. Overrides the parent forward to: - 1. Reuse parent's get_query_key_value_tensors for Q/K/V - 2. Compute Indexer inputs separately (lightweight re-computation) + 1. Reuse parent's get_query_key_value_tensors for Q/K/V + q_compressed + 2. Prepare Indexer inputs from returned q_compressed (no re-computation) 3. Run DSA Indexer to get top-k indices 4. Build sparse mask and use unfused bmm attention 5. Compute DSA Indexer KL loss if enabled @@ -609,22 +547,41 @@ def forward( assert rotary_pos_cos is None and rotary_pos_sin is None # ===================== - # Query, Key, and Value (reuse parent's MLA implementation) + # Query, Key, Value + compressed intermediates (aligned with Megatron) # ===================== - query, key, value = self.get_query_key_value_tensors( - hidden_states, - key_value_states, - position_ids, - packed_seq_params, + query, key, value, q_compressed, kv_compressed = ( + self.get_query_key_value_tensors( + hidden_states, + key_value_states, + position_ids, + packed_seq_params, + ) ) # ===================== - # DSA Indexer + # DSA Indexer inputs (reuse q_compressed from parent, no re-computation) # ===================== - # Compute Indexer inputs: re-runs q_a_proj + norm (cheap, detached) - indexer_hidden, indexer_q_latent = self._compute_indexer_inputs( - hidden_states - ) + # q_compressed: [s/tp, b, q_lora_rank] (SP) or [s, b, q_lora_rank] (non-SP) + # Indexer needs full-sequence [b, s, ...] tensors, all detached. + with paddle.no_grad(): + if self.config.sequence_parallel: + indexer_q_latent = gather_from_sequence_parallel_region( + q_compressed, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ).detach() + indexer_hidden = gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ).detach() + else: + indexer_q_latent = q_compressed.detach() + indexer_hidden = hidden_states.detach() + + # Convert [s, b, ...] -> [b, s, ...] + indexer_q_latent = indexer_q_latent.transpose([1, 0, 2]) + indexer_hidden = indexer_hidden.transpose([1, 0, 2]) # Get RoPE freqs for Indexer (non-interleaved, computed from rotary_pos_emb) # Re-compute from self.rotary_pos_emb since parent doesn't expose it diff --git a/src/paddlefleet/transformer/multi_latent_attention.py b/src/paddlefleet/transformer/multi_latent_attention.py index 2df30636c..5e1b908aa 100644 --- a/src/paddlefleet/transformer/multi_latent_attention.py +++ b/src/paddlefleet/transformer/multi_latent_attention.py @@ -187,7 +187,7 @@ def forward( # ===================== # Get the query, key and value tensors based on the type of attention - query, key, value = self.get_query_key_value_tensors( + query, key, value, _, _ = self.get_query_key_value_tensors( hidden_states, key_value_states, position_ids, @@ -662,7 +662,7 @@ def qkv_up_proj_and_rope_apply( q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb ) - return query, key, value + return query, key, value, q_compressed, kv_compressed def backward_dw(self) -> NoReturn: """Execute weight gradient computation""" From a84e531268ce16ec18fd4d5737f7f2b3ec3dc60a Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Sun, 15 Mar 2026 19:00:49 +0800 Subject: [PATCH 04/12] fix --- .../embeddings/yarn_rotary_pos_embedding.py | 11 +- src/paddlefleet/transformer/dsa_attention.py | 747 ++++++++++++++---- .../transformer/multi_latent_attention.py | 2 + .../transformer/transformer_config.py | 4 +- 4 files changed, 624 insertions(+), 140 deletions(-) diff --git a/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py b/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py index 6e037c878..5443106c7 100644 --- a/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -117,10 +117,6 @@ def forward( Returns: Tensor: Embeddings after applying Yarn RoPE. """ - assert not self.rotary_interleaved, ( - "Yarn RoPE does not support interleaved rotary embeddings" - ) - low, high = _yarn_find_correction_range( self.beta_fast, self.beta_slow, @@ -150,7 +146,12 @@ def forward( self.scaling_factor, self.mscale, self.mscale_all_dim ) - emb = paddle.cat((freqs, freqs), axis=-1) + if not self.rotary_interleaved: + emb = paddle.cat((freqs, freqs), axis=-1) + else: + emb = paddle.stack( + (freqs.reshape((-1, 1)), freqs.reshape((-1, 1))), axis=-1 + ).reshape((freqs.shape[0], -1)) # emb [1, seq_len, 1, dim] emb = emb[None, :, None, :] return emb, _mscale diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py index 20286c3c8..cbefd3013 100644 --- a/src/paddlefleet/transformer/dsa_attention.py +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -18,15 +18,15 @@ This module extends the upstream MLASelfAttention with DSA Indexer support (DeepSeek V3.2 architecture): - Indexer: Token scoring module that selects top-k relevant positions - - DSAIndexerLoss: KL-divergence loss for Indexer training + - FusedDSAIndexerLoss: Fused KL-divergence loss with full manual backward - DSAIndexerLossAutoScaler: Loss scaling helper - MLASelfAttentionWithDSA: Subclass of MLASelfAttention with DSA integration -Reference: Megatron-LM/megatron/core/transformer/experimental_attention_variant/dsa.py """ from __future__ import annotations +import logging from typing import TYPE_CHECKING import paddle @@ -34,6 +34,7 @@ from paddle import Tensor from paddle.distributed.fleet.utils import recompute +from paddlefleet import parallel_state from paddlefleet.models.common.embeddings.rope_utils import ( _apply_rotary_pos_emb_bshd, ) @@ -126,7 +127,7 @@ def _unfused_dsa_attention( combined_mask: Tensor | None, softmax_scale: float, ) -> Tensor: - """Unfused DSA sparse attention (matches Megatron-Core unfused_dsa_fn). + """Unfused DSA sparse attention Uses explicit bmm instead of flash attention to support: - Different Q/K head_dim vs V head_dim (MLA architecture) @@ -200,7 +201,6 @@ class Indexer(paddle.nn.Layer): - Per-head learned importance weights via weights_proj - weights absorbs softmax_scale - Reference: Megatron-LM dsa.py DSAIndexer """ def __init__(self, config: TransformerConfig, layer_number: int): @@ -244,23 +244,23 @@ def _apply_rope( ) -> Tensor: """Apply non-interleaved RoPE to the pe portion of x. - Split order: [nope | pe], matching Megatron-Core dsa.py _apply_rope. + Split order: [pe | nope], matching DeepSeek-V3.2 Indexer (model.py:462). Args: - x: [..., head_dim] (nope_dim + rope_dim) - freqs: RoPE frequencies + x: [..., head_dim] (rope_dim + nope_dim) + freqs: RoPE frequencies (must be half-half format for non-interleaved) mscale: YaRN concentration factor (1.0 for plain RoPE, ~1.37 for YaRN) """ - x_nope = x[..., : self.nope_head_dim] - x_pe = x[..., self.nope_head_dim :] + x_pe = x[..., : self.rope_head_dim] + x_nope = x[..., self.rope_head_dim :] x_pe = _apply_rotary_pos_emb_bshd( x_pe, freqs, rotary_interleaved=False, - multi_latent_attention=self.config.multi_latent_attention, + multi_latent_attention=False, mscale=mscale, ) - return paddle.concat([x_nope, x_pe], axis=-1) + return paddle.concat([x_pe, x_nope], axis=-1) def forward_before_topk( self, @@ -334,114 +334,423 @@ def forward( return index_scores, topk_indices -# --------------------------------------------------------------------------- -# DSA Indexer Loss (PyLayer) -# --------------------------------------------------------------------------- +def _compute_index_scores_fused( + q: Tensor, weights: Tensor, k: Tensor +) -> Tensor: + """Compute index scores from Indexer outputs. + Args: + q: [sq, b, h, d] (Indexer query, after RoPE + Hadamard) + weights: [sq, b, h] (per-head importance weights) + k: [sk, b, d] (Indexer key, after RoPE + Hadamard) -class DSAIndexerLoss(paddle.autograd.PyLayer): - """Fused DSA Indexer KL-divergence loss. + Returns: + index_scores: [b, sq, sk] + """ + # q @ k^T -> [sq, b, h, sk] + index_scores = paddle.einsum( + "sbhd,tbd->sbht", q.cast("float32"), k.cast("float32") + ) + # ReLU activation + index_scores = F.relu(index_scores) + # Weight each head: [sq, b, h, sk] * [sq, b, h, 1] -> [sq, b, h, sk] + index_scores = index_scores * weights.unsqueeze(-1) + # Sum across heads: [sq, b, h, sk] -> [sq, b, sk] + index_scores = index_scores.sum(axis=2) + # Transpose to [b, sq, sk] + index_scores = index_scores.transpose([1, 0, 2]) + return index_scores + + +def _compute_dsa_indexer_loss( + index_scores: Tensor, + topk_indices: Tensor, + query: Tensor, + key: Tensor, + softmax_scale: float, + loss_coeff: float, + sparse_loss: bool, + tp_group, +) -> Tensor: + """Compute KL divergence loss between index_scores and true attention_scores. + + Args: + index_scores: [b, sq, sk] + topk_indices: [b, sq, topk] + query: [sq, b, np, hn] (MLA query, DETACHED) + key: [sk, b, np, hn] (MLA key, DETACHED) + softmax_scale: Scale coefficient after q @ k^T + loss_coeff: Coefficient for the indexer KL divergence loss + sparse_loss: Whether to apply sparse index mask + tp_group: TP process group (or None) - Trains the Indexer to predict which tokens receive high attention weights. - Reference: Megatron-Core dsa.py FusedDSAIndexerLoss + Returns: + indexer_loss: scalar """ + sq, b, np, hn = query.shape + sk = key.shape[0] + + # [sq, b, np, hn] -> [b, np, sq, hn] -> [b * np, sq, hn] + query_reshaped = query.transpose([1, 2, 0, 3]).reshape([b * np, sq, hn]) + # [sk, b, np, hn] -> [b, np, hn, sk] -> [b * np, hn, sk] + key_reshaped = key.transpose([1, 2, 3, 0]).reshape([b * np, hn, sk]) + # Compute attention scores [b * np, sq, sk] + attention_scores = ( + paddle.bmm(query_reshaped.cast("float32"), key_reshaped.cast("float32")) + * softmax_scale + ) + # Reshape to [b, np, sq, sk] + attention_scores = attention_scores.reshape([b, np, sq, sk]) - @staticmethod - def forward( - ctx, - index_scores: Tensor, # [b, sq, sk] - topk_indices: Tensor, # [b, sq, topk] - query: Tensor, # [b, sq, nhpp, qk_head_dim] (DETACHED) - key: Tensor, # [b, sk, nhpp, qk_head_dim] (DETACHED) - mla_softmax_scale: float, - loss_coeff: float, - sparse_loss: bool, - tp_group, - ) -> Tensor: - b, sq, sk = index_scores.shape - nhpp = query.shape[2] + # causal_mask [sq, sk] + causal_mask = paddle.triu( + paddle.full([sq, sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + # index_mask [b, sq, sk] + index_mask = paddle.full([b, sq, sk], float("-inf"), dtype="float32") + index_mask = paddle.put_along_axis( + index_mask, + topk_indices, + paddle.zeros_like(topk_indices, dtype="float32"), + axis=-1, + ) - q_f = query.cast("float32") - k_f = key.cast("float32") - attention_scores = ( - paddle.einsum("bshd,bthd->bhst", q_f, k_f) * mla_softmax_scale + # [b, np, sq, sk] + [1, 1, sq, sk] -> [b, np, sq, sk] + attention_scores = attention_scores + causal_mask.reshape([1, 1, sq, sk]) + if sparse_loss: + # [b, np, sq, sk] + [b, 1, sq, sk] -> [b, np, sq, sk] + attention_scores = attention_scores + index_mask.reshape([b, 1, sq, sk]) + # [b, sq, sk] + [b, sq, sk] -> [b, sq, sk] + index_scores = index_scores + index_mask + + # [b, np, sq, sk] -> [b, np, sq, sk] + attention_scores = F.softmax(attention_scores, axis=-1, dtype="float32") + # [b, sq, sk] -> [b, sq, sk] + index_scores = F.softmax(index_scores, axis=-1, dtype="float32") + + # Sum attention scores across heads: [b, np, sq, sk] -> [b, sq, sk] + attention_scores = attention_scores.sum(axis=1) + if tp_group is not None and tp_group.nranks > 1: + paddle.distributed.all_reduce( + attention_scores.contiguous(), group=tp_group ) + # L1 normalize target on the last dimension + attention_scores = attention_scores / attention_scores.sum( + axis=-1, keepdim=True + ) - causal_mask = paddle.triu( - paddle.full([sq, sk], float("-inf"), dtype="float32"), - diagonal=1, - ) - attention_scores = attention_scores + causal_mask.unsqueeze([0, 1]) + # KL divergence: KL(target || index) = target * log(target / index) + kl_per_element = attention_scores * ( + paddle.log(attention_scores + 1e-10) - paddle.log(index_scores + 1e-10) + ) - index_mask = paddle.full([b, sq, sk], float("-inf"), dtype="float32") - index_mask = paddle.put_along_axis( - index_mask, - topk_indices, - paddle.zeros_like(topk_indices, dtype="float32"), - axis=-1, - ) + # [b, sq, sk] -> [b, sq] -> [1] + kl_div = kl_per_element.sum(axis=-1).mean() + indexer_loss = kl_div * loss_coeff - masked_index_scores = index_scores.cast( - "float32" - ) + causal_mask.unsqueeze(0) - if sparse_loss: - attention_scores = attention_scores + index_mask.unsqueeze(1) - masked_index_scores = masked_index_scores + index_mask + return indexer_loss - attn_probs = F.softmax(attention_scores, axis=-1) - idx_probs = F.softmax(masked_index_scores, axis=-1) - attn_probs_sum = attn_probs.sum(axis=1) - if tp_group is not None and tp_group.nranks > 1: - paddle.distributed.all_reduce(attn_probs_sum, group=tp_group) +def _bwd_fused_indexer_loss( + q: Tensor, + weights: Tensor, + k: Tensor, + query: Tensor, + key: Tensor, + topk_indices: Tensor, + softmax_scale: float, + loss_coeff: float, + sparse_loss: bool, + grad_loss: Tensor, + tp_group, +) -> tuple[Tensor, Tensor, Tensor]: + """Manual backward for fused indexer loss. - target = attn_probs_sum / ( - attn_probs_sum.sum(axis=-1, keepdim=True) + 1e-10 - ) - kl = target * ( - paddle.log(target + 1e-10) - paddle.log(idx_probs + 1e-10) + All tensor layouts (sequence-first): + q: [sq, b, h, d] + weights: [sq, b, h] + k: [sk, b, d] + query: [sq, b, np, hn] (MLA query) + key: [sk, b, np, hn] (MLA key) + + Returns: + grad_q: [sq, b, h, d] + grad_weights: [sq, b, h] + grad_k: [sk, b, d] + """ + # Recompute index_scores from (q, weights, k) + index_scores = _compute_index_scores_fused(q, weights, k) # [b, sq, sk] + + sq, b, np, hn = query.shape + sk = key.shape[0] + + # [sq, b, np, hn] -> [b, np, sq, hn] -> [b * np, sq, hn] + query_reshaped = query.transpose([1, 2, 0, 3]).reshape([b * np, sq, hn]) + # [sk, b, np, hn] -> [b, np, hn, sk] -> [b * np, hn, sk] + key_reshaped = key.transpose([1, 2, 3, 0]).reshape([b * np, hn, sk]) + # Compute attention scores [b * np, sq, sk] + attention_scores = ( + paddle.bmm(query_reshaped.cast("float32"), key_reshaped.cast("float32")) + * softmax_scale + ) + del query_reshaped, key_reshaped + + # Reshape to [b, np, sq, sk] + attention_scores = attention_scores.reshape([b, np, sq, sk]) + + # causal_mask [sq, sk] + causal_mask = paddle.triu( + paddle.full([sq, sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + # index_mask [b, sq, sk] + index_mask = paddle.full([b, sq, sk], float("-inf"), dtype="float32") + index_mask = paddle.put_along_axis( + index_mask, + topk_indices, + paddle.zeros_like(topk_indices, dtype="float32"), + axis=-1, + ) + + # Apply causal mask to both attention and index scores + attention_scores = attention_scores + causal_mask.reshape([1, 1, sq, sk]) + index_scores = index_scores + causal_mask.unsqueeze(0) + del causal_mask + + if sparse_loss: + attention_scores = attention_scores + index_mask.reshape([b, 1, sq, sk]) + index_scores = index_scores + index_mask + + # Compute softmax for both + attention_scores_softmax = F.softmax( + attention_scores, axis=-1, dtype="float32" + ) + del attention_scores + + index_scores_softmax = F.softmax(index_scores, axis=-1, dtype="float32") + del index_scores + + # Sum attention scores across heads: [b, np, sq, sk] -> [b, sq, sk] + attention_scores_sum = attention_scores_softmax.sum(axis=1) + del attention_scores_softmax + + if tp_group is not None and tp_group.nranks > 1: + paddle.distributed.all_reduce( + attention_scores_sum.contiguous(), group=tp_group ) - kl_div = kl.sum(axis=-1).mean() - indexer_loss = kl_div * loss_coeff - ctx.save_for_backward( - target, idx_probs, index_mask if sparse_loss else None + # L1 normalize + attention_scores_normalized = ( + attention_scores_sum / attention_scores_sum.sum(axis=-1, keepdim=True) + ) + del attention_scores_sum + + # Backward through loss = kl_div * loss_coeff + # where kl_div = kl_per_element.sum(dim=-1).mean() + grad_kl_div = grad_loss.cast("float32") * loss_coeff # scalar + + # Backward through mean: distribute gradient equally + grad_kl_per_row = grad_kl_div / (b * sq) # scalar + + # Backward through sum(dim=-1): broadcast back to [b, sq, sk] + grad_kl_per_element = grad_kl_per_row.reshape([1, 1, 1]).expand([b, sq, sk]) + + # Backward through kl: ∂kl/∂index_softmax = -target / index_softmax + grad_index_scores_softmax = ( + -attention_scores_normalized + / (index_scores_softmax + 1e-10) + * grad_kl_per_element + ) + del attention_scores_normalized + + # Backward through softmax: + # ∂L/∂x = softmax * (∂L/∂softmax - sum(∂L/∂softmax * softmax)) + sum_grad = (grad_index_scores_softmax * index_scores_softmax).sum( + axis=-1, keepdim=True + ) + grad_index_scores_logits = index_scores_softmax * ( + grad_index_scores_softmax - sum_grad + ) + del index_scores_softmax, grad_index_scores_softmax, sum_grad + + # Zero out gradients for masked positions + causal_valid_mask = paddle.tril( + paddle.ones([sq, sk], dtype="bool") + ) # [sq, sk] + if sparse_loss: + index_valid_mask = index_mask == 0 # [b, sq, sk] + del index_mask + valid_mask = ( + causal_valid_mask.unsqueeze(0) & index_valid_mask + ) # [b, sq, sk] + del index_valid_mask + else: + del index_mask + valid_mask = causal_valid_mask.unsqueeze(0).expand( + [b, sq, sk] + ) # [b, sq, sk] + del causal_valid_mask + + grad_index_scores_logits = grad_index_scores_logits * valid_mask.cast( + "float32" + ) + del valid_mask + + # Transpose from [b, sq, sk] to [sq, b, sk] + grad_index_scores = grad_index_scores_logits.transpose( + [1, 0, 2] + ) # [sq, b, sk] + del grad_index_scores_logits + + # Backward through sum over heads: expand gradient + grad_weighted_scores = grad_index_scores.unsqueeze(2) # [sq, b, 1, sk] + del grad_index_scores + + # Compute forward values needed for backward (recomputation) + scores = paddle.einsum( + "sbhd,tbd->sbht", q.cast("float32"), k.cast("float32") + ) # [sq, b, h, sk] + relu_mask = scores > 0 + scores_after_relu = F.relu(scores) + del scores + + # Backward through multiplication by weights: + # ∂L/∂weights = grad * relu_scores (sum over sk) + grad_weights = (grad_weighted_scores * scores_after_relu).sum( + axis=-1 + ) # [sq, b, h] + + # ∂L/∂relu_scores = grad * weights + grad_scores_after_relu = grad_weighted_scores * weights.unsqueeze( + -1 + ) # [sq, b, h, sk] + del grad_weighted_scores, scores_after_relu + + # Backward through ReLU + grad_scores = grad_scores_after_relu * relu_mask.cast( + "float32" + ) # [sq, b, h, sk] + del grad_scores_after_relu, relu_mask + + # Backward through einsum 'sbhd,tbd->sbht' + # ∂L/∂q = einsum('sbht,tbd->sbhd', grad_scores, k) + grad_q = paddle.einsum( + "sbht,tbd->sbhd", grad_scores, k.cast("float32") + ) # [sq, b, h, d] + # ∂L/∂k = einsum('sbht,sbhd->tbd', grad_scores, q) + grad_k = paddle.einsum( + "sbht,sbhd->tbd", grad_scores, q.cast("float32") + ) # [sk, b, d] + del grad_scores + + return ( + grad_q.cast(q.dtype), + grad_weights.cast(weights.dtype), + grad_k.cast(k.dtype), + ) + + +class FusedDSAIndexerLoss(paddle.autograd.PyLayer): + """Fused DSA Indexer Loss: index_scores + topk + KL loss + full manual backward.""" + + _last_topk_indices: Tensor | None = None + + @staticmethod + def forward( + ctx, + q: Tensor, # [sq, b, h, d] — Indexer query output + weights: Tensor, # [sq, b, h] — Indexer per-head weights + k: Tensor, # [sk, b, d] — Indexer key output + query: Tensor, # [sq, b, np, hn] — MLA query (DETACHED) + key: Tensor, # [sk, b, np, hn] — MLA key (DETACHED) + # Non-tensor params follow (stored on ctx, not in backward returns) + softmax_scale: float = 1.0, + topk: int = 64, + loss_coeff: float = 1.0, + mask: Tensor | None = None, + sparse_loss: bool = True, + tp_group=None, + ) -> Tensor: + """Fused forward: compute index_scores, topk, and KL loss. + + Args: + q: Indexer query after RoPE+Hadamard [sq, b, h, d] + weights: Per-head importance weights [sq, b, h] + k: Indexer key after RoPE+Hadamard [sk, b, d] + query: MLA query (detached) [sq, b, np, hn] + key: MLA key (detached) [sk, b, np, hn] + softmax_scale: MLA attention softmax scale + topk: Number of top-k indices to select + loss_coeff: Coefficient for KL loss + mask: Optional mask for index_scores [b, 1, sq, sk] or [1, 1, sq, sk] + sparse_loss: Whether to use sparse index mask in loss + tp_group: TP process group (or None) + + Returns: + indexer_loss: scalar KL divergence loss + """ + # Step 1: Compute index_scores from (q, weights, k) + index_scores = _compute_index_scores_fused(q, weights, k) # [b, sq, sk] + + # Step 2: Apply mask and select topk + if mask is not None: + masked_scores = index_scores + mask.squeeze(1) + else: + masked_scores = index_scores + topk_k = min(topk, masked_scores.shape[-1]) + topk_indices = paddle.topk(masked_scores, k=topk_k, axis=-1)[1] + + # Store topk_indices for caller to retrieve + FusedDSAIndexerLoss._last_topk_indices = topk_indices.detach() + + # Step 3: Compute KL loss + indexer_loss = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + softmax_scale, + loss_coeff, + sparse_loss, + tp_group, ) - ctx.b = b - ctx.sq = sq - ctx.sparse_loss = sparse_loss + + ctx.save_for_backward(q, weights, k, query, key, topk_indices) + ctx.softmax_scale = softmax_scale ctx.loss_coeff = loss_coeff + ctx.sparse_loss = sparse_loss + ctx.tp_group = tp_group return indexer_loss @staticmethod def backward(ctx, grad_loss: Tensor): - target, idx_probs, index_mask = ctx.saved_tensor() - b, sq = ctx.b, ctx.sq - sparse_loss = ctx.sparse_loss - loss_coeff = ctx.loss_coeff - sk = target.shape[-1] - - grad_idx_probs = ( - -target - / (idx_probs + 1e-10) - * (grad_loss.cast("float32") * loss_coeff / (b * sq)) - ) - sum_grad = (grad_idx_probs * idx_probs).sum(axis=-1, keepdim=True) - grad_index_scores = idx_probs * (grad_idx_probs - sum_grad) + """Backward: recompute and manually backprop to (q, weights, k). - causal_valid = paddle.tril(paddle.ones([sq, sk], dtype="bool")) - if sparse_loss and index_mask is not None: - valid_mask = causal_valid.unsqueeze(0) & (index_mask == 0) - else: - valid_mask = causal_valid.unsqueeze(0).expand([b, sq, sk]) - - grad_index_scores = grad_index_scores * valid_mask.cast("float32") + Returns 6 gradients for the 6 Tensor inputs to forward: + q, weights, k, query, key, mask + (Paddle PyLayer only counts Tensor params, not float/int/bool/None.) + """ + q, weights, k, query, key, topk_indices = ctx.saved_tensor() + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + ctx.softmax_scale, + ctx.loss_coeff, + ctx.sparse_loss, + grad_loss, + ctx.tp_group, + ) - # Gradients for Tensor inputs only (Paddle PyLayer convention): - # index_scores, topk_indices(None), query(None), key(None) - return grad_index_scores.cast(idx_probs.dtype), None, None, None + # 6 Tensor inputs: q, weights, k, query, key, mask + return grad_q, grad_weights, grad_k, None, None, None class DSAIndexerLossAutoScaler(paddle.autograd.PyLayer): @@ -451,7 +760,7 @@ class DSAIndexerLossAutoScaler(paddle.autograd.PyLayer): @staticmethod def forward(ctx, output: Tensor, indexer_loss: Tensor) -> Tensor: - print(f"===========> indexer_loss: {indexer_loss}") + print(f"[DSA DEBUG] indexer_loss = {indexer_loss.item():.6f}") ctx.save_for_backward(indexer_loss) return output @@ -469,6 +778,126 @@ def set_loss_scale(scale: Tensor): DSAIndexerLossAutoScaler._main_loss_backward_scale = scale +logger = logging.getLogger(__name__) + + +class DSAIndexerLossLoggingHelper: + """Helper class for logging sparse attention indexer losses across layers and ranks.""" + + tracker: dict = {} + + @staticmethod + def save_loss_to_tracker( + loss: Tensor, + layer_number: int, + num_layers: int, + reduce_group=None, + avg_group=None, + ): + """Save the indexer loss for logging. + + Args: + loss: The loss tensor (scalar). + layer_number: Layer index of the loss, 1-indexed. + num_layers: The number of total layers. + reduce_group: The group for reducing the loss. + avg_group: The group for averaging the loss. + """ + if layer_number is None: + return + + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" not in tracker: + tracker["values"] = paddle.zeros([num_layers]) + tracker["values"][layer_number - 1] += loss.detach() + tracker["reduce_group"] = reduce_group + tracker["avg_group"] = avg_group + + @staticmethod + def clean_loss_in_tracker(): + """Clear the indexer losses.""" + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" in tracker: + tracker["values"].zero_() + tracker["reduce_group"] = None + tracker["avg_group"] = None + + @staticmethod + def reduce_loss_in_tracker(): + """Collect and reduce the indexer losses across ranks.""" + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" not in tracker: + return + values = tracker["values"] + + # PP all-reduce + pp_group = parallel_state.get_pipeline_model_parallel_group( + check_initialized=False + ) + if pp_group is not None and pp_group.nranks > 1: + paddle.distributed.all_reduce(values, group=pp_group) + + # TP reduce + if tracker.get("reduce_group") is not None: + paddle.distributed.all_reduce(values, group=tracker["reduce_group"]) + + # CP avg + if tracker.get("avg_group") is not None: + paddle.distributed.all_reduce(values, group=tracker["avg_group"]) + values /= tracker["avg_group"].nranks + + # DP avg + dp_group = parallel_state.get_data_parallel_group( + check_initialized=False + ) + if dp_group is not None and dp_group.nranks > 1: + paddle.distributed.all_reduce(values, group=dp_group) + values /= dp_group.nranks + + @staticmethod + def track_indexer_metrics( + loss_scale: float, + iteration: int, + writer=None, + total_loss_dict: dict | None = None, + ): + """Track the sparse attention indexer metrics for logging. + + Args: + loss_scale: Scale factor for the loss (e.g. 1/num_microbatches). + iteration: Current training iteration. + writer: TensorBoard writer (optional). + total_loss_dict: Dictionary to accumulate total losses (optional). + """ + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" not in tracker: + return + + indexer_loss_values = tracker["values"] * loss_scale + num_layers = indexer_loss_values.shape[0] + avg_indexer_loss = indexer_loss_values.sum() / num_layers + + if total_loss_dict is not None: + if "indexer loss" in total_loss_dict: + total_loss_dict["indexer loss"] += avg_indexer_loss + else: + total_loss_dict["indexer loss"] = avg_indexer_loss + + if writer is not None: + writer.add_scalar( + "indexer loss", avg_indexer_loss.item(), iteration + ) + + logger.info( + "Iteration %d | indexer loss: %.6f", + iteration, + avg_indexer_loss.item(), + ) + + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + + # --------------------------------------------------------------------------- # MLASelfAttentionWithDSA — extends upstream MLASelfAttention # --------------------------------------------------------------------------- @@ -484,8 +913,7 @@ class MLASelfAttentionWithDSA(MLASelfAttention): The Indexer needs q_compressed (normed, from MLA down-projection) and hidden_states in [b, s, h] format. The parent's get_query_key_value_tensors - now returns (query, key, value, q_compressed, kv_compressed) — aligned with - Megatron-LM — so we directly reuse q_compressed for the Indexer path. + now returns (query, key, value, q_compressed, kv_compressed) """ def __init__( @@ -547,7 +975,7 @@ def forward( assert rotary_pos_cos is None and rotary_pos_sin is None # ===================== - # Query, Key, Value + compressed intermediates (aligned with Megatron) + # Query, Key, Value + compressed intermediates # ===================== query, key, value, q_compressed, kv_compressed = ( self.get_query_key_value_tensors( @@ -562,26 +990,28 @@ def forward( # DSA Indexer inputs (reuse q_compressed from parent, no re-computation) # ===================== # q_compressed: [s/tp, b, q_lora_rank] (SP) or [s, b, q_lora_rank] (non-SP) - # Indexer needs full-sequence [b, s, ...] tensors, all detached. - with paddle.no_grad(): - if self.config.sequence_parallel: - indexer_q_latent = gather_from_sequence_parallel_region( - q_compressed, - tensor_parallel_output_grad=True, - group=self.pg_collection.tp, - ).detach() - indexer_hidden = gather_from_sequence_parallel_region( - hidden_states, - tensor_parallel_output_grad=True, - group=self.pg_collection.tp, - ).detach() - else: - indexer_q_latent = q_compressed.detach() - indexer_hidden = hidden_states.detach() + # Indexer needs full-sequence [b, s, ...] tensors. + # Detach inputs to prevent gradients from flowing back to the main model + # Do NOT use no_grad here — Indexer's own parameters (wq_b, wk, weights_proj) + # need gradient tracking via FusedDSAIndexerLoss manual backward. + if self.config.sequence_parallel: + indexer_q_latent = gather_from_sequence_parallel_region( + q_compressed, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ).detach() + indexer_hidden = gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ).detach() + else: + indexer_q_latent = q_compressed.detach() + indexer_hidden = hidden_states.detach() - # Convert [s, b, ...] -> [b, s, ...] - indexer_q_latent = indexer_q_latent.transpose([1, 0, 2]) - indexer_hidden = indexer_hidden.transpose([1, 0, 2]) + # Convert [s, b, ...] -> [b, s, ...] + indexer_q_latent = indexer_q_latent.transpose([1, 0, 2]) + indexer_hidden = indexer_hidden.transpose([1, 0, 2]) # Get RoPE freqs for Indexer (non-interleaved, computed from rotary_pos_emb) # Re-compute from self.rotary_pos_emb since parent doesn't expose it @@ -624,6 +1054,16 @@ def forward( :, 0:actual_seq_len ] # slice seq dim (dim 1) + # MLA's YarnRotaryEmbedding generates interleaved freqs [θ₁,θ₁,θ₂,θ₂,...] + # when config.rotary_interleaved=True, but Indexer uses non-interleaved + # RoPE which expects half-half freqs [θ₁,θ₂,...,θ₁,θ₂,...]. + # Convert format here to match Indexer's rotary_interleaved=False. + if self.config.rotary_interleaved and indexer_freqs is not None: + indexer_freqs = paddle.concat( + [indexer_freqs[..., 0::2], indexer_freqs[..., 1::2]], + axis=-1, + ) + # Build causal float_mask for Indexer scoring, matching MG DSAttention.forward: # MG always passes a causal-aware mask to the Indexer so that topk selection # never picks future positions. Without this, Indexer could select future @@ -648,15 +1088,60 @@ def forward( 0 ) # [1, 1, s, s] - index_scores, topk_indices = self.indexer( + # Indexer forward_before_topk runs WITH gradient tracking so that + # FusedDSAIndexerLoss can backprop through (q, weights, k) to + # Indexer parameters (wq_b, wk, weights_proj). + q_idx, k_idx, weights_idx = self.indexer.forward_before_topk( indexer_hidden, indexer_q_latent, indexer_freqs, - indexer_float_mask, mscale=indexer_mscale, ) - # Detach topk_indices: int64 index tensor, no meaningful gradients. - topk_indices = topk_indices.detach() + + # Convert Indexer outputs from batch-first [b, s, h, d] to + # sequence-first [s, b, h, d] + q_idx_sf = q_idx.transpose([1, 0, 2, 3]) # [sq, b, h, d] + k_idx_sf = k_idx.transpose([1, 0, 2]) # [sk, b, d] + weights_idx_sf = weights_idx.transpose([1, 0, 2]) # [sq, b, h] + + # MLA query/key for loss: sequence-first, detached + # query is [s, b, nhpp, hd] or [b, s, nhpp, hd] depending on SP + if self.config.sequence_parallel: + # query is [s/tp, b, nhpp, hd] in SP; need full-seq [s, b, nhpp, hd] + # But for the loss we use the already-gathered full-seq versions from + # the batch-first path after transpose. + mla_query_sf = query # already [s, b, nhpp, hd] in SP mode + mla_key_sf = key + else: + # query is [b, s, nhpp, hd] → [s, b, nhpp, hd] + mla_query_sf = query.transpose([1, 0, 2, 3]) + mla_key_sf = key.transpose([1, 0, 2, 3]) + + # FusedDSAIndexerLoss: compute index_scores + topk + KL loss inside PyLayer, + # with full manual backward to (q, weights, k). + if self.training and self.dsa_indexer_loss_coeff is not None: + indexer_loss = FusedDSAIndexerLoss.apply( + q_idx_sf, + weights_idx_sf, + k_idx_sf, + mla_query_sf.detach(), + mla_key_sf.detach(), + self.softmax_scale, + self.indexer.index_topk, + float(self.dsa_indexer_loss_coeff), + indexer_float_mask, + bool(self.dsa_indexer_use_sparse_loss), + self.pg_collection.tp + if self.pg_collection.tp.nranks > 1 + else None, + ) + topk_indices = FusedDSAIndexerLoss._last_topk_indices + else: + # Inference or no loss: compute index_scores + topk directly + index_scores, topk_indices = self.indexer.compute_index_scores( + q_idx, k_idx, weights_idx, indexer_float_mask + ) + topk_indices = topk_indices.detach() # ===================== # Build sparse mask @@ -730,21 +1215,15 @@ def forward( output, bias = self.o_proj(core_attn_out) # ===================== - # DSA Indexer KL loss + # DSA Indexer KL loss (already computed by FusedDSAIndexerLoss above) # ===================== if self.training and self.dsa_indexer_loss_coeff is not None: - indexer_loss = DSAIndexerLoss.apply( - index_scores, - topk_indices, - query.detach(), - key.detach(), - self.softmax_scale, - float(self.dsa_indexer_loss_coeff), - bool(self.dsa_indexer_use_sparse_loss), - self.pg_collection.tp - if self.pg_collection.tp.nranks > 1 - else None, - ) + if self.dsa_indexer_loss_coeff > 0: + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss=indexer_loss, + layer_number=self.layer_number, + num_layers=self.config.num_hidden_layers, + ) output = DSAIndexerLossAutoScaler.apply(output, indexer_loss) return output, bias diff --git a/src/paddlefleet/transformer/multi_latent_attention.py b/src/paddlefleet/transformer/multi_latent_attention.py index 5e1b908aa..03f38d924 100644 --- a/src/paddlefleet/transformer/multi_latent_attention.py +++ b/src/paddlefleet/transformer/multi_latent_attention.py @@ -107,6 +107,7 @@ def __init__( if self.config.rope_type == "rope": self.rotary_pos_emb = RotaryEmbedding( self.config.qk_rope_head_dim, + rotary_interleaved=self.config.rotary_interleaved, rotary_percent=self.config.rotary_percent, rotary_base=self.config.rotary_base, cp_group=self.pg_collection.cp, @@ -114,6 +115,7 @@ def __init__( elif self.config.rope_type == "yarn": self.rotary_pos_emb = YarnRotaryEmbedding( self.config.qk_rope_head_dim, + rotary_interleaved=self.config.rotary_interleaved, rotary_base=self.config.rotary_base, scaling_factor=self.config.rotary_scaling_factor, original_max_position_embeddings=self.config.original_max_position_embeddings, diff --git a/src/paddlefleet/transformer/transformer_config.py b/src/paddlefleet/transformer/transformer_config.py index 72a0c9a1c..c4529e391 100644 --- a/src/paddlefleet/transformer/transformer_config.py +++ b/src/paddlefleet/transformer/transformer_config.py @@ -656,7 +656,9 @@ def __post_init__(self): if self.moe_layer_freq: moe_layer_pattern = [ 1 if (i % self.moe_layer_freq == 0) else 0 - for i in range(self.num_hidden_layers) + for i in range( + self.num_hidden_layers - self.first_k_dense_replace + ) ] else: moe_layer_pattern = [1] * ( From 468cf8652dd72e2bae4920e06738f8ada7c74f9b Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 23 Mar 2026 21:37:31 +0800 Subject: [PATCH 05/12] fix dsa_attention --- src/paddlefleet/transformer/dsa_attention.py | 73 ++++++-------------- 1 file changed, 20 insertions(+), 53 deletions(-) diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py index cbefd3013..254633672 100644 --- a/src/paddlefleet/transformer/dsa_attention.py +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -82,8 +82,6 @@ def hadamard_transform(x: Tensor, scale: float = 1.0) -> Tensor: # Flatten batch dims: (..., dim) -> (batch, dim) x = x.reshape([-1, dim]) - # Butterfly: iteratively halve and compute sum/diff pairs. - # Uses paddle.stack (not in-place index assignment) to keep autograd intact. h = 1 while h < dim: x = x.reshape([-1, dim // (2 * h), 2, h]) @@ -118,8 +116,6 @@ def rotate_activation(x: Tensor) -> Tensor: # --------------------------------------------------------------------------- # Unfused DSA attention (explicit bmm, supports asymmetric Q/K vs V dims) # --------------------------------------------------------------------------- - - def _unfused_dsa_attention( query: Tensor, key: Tensor, @@ -184,8 +180,6 @@ def _unfused_dsa_attention( # --------------------------------------------------------------------------- # DSA Indexer # --------------------------------------------------------------------------- - - class Indexer(paddle.nn.Layer): """DSA Indexer: DeepSeek Sparse Attention token selection module. @@ -271,7 +265,6 @@ def forward_before_topk( ): """Compute q, k, weights before top-k selection.""" bsz, seqlen, _ = hidden_states.shape - q = self.wq_b(q_latent) # [b, s, n_heads * head_dim] q = q.reshape([bsz, seqlen, self.n_heads, self.head_dim]) if freqs is not None: @@ -702,12 +695,11 @@ def forward( topk_k = min(topk, masked_scores.shape[-1]) topk_indices = paddle.topk(masked_scores, k=topk_k, axis=-1)[1] - # Store topk_indices for caller to retrieve FusedDSAIndexerLoss._last_topk_indices = topk_indices.detach() - # Step 3: Compute KL loss + # Step 3: Compute KL loss (use masked_scores) indexer_loss = _compute_dsa_indexer_loss( - index_scores, + masked_scores, topk_indices, query, key, @@ -749,7 +741,6 @@ def backward(ctx, grad_loss: Tensor): ctx.tp_group, ) - # 6 Tensor inputs: q, weights, k, query, key, mask return grad_q, grad_weights, grad_k, None, None, None @@ -901,8 +892,6 @@ def track_indexer_metrics( # --------------------------------------------------------------------------- # MLASelfAttentionWithDSA — extends upstream MLASelfAttention # --------------------------------------------------------------------------- - - class MLASelfAttentionWithDSA(MLASelfAttention): """MLA Self-attention with DeepSeek Sparse Attention (DSA) Indexer. @@ -973,6 +962,9 @@ def forward( ) assert attention_bias is None assert rotary_pos_cos is None and rotary_pos_sin is None + assert packed_seq_params is None, ( + "packed_seq_params is not supported yet." + ) # ===================== # Query, Key, Value + compressed intermediates @@ -986,14 +978,6 @@ def forward( ) ) - # ===================== - # DSA Indexer inputs (reuse q_compressed from parent, no re-computation) - # ===================== - # q_compressed: [s/tp, b, q_lora_rank] (SP) or [s, b, q_lora_rank] (non-SP) - # Indexer needs full-sequence [b, s, ...] tensors. - # Detach inputs to prevent gradients from flowing back to the main model - # Do NOT use no_grad here — Indexer's own parameters (wq_b, wk, weights_proj) - # need gradient tracking via FusedDSAIndexerLoss manual backward. if self.config.sequence_parallel: indexer_q_latent = gather_from_sequence_parallel_region( q_compressed, @@ -1009,9 +993,15 @@ def forward( indexer_q_latent = q_compressed.detach() indexer_hidden = hidden_states.detach() - # Convert [s, b, ...] -> [b, s, ...] indexer_q_latent = indexer_q_latent.transpose([1, 0, 2]) - indexer_hidden = indexer_hidden.transpose([1, 0, 2]) + if self.config.sequence_parallel: + indexer_hidden = indexer_hidden.transpose([1, 0, 2]) + + # indexer_hidden: [b, s, h] + # indexer_q_latent: [b, s, q_lora_rank] + # query: [s, b, n, h] + # key: [s, b, n, h] + # value: [s, b, n, h] # Get RoPE freqs for Indexer (non-interleaved, computed from rotary_pos_emb) # Re-compute from self.rotary_pos_emb since parent doesn't expose it @@ -1064,11 +1054,6 @@ def forward( axis=-1, ) - # Build causal float_mask for Indexer scoring, matching MG DSAttention.forward: - # MG always passes a causal-aware mask to the Indexer so that topk selection - # never picks future positions. Without this, Indexer could select future - # tokens which are then killed by the causal mask in attention → all-inf rows - # → softmax NaN. indexer_seq_len = indexer_hidden.shape[1] # [b, s, h] — always full seq indexer_causal_mask = paddle.triu( paddle.full( @@ -1104,19 +1089,6 @@ def forward( k_idx_sf = k_idx.transpose([1, 0, 2]) # [sk, b, d] weights_idx_sf = weights_idx.transpose([1, 0, 2]) # [sq, b, h] - # MLA query/key for loss: sequence-first, detached - # query is [s, b, nhpp, hd] or [b, s, nhpp, hd] depending on SP - if self.config.sequence_parallel: - # query is [s/tp, b, nhpp, hd] in SP; need full-seq [s, b, nhpp, hd] - # But for the loss we use the already-gathered full-seq versions from - # the batch-first path after transpose. - mla_query_sf = query # already [s, b, nhpp, hd] in SP mode - mla_key_sf = key - else: - # query is [b, s, nhpp, hd] → [s, b, nhpp, hd] - mla_query_sf = query.transpose([1, 0, 2, 3]) - mla_key_sf = key.transpose([1, 0, 2, 3]) - # FusedDSAIndexerLoss: compute index_scores + topk + KL loss inside PyLayer, # with full manual backward to (q, weights, k). if self.training and self.dsa_indexer_loss_coeff is not None: @@ -1124,8 +1096,8 @@ def forward( q_idx_sf, weights_idx_sf, k_idx_sf, - mla_query_sf.detach(), - mla_key_sf.detach(), + query.detach(), + key.detach(), self.softmax_scale, self.indexer.index_topk, float(self.dsa_indexer_loss_coeff), @@ -1146,12 +1118,8 @@ def forward( # ===================== # Build sparse mask # ===================== - if self.config.sequence_parallel: - seqlen = query.shape[0] # [s, b, nhpp, hd] - bsz = query.shape[1] - else: - bsz = query.shape[0] # [b, s, nhpp, hd] - seqlen = query.shape[1] + seqlen = query.shape[0] # [s, b, nhpp, hd] + bsz = query.shape[1] index_mask = paddle.full( [bsz, seqlen, seqlen], @@ -1184,10 +1152,9 @@ def forward( # ===================== # Core attention (unfused bmm for DSA) # ===================== - if self.config.sequence_parallel: - query = query.transpose([1, 0, 2, 3]).contiguous() - key = key.transpose([1, 0, 2, 3]).contiguous() - value = value.transpose([1, 0, 2, 3]).contiguous() + query = query.transpose([1, 0, 2, 3]).contiguous() + key = key.transpose([1, 0, 2, 3]).contiguous() + value = value.transpose([1, 0, 2, 3]).contiguous() if self.recompute_core_attention and self.training: core_attn_out = recompute( From f4a89c0914f8d728651431410e5907175ca9e7d9 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Tue, 24 Mar 2026 21:56:54 +0800 Subject: [PATCH 06/12] fix --- src/paddlefleet/transformer/dsa_attention.py | 28 +++++++++++++------ .../transformer/multi_latent_attention.py | 1 + 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py index 254633672..420526d63 100644 --- a/src/paddlefleet/transformer/dsa_attention.py +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -751,7 +751,6 @@ class DSAIndexerLossAutoScaler(paddle.autograd.PyLayer): @staticmethod def forward(ctx, output: Tensor, indexer_loss: Tensor) -> Tensor: - print(f"[DSA DEBUG] indexer_loss = {indexer_loss.item():.6f}") ctx.save_for_backward(indexer_loss) return output @@ -978,7 +977,11 @@ def forward( ) ) + # Non-SP: batch-first [b, s, ...] + # SP: seq-first [s/tp, b, ...] if self.config.sequence_parallel: + # SP: q_compressed [s/tp, b, q_lora_rank] -> gather -> [s, b, q_lora_rank] + # -> transpose -> [b, s, q_lora_rank] indexer_q_latent = gather_from_sequence_parallel_region( q_compressed, tensor_parallel_output_grad=True, @@ -989,13 +992,20 @@ def forward( tensor_parallel_output_grad=True, group=self.pg_collection.tp, ).detach() + indexer_q_latent = indexer_q_latent.transpose([1, 0, 2]) + indexer_hidden = indexer_hidden.transpose([1, 0, 2]) else: + # Non-SP: already batch-first [b, s, ...], no transpose needed indexer_q_latent = q_compressed.detach() indexer_hidden = hidden_states.detach() - indexer_q_latent = indexer_q_latent.transpose([1, 0, 2]) - if self.config.sequence_parallel: - indexer_hidden = indexer_hidden.transpose([1, 0, 2]) + # Convert query/key/value to sequence-first [s, b, n, h] for DSA + if not self.config.sequence_parallel: + # Non-SP: [b, s, n, h] -> [s, b, n, h] + query = query.transpose([1, 0, 2, 3]) + key = key.transpose([1, 0, 2, 3]) + value = value.transpose([1, 0, 2, 3]) + # SP: already seq-first [s/tp, b, n, h], no transpose needed # indexer_hidden: [b, s, h] # indexer_q_latent: [b, s, q_lora_rank] @@ -1048,11 +1058,11 @@ def forward( # when config.rotary_interleaved=True, but Indexer uses non-interleaved # RoPE which expects half-half freqs [θ₁,θ₂,...,θ₁,θ₂,...]. # Convert format here to match Indexer's rotary_interleaved=False. - if self.config.rotary_interleaved and indexer_freqs is not None: - indexer_freqs = paddle.concat( - [indexer_freqs[..., 0::2], indexer_freqs[..., 1::2]], - axis=-1, - ) + # if self.config.rotary_interleaved and indexer_freqs is not None: + # indexer_freqs = paddle.concat( + # [indexer_freqs[..., 0::2], indexer_freqs[..., 1::2]], + # axis=-1, + # ) indexer_seq_len = indexer_hidden.shape[1] # [b, s, h] — always full seq indexer_causal_mask = paddle.triu( diff --git a/src/paddlefleet/transformer/multi_latent_attention.py b/src/paddlefleet/transformer/multi_latent_attention.py index 456c5c382..76fee2269 100644 --- a/src/paddlefleet/transformer/multi_latent_attention.py +++ b/src/paddlefleet/transformer/multi_latent_attention.py @@ -116,6 +116,7 @@ def __init__( self.rotary_pos_emb = YarnRotaryEmbedding( self.config.qk_rope_head_dim, rotary_base=self.config.rope_theta, + rotary_interleaved=self.config.rotary_interleaved, scaling_factor=self.config.rotary_scaling_factor, original_max_position_embeddings=self.config.original_max_position_embeddings, beta_fast=self.config.beta_fast, From 43f4567b9f62714737089bf9ea72e47ef148741f Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 25 Mar 2026 11:27:20 +0800 Subject: [PATCH 07/12] add test --- src/paddlefleet/transformer/dsa_attention.py | 3 +- .../transformer/test_dsa_attention.py | 766 ++++++++++++++++++ 2 files changed, 768 insertions(+), 1 deletion(-) create mode 100644 tests/single_card_tests/transformer/test_dsa_attention.py diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py index 420526d63..e047e7e78 100644 --- a/src/paddlefleet/transformer/dsa_attention.py +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -1114,7 +1114,8 @@ def forward( indexer_float_mask, bool(self.dsa_indexer_use_sparse_loss), self.pg_collection.tp - if self.pg_collection.tp.nranks > 1 + if self.pg_collection.tp is not None + and self.pg_collection.tp.nranks > 1 else None, ) topk_indices = FusedDSAIndexerLoss._last_topk_indices diff --git a/tests/single_card_tests/transformer/test_dsa_attention.py b/tests/single_card_tests/transformer/test_dsa_attention.py new file mode 100644 index 000000000..7b6e838f6 --- /dev/null +++ b/tests/single_card_tests/transformer/test_dsa_attention.py @@ -0,0 +1,766 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DSA (DeepSeek Sparse Attention) module. + +Tests are organized in 4 layers: + 1. Pure functions: hadamard_transform, rotate_activation, _unfused_dsa_attention, + _compute_index_scores_fused + 2. Indexer module: forward_before_topk, compute_index_scores, backward + 3. Loss: _compute_dsa_indexer_loss, FusedDSAIndexerLoss, DSAIndexerLossAutoScaler + 4. Integration: MLASelfAttentionWithDSA forward + backward +""" + +import unittest + +import paddle + +from paddlefleet.transformer.dot_product_attention import DotProductAttention +from paddlefleet.transformer.dsa_attention import ( + DSAIndexerLossAutoScaler, + FusedDSAIndexerLoss, + Indexer, + MLASelfAttentionWithDSA, + _compute_dsa_indexer_loss, + _compute_index_scores_fused, + _unfused_dsa_attention, + hadamard_transform, + rotate_activation, +) +from paddlefleet.transformer.enums import AttnMaskType +from paddlefleet.transformer.multi_latent_attention import ( + MLASelfAttentionSublayersSpec, +) +from paddlefleet.transformer.transformer_config import TransformerConfig +from paddlefleet.utils import ( + init_method_normal, + scaled_init_method_normal, +) + + +# --------------------------------------------------------------------------- +# Stub layers (same pattern as test_attention.py) +# --------------------------------------------------------------------------- +class BiasedLinear(paddle.nn.Layer): + def __init__(self, in_features, out_features, **kwargs): + super().__init__() + self.linear = paddle.nn.Linear(in_features, out_features) + + def forward(self, x): + return self.linear(x), self.linear.bias + + +class RMSNorm(paddle.nn.Layer): + def __init__(self, hidden_size, eps, **kwargs): + super().__init__() + self.weight = paddle.nn.Parameter(paddle.ones([hidden_size])) + self.eps = eps + + def forward(self, x): + d_norm = paddle.rsqrt(x.pow(2).mean(axis=-1, keepdim=True) + self.eps) + return x * d_norm * self.weight + + +# --------------------------------------------------------------------------- +# Helper: create DSA-compatible TransformerConfig +# --------------------------------------------------------------------------- +def _create_dsa_config( + hidden_size=256, + num_attention_heads=2, + q_lora_rank=64, + kv_lora_rank=64, + qk_nope_head_dim=32, + qk_rope_head_dim=32, + v_head_dim=64, + index_n_heads=2, + index_head_dim=128, + index_topk=16, + indexer_loss_coeff=1.0, + indexer_use_sparse_loss=False, + sequence_parallel=False, +): + config = TransformerConfig( + num_hidden_layers=2, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ) + # MLA fields + config.num_key_value_heads = num_attention_heads + config.head_dim = hidden_size // num_attention_heads + config.q_lora_rank = q_lora_rank + config.kv_lora_rank = kv_lora_rank + config.qk_nope_head_dim = qk_nope_head_dim + config.qk_rope_head_dim = qk_rope_head_dim + config.v_head_dim = v_head_dim + config.multi_latent_attention = True + + # RoPE / YaRN + config.rope_type = "yarn" + config.rope_theta = 10000.0 + config.rotary_interleaved = False + config.rotary_percent = 1.0 + config.rotary_scaling_factor = 40.0 + config.original_max_position_embeddings = 4096 + config.beta_fast = 32.0 + config.beta_slow = 1.0 + config.mscale = 1.0 + config.mscale_all_dim = 0.0 + config.apply_rope_fusion = False # DSA requires unfused RoPE + + # DSA Indexer fields + config.index_n_heads = index_n_heads + config.index_head_dim = index_head_dim + config.index_topk = index_topk + config.indexer_loss_coeff = indexer_loss_coeff + config.indexer_use_sparse_loss = indexer_use_sparse_loss + + # Attention generic fields + config.softmax_scale = None + config.use_bias = True + config.no_rope_freq = None + config.recompute_granularity = None + config.fused_single_qkv_rope = False + config.init_method = init_method_normal(0.02) + config.output_layer_init_method = scaled_init_method_normal(0.02, 1, 2.0) + config.rms_norm_eps = 1e-5 + config.context_parallel_size = 1 + config.apply_query_key_layer_scaling = False + config.sliding_window = None + config.window_attn_skip_freq = None + config.fp16 = False + config.bf16 = False + config.masked_softmax_fusion = False + config.attention_softmax_in_fp32 = True + config.attention_dropout = 0.0 + config.softmax_type = "vanilla" + config.sequence_parallel = sequence_parallel + + return config + + +def _create_sublayers_spec(): + return MLASelfAttentionSublayersSpec( + q_proj=BiasedLinear, + q_a_proj=BiasedLinear, + q_b_proj=BiasedLinear, + kv_a_proj_with_mqa=BiasedLinear, + kv_b_proj=BiasedLinear, + core_attention=DotProductAttention, + o_proj=BiasedLinear, + q_a_layernorm=RMSNorm, + kv_a_layernorm=RMSNorm, + ) + + +def _make_causal_topk_indices(b, sq, sk, topk): + """Generate topk indices that respect causality (indices <= current position).""" + indices_list = [] + for i in range(sq): + max_idx = min(i + 1, sk) + actual_topk = min(topk, max_idx) + # Pick the last `actual_topk` positions (most recent) + row_indices = paddle.arange(max_idx - actual_topk, max_idx) + if actual_topk < topk: + # Pad with the last valid index + pad = paddle.full([topk - actual_topk], max_idx - 1, dtype="int64") + row_indices = paddle.concat([row_indices, pad]) + indices_list.append(row_indices) + # [sq, topk] -> expand to [b, sq, topk] + indices = ( + paddle.stack(indices_list, axis=0).unsqueeze(0).expand([b, sq, topk]) + ) + return indices + + +# =========================================================================== +# Layer 1: Pure function tests +# =========================================================================== +class TestHadamardTransform(unittest.TestCase): + def test_output_shape(self): + x = paddle.randn([4, 8, 16]) + out = hadamard_transform(x) + self.assertEqual(out.shape, [4, 8, 16]) + + def test_power_of_two_assertion(self): + x = paddle.randn([4, 7]) + with self.assertRaises(AssertionError): + hadamard_transform(x) + + def test_involution(self): + """H(H(x)) = dim * x (Hadamard is involutory up to scaling).""" + dim = 16 + x = paddle.randn([3, dim], dtype="float32") + hx = hadamard_transform(x) + hhx = hadamard_transform(hx) + self.assertTrue(paddle.allclose(hhx, x * dim, atol=1e-4, rtol=1e-4)) + + def test_scale_factor(self): + x = paddle.randn([4, 8]) + out_unscaled = hadamard_transform(x) + out_scaled = hadamard_transform(x, scale=0.5) + self.assertTrue( + paddle.allclose(out_scaled, out_unscaled * 0.5, atol=1e-5) + ) + + def test_1d_input(self): + x = paddle.randn([16]) + out = hadamard_transform(x) + self.assertEqual(out.shape, [16]) + + +class TestRotateActivation(unittest.TestCase): + def test_output_shape(self): + x = paddle.randn([2, 4, 128]).cast("bfloat16") + out = rotate_activation(x) + self.assertEqual(list(out.shape), [2, 4, 128]) + self.assertEqual(out.dtype, paddle.bfloat16) + + def test_dtype_assertion(self): + x = paddle.randn([2, 4, 64], dtype="float32") + with self.assertRaises(AssertionError): + rotate_activation(x) + + +class TestUnfusedDSAAttention(unittest.TestCase): + def setUp(self): + self.b, self.s, self.nhpp = 2, 8, 4 + self.qk_hd, self.v_hd = 32, 64 + self.softmax_scale = self.qk_hd**-0.5 + + def test_output_shape(self): + query = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + key = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + value = paddle.randn([self.b, self.s, self.nhpp, self.v_hd]) + out = _unfused_dsa_attention( + query, key, value, None, self.softmax_scale + ) + self.assertEqual(out.shape, [self.b, self.s, self.nhpp * self.v_hd]) + + def test_with_causal_mask(self): + query = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + key = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + value = paddle.randn([self.b, self.s, self.nhpp, self.v_hd]) + causal = paddle.triu( + paddle.full([self.s, self.s], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) # [1, 1, s, s] + out = _unfused_dsa_attention( + query, key, value, mask, self.softmax_scale + ) + self.assertEqual(out.shape, [self.b, self.s, self.nhpp * self.v_hd]) + + def test_asymmetric_dims(self): + """qk_head_dim != v_head_dim should work.""" + qk_hd, v_hd = 48, 32 + query = paddle.randn([self.b, self.s, self.nhpp, qk_hd]) + key = paddle.randn([self.b, self.s, self.nhpp, qk_hd]) + value = paddle.randn([self.b, self.s, self.nhpp, v_hd]) + out = _unfused_dsa_attention(query, key, value, None, qk_hd**-0.5) + self.assertEqual(out.shape, [self.b, self.s, self.nhpp * v_hd]) + + +class TestComputeIndexScoresFused(unittest.TestCase): + def test_output_shape(self): + sq, b, h, d = 8, 2, 4, 32 + sk = 8 + q = paddle.randn([sq, b, h, d]) + weights = paddle.randn([sq, b, h]) + k = paddle.randn([sk, b, d]) + out = _compute_index_scores_fused(q, weights, k) + self.assertEqual(out.shape, [b, sq, sk]) + + def test_nonnegative_after_relu(self): + sq, b, h, d = 8, 2, 4, 32 + q = paddle.randn([sq, b, h, d]) + # Use positive weights so that relu * positive_weights >= 0 + weights = paddle.abs(paddle.randn([sq, b, h])) + 0.1 + k = paddle.randn([sq, b, d]) + out = _compute_index_scores_fused(q, weights, k) + self.assertTrue((out >= -1e-6).all().item()) + + +# =========================================================================== +# Layer 2: Indexer module tests +# =========================================================================== +class TestIndexer(unittest.TestCase): + def setUp(self): + self.config = _create_dsa_config() + self.indexer = Indexer(self.config, layer_number=1) + self.b = 2 + self.s = 16 + + def _prepare_indexer_bf16(self): + """Convert wq_b/wk to bf16 for rotate_activation, keep weights_proj fp32.""" + self.indexer.wq_b = self.indexer.wq_b.to(dtype="bfloat16") + self.indexer.wk = self.indexer.wk.to(dtype="bfloat16") + self.indexer.k_norm = self.indexer.k_norm.to(dtype="bfloat16") + # weights_proj stays fp32 (code does hidden.cast("float32") before calling it) + + def test_forward_before_topk_shapes(self): + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + + q, k, weights = self.indexer.forward_before_topk( + hidden, q_latent, freqs=None, mscale=1.0 + ) + self.assertEqual( + list(q.shape), + [ + self.b, + self.s, + self.config.index_n_heads, + self.config.index_head_dim, + ], + ) + self.assertEqual( + list(k.shape), + [self.b, self.s, self.config.index_head_dim], + ) + self.assertEqual( + list(weights.shape), + [self.b, self.s, self.config.index_n_heads], + ) + + def test_compute_index_scores_shapes(self): + q = paddle.randn( + [ + self.b, + self.s, + self.config.index_n_heads, + self.config.index_head_dim, + ] + ) + k = paddle.randn([self.b, self.s, self.config.index_head_dim]) + weights = paddle.randn([self.b, self.s, self.config.index_n_heads]) + index_scores, topk_indices = self.indexer.compute_index_scores( + q, k, weights, mask=None + ) + self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) + self.assertEqual( + list(topk_indices.shape), + [self.b, self.s, self.config.index_topk], + ) + + def test_topk_in_range(self): + q = paddle.randn( + [ + self.b, + self.s, + self.config.index_n_heads, + self.config.index_head_dim, + ] + ) + k = paddle.randn([self.b, self.s, self.config.index_head_dim]) + weights = paddle.randn([self.b, self.s, self.config.index_n_heads]) + _, topk_indices = self.indexer.compute_index_scores( + q, k, weights, mask=None + ) + self.assertTrue((topk_indices >= 0).all().item()) + self.assertTrue((topk_indices < self.s).all().item()) + + def test_backward(self): + """Indexer parameters receive gradients.""" + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + + q, k, weights = self.indexer.forward_before_topk( + hidden, q_latent, freqs=None, mscale=1.0 + ) + # rotate_activation requires bf16, so skip it in this unit test + # and just use the raw outputs for gradient checking. + loss = q.cast("float32").sum() + k.cast("float32").sum() + weights.sum() + loss.backward() + + for name, param in self.indexer.named_parameters(): + self.assertIsNotNone( + param.grad, f"Parameter {name} has no gradient" + ) + + +# =========================================================================== +# Layer 3: Loss tests +# =========================================================================== +class TestComputeDSAIndexerLoss(unittest.TestCase): + def setUp(self): + self.sq, self.sk = 8, 8 + self.b, self.np, self.hn = 2, 4, 32 + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + self.loss_coeff = 1.0 + + def _make_inputs(self, sparse=False): + index_scores = paddle.randn([self.b, self.sq, self.sk], dtype="float32") + if sparse: + topk_indices = _make_causal_topk_indices( + self.b, self.sq, self.sk, self.topk + ) + else: + topk_indices = paddle.randint( + 0, self.sk, [self.b, self.sq, self.topk] + ).cast("int64") + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + return index_scores, topk_indices, query, key + + def test_loss_is_scalar(self): + index_scores, topk_indices, query, key = self._make_inputs() + loss = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + self.loss_coeff, + False, + None, + ) + self.assertEqual(loss.shape, []) + + def test_loss_with_sparse(self): + index_scores, topk_indices, query, key = self._make_inputs(sparse=True) + loss = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + self.loss_coeff, + True, + None, + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + def test_loss_coeff_scaling(self): + index_scores, topk_indices, query, key = self._make_inputs() + loss1 = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + 1.0, + False, + None, + ) + loss2 = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + 2.0, + False, + None, + ) + self.assertTrue( + paddle.allclose(loss2, loss1 * 2.0, atol=1e-4), + f"loss2={loss2.item():.6f} != 2*loss1={2 * loss1.item():.6f}", + ) + + +class TestFusedDSAIndexerLoss(unittest.TestCase): + def setUp(self): + self.sq, self.sk = 8, 8 + self.b = 2 + self.h, self.d = 4, 32 # indexer heads/dim + self.np, self.hn = 4, 64 # MLA heads/dim + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + + def _make_inputs(self, with_mask=False): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + if with_mask: + causal = paddle.triu( + paddle.full([self.sq, self.sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) # [1, 1, sq, sk] + else: + mask = None + return q, weights, k, query, key, mask + + def test_forward_returns_scalar(self): + q, weights, k, query, key, mask = self._make_inputs(with_mask=True) + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + False, + None, + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + def test_topk_indices_stored(self): + FusedDSAIndexerLoss._last_topk_indices = None + q, weights, k, query, key, _ = self._make_inputs() + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + None, + False, + None, + ) + self.assertIsNotNone(FusedDSAIndexerLoss._last_topk_indices) + self.assertEqual( + list(FusedDSAIndexerLoss._last_topk_indices.shape), + [self.b, self.sq, self.topk], + ) + + def test_backward_gradients(self): + # Pass a mask tensor so PyLayer sees 6 tensor inputs (q, weights, k, + # query, key, mask) matching the 6 return values in backward. + q, weights, k, query, key, mask = self._make_inputs(with_mask=True) + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + False, + None, + ) + loss.backward() + + self.assertIsNotNone(q.grad) + self.assertIsNotNone(weights.grad) + self.assertIsNotNone(k.grad) + self.assertEqual(list(q.grad.shape), [self.sq, self.b, self.h, self.d]) + self.assertEqual(list(weights.grad.shape), [self.sq, self.b, self.h]) + self.assertEqual(list(k.grad.shape), [self.sk, self.b, self.d]) + self.assertTrue(paddle.isfinite(q.grad).all().item()) + self.assertTrue(paddle.isfinite(weights.grad).all().item()) + self.assertTrue(paddle.isfinite(k.grad).all().item()) + + +class TestDSAIndexerLossAutoScaler(unittest.TestCase): + def _make_non_leaf_output(self, shape): + """Create a non-leaf tensor (required by PyLayer inplace check).""" + x = paddle.randn(shape) + x.stop_gradient = False + return x + 0 # Adding 0 makes it non-leaf + + def test_forward_passthrough(self): + output = self._make_non_leaf_output([2, 8, 64]) + indexer_loss = self._make_non_leaf_output([]) + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + self.assertEqual(list(result.shape), [2, 8, 64]) + + def test_backward_grad_output(self): + output = self._make_non_leaf_output([2, 8, 64]) + indexer_loss = self._make_non_leaf_output([]) + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + loss = result.sum() + loss.backward() + # output is non-leaf (x + 0), so its grad may not be retained, + # but the computation should not error out. + self.assertTrue(True) # Just verify no error + + def test_loss_scale(self): + DSAIndexerLossAutoScaler.set_loss_scale( + paddle.to_tensor(2.0, dtype="float32") + ) + output = self._make_non_leaf_output([2, 4]) + indexer_loss = paddle.to_tensor(1.0, dtype="float32") + indexer_loss.stop_gradient = False + indexer_loss = indexer_loss * 1.0 # Make non-leaf + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + loss = result.sum() + loss.backward() + # Verify no errors + self.assertTrue(True) + # Reset + DSAIndexerLossAutoScaler._main_loss_backward_scale = None + + +# =========================================================================== +# Layer 4: MLASelfAttentionWithDSA integration tests +# =========================================================================== +class TestMLASelfAttentionWithDSA(unittest.TestCase): + def setUp(self): + self.config = _create_dsa_config() + self.micro_batch_size = 2 + self.sequence_length = 32 + + def _build_model(self, config=None): + cfg = config or self.config + model = MLASelfAttentionWithDSA( + cfg, + _create_sublayers_spec(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + # Convert model to bf16 because rotate_activation requires bf16 input. + # But weights_proj does hidden.cast("float32") internally and expects + # fp32 weights, so convert it back to fp32 after the global bf16 cast. + model = model.to(dtype="bfloat16") + model.indexer.weights_proj = model.indexer.weights_proj.to( + dtype="float32" + ) + return model + + def _make_hidden(self, dtype="bfloat16"): + return paddle.randn( + [ + self.micro_batch_size, + self.sequence_length, + self.config.hidden_size, + ], + ).cast(dtype) + + def test_forward_shape(self): + model = self._build_model() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + self.assertEqual(output.shape[2], self.config.hidden_size) + self.assertEqual(bias.shape[0], self.config.hidden_size) + + def test_forward_with_attention_mask(self): + model = self._build_model() + hidden = self._make_hidden() + causal = paddle.triu( + paddle.full( + [self.sequence_length, self.sequence_length], + float("-inf"), + dtype="float32", + ), + diagonal=1, + ) + mask = ( + causal.unsqueeze(0) + .unsqueeze(0) + .expand( + [ + self.micro_batch_size, + 1, + self.sequence_length, + self.sequence_length, + ] + ) + ) + output, bias = model(hidden, attention_mask=mask) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + self.assertEqual(output.shape[2], self.config.hidden_size) + + def test_forward_training_with_loss(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + self.assertEqual(output.shape[2], self.config.hidden_size) + + def test_forward_eval_mode(self): + config = _create_dsa_config(indexer_loss_coeff=None) + model = self._build_model(config) + model.eval() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + + def test_backward_gradients(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + hidden.stop_gradient = False + output, bias = model(hidden, attention_mask=None) + loss = output.cast("float32").sum() + loss.backward() + + self.assertIsNotNone(hidden.grad) + for name, param in model.named_parameters(): + if not param.stop_gradient: + self.assertIsNotNone( + param.grad, f"Parameter {name} has no gradient" + ) + self.assertTrue( + paddle.isfinite(param.grad).all().item(), + f"Parameter {name} has non-finite gradient", + ) + + def test_indexer_params_have_grad(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + hidden.stop_gradient = False + output, bias = model(hidden, attention_mask=None) + loss = output.cast("float32").sum() + loss.backward() + + indexer_param_names = [ + "indexer.wq_b", + "indexer.wk", + "indexer.weights_proj", + ] + for name, param in model.named_parameters(): + for iname in indexer_param_names: + if iname in name: + self.assertIsNotNone( + param.grad, + f"Indexer parameter {name} has no gradient", + ) + + +if __name__ == "__main__": + unittest.main() From 957c8767df35a2e1aac39ef9e59b8b64a91afd3f Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:18:21 +0800 Subject: [PATCH 08/12] Update yarn_rotary_pos_embedding.py --- .../models/common/embeddings/yarn_rotary_pos_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py b/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py index bdef4b4a3..5443106c7 100644 --- a/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -150,8 +150,8 @@ def forward( emb = paddle.cat((freqs, freqs), axis=-1) else: emb = paddle.stack( - (freqs.view(-1, 1), freqs.view(-1, 1)), axis=-1 - ).view(freqs.shape[0], -1) + (freqs.reshape((-1, 1)), freqs.reshape((-1, 1))), axis=-1 + ).reshape((freqs.shape[0], -1)) # emb [1, seq_len, 1, dim] emb = emb[None, :, None, :] return emb, _mscale From 79b00d6f5c5fd1566232165c0fbfffee689ca6cf Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:19:05 +0800 Subject: [PATCH 09/12] Update yarn_rotary_pos_embedding.py --- .../models/common/embeddings/yarn_rotary_pos_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py b/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py index 5443106c7..bdef4b4a3 100644 --- a/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/src/paddlefleet/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -150,8 +150,8 @@ def forward( emb = paddle.cat((freqs, freqs), axis=-1) else: emb = paddle.stack( - (freqs.reshape((-1, 1)), freqs.reshape((-1, 1))), axis=-1 - ).reshape((freqs.shape[0], -1)) + (freqs.view(-1, 1), freqs.view(-1, 1)), axis=-1 + ).view(freqs.shape[0], -1) # emb [1, seq_len, 1, dim] emb = emb[None, :, None, :] return emb, _mscale From 7f779ac3ee2dacfb26d287988edd2445767b33c3 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:20:47 +0800 Subject: [PATCH 10/12] Update multi_latent_attention.py --- src/paddlefleet/transformer/multi_latent_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/paddlefleet/transformer/multi_latent_attention.py b/src/paddlefleet/transformer/multi_latent_attention.py index f20b5e221..d4f2252a1 100644 --- a/src/paddlefleet/transformer/multi_latent_attention.py +++ b/src/paddlefleet/transformer/multi_latent_attention.py @@ -117,7 +117,6 @@ def __init__( self.config.qk_rope_head_dim, rotary_interleaved=self.config.rotary_interleaved, rotary_base=self.config.rope_theta, - rotary_interleaved=self.config.rotary_interleaved, scaling_factor=self.config.rotary_scaling_factor, original_max_position_embeddings=self.config.original_max_position_embeddings, beta_fast=self.config.beta_fast, From 69c0f4eeab596a74f2a269c04c5aee66c68741e9 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 25 Mar 2026 19:12:30 +0800 Subject: [PATCH 11/12] add tests --- .../transformer/test_dsa_attention.py | 897 ++++++++++++++++++ 1 file changed, 897 insertions(+) diff --git a/tests/single_card_tests/transformer/test_dsa_attention.py b/tests/single_card_tests/transformer/test_dsa_attention.py index 7b6e838f6..a91a0e546 100644 --- a/tests/single_card_tests/transformer/test_dsa_attention.py +++ b/tests/single_card_tests/transformer/test_dsa_attention.py @@ -23,15 +23,18 @@ """ import unittest +from unittest.mock import MagicMock, patch import paddle from paddlefleet.transformer.dot_product_attention import DotProductAttention from paddlefleet.transformer.dsa_attention import ( DSAIndexerLossAutoScaler, + DSAIndexerLossLoggingHelper, FusedDSAIndexerLoss, Indexer, MLASelfAttentionWithDSA, + _bwd_fused_indexer_loss, _compute_dsa_indexer_loss, _compute_index_scores_fused, _unfused_dsa_attention, @@ -762,5 +765,899 @@ def test_indexer_params_have_grad(self): ) +# =========================================================================== +# Layer 5: DSAIndexerLossLoggingHelper tests +# =========================================================================== +class TestDSAIndexerLossLoggingHelperSaveLoss(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.save_loss_to_tracker.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def test_save_loss_initializes_values(self): + """First call should create the 'values' tensor with correct size.""" + loss = paddle.to_tensor(0.5, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=4 + ) + self.assertIn("values", DSAIndexerLossLoggingHelper.tracker) + self.assertEqual( + list(DSAIndexerLossLoggingHelper.tracker["values"].shape), [4] + ) + + def test_save_loss_accumulates(self): + """Multiple saves to the same layer should accumulate.""" + loss1 = paddle.to_tensor(0.5, dtype="float32") + loss2 = paddle.to_tensor(0.3, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss1, layer_number=1, num_layers=4 + ) + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss2, layer_number=1, num_layers=4 + ) + self.assertTrue( + paddle.allclose( + DSAIndexerLossLoggingHelper.tracker["values"][0], + paddle.to_tensor(0.8, dtype="float32"), + atol=1e-5, + ) + ) + + def test_save_loss_different_layers(self): + """Saving to different layers puts values in correct positions.""" + loss1 = paddle.to_tensor(1.0, dtype="float32") + loss2 = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss1, layer_number=1, num_layers=3 + ) + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss2, layer_number=3, num_layers=3 + ) + values = DSAIndexerLossLoggingHelper.tracker["values"] + self.assertAlmostEqual(values[0].item(), 1.0, places=5) + self.assertAlmostEqual(values[1].item(), 0.0, places=5) + self.assertAlmostEqual(values[2].item(), 2.0, places=5) + + def test_save_loss_none_layer_number_noop(self): + """layer_number=None should be a no-op.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=None, num_layers=4 + ) + self.assertNotIn("values", DSAIndexerLossLoggingHelper.tracker) + + def test_save_loss_stores_groups(self): + """reduce_group and avg_group should be stored in the tracker.""" + loss = paddle.to_tensor(0.1, dtype="float32") + mock_reduce = MagicMock() + mock_avg = MagicMock() + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + reduce_group=mock_reduce, + avg_group=mock_avg, + ) + self.assertIs( + DSAIndexerLossLoggingHelper.tracker["reduce_group"], mock_reduce + ) + self.assertIs( + DSAIndexerLossLoggingHelper.tracker["avg_group"], mock_avg + ) + + +class TestDSAIndexerLossLoggingHelperClean(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.clean_loss_in_tracker.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def test_clean_zeros_values(self): + """Clean should zero out the values tensor.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=3 + ) + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + values = DSAIndexerLossLoggingHelper.tracker["values"] + self.assertTrue(paddle.allclose(values, paddle.zeros([3]))) + + def test_clean_resets_groups(self): + """Clean should set reduce_group and avg_group to None.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + reduce_group=MagicMock(), + avg_group=MagicMock(), + ) + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + self.assertIsNone(DSAIndexerLossLoggingHelper.tracker["reduce_group"]) + self.assertIsNone(DSAIndexerLossLoggingHelper.tracker["avg_group"]) + + def test_clean_empty_tracker_noop(self): + """Clean on empty tracker should not raise.""" + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + self.assertIsNone( + DSAIndexerLossLoggingHelper.tracker.get("reduce_group") + ) + + +class TestDSAIndexerLossLoggingHelperReduce(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.reduce_loss_in_tracker.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def test_reduce_empty_tracker_noop(self): + """Reduce with no 'values' should be a no-op.""" + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + # Should not raise + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_no_distributed_groups(self, mock_ps): + """Reduce with no distributed groups should keep values unchanged.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + original_values = DSAIndexerLossLoggingHelper.tracker["values"].clone() + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + self.assertTrue( + paddle.allclose( + DSAIndexerLossLoggingHelper.tracker["values"], + original_values, + ) + ) + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_pp_group(self, mock_ps): + """Reduce with PP group should call all_reduce.""" + pp_group = MagicMock() + pp_group.nranks = 2 + mock_ps.get_pipeline_model_parallel_group.return_value = pp_group + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_called_once() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_dp_group(self, mock_ps): + """Reduce with DP group should call all_reduce and divide by nranks.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + dp_group = MagicMock() + dp_group.nranks = 4 + mock_ps.get_data_parallel_group.return_value = dp_group + + loss = paddle.to_tensor(4.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce"): + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + # After DP reduce, values should be divided by nranks + # (all_reduce is mocked, so actual value won't change, but the + # division path is exercised) + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_reduce_group(self, mock_ps): + """Reduce with TP reduce_group should call all_reduce.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(1.0, dtype="float32") + reduce_group = MagicMock() + reduce_group.nranks = 2 + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + reduce_group=reduce_group, + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_called_once() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_avg_group(self, mock_ps): + """Reduce with avg_group should call all_reduce and divide by nranks.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + mock_ps.get_data_parallel_group.return_value = None + + avg_group = MagicMock() + avg_group.nranks = 3 + loss = paddle.to_tensor(3.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + avg_group=avg_group, + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_called_once() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_pp_group_single_rank_skipped(self, mock_ps): + """PP group with nranks=1 should not trigger all_reduce.""" + pp_group = MagicMock() + pp_group.nranks = 1 + mock_ps.get_pipeline_model_parallel_group.return_value = pp_group + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_not_called() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_dp_group_single_rank_skipped(self, mock_ps): + """DP group with nranks=1 should not trigger all_reduce.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + dp_group = MagicMock() + dp_group.nranks = 1 + mock_ps.get_data_parallel_group.return_value = dp_group + + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_not_called() + + +class TestDSAIndexerLossLoggingHelperTrackMetrics(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.track_indexer_metrics.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_empty_tracker_noop(self, mock_reduce): + """With no values, track_indexer_metrics should be a no-op after reduce.""" + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=10 + ) + mock_reduce.assert_called_once() + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_logs_loss(self, mock_reduce): + """track_indexer_metrics should log the averaged indexer loss.""" + loss = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with self.assertLogs( + "paddlefleet.transformer.dsa_attention", level="INFO" + ) as cm: + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=42 + ) + log_output = "\n".join(cm.output) + self.assertIn("42", log_output) + self.assertIn("indexer loss", log_output) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_updates_total_loss_dict(self, mock_reduce): + """track_indexer_metrics should add 'indexer loss' to total_loss_dict.""" + loss = paddle.to_tensor(4.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + total_loss_dict = {} + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=0.5, iteration=1, total_loss_dict=total_loss_dict + ) + self.assertIn("indexer loss", total_loss_dict) + # loss=4.0 at layer 1 only, num_layers=2 + # avg = (4.0 * 0.5) / 2 = 1.0 + expected = 1.0 + self.assertAlmostEqual( + total_loss_dict["indexer loss"].item(), expected, places=4 + ) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_accumulates_total_loss_dict(self, mock_reduce): + """Calling track_indexer_metrics twice should accumulate in total_loss_dict.""" + loss1 = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss1, layer_number=1, num_layers=1 + ) + total_loss_dict = {} + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=1, total_loss_dict=total_loss_dict + ) + first_value = total_loss_dict["indexer loss"].item() + + # Save and track again + loss2 = paddle.to_tensor(3.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss2, layer_number=1, num_layers=1 + ) + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=2, total_loss_dict=total_loss_dict + ) + self.assertAlmostEqual( + total_loss_dict["indexer loss"].item(), + first_value + 3.0, + places=4, + ) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_with_writer(self, mock_reduce): + """track_indexer_metrics should call writer.add_scalar when provided.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=1 + ) + mock_writer = MagicMock() + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=100, writer=mock_writer + ) + mock_writer.add_scalar.assert_called_once() + args = mock_writer.add_scalar.call_args + self.assertEqual(args[0][0], "indexer loss") + self.assertEqual(args[0][2], 100) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_cleans_tracker(self, mock_reduce): + """track_indexer_metrics should clean the tracker after logging.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=1 + ) + # After track_indexer_metrics, values should be zeroed + values = DSAIndexerLossLoggingHelper.tracker["values"] + self.assertTrue(paddle.allclose(values, paddle.zeros([2]))) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_loss_scale_applied(self, mock_reduce): + """loss_scale should be multiplied into the loss values.""" + loss = paddle.to_tensor(6.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=1 + ) + total_loss_dict = {} + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=0.25, iteration=1, total_loss_dict=total_loss_dict + ) + # 6.0 * 0.25 / 1 = 1.5 + self.assertAlmostEqual( + total_loss_dict["indexer loss"].item(), 1.5, places=4 + ) + + +# =========================================================================== +# Layer 6: Additional coverage for Indexer and helper functions +# =========================================================================== +class TestIndexerForward(unittest.TestCase): + """Test Indexer.forward (the combined forward_before_topk + compute_index_scores).""" + + def setUp(self): + self.config = _create_dsa_config() + self.indexer = Indexer(self.config, layer_number=1) + self.b = 2 + self.s = 16 + + def _prepare_indexer_bf16(self): + self.indexer.wq_b = self.indexer.wq_b.to(dtype="bfloat16") + self.indexer.wk = self.indexer.wk.to(dtype="bfloat16") + self.indexer.k_norm = self.indexer.k_norm.to(dtype="bfloat16") + + def test_forward_returns_scores_and_indices(self): + """Indexer.forward should return (index_scores, topk_indices).""" + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + causal = paddle.triu( + paddle.full([self.s, self.s], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + index_scores, topk_indices = self.indexer.forward( + hidden, q_latent, freqs=None, attention_mask=mask, mscale=1.0 + ) + self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) + self.assertEqual( + list(topk_indices.shape), + [self.b, self.s, self.config.index_topk], + ) + + def test_forward_no_mask(self): + """Indexer.forward should work with mask=None.""" + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + index_scores, topk_indices = self.indexer.forward( + hidden, q_latent, freqs=None, attention_mask=None, mscale=1.0 + ) + self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) + + +class TestIndexerComputeScoresWithMask(unittest.TestCase): + """Test mask handling in Indexer.compute_index_scores.""" + + def setUp(self): + self.config = _create_dsa_config(index_topk=4) + self.indexer = Indexer(self.config, layer_number=1) + self.b = 2 + self.s = 8 + + def test_mask_zeros_future_positions(self): + """With causal mask, topk indices should not exceed current position.""" + q = paddle.randn( + [ + self.b, + self.s, + self.config.index_n_heads, + self.config.index_head_dim, + ] + ) + k = paddle.randn([self.b, self.s, self.config.index_head_dim]) + weights = paddle.abs( + paddle.randn([self.b, self.s, self.config.index_n_heads]) + ) + causal = paddle.triu( + paddle.full([self.s, self.s], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + index_scores, topk_indices = self.indexer.compute_index_scores( + q, k, weights, mask=mask + ) + # For a later position (e.g., position 5), all topk indices should be <= 5 + pos = 5 + later_token_indices = topk_indices[:, pos, :] + self.assertTrue( + (later_token_indices <= pos).all().item(), + f"topk indices at position {pos} exceed causal bound", + ) + + +class TestComputeIndexScoresFusedAdditional(unittest.TestCase): + """Additional tests for _compute_index_scores_fused.""" + + def test_matches_unfused(self): + """Fused scores should match unfused Indexer.compute_index_scores logic.""" + sq, b, h, d = 4, 2, 2, 16 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + + fused_scores = _compute_index_scores_fused(q, weights, k) # [b, sq, sk] + + # Manual unfused computation + scores = paddle.einsum("sbhd,tbd->sbht", q, k) + relu_scores = paddle.nn.functional.relu(scores) + weighted = relu_scores * weights.unsqueeze(-1) + summed = weighted.sum(axis=2) # [sq, b, sk] + unfused_scores = summed.transpose([1, 0, 2]) # [b, sq, sk] + + self.assertTrue( + paddle.allclose(fused_scores, unfused_scores, atol=1e-5), + "Fused and unfused scores do not match", + ) + + +class TestComputeDSAIndexerLossWithMask(unittest.TestCase): + """Additional edge cases for _compute_dsa_indexer_loss.""" + + def test_loss_is_non_negative(self): + """KL divergence should be non-negative.""" + sq, sk, b, np, hn = 4, 4, 2, 2, 16 + topk = 2 + index_scores = paddle.randn([b, sq, sk], dtype="float32") + topk_indices = paddle.randint(0, sk, [b, sq, topk]).cast("int64") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sk, b, np, hn], dtype="float32") + + loss = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=False, + tp_group=None, + ) + # KL divergence can be slightly negative due to numerical noise, + # but should be very close to non-negative + self.assertTrue( + loss.item() > -0.1, f"Loss is too negative: {loss.item()}" + ) + self.assertTrue(paddle.isfinite(loss).item()) + + +class TestBwdFusedIndexerLoss(unittest.TestCase): + """Tests for _bwd_fused_indexer_loss manual backward.""" + + def test_backward_shapes(self): + """Manual backward should return gradients with correct shapes.""" + sq, b, h, d = 4, 2, 2, 16 + np, hn = 4, 32 + topk = 2 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sq, b, np, hn], dtype="float32") + topk_indices = paddle.randint(0, sq, [b, sq, topk]).cast("int64") + grad_loss = paddle.to_tensor(1.0, dtype="float32") + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=False, + grad_loss=grad_loss, + tp_group=None, + ) + self.assertEqual(list(grad_q.shape), [sq, b, h, d]) + self.assertEqual(list(grad_weights.shape), [sq, b, h]) + self.assertEqual(list(grad_k.shape), [sq, b, d]) + + def test_backward_finite(self): + """All gradients from manual backward should be finite.""" + sq, b, h, d = 4, 2, 2, 16 + np, hn = 4, 32 + topk = 2 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sq, b, np, hn], dtype="float32") + topk_indices = _make_causal_topk_indices(b, sq, sq, topk) + grad_loss = paddle.to_tensor(1.0, dtype="float32") + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=False, + grad_loss=grad_loss, + tp_group=None, + ) + self.assertTrue(paddle.isfinite(grad_q).all().item()) + self.assertTrue(paddle.isfinite(grad_weights).all().item()) + self.assertTrue(paddle.isfinite(grad_k).all().item()) + + def test_backward_with_sparse_loss(self): + """Manual backward should work with sparse_loss=True.""" + sq, b, h, d = 4, 2, 2, 16 + np, hn = 4, 32 + topk = 2 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sq, b, np, hn], dtype="float32") + topk_indices = _make_causal_topk_indices(b, sq, sq, topk) + grad_loss = paddle.to_tensor(1.0, dtype="float32") + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=True, + grad_loss=grad_loss, + tp_group=None, + ) + self.assertTrue(paddle.isfinite(grad_q).all().item()) + self.assertTrue(paddle.isfinite(grad_weights).all().item()) + self.assertTrue(paddle.isfinite(grad_k).all().item()) + + +class TestFusedDSAIndexerLossNoMask(unittest.TestCase): + """Test FusedDSAIndexerLoss with no mask (mask=None path).""" + + def setUp(self): + self.sq, self.sk = 8, 8 + self.b = 2 + self.h, self.d = 4, 32 + self.np, self.hn = 4, 64 + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + + def test_forward_no_mask(self): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + None, # no mask + False, + None, + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + +class TestFusedDSAIndexerLossSparseLoss(unittest.TestCase): + """Test FusedDSAIndexerLoss with sparse_loss=True.""" + + def setUp(self): + self.sq, self.sk = 8, 8 + self.b = 2 + self.h, self.d = 4, 32 + self.np, self.hn = 4, 64 + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + + def test_forward_sparse_loss(self): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + causal = paddle.triu( + paddle.full([self.sq, self.sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + True, + None, # sparse_loss=True + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + def test_backward_sparse_loss(self): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + causal = paddle.triu( + paddle.full([self.sq, self.sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + True, + None, + ) + loss.backward() + self.assertIsNotNone(q.grad) + self.assertIsNotNone(weights.grad) + self.assertIsNotNone(k.grad) + self.assertTrue(paddle.isfinite(q.grad).all().item()) + self.assertTrue(paddle.isfinite(weights.grad).all().item()) + self.assertTrue(paddle.isfinite(k.grad).all().item()) + + +class TestDSAIndexerLossAutoScalerAdditional(unittest.TestCase): + """Additional tests for DSAIndexerLossAutoScaler edge cases.""" + + def _make_non_leaf_output(self, shape): + x = paddle.randn(shape) + x.stop_gradient = False + return x + 0 + + def test_backward_without_loss_scale(self): + """When _main_loss_backward_scale is None, backward should use ones.""" + DSAIndexerLossAutoScaler._main_loss_backward_scale = None + output = self._make_non_leaf_output([2, 4]) + indexer_loss = paddle.to_tensor(1.0, dtype="float32") + indexer_loss.stop_gradient = False + indexer_loss = indexer_loss * 1.0 + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + loss = result.sum() + loss.backward() + # Should not error; the None-scale path creates ones + self.assertTrue(True) + + def test_set_loss_scale_stores_value(self): + """set_loss_scale should store the scale tensor.""" + scale = paddle.to_tensor(3.14, dtype="float32") + DSAIndexerLossAutoScaler.set_loss_scale(scale) + stored = DSAIndexerLossAutoScaler._main_loss_backward_scale + self.assertIsNotNone(stored) + self.assertAlmostEqual(stored.item(), 3.14, places=4) + DSAIndexerLossAutoScaler._main_loss_backward_scale = None + + def test_forward_preserves_value(self): + """Forward should return output with the same values (passthrough).""" + x = paddle.randn([3, 5]) + x.stop_gradient = False + output = x + 0 + indexer_loss = paddle.to_tensor(0.0, dtype="float32") + indexer_loss.stop_gradient = False + indexer_loss = indexer_loss + 0 + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + self.assertTrue(paddle.allclose(result, output, atol=1e-7)) + + +class TestMLASelfAttentionWithDSASparseLoss(unittest.TestCase): + """Integration test for MLASelfAttentionWithDSA with sparse_loss enabled.""" + + def setUp(self): + self.config = _create_dsa_config(indexer_use_sparse_loss=True) + self.micro_batch_size = 2 + self.sequence_length = 32 + + def _build_model(self, config=None): + cfg = config or self.config + model = MLASelfAttentionWithDSA( + cfg, + _create_sublayers_spec(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + model = model.to(dtype="bfloat16") + model.indexer.weights_proj = model.indexer.weights_proj.to( + dtype="float32" + ) + return model + + def _make_hidden(self, dtype="bfloat16"): + return paddle.randn( + [ + self.micro_batch_size, + self.sequence_length, + self.config.hidden_size, + ] + ).cast(dtype) + + def test_forward_with_sparse_loss(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + + def test_backward_with_sparse_loss(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + hidden.stop_gradient = False + output, bias = model(hidden, attention_mask=None) + loss = output.cast("float32").sum() + loss.backward() + self.assertIsNotNone(hidden.grad) + + +class TestMLASelfAttentionWithDSAZeroLossCoeff(unittest.TestCase): + """Test MLASelfAttentionWithDSA with indexer_loss_coeff=0.""" + + def setUp(self): + self.config = _create_dsa_config(indexer_loss_coeff=0.0) + self.micro_batch_size = 2 + self.sequence_length = 32 + + def _build_model(self): + model = MLASelfAttentionWithDSA( + self.config, + _create_sublayers_spec(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + model = model.to(dtype="bfloat16") + model.indexer.weights_proj = model.indexer.weights_proj.to( + dtype="float32" + ) + return model + + def test_forward_zero_loss_coeff(self): + """With loss_coeff=0, the model should still compute loss but skip logging.""" + model = self._build_model() + model.train() + hidden = paddle.randn( + [ + self.micro_batch_size, + self.sequence_length, + self.config.hidden_size, + ] + ).cast("bfloat16") + DSAIndexerLossLoggingHelper.tracker = {} + output, bias = model(hidden, attention_mask=None) + self.assertEqual(output.shape[0], self.micro_batch_size) + # loss_coeff=0 means save_loss_to_tracker is NOT called (coeff <= 0 check) + # Actually the code checks `if self.dsa_indexer_loss_coeff > 0` + self.assertNotIn("values", DSAIndexerLossLoggingHelper.tracker) + DSAIndexerLossLoggingHelper.tracker = {} + + if __name__ == "__main__": unittest.main() From 3fe98d9900cb49153ff94fa154b8eaa7c227180e Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 27 Mar 2026 19:30:14 +0800 Subject: [PATCH 12/12] rename config name --- src/paddlefleet/transformer/dsa_attention.py | 20 +++--- .../transformer/transformer_config.py | 71 ++++++++++++++++--- .../transformer/test_dsa_attention.py | 47 ++++++------ 3 files changed, 98 insertions(+), 40 deletions(-) diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py index e047e7e78..6644b9fcf 100644 --- a/src/paddlefleet/transformer/dsa_attention.py +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -201,11 +201,11 @@ def __init__(self, config: TransformerConfig, layer_number: int): super().__init__() self.config = config - self.n_heads = config.index_n_heads - self.head_dim = config.index_head_dim + self.n_heads = config.dsa_index_n_heads + self.head_dim = config.dsa_index_head_dim self.rope_head_dim = config.qk_rope_head_dim self.nope_head_dim = self.head_dim - self.rope_head_dim - self.index_topk = config.index_topk + self.index_topk = config.dsa_index_topk self.softmax_scale = self.head_dim**-0.5 self.layer_number = layer_number @@ -236,13 +236,17 @@ def __init__(self, config: TransformerConfig, layer_number: int): def _apply_rope( self, x: Tensor, freqs: Tensor, mscale: float = 1.0 ) -> Tensor: - """Apply non-interleaved RoPE to the pe portion of x. + """Apply RoPE to the pe portion of x. Split order: [pe | nope], matching DeepSeek-V3.2 Indexer (model.py:462). + RoPE format is controlled by config.dsa_indexer_rotary_interleaved: + - False (default): non-interleaved RoPE with half-head frequencies [θ₁,θ₂,...,θ₁,θ₂,...] + - True: interleaved RoPE with paired frequencies [θ₁,θ₁,θ₂,θ₂,...] + Args: x: [..., head_dim] (rope_dim + nope_dim) - freqs: RoPE frequencies (must be half-half format for non-interleaved) + freqs: RoPE frequencies mscale: YaRN concentration factor (1.0 for plain RoPE, ~1.37 for YaRN) """ x_pe = x[..., : self.rope_head_dim] @@ -250,7 +254,7 @@ def _apply_rope( x_pe = _apply_rotary_pos_emb_bshd( x_pe, freqs, - rotary_interleaved=False, + rotary_interleaved=self.config.dsa_indexer_rotary_interleaved, multi_latent_attention=False, mscale=mscale, ) @@ -927,10 +931,10 @@ def __init__( # DSA loss config self.dsa_indexer_loss_coeff = getattr( - config, "indexer_loss_coeff", None + config, "dsa_indexer_loss_coeff", None ) self.dsa_indexer_use_sparse_loss = getattr( - config, "indexer_use_sparse_loss", False + config, "dsa_indexer_use_sparse_loss", False ) def forward( diff --git a/src/paddlefleet/transformer/transformer_config.py b/src/paddlefleet/transformer/transformer_config.py index 9a76e4c73..b72225db8 100644 --- a/src/paddlefleet/transformer/transformer_config.py +++ b/src/paddlefleet/transformer/transformer_config.py @@ -554,21 +554,74 @@ class TransformerConfig(ModelParallelConfig): # DSA (DeepSeek Sparse Attention) #################### - index_n_heads: int | None = None + dsa_index_n_heads: int | None = None """Number of DSA Indexer heads. None disables DSA; non-None activates - DeepSeek V3.2 sparse attention path.""" + DeepSeek V3.2 sparse attention path. - index_head_dim: int = 128 - """Per-head dimension for Indexer Q/K vectors.""" + Note: This field corresponds to the HuggingFace config.json field "index_n_heads". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_index_head_dim: int = 128 + """Per-head dimension for Indexer Q/K vectors. + + Note: This field corresponds to the HuggingFace config.json field "index_head_dim". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_index_topk: int = 2048 + """Number of token positions selected by Indexer per query token. + + Note: This field corresponds to the HuggingFace config.json field "index_topk". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_indexer_loss_coeff: float | None = None + """KL loss coefficient for DSA Indexer training. None disables the KL loss. + + Note: This field corresponds to the HuggingFace config.json field "indexer_loss_coeff". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_indexer_use_sparse_loss: bool = False + """Whether to restrict DSA KL loss to top-k positions only. - index_topk: int = 2048 - """Number of token positions selected by Indexer per query token.""" + Note: This field corresponds to the HuggingFace config.json field "indexer_use_sparse_loss". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_indexer_rotary_interleaved: bool = False + """ + Whether Indexer uses interleaved Rotary Position Embeddings. + + When False (default), Indexer uses non-interleaved RoPE with + half-head frequencies [θ₁,θ₂,...,θ₁,θ₂,...]. + + When True, Indexer uses interleaved RoPE with paired frequencies + [θ₁,θ₁,θ₂,θ₂,...]. + + This allows compatibility with MLA's YaRN RoPE which always generates + interleaved frequencies. + """ - indexer_loss_coeff: float | None = None + dsa_indexer_loss_coeff: float = 0.01 """KL loss coefficient for DSA Indexer training. None disables the KL loss.""" - indexer_use_sparse_loss: bool = False - """Whether to restrict DSA KL loss to top-k positions only.""" + # Field name mapping rules: HuggingFace config.json name -> TransformerConfig name + transform_rules = { + # DSA field mapping + "index_n_heads": "dsa_index_n_heads", + "index_head_dim": "dsa_index_head_dim", + "index_topk": "dsa_index_topk", + "indexer_loss_coeff": "dsa_indexer_loss_coeff", + "indexer_use_sparse_loss": "dsa_indexer_use_sparse_loss", + "indexer_rotary_interleaved": "dsa_indexer_rotary_interleaved", + } @classmethod def from_config(cls, config_dict): diff --git a/tests/single_card_tests/transformer/test_dsa_attention.py b/tests/single_card_tests/transformer/test_dsa_attention.py index a91a0e546..45ed2c5a3 100644 --- a/tests/single_card_tests/transformer/test_dsa_attention.py +++ b/tests/single_card_tests/transformer/test_dsa_attention.py @@ -122,11 +122,12 @@ def _create_dsa_config( config.apply_rope_fusion = False # DSA requires unfused RoPE # DSA Indexer fields - config.index_n_heads = index_n_heads - config.index_head_dim = index_head_dim - config.index_topk = index_topk - config.indexer_loss_coeff = indexer_loss_coeff - config.indexer_use_sparse_loss = indexer_use_sparse_loss + config.dsa_index_n_heads = index_n_heads + config.dsa_index_head_dim = index_head_dim + config.dsa_index_topk = index_topk + config.dsa_indexer_loss_coeff = indexer_loss_coeff + config.dsa_indexer_use_sparse_loss = indexer_use_sparse_loss + config.dsa_indexer_rotary_interleaved = False # Test default value # Attention generic fields config.softmax_scale = None @@ -328,17 +329,17 @@ def test_forward_before_topk_shapes(self): [ self.b, self.s, - self.config.index_n_heads, - self.config.index_head_dim, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, ], ) self.assertEqual( list(k.shape), - [self.b, self.s, self.config.index_head_dim], + [self.b, self.s, self.config.dsa_index_head_dim], ) self.assertEqual( list(weights.shape), - [self.b, self.s, self.config.index_n_heads], + [self.b, self.s, self.config.dsa_index_n_heads], ) def test_compute_index_scores_shapes(self): @@ -346,19 +347,19 @@ def test_compute_index_scores_shapes(self): [ self.b, self.s, - self.config.index_n_heads, - self.config.index_head_dim, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, ] ) - k = paddle.randn([self.b, self.s, self.config.index_head_dim]) - weights = paddle.randn([self.b, self.s, self.config.index_n_heads]) + k = paddle.randn([self.b, self.s, self.config.dsa_index_head_dim]) + weights = paddle.randn([self.b, self.s, self.config.dsa_index_n_heads]) index_scores, topk_indices = self.indexer.compute_index_scores( q, k, weights, mask=None ) self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) self.assertEqual( list(topk_indices.shape), - [self.b, self.s, self.config.index_topk], + [self.b, self.s, self.config.dsa_index_topk], ) def test_topk_in_range(self): @@ -366,12 +367,12 @@ def test_topk_in_range(self): [ self.b, self.s, - self.config.index_n_heads, - self.config.index_head_dim, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, ] ) - k = paddle.randn([self.b, self.s, self.config.index_head_dim]) - weights = paddle.randn([self.b, self.s, self.config.index_n_heads]) + k = paddle.randn([self.b, self.s, self.config.dsa_index_head_dim]) + weights = paddle.randn([self.b, self.s, self.config.dsa_index_n_heads]) _, topk_indices = self.indexer.compute_index_scores( q, k, weights, mask=None ) @@ -1193,7 +1194,7 @@ def test_forward_returns_scores_and_indices(self): self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) self.assertEqual( list(topk_indices.shape), - [self.b, self.s, self.config.index_topk], + [self.b, self.s, self.config.dsa_index_topk], ) def test_forward_no_mask(self): @@ -1226,13 +1227,13 @@ def test_mask_zeros_future_positions(self): [ self.b, self.s, - self.config.index_n_heads, - self.config.index_head_dim, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, ] ) - k = paddle.randn([self.b, self.s, self.config.index_head_dim]) + k = paddle.randn([self.b, self.s, self.config.dsa_index_head_dim]) weights = paddle.abs( - paddle.randn([self.b, self.s, self.config.index_n_heads]) + paddle.randn([self.b, self.s, self.config.dsa_index_n_heads]) ) causal = paddle.triu( paddle.full([self.s, self.s], float("-inf"), dtype="float32"),