diff --git a/src/paddlefleet/models/gpt/gpt_layer_specs.py b/src/paddlefleet/models/gpt/gpt_layer_specs.py index 515ea703f..bb78b983c 100644 --- a/src/paddlefleet/models/gpt/gpt_layer_specs.py +++ b/src/paddlefleet/models/gpt/gpt_layer_specs.py @@ -45,6 +45,9 @@ SelfAttention, SelfAttentionSublayersSpec, ) +from paddlefleet.transformer.dsa_attention import ( + MLASelfAttentionWithDSA, +) from paddlefleet.transformer.enums import AttnMaskType from paddlefleet.transformer.identity_op import IdentityOp from paddlefleet.transformer.mlp import MLP, MLPSublayersSpec @@ -124,12 +127,20 @@ def get_gpt_layer_local_spec( if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." + + # Decide attention class: DSA variant if index_n_heads is configured + use_dsa = ( + config is not None + and getattr(config, "index_n_heads", None) is not None + ) + attn_cls = MLASelfAttentionWithDSA if use_dsa else MLASelfAttention + return LayerSpec( layer=transformer_cls, sublayers_spec=TransformerLayerSublayersSpec( input_layernorm=layer_norm, self_attn=LayerSpec( - layer=MLASelfAttention, + layer=attn_cls, extra_kwargs={"attn_mask_type": AttnMaskType.causal}, sublayers_spec=MLASelfAttentionSublayersSpec( q_proj=backend.column_parallel_linear(), diff --git a/src/paddlefleet/transformer/dsa_attention.py b/src/paddlefleet/transformer/dsa_attention.py new file mode 100644 index 000000000..6644b9fcf --- /dev/null +++ b/src/paddlefleet/transformer/dsa_attention.py @@ -0,0 +1,1211 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DeepSeek Sparse Attention (DSA) extension for Multi-Latent Attention. + +This module extends the upstream MLASelfAttention with DSA Indexer support +(DeepSeek V3.2 architecture): + - Indexer: Token scoring module that selects top-k relevant positions + - FusedDSAIndexerLoss: Fused KL-divergence loss with full manual backward + - DSAIndexerLossAutoScaler: Loss scaling helper + - MLASelfAttentionWithDSA: Subclass of MLASelfAttention with DSA integration + +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import paddle +import paddle.nn.functional as F +from paddle import Tensor +from paddle.distributed.fleet.utils import recompute + +from paddlefleet import parallel_state +from paddlefleet.models.common.embeddings.rope_utils import ( + _apply_rotary_pos_emb_bshd, +) +from paddlefleet.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, +) +from paddlefleet.transformer.enums import AttnMaskType +from paddlefleet.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSublayersSpec, +) + +if TYPE_CHECKING: + from paddlefleet.process_groups_config import ProcessGroupCollection + from paddlefleet.transformer.transformer_config import TransformerConfig + + +def hadamard_transform(x: Tensor, scale: float = 1.0) -> Tensor: + """Fast Walsh-Hadamard Transform using the butterfly algorithm. + + Pure Paddle implementation, equivalent to: + F.linear(x, hadamard_matrix(dim)) * scale + + Uses O(N log N) butterfly operations instead of O(N^2) matrix multiply. + The Hadamard matrix is symmetric and orthogonal, so backward is the same + transform applied to grad_output (handled automatically by Paddle autograd). + + Reference: + - fast-hadamard-transform (Tri Dao): csrc/fast_hadamard_transform_cuda.cu + - PaddleFormers/paddleformers/quantization/hadamard_utils.py (matmul_hadU) + + Args: + x: Input tensor of shape (..., dim). dim must be a power of 2. + scale: Scaling factor applied to the output. + + Returns: + Hadamard-transformed tensor of the same shape. + """ + original_shape = x.shape + dim = original_shape[-1] + assert dim > 0 and (dim & (dim - 1)) == 0, ( + f"hadamard_transform requires dim to be a power of 2, got {dim}" + ) + + # Flatten batch dims: (..., dim) -> (batch, dim) + x = x.reshape([-1, dim]) + + h = 1 + while h < dim: + x = x.reshape([-1, dim // (2 * h), 2, h]) + a = x[:, :, 0, :] + b = x[:, :, 1, :] + x = paddle.stack([a + b, a - b], axis=2) + x = x.reshape([-1, dim]) + h *= 2 + + return x.reshape(original_shape) * scale + + +def rotate_activation(x: Tensor) -> Tensor: + """Apply Hadamard rotation activation. + + Reference: + https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/inference/model.py#L424-L428 + + Args: + x: Input tensor (must be bfloat16). + + Returns: + Rotated tensor. + """ + assert x.dtype == paddle.bfloat16, ( + f"rotate_activation only support bf16 input, but got {x.dtype}" + ) + hidden_size = x.shape[-1] + return hadamard_transform(x, scale=hidden_size**-0.5) + + +# --------------------------------------------------------------------------- +# Unfused DSA attention (explicit bmm, supports asymmetric Q/K vs V dims) +# --------------------------------------------------------------------------- +def _unfused_dsa_attention( + query: Tensor, + key: Tensor, + value: Tensor, + combined_mask: Tensor | None, + softmax_scale: float, +) -> Tensor: + """Unfused DSA sparse attention + + Uses explicit bmm instead of flash attention to support: + - Different Q/K head_dim vs V head_dim (MLA architecture) + - Arbitrary per-token sparse masks from DSA Indexer + + Args: + query: [b, s, nhpp, qk_head_dim] + key: [b, s, nhpp, qk_head_dim] + value: [b, s, nhpp, v_head_dim] (v_head_dim may differ from qk_head_dim) + combined_mask: [b, 1, s, s] (causal + sparse index mask, -inf for masked) + softmax_scale: 1/sqrt(qk_head_dim) + + Returns: + output: [b, s, nhpp * v_head_dim] + """ + b, s, nhpp, qk_hd = query.shape + v_hd = value.shape[-1] + + # Reshape for bmm: [b*nhpp, s, hd] + q = query.transpose([0, 2, 1, 3]).reshape([b * nhpp, s, qk_hd]) + k = key.transpose([0, 2, 1, 3]).reshape([b * nhpp, s, qk_hd]) + v = value.transpose([0, 2, 1, 3]).reshape([b * nhpp, s, v_hd]) + + # Q * K^T with scale: [b*nhpp, s, s] + attn_scores = ( + paddle.bmm(q.cast("float32"), k.cast("float32").transpose([0, 2, 1])) + * softmax_scale + ) + + # Apply combined mask (causal + sparse index mask) + if combined_mask is not None: + mask = ( + combined_mask.expand([b, nhpp, s, s]) + .contiguous() + .reshape([b * nhpp, s, s]) + ) + attn_scores = attn_scores + mask.cast("float32") + + attn_weights = F.softmax(attn_scores, axis=-1) + + # Attention_weights * V: [b*nhpp, s, v_hd] + output = paddle.bmm(attn_weights.cast(v.dtype), v) + + # [b*nhpp, s, v_hd] -> [b, s, nhpp*v_hd] + output = ( + output.reshape([b, nhpp, s, v_hd]) + .transpose([0, 2, 1, 3]) + .reshape([b, s, nhpp * v_hd]) + ) + + return output + + +# --------------------------------------------------------------------------- +# DSA Indexer +# --------------------------------------------------------------------------- +class Indexer(paddle.nn.Layer): + """DSA Indexer: DeepSeek Sparse Attention token selection module. + + For each query token, scores all cached key positions using a lightweight + n_heads-head attention mechanism, then selects the top-k most relevant + positions for the full MLA attention computation. + + Key design notes: + - Uses non-interleaved RoPE (unlike MLA which uses interleaved) + - Uses LayerNorm (not RMSNorm) on K + - nope/pe split order: [nope | pe] + - Uses ReLU-aggregated scoring across heads + - Per-head learned importance weights via weights_proj + - weights absorbs softmax_scale + + """ + + def __init__(self, config: TransformerConfig, layer_number: int): + super().__init__() + self.config = config + + self.n_heads = config.dsa_index_n_heads + self.head_dim = config.dsa_index_head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = self.head_dim - self.rope_head_dim + self.index_topk = config.dsa_index_topk + self.softmax_scale = self.head_dim**-0.5 + self.layer_number = layer_number + + # wq_b: q_lora_rank -> n_heads * head_dim (duplicated) + self.wq_b = paddle.nn.Linear( + config.q_lora_rank, + self.n_heads * self.head_dim, + bias_attr=False, + ) + + # wk: hidden_size -> head_dim (single shared K, duplicated) + self.wk = paddle.nn.Linear( + config.hidden_size, + self.head_dim, + bias_attr=False, + ) + + # k_norm: LayerNorm (NOT RMSNorm) per reference + self.k_norm = paddle.nn.LayerNorm(self.head_dim, epsilon=1e-6) + + # weights_proj: learned per-head importance [hidden -> n_heads] + self.weights_proj = paddle.nn.Linear( + config.hidden_size, + self.n_heads, + bias_attr=False, + ) + + def _apply_rope( + self, x: Tensor, freqs: Tensor, mscale: float = 1.0 + ) -> Tensor: + """Apply RoPE to the pe portion of x. + + Split order: [pe | nope], matching DeepSeek-V3.2 Indexer (model.py:462). + + RoPE format is controlled by config.dsa_indexer_rotary_interleaved: + - False (default): non-interleaved RoPE with half-head frequencies [θ₁,θ₂,...,θ₁,θ₂,...] + - True: interleaved RoPE with paired frequencies [θ₁,θ₁,θ₂,θ₂,...] + + Args: + x: [..., head_dim] (rope_dim + nope_dim) + freqs: RoPE frequencies + mscale: YaRN concentration factor (1.0 for plain RoPE, ~1.37 for YaRN) + """ + x_pe = x[..., : self.rope_head_dim] + x_nope = x[..., self.rope_head_dim :] + x_pe = _apply_rotary_pos_emb_bshd( + x_pe, + freqs, + rotary_interleaved=self.config.dsa_indexer_rotary_interleaved, + multi_latent_attention=False, + mscale=mscale, + ) + return paddle.concat([x_pe, x_nope], axis=-1) + + def forward_before_topk( + self, + hidden_states: Tensor, # [b, s, hidden_size] + q_latent: Tensor, # [b, s, q_lora_rank] + freqs: Tensor, + mscale: float = 1.0, + ): + """Compute q, k, weights before top-k selection.""" + bsz, seqlen, _ = hidden_states.shape + q = self.wq_b(q_latent) # [b, s, n_heads * head_dim] + q = q.reshape([bsz, seqlen, self.n_heads, self.head_dim]) + if freqs is not None: + q = self._apply_rope(q, freqs, mscale) + + k = self.wk(hidden_states) # [b, s, head_dim] + k = self.k_norm(k) + if freqs is not None: + k = self._apply_rope(k.unsqueeze(2), freqs, mscale).squeeze(2) + + # Rotate activation (Hadamard transform) + q = rotate_activation(q) + k = rotate_activation(k) + + weights = ( + self.weights_proj(hidden_states.cast("float32")) + * (self.n_heads**-0.5) + * self.softmax_scale + ) + + return q, k, weights + + def compute_index_scores( + self, + q: Tensor, # [b, s, n_heads, head_dim] + k: Tensor, # [b, t, head_dim] + weights: Tensor, # [b, s, n_heads] + mask: Tensor | None = None, + ): + """Compute index scores and select top-k.""" + q_fp32 = q.cast("float32") + k_fp32 = k.cast("float32") + + scores = paddle.einsum("bshd,btd->bsht", q_fp32, k_fp32) + index_scores = (weights.unsqueeze(-1) * F.relu(scores)).sum(axis=2) + + if mask is not None: + index_scores = index_scores + mask.squeeze(1) + + topk_k = min(self.index_topk, index_scores.shape[-1]) + topk_indices = paddle.topk(index_scores, k=topk_k, axis=-1)[1] + + return index_scores, topk_indices + + def forward( + self, + hidden_states: Tensor, + q_latent: Tensor, + freqs: Tensor, + attention_mask: Tensor, + mscale: float = 1.0, + ) -> tuple[Tensor, Tensor]: + """Compute DSA token importance scores and return scores + top-k indices.""" + q, k, weights = self.forward_before_topk( + hidden_states, q_latent, freqs, mscale + ) + index_scores, topk_indices = self.compute_index_scores( + q, k, weights, attention_mask + ) + return index_scores, topk_indices + + +def _compute_index_scores_fused( + q: Tensor, weights: Tensor, k: Tensor +) -> Tensor: + """Compute index scores from Indexer outputs. + + Args: + q: [sq, b, h, d] (Indexer query, after RoPE + Hadamard) + weights: [sq, b, h] (per-head importance weights) + k: [sk, b, d] (Indexer key, after RoPE + Hadamard) + + Returns: + index_scores: [b, sq, sk] + """ + # q @ k^T -> [sq, b, h, sk] + index_scores = paddle.einsum( + "sbhd,tbd->sbht", q.cast("float32"), k.cast("float32") + ) + # ReLU activation + index_scores = F.relu(index_scores) + # Weight each head: [sq, b, h, sk] * [sq, b, h, 1] -> [sq, b, h, sk] + index_scores = index_scores * weights.unsqueeze(-1) + # Sum across heads: [sq, b, h, sk] -> [sq, b, sk] + index_scores = index_scores.sum(axis=2) + # Transpose to [b, sq, sk] + index_scores = index_scores.transpose([1, 0, 2]) + return index_scores + + +def _compute_dsa_indexer_loss( + index_scores: Tensor, + topk_indices: Tensor, + query: Tensor, + key: Tensor, + softmax_scale: float, + loss_coeff: float, + sparse_loss: bool, + tp_group, +) -> Tensor: + """Compute KL divergence loss between index_scores and true attention_scores. + + Args: + index_scores: [b, sq, sk] + topk_indices: [b, sq, topk] + query: [sq, b, np, hn] (MLA query, DETACHED) + key: [sk, b, np, hn] (MLA key, DETACHED) + softmax_scale: Scale coefficient after q @ k^T + loss_coeff: Coefficient for the indexer KL divergence loss + sparse_loss: Whether to apply sparse index mask + tp_group: TP process group (or None) + + Returns: + indexer_loss: scalar + """ + sq, b, np, hn = query.shape + sk = key.shape[0] + + # [sq, b, np, hn] -> [b, np, sq, hn] -> [b * np, sq, hn] + query_reshaped = query.transpose([1, 2, 0, 3]).reshape([b * np, sq, hn]) + # [sk, b, np, hn] -> [b, np, hn, sk] -> [b * np, hn, sk] + key_reshaped = key.transpose([1, 2, 3, 0]).reshape([b * np, hn, sk]) + # Compute attention scores [b * np, sq, sk] + attention_scores = ( + paddle.bmm(query_reshaped.cast("float32"), key_reshaped.cast("float32")) + * softmax_scale + ) + # Reshape to [b, np, sq, sk] + attention_scores = attention_scores.reshape([b, np, sq, sk]) + + # causal_mask [sq, sk] + causal_mask = paddle.triu( + paddle.full([sq, sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + # index_mask [b, sq, sk] + index_mask = paddle.full([b, sq, sk], float("-inf"), dtype="float32") + index_mask = paddle.put_along_axis( + index_mask, + topk_indices, + paddle.zeros_like(topk_indices, dtype="float32"), + axis=-1, + ) + + # [b, np, sq, sk] + [1, 1, sq, sk] -> [b, np, sq, sk] + attention_scores = attention_scores + causal_mask.reshape([1, 1, sq, sk]) + if sparse_loss: + # [b, np, sq, sk] + [b, 1, sq, sk] -> [b, np, sq, sk] + attention_scores = attention_scores + index_mask.reshape([b, 1, sq, sk]) + # [b, sq, sk] + [b, sq, sk] -> [b, sq, sk] + index_scores = index_scores + index_mask + + # [b, np, sq, sk] -> [b, np, sq, sk] + attention_scores = F.softmax(attention_scores, axis=-1, dtype="float32") + # [b, sq, sk] -> [b, sq, sk] + index_scores = F.softmax(index_scores, axis=-1, dtype="float32") + + # Sum attention scores across heads: [b, np, sq, sk] -> [b, sq, sk] + attention_scores = attention_scores.sum(axis=1) + if tp_group is not None and tp_group.nranks > 1: + paddle.distributed.all_reduce( + attention_scores.contiguous(), group=tp_group + ) + # L1 normalize target on the last dimension + attention_scores = attention_scores / attention_scores.sum( + axis=-1, keepdim=True + ) + + # KL divergence: KL(target || index) = target * log(target / index) + kl_per_element = attention_scores * ( + paddle.log(attention_scores + 1e-10) - paddle.log(index_scores + 1e-10) + ) + + # [b, sq, sk] -> [b, sq] -> [1] + kl_div = kl_per_element.sum(axis=-1).mean() + indexer_loss = kl_div * loss_coeff + + return indexer_loss + + +def _bwd_fused_indexer_loss( + q: Tensor, + weights: Tensor, + k: Tensor, + query: Tensor, + key: Tensor, + topk_indices: Tensor, + softmax_scale: float, + loss_coeff: float, + sparse_loss: bool, + grad_loss: Tensor, + tp_group, +) -> tuple[Tensor, Tensor, Tensor]: + """Manual backward for fused indexer loss. + + + All tensor layouts (sequence-first): + q: [sq, b, h, d] + weights: [sq, b, h] + k: [sk, b, d] + query: [sq, b, np, hn] (MLA query) + key: [sk, b, np, hn] (MLA key) + + Returns: + grad_q: [sq, b, h, d] + grad_weights: [sq, b, h] + grad_k: [sk, b, d] + """ + # Recompute index_scores from (q, weights, k) + index_scores = _compute_index_scores_fused(q, weights, k) # [b, sq, sk] + + sq, b, np, hn = query.shape + sk = key.shape[0] + + # [sq, b, np, hn] -> [b, np, sq, hn] -> [b * np, sq, hn] + query_reshaped = query.transpose([1, 2, 0, 3]).reshape([b * np, sq, hn]) + # [sk, b, np, hn] -> [b, np, hn, sk] -> [b * np, hn, sk] + key_reshaped = key.transpose([1, 2, 3, 0]).reshape([b * np, hn, sk]) + # Compute attention scores [b * np, sq, sk] + attention_scores = ( + paddle.bmm(query_reshaped.cast("float32"), key_reshaped.cast("float32")) + * softmax_scale + ) + del query_reshaped, key_reshaped + + # Reshape to [b, np, sq, sk] + attention_scores = attention_scores.reshape([b, np, sq, sk]) + + # causal_mask [sq, sk] + causal_mask = paddle.triu( + paddle.full([sq, sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + # index_mask [b, sq, sk] + index_mask = paddle.full([b, sq, sk], float("-inf"), dtype="float32") + index_mask = paddle.put_along_axis( + index_mask, + topk_indices, + paddle.zeros_like(topk_indices, dtype="float32"), + axis=-1, + ) + + # Apply causal mask to both attention and index scores + attention_scores = attention_scores + causal_mask.reshape([1, 1, sq, sk]) + index_scores = index_scores + causal_mask.unsqueeze(0) + del causal_mask + + if sparse_loss: + attention_scores = attention_scores + index_mask.reshape([b, 1, sq, sk]) + index_scores = index_scores + index_mask + + # Compute softmax for both + attention_scores_softmax = F.softmax( + attention_scores, axis=-1, dtype="float32" + ) + del attention_scores + + index_scores_softmax = F.softmax(index_scores, axis=-1, dtype="float32") + del index_scores + + # Sum attention scores across heads: [b, np, sq, sk] -> [b, sq, sk] + attention_scores_sum = attention_scores_softmax.sum(axis=1) + del attention_scores_softmax + + if tp_group is not None and tp_group.nranks > 1: + paddle.distributed.all_reduce( + attention_scores_sum.contiguous(), group=tp_group + ) + + # L1 normalize + attention_scores_normalized = ( + attention_scores_sum / attention_scores_sum.sum(axis=-1, keepdim=True) + ) + del attention_scores_sum + + # Backward through loss = kl_div * loss_coeff + # where kl_div = kl_per_element.sum(dim=-1).mean() + grad_kl_div = grad_loss.cast("float32") * loss_coeff # scalar + + # Backward through mean: distribute gradient equally + grad_kl_per_row = grad_kl_div / (b * sq) # scalar + + # Backward through sum(dim=-1): broadcast back to [b, sq, sk] + grad_kl_per_element = grad_kl_per_row.reshape([1, 1, 1]).expand([b, sq, sk]) + + # Backward through kl: ∂kl/∂index_softmax = -target / index_softmax + grad_index_scores_softmax = ( + -attention_scores_normalized + / (index_scores_softmax + 1e-10) + * grad_kl_per_element + ) + del attention_scores_normalized + + # Backward through softmax: + # ∂L/∂x = softmax * (∂L/∂softmax - sum(∂L/∂softmax * softmax)) + sum_grad = (grad_index_scores_softmax * index_scores_softmax).sum( + axis=-1, keepdim=True + ) + grad_index_scores_logits = index_scores_softmax * ( + grad_index_scores_softmax - sum_grad + ) + del index_scores_softmax, grad_index_scores_softmax, sum_grad + + # Zero out gradients for masked positions + causal_valid_mask = paddle.tril( + paddle.ones([sq, sk], dtype="bool") + ) # [sq, sk] + if sparse_loss: + index_valid_mask = index_mask == 0 # [b, sq, sk] + del index_mask + valid_mask = ( + causal_valid_mask.unsqueeze(0) & index_valid_mask + ) # [b, sq, sk] + del index_valid_mask + else: + del index_mask + valid_mask = causal_valid_mask.unsqueeze(0).expand( + [b, sq, sk] + ) # [b, sq, sk] + del causal_valid_mask + + grad_index_scores_logits = grad_index_scores_logits * valid_mask.cast( + "float32" + ) + del valid_mask + + # Transpose from [b, sq, sk] to [sq, b, sk] + grad_index_scores = grad_index_scores_logits.transpose( + [1, 0, 2] + ) # [sq, b, sk] + del grad_index_scores_logits + + # Backward through sum over heads: expand gradient + grad_weighted_scores = grad_index_scores.unsqueeze(2) # [sq, b, 1, sk] + del grad_index_scores + + # Compute forward values needed for backward (recomputation) + scores = paddle.einsum( + "sbhd,tbd->sbht", q.cast("float32"), k.cast("float32") + ) # [sq, b, h, sk] + relu_mask = scores > 0 + scores_after_relu = F.relu(scores) + del scores + + # Backward through multiplication by weights: + # ∂L/∂weights = grad * relu_scores (sum over sk) + grad_weights = (grad_weighted_scores * scores_after_relu).sum( + axis=-1 + ) # [sq, b, h] + + # ∂L/∂relu_scores = grad * weights + grad_scores_after_relu = grad_weighted_scores * weights.unsqueeze( + -1 + ) # [sq, b, h, sk] + del grad_weighted_scores, scores_after_relu + + # Backward through ReLU + grad_scores = grad_scores_after_relu * relu_mask.cast( + "float32" + ) # [sq, b, h, sk] + del grad_scores_after_relu, relu_mask + + # Backward through einsum 'sbhd,tbd->sbht' + # ∂L/∂q = einsum('sbht,tbd->sbhd', grad_scores, k) + grad_q = paddle.einsum( + "sbht,tbd->sbhd", grad_scores, k.cast("float32") + ) # [sq, b, h, d] + # ∂L/∂k = einsum('sbht,sbhd->tbd', grad_scores, q) + grad_k = paddle.einsum( + "sbht,sbhd->tbd", grad_scores, q.cast("float32") + ) # [sk, b, d] + del grad_scores + + return ( + grad_q.cast(q.dtype), + grad_weights.cast(weights.dtype), + grad_k.cast(k.dtype), + ) + + +class FusedDSAIndexerLoss(paddle.autograd.PyLayer): + """Fused DSA Indexer Loss: index_scores + topk + KL loss + full manual backward.""" + + _last_topk_indices: Tensor | None = None + + @staticmethod + def forward( + ctx, + q: Tensor, # [sq, b, h, d] — Indexer query output + weights: Tensor, # [sq, b, h] — Indexer per-head weights + k: Tensor, # [sk, b, d] — Indexer key output + query: Tensor, # [sq, b, np, hn] — MLA query (DETACHED) + key: Tensor, # [sk, b, np, hn] — MLA key (DETACHED) + # Non-tensor params follow (stored on ctx, not in backward returns) + softmax_scale: float = 1.0, + topk: int = 64, + loss_coeff: float = 1.0, + mask: Tensor | None = None, + sparse_loss: bool = True, + tp_group=None, + ) -> Tensor: + """Fused forward: compute index_scores, topk, and KL loss. + + Args: + q: Indexer query after RoPE+Hadamard [sq, b, h, d] + weights: Per-head importance weights [sq, b, h] + k: Indexer key after RoPE+Hadamard [sk, b, d] + query: MLA query (detached) [sq, b, np, hn] + key: MLA key (detached) [sk, b, np, hn] + softmax_scale: MLA attention softmax scale + topk: Number of top-k indices to select + loss_coeff: Coefficient for KL loss + mask: Optional mask for index_scores [b, 1, sq, sk] or [1, 1, sq, sk] + sparse_loss: Whether to use sparse index mask in loss + tp_group: TP process group (or None) + + Returns: + indexer_loss: scalar KL divergence loss + """ + # Step 1: Compute index_scores from (q, weights, k) + index_scores = _compute_index_scores_fused(q, weights, k) # [b, sq, sk] + + # Step 2: Apply mask and select topk + if mask is not None: + masked_scores = index_scores + mask.squeeze(1) + else: + masked_scores = index_scores + topk_k = min(topk, masked_scores.shape[-1]) + topk_indices = paddle.topk(masked_scores, k=topk_k, axis=-1)[1] + + FusedDSAIndexerLoss._last_topk_indices = topk_indices.detach() + + # Step 3: Compute KL loss (use masked_scores) + indexer_loss = _compute_dsa_indexer_loss( + masked_scores, + topk_indices, + query, + key, + softmax_scale, + loss_coeff, + sparse_loss, + tp_group, + ) + + ctx.save_for_backward(q, weights, k, query, key, topk_indices) + ctx.softmax_scale = softmax_scale + ctx.loss_coeff = loss_coeff + ctx.sparse_loss = sparse_loss + ctx.tp_group = tp_group + + return indexer_loss + + @staticmethod + def backward(ctx, grad_loss: Tensor): + """Backward: recompute and manually backprop to (q, weights, k). + + Returns 6 gradients for the 6 Tensor inputs to forward: + q, weights, k, query, key, mask + (Paddle PyLayer only counts Tensor params, not float/int/bool/None.) + """ + q, weights, k, query, key, topk_indices = ctx.saved_tensor() + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + ctx.softmax_scale, + ctx.loss_coeff, + ctx.sparse_loss, + grad_loss, + ctx.tp_group, + ) + + return grad_q, grad_weights, grad_k, None, None, None + + +class DSAIndexerLossAutoScaler(paddle.autograd.PyLayer): + """Attaches indexer_loss to the backward graph without changing output value.""" + + _main_loss_backward_scale: Tensor | None = None + + @staticmethod + def forward(ctx, output: Tensor, indexer_loss: Tensor) -> Tensor: + ctx.save_for_backward(indexer_loss) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + (indexer_loss,) = ctx.saved_tensor() + scale = DSAIndexerLossAutoScaler._main_loss_backward_scale + if scale is None: + scale = paddle.ones([1], dtype=indexer_loss.dtype) + scaled_grad = paddle.ones_like(indexer_loss) * scale + return grad_output, scaled_grad + + @staticmethod + def set_loss_scale(scale: Tensor): + DSAIndexerLossAutoScaler._main_loss_backward_scale = scale + + +logger = logging.getLogger(__name__) + + +class DSAIndexerLossLoggingHelper: + """Helper class for logging sparse attention indexer losses across layers and ranks.""" + + tracker: dict = {} + + @staticmethod + def save_loss_to_tracker( + loss: Tensor, + layer_number: int, + num_layers: int, + reduce_group=None, + avg_group=None, + ): + """Save the indexer loss for logging. + + Args: + loss: The loss tensor (scalar). + layer_number: Layer index of the loss, 1-indexed. + num_layers: The number of total layers. + reduce_group: The group for reducing the loss. + avg_group: The group for averaging the loss. + """ + if layer_number is None: + return + + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" not in tracker: + tracker["values"] = paddle.zeros([num_layers]) + tracker["values"][layer_number - 1] += loss.detach() + tracker["reduce_group"] = reduce_group + tracker["avg_group"] = avg_group + + @staticmethod + def clean_loss_in_tracker(): + """Clear the indexer losses.""" + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" in tracker: + tracker["values"].zero_() + tracker["reduce_group"] = None + tracker["avg_group"] = None + + @staticmethod + def reduce_loss_in_tracker(): + """Collect and reduce the indexer losses across ranks.""" + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" not in tracker: + return + values = tracker["values"] + + # PP all-reduce + pp_group = parallel_state.get_pipeline_model_parallel_group( + check_initialized=False + ) + if pp_group is not None and pp_group.nranks > 1: + paddle.distributed.all_reduce(values, group=pp_group) + + # TP reduce + if tracker.get("reduce_group") is not None: + paddle.distributed.all_reduce(values, group=tracker["reduce_group"]) + + # CP avg + if tracker.get("avg_group") is not None: + paddle.distributed.all_reduce(values, group=tracker["avg_group"]) + values /= tracker["avg_group"].nranks + + # DP avg + dp_group = parallel_state.get_data_parallel_group( + check_initialized=False + ) + if dp_group is not None and dp_group.nranks > 1: + paddle.distributed.all_reduce(values, group=dp_group) + values /= dp_group.nranks + + @staticmethod + def track_indexer_metrics( + loss_scale: float, + iteration: int, + writer=None, + total_loss_dict: dict | None = None, + ): + """Track the sparse attention indexer metrics for logging. + + Args: + loss_scale: Scale factor for the loss (e.g. 1/num_microbatches). + iteration: Current training iteration. + writer: TensorBoard writer (optional). + total_loss_dict: Dictionary to accumulate total losses (optional). + """ + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + tracker = DSAIndexerLossLoggingHelper.tracker + if "values" not in tracker: + return + + indexer_loss_values = tracker["values"] * loss_scale + num_layers = indexer_loss_values.shape[0] + avg_indexer_loss = indexer_loss_values.sum() / num_layers + + if total_loss_dict is not None: + if "indexer loss" in total_loss_dict: + total_loss_dict["indexer loss"] += avg_indexer_loss + else: + total_loss_dict["indexer loss"] = avg_indexer_loss + + if writer is not None: + writer.add_scalar( + "indexer loss", avg_indexer_loss.item(), iteration + ) + + logger.info( + "Iteration %d | indexer loss: %.6f", + iteration, + avg_indexer_loss.item(), + ) + + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + + +# --------------------------------------------------------------------------- +# MLASelfAttentionWithDSA — extends upstream MLASelfAttention +# --------------------------------------------------------------------------- +class MLASelfAttentionWithDSA(MLASelfAttention): + """MLA Self-attention with DeepSeek Sparse Attention (DSA) Indexer. + + Extends the upstream MLASelfAttention by: + 1. Reusing parent's get_query_key_value_tensors() for Q/K/V + q_compressed. + 2. Overriding forward() to run the DSA Indexer, build a sparse mask, + and use unfused bmm attention. + + The Indexer needs q_compressed (normed, from MLA down-projection) and + hidden_states in [b, s, h] format. The parent's get_query_key_value_tensors + now returns (query, key, value, q_compressed, kv_compressed) + """ + + def __init__( + self, + config: TransformerConfig, + sublayers_spec: MLASelfAttentionSublayersSpec, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + cp_comm_type: str | None = None, + pg_collection: ProcessGroupCollection | None = None, + ): + super().__init__( + config=config, + sublayers_spec=sublayers_spec, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + cp_comm_type=cp_comm_type, + pg_collection=pg_collection, + ) + + # DSA Indexer + self.indexer = Indexer(config, layer_number) + + # DSA loss config + self.dsa_indexer_loss_coeff = getattr( + config, "dsa_indexer_loss_coeff", None + ) + self.dsa_indexer_use_sparse_loss = getattr( + config, "dsa_indexer_use_sparse_loss", False + ) + + def forward( + self, + hidden_states, + attention_mask, + attn_mask_startend_row_indices: Tensor | None = None, + key_value_states=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + packed_seq_params=None, + in_recompute: bool = False, + position_ids=None, + ): + """Forward: MLA projections + DSA Indexer + sparse attention. + + Overrides the parent forward to: + 1. Reuse parent's get_query_key_value_tensors for Q/K/V + q_compressed + 2. Prepare Indexer inputs from returned q_compressed (no re-computation) + 3. Run DSA Indexer to get top-k indices + 4. Build sparse mask and use unfused bmm attention + 5. Compute DSA Indexer KL loss if enabled + """ + assert rotary_pos_emb is None, ( + "Rotary position embeddings should not be passed into MLA." + ) + assert attention_bias is None + assert rotary_pos_cos is None and rotary_pos_sin is None + assert packed_seq_params is None, ( + "packed_seq_params is not supported yet." + ) + + # ===================== + # Query, Key, Value + compressed intermediates + # ===================== + query, key, value, q_compressed, kv_compressed = ( + self.get_query_key_value_tensors( + hidden_states, + key_value_states, + position_ids, + packed_seq_params, + ) + ) + + # Non-SP: batch-first [b, s, ...] + # SP: seq-first [s/tp, b, ...] + if self.config.sequence_parallel: + # SP: q_compressed [s/tp, b, q_lora_rank] -> gather -> [s, b, q_lora_rank] + # -> transpose -> [b, s, q_lora_rank] + indexer_q_latent = gather_from_sequence_parallel_region( + q_compressed, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ).detach() + indexer_hidden = gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=True, + group=self.pg_collection.tp, + ).detach() + indexer_q_latent = indexer_q_latent.transpose([1, 0, 2]) + indexer_hidden = indexer_hidden.transpose([1, 0, 2]) + else: + # Non-SP: already batch-first [b, s, ...], no transpose needed + indexer_q_latent = q_compressed.detach() + indexer_hidden = hidden_states.detach() + + # Convert query/key/value to sequence-first [s, b, n, h] for DSA + if not self.config.sequence_parallel: + # Non-SP: [b, s, n, h] -> [s, b, n, h] + query = query.transpose([1, 0, 2, 3]) + key = key.transpose([1, 0, 2, 3]) + value = value.transpose([1, 0, 2, 3]) + # SP: already seq-first [s/tp, b, n, h], no transpose needed + + # indexer_hidden: [b, s, h] + # indexer_q_latent: [b, s, q_lora_rank] + # query: [s, b, n, h] + # key: [s, b, n, h] + # value: [s, b, n, h] + + # Get RoPE freqs for Indexer (non-interleaved, computed from rotary_pos_emb) + # Re-compute from self.rotary_pos_emb since parent doesn't expose it + with paddle.no_grad(): + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + hidden_states, self.config, packed_seq_params + ) + packed_seq = ( + packed_seq_params is not None + and packed_seq_params.qkv_format == "thd" + ) + assert not self.config.apply_rope_fusion, ( + "DSA Indexer requires unfused RoPE (apply_rope_fusion=False). " + "Fused RoPE returns cos/sin instead of freqs, which the Indexer cannot use." + ) + indexer_mscale = 1.0 + if self.config.rope_type == "rope": + indexer_freqs = self.rotary_pos_emb( + rotary_seq_len, packed_seq=packed_seq + ) + else: + indexer_freqs, indexer_mscale = self.rotary_pos_emb( + rotary_seq_len, packed_seq=packed_seq + ) + # rotary_pos_emb returns [1, seq_len, 1, dim]. + # Indexer tensors are batch-first [b, s, heads, head_dim], so keep + # freqs as 4D [1, seq_len, 1, dim] for correct broadcasting. + # Do NOT squeeze(0) — that would make it [seq_len, 1, dim] (seq-first) + # which causes _apply_rotary_pos_emb_bshd to mis-broadcast. + if indexer_freqs is not None and ( + packed_seq_params is None + or self.config.context_parallel_size == 1 + ): + # Use the actual full sequence length from gathered indexer input, + # not the potentially-sharded hidden_states.shape[0] (s/tp in SP mode). + actual_seq_len = indexer_hidden.shape[ + 1 + ] # [b, s, h] — always full seq + indexer_freqs = indexer_freqs[ + :, 0:actual_seq_len + ] # slice seq dim (dim 1) + + # MLA's YarnRotaryEmbedding generates interleaved freqs [θ₁,θ₁,θ₂,θ₂,...] + # when config.rotary_interleaved=True, but Indexer uses non-interleaved + # RoPE which expects half-half freqs [θ₁,θ₂,...,θ₁,θ₂,...]. + # Convert format here to match Indexer's rotary_interleaved=False. + # if self.config.rotary_interleaved and indexer_freqs is not None: + # indexer_freqs = paddle.concat( + # [indexer_freqs[..., 0::2], indexer_freqs[..., 1::2]], + # axis=-1, + # ) + + indexer_seq_len = indexer_hidden.shape[1] # [b, s, h] — always full seq + indexer_causal_mask = paddle.triu( + paddle.full( + [indexer_seq_len, indexer_seq_len], + float("-inf"), + dtype="float32", + ), + diagonal=1, + ) # [s, s] + if attention_mask is not None: + # attention_mask: [b, 1, sq, sk] — may contain padding info beyond causal + indexer_float_mask = attention_mask.cast( + "float32" + ) + indexer_causal_mask.unsqueeze(0).unsqueeze(0) # [b, 1, s, s] + else: + indexer_float_mask = indexer_causal_mask.unsqueeze(0).unsqueeze( + 0 + ) # [1, 1, s, s] + + # Indexer forward_before_topk runs WITH gradient tracking so that + # FusedDSAIndexerLoss can backprop through (q, weights, k) to + # Indexer parameters (wq_b, wk, weights_proj). + q_idx, k_idx, weights_idx = self.indexer.forward_before_topk( + indexer_hidden, + indexer_q_latent, + indexer_freqs, + mscale=indexer_mscale, + ) + + # Convert Indexer outputs from batch-first [b, s, h, d] to + # sequence-first [s, b, h, d] + q_idx_sf = q_idx.transpose([1, 0, 2, 3]) # [sq, b, h, d] + k_idx_sf = k_idx.transpose([1, 0, 2]) # [sk, b, d] + weights_idx_sf = weights_idx.transpose([1, 0, 2]) # [sq, b, h] + + # FusedDSAIndexerLoss: compute index_scores + topk + KL loss inside PyLayer, + # with full manual backward to (q, weights, k). + if self.training and self.dsa_indexer_loss_coeff is not None: + indexer_loss = FusedDSAIndexerLoss.apply( + q_idx_sf, + weights_idx_sf, + k_idx_sf, + query.detach(), + key.detach(), + self.softmax_scale, + self.indexer.index_topk, + float(self.dsa_indexer_loss_coeff), + indexer_float_mask, + bool(self.dsa_indexer_use_sparse_loss), + self.pg_collection.tp + if self.pg_collection.tp is not None + and self.pg_collection.tp.nranks > 1 + else None, + ) + topk_indices = FusedDSAIndexerLoss._last_topk_indices + else: + # Inference or no loss: compute index_scores + topk directly + index_scores, topk_indices = self.indexer.compute_index_scores( + q_idx, k_idx, weights_idx, indexer_float_mask + ) + topk_indices = topk_indices.detach() + + # ===================== + # Build sparse mask + # ===================== + seqlen = query.shape[0] # [s, b, nhpp, hd] + bsz = query.shape[1] + + index_mask = paddle.full( + [bsz, seqlen, seqlen], + fill_value=float("-inf"), + dtype="float32", + ) + zeros = paddle.zeros( + [ + topk_indices.shape[0], + topk_indices.shape[1], + topk_indices.shape[2], + ], + dtype="float32", + ) + index_mask = paddle.put_along_axis( + index_mask, topk_indices, zeros, axis=-1 + ) + # Merge causal + index into [b, s, s], then unsqueeze to [b, 1, s, s] + # causal_mask is [s, s], reuse the one built for indexer (same seqlen) + causal_mask = paddle.triu( + paddle.full([seqlen, seqlen], float("-inf"), dtype="float32"), + diagonal=1, + ) + index_mask = index_mask + causal_mask.unsqueeze(0) + combined_mask = index_mask.unsqueeze(1) # [b, 1, s, s] + + if attention_mask is not None: + combined_mask = attention_mask.cast("float32") + combined_mask + + # ===================== + # Core attention (unfused bmm for DSA) + # ===================== + query = query.transpose([1, 0, 2, 3]).contiguous() + key = key.transpose([1, 0, 2, 3]).contiguous() + value = value.transpose([1, 0, 2, 3]).contiguous() + + if self.recompute_core_attention and self.training: + core_attn_out = recompute( + _unfused_dsa_attention, + query, + key, + value, + combined_mask.clone() if combined_mask is not None else None, + self.softmax_scale, + ) + else: + core_attn_out = _unfused_dsa_attention( + query, + key, + value, + combined_mask, + self.softmax_scale, + ) + + # ===================== + # Output projection + # ===================== + if self.config.sequence_parallel: + core_attn_out = core_attn_out.transpose([1, 0, 2]).contiguous() + output, bias = self.o_proj(core_attn_out) + + # ===================== + # DSA Indexer KL loss (already computed by FusedDSAIndexerLoss above) + # ===================== + if self.training and self.dsa_indexer_loss_coeff is not None: + if self.dsa_indexer_loss_coeff > 0: + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss=indexer_loss, + layer_number=self.layer_number, + num_layers=self.config.num_hidden_layers, + ) + output = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + + return output, bias diff --git a/src/paddlefleet/transformer/multi_latent_attention.py b/src/paddlefleet/transformer/multi_latent_attention.py index 0875c9872..d4f2252a1 100644 --- a/src/paddlefleet/transformer/multi_latent_attention.py +++ b/src/paddlefleet/transformer/multi_latent_attention.py @@ -189,7 +189,7 @@ def forward( # ===================== # Get the query, key and value tensors based on the type of attention - query, key, value = self.get_query_key_value_tensors( + query, key, value, _, _ = self.get_query_key_value_tensors( hidden_states, key_value_states, position_ids, @@ -681,7 +681,7 @@ def qkv_up_proj_and_rope_apply( q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb ) - return query, key, value + return query, key, value, q_compressed, kv_compressed def backward_dw(self) -> NoReturn: """Execute weight gradient computation""" diff --git a/src/paddlefleet/transformer/transformer_config.py b/src/paddlefleet/transformer/transformer_config.py index 2d220c133..b72225db8 100644 --- a/src/paddlefleet/transformer/transformer_config.py +++ b/src/paddlefleet/transformer/transformer_config.py @@ -550,6 +550,79 @@ class TransformerConfig(ModelParallelConfig): # cache_mla_latents: bool = False + #################### + # DSA (DeepSeek Sparse Attention) + #################### + + dsa_index_n_heads: int | None = None + """Number of DSA Indexer heads. None disables DSA; non-None activates + DeepSeek V3.2 sparse attention path. + + Note: This field corresponds to the HuggingFace config.json field "index_n_heads". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_index_head_dim: int = 128 + """Per-head dimension for Indexer Q/K vectors. + + Note: This field corresponds to the HuggingFace config.json field "index_head_dim". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_index_topk: int = 2048 + """Number of token positions selected by Indexer per query token. + + Note: This field corresponds to the HuggingFace config.json field "index_topk". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_indexer_loss_coeff: float | None = None + """KL loss coefficient for DSA Indexer training. None disables the KL loss. + + Note: This field corresponds to the HuggingFace config.json field "indexer_loss_coeff". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_indexer_use_sparse_loss: bool = False + """Whether to restrict DSA KL loss to top-k positions only. + + Note: This field corresponds to the HuggingFace config.json field "indexer_use_sparse_loss". + The mapping from HuggingFace field name to PaddleFleet internal field name is handled + by TransformerConfig.transform_rules. + """ + + dsa_indexer_rotary_interleaved: bool = False + """ + Whether Indexer uses interleaved Rotary Position Embeddings. + + When False (default), Indexer uses non-interleaved RoPE with + half-head frequencies [θ₁,θ₂,...,θ₁,θ₂,...]. + + When True, Indexer uses interleaved RoPE with paired frequencies + [θ₁,θ₁,θ₂,θ₂,...]. + + This allows compatibility with MLA's YaRN RoPE which always generates + interleaved frequencies. + """ + + dsa_indexer_loss_coeff: float = 0.01 + """KL loss coefficient for DSA Indexer training. None disables the KL loss.""" + + # Field name mapping rules: HuggingFace config.json name -> TransformerConfig name + transform_rules = { + # DSA field mapping + "index_n_heads": "dsa_index_n_heads", + "index_head_dim": "dsa_index_head_dim", + "index_topk": "dsa_index_topk", + "indexer_loss_coeff": "dsa_indexer_loss_coeff", + "indexer_use_sparse_loss": "dsa_indexer_use_sparse_loss", + "indexer_rotary_interleaved": "dsa_indexer_rotary_interleaved", + } + @classmethod def from_config(cls, config_dict): # note(zhangweilong): if cls(),will call __post_init__ directly,but __new__ will skip some attr init .please check provider attr diff --git a/tests/single_card_tests/transformer/test_dsa_attention.py b/tests/single_card_tests/transformer/test_dsa_attention.py new file mode 100644 index 000000000..45ed2c5a3 --- /dev/null +++ b/tests/single_card_tests/transformer/test_dsa_attention.py @@ -0,0 +1,1664 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DSA (DeepSeek Sparse Attention) module. + +Tests are organized in 4 layers: + 1. Pure functions: hadamard_transform, rotate_activation, _unfused_dsa_attention, + _compute_index_scores_fused + 2. Indexer module: forward_before_topk, compute_index_scores, backward + 3. Loss: _compute_dsa_indexer_loss, FusedDSAIndexerLoss, DSAIndexerLossAutoScaler + 4. Integration: MLASelfAttentionWithDSA forward + backward +""" + +import unittest +from unittest.mock import MagicMock, patch + +import paddle + +from paddlefleet.transformer.dot_product_attention import DotProductAttention +from paddlefleet.transformer.dsa_attention import ( + DSAIndexerLossAutoScaler, + DSAIndexerLossLoggingHelper, + FusedDSAIndexerLoss, + Indexer, + MLASelfAttentionWithDSA, + _bwd_fused_indexer_loss, + _compute_dsa_indexer_loss, + _compute_index_scores_fused, + _unfused_dsa_attention, + hadamard_transform, + rotate_activation, +) +from paddlefleet.transformer.enums import AttnMaskType +from paddlefleet.transformer.multi_latent_attention import ( + MLASelfAttentionSublayersSpec, +) +from paddlefleet.transformer.transformer_config import TransformerConfig +from paddlefleet.utils import ( + init_method_normal, + scaled_init_method_normal, +) + + +# --------------------------------------------------------------------------- +# Stub layers (same pattern as test_attention.py) +# --------------------------------------------------------------------------- +class BiasedLinear(paddle.nn.Layer): + def __init__(self, in_features, out_features, **kwargs): + super().__init__() + self.linear = paddle.nn.Linear(in_features, out_features) + + def forward(self, x): + return self.linear(x), self.linear.bias + + +class RMSNorm(paddle.nn.Layer): + def __init__(self, hidden_size, eps, **kwargs): + super().__init__() + self.weight = paddle.nn.Parameter(paddle.ones([hidden_size])) + self.eps = eps + + def forward(self, x): + d_norm = paddle.rsqrt(x.pow(2).mean(axis=-1, keepdim=True) + self.eps) + return x * d_norm * self.weight + + +# --------------------------------------------------------------------------- +# Helper: create DSA-compatible TransformerConfig +# --------------------------------------------------------------------------- +def _create_dsa_config( + hidden_size=256, + num_attention_heads=2, + q_lora_rank=64, + kv_lora_rank=64, + qk_nope_head_dim=32, + qk_rope_head_dim=32, + v_head_dim=64, + index_n_heads=2, + index_head_dim=128, + index_topk=16, + indexer_loss_coeff=1.0, + indexer_use_sparse_loss=False, + sequence_parallel=False, +): + config = TransformerConfig( + num_hidden_layers=2, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ) + # MLA fields + config.num_key_value_heads = num_attention_heads + config.head_dim = hidden_size // num_attention_heads + config.q_lora_rank = q_lora_rank + config.kv_lora_rank = kv_lora_rank + config.qk_nope_head_dim = qk_nope_head_dim + config.qk_rope_head_dim = qk_rope_head_dim + config.v_head_dim = v_head_dim + config.multi_latent_attention = True + + # RoPE / YaRN + config.rope_type = "yarn" + config.rope_theta = 10000.0 + config.rotary_interleaved = False + config.rotary_percent = 1.0 + config.rotary_scaling_factor = 40.0 + config.original_max_position_embeddings = 4096 + config.beta_fast = 32.0 + config.beta_slow = 1.0 + config.mscale = 1.0 + config.mscale_all_dim = 0.0 + config.apply_rope_fusion = False # DSA requires unfused RoPE + + # DSA Indexer fields + config.dsa_index_n_heads = index_n_heads + config.dsa_index_head_dim = index_head_dim + config.dsa_index_topk = index_topk + config.dsa_indexer_loss_coeff = indexer_loss_coeff + config.dsa_indexer_use_sparse_loss = indexer_use_sparse_loss + config.dsa_indexer_rotary_interleaved = False # Test default value + + # Attention generic fields + config.softmax_scale = None + config.use_bias = True + config.no_rope_freq = None + config.recompute_granularity = None + config.fused_single_qkv_rope = False + config.init_method = init_method_normal(0.02) + config.output_layer_init_method = scaled_init_method_normal(0.02, 1, 2.0) + config.rms_norm_eps = 1e-5 + config.context_parallel_size = 1 + config.apply_query_key_layer_scaling = False + config.sliding_window = None + config.window_attn_skip_freq = None + config.fp16 = False + config.bf16 = False + config.masked_softmax_fusion = False + config.attention_softmax_in_fp32 = True + config.attention_dropout = 0.0 + config.softmax_type = "vanilla" + config.sequence_parallel = sequence_parallel + + return config + + +def _create_sublayers_spec(): + return MLASelfAttentionSublayersSpec( + q_proj=BiasedLinear, + q_a_proj=BiasedLinear, + q_b_proj=BiasedLinear, + kv_a_proj_with_mqa=BiasedLinear, + kv_b_proj=BiasedLinear, + core_attention=DotProductAttention, + o_proj=BiasedLinear, + q_a_layernorm=RMSNorm, + kv_a_layernorm=RMSNorm, + ) + + +def _make_causal_topk_indices(b, sq, sk, topk): + """Generate topk indices that respect causality (indices <= current position).""" + indices_list = [] + for i in range(sq): + max_idx = min(i + 1, sk) + actual_topk = min(topk, max_idx) + # Pick the last `actual_topk` positions (most recent) + row_indices = paddle.arange(max_idx - actual_topk, max_idx) + if actual_topk < topk: + # Pad with the last valid index + pad = paddle.full([topk - actual_topk], max_idx - 1, dtype="int64") + row_indices = paddle.concat([row_indices, pad]) + indices_list.append(row_indices) + # [sq, topk] -> expand to [b, sq, topk] + indices = ( + paddle.stack(indices_list, axis=0).unsqueeze(0).expand([b, sq, topk]) + ) + return indices + + +# =========================================================================== +# Layer 1: Pure function tests +# =========================================================================== +class TestHadamardTransform(unittest.TestCase): + def test_output_shape(self): + x = paddle.randn([4, 8, 16]) + out = hadamard_transform(x) + self.assertEqual(out.shape, [4, 8, 16]) + + def test_power_of_two_assertion(self): + x = paddle.randn([4, 7]) + with self.assertRaises(AssertionError): + hadamard_transform(x) + + def test_involution(self): + """H(H(x)) = dim * x (Hadamard is involutory up to scaling).""" + dim = 16 + x = paddle.randn([3, dim], dtype="float32") + hx = hadamard_transform(x) + hhx = hadamard_transform(hx) + self.assertTrue(paddle.allclose(hhx, x * dim, atol=1e-4, rtol=1e-4)) + + def test_scale_factor(self): + x = paddle.randn([4, 8]) + out_unscaled = hadamard_transform(x) + out_scaled = hadamard_transform(x, scale=0.5) + self.assertTrue( + paddle.allclose(out_scaled, out_unscaled * 0.5, atol=1e-5) + ) + + def test_1d_input(self): + x = paddle.randn([16]) + out = hadamard_transform(x) + self.assertEqual(out.shape, [16]) + + +class TestRotateActivation(unittest.TestCase): + def test_output_shape(self): + x = paddle.randn([2, 4, 128]).cast("bfloat16") + out = rotate_activation(x) + self.assertEqual(list(out.shape), [2, 4, 128]) + self.assertEqual(out.dtype, paddle.bfloat16) + + def test_dtype_assertion(self): + x = paddle.randn([2, 4, 64], dtype="float32") + with self.assertRaises(AssertionError): + rotate_activation(x) + + +class TestUnfusedDSAAttention(unittest.TestCase): + def setUp(self): + self.b, self.s, self.nhpp = 2, 8, 4 + self.qk_hd, self.v_hd = 32, 64 + self.softmax_scale = self.qk_hd**-0.5 + + def test_output_shape(self): + query = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + key = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + value = paddle.randn([self.b, self.s, self.nhpp, self.v_hd]) + out = _unfused_dsa_attention( + query, key, value, None, self.softmax_scale + ) + self.assertEqual(out.shape, [self.b, self.s, self.nhpp * self.v_hd]) + + def test_with_causal_mask(self): + query = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + key = paddle.randn([self.b, self.s, self.nhpp, self.qk_hd]) + value = paddle.randn([self.b, self.s, self.nhpp, self.v_hd]) + causal = paddle.triu( + paddle.full([self.s, self.s], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) # [1, 1, s, s] + out = _unfused_dsa_attention( + query, key, value, mask, self.softmax_scale + ) + self.assertEqual(out.shape, [self.b, self.s, self.nhpp * self.v_hd]) + + def test_asymmetric_dims(self): + """qk_head_dim != v_head_dim should work.""" + qk_hd, v_hd = 48, 32 + query = paddle.randn([self.b, self.s, self.nhpp, qk_hd]) + key = paddle.randn([self.b, self.s, self.nhpp, qk_hd]) + value = paddle.randn([self.b, self.s, self.nhpp, v_hd]) + out = _unfused_dsa_attention(query, key, value, None, qk_hd**-0.5) + self.assertEqual(out.shape, [self.b, self.s, self.nhpp * v_hd]) + + +class TestComputeIndexScoresFused(unittest.TestCase): + def test_output_shape(self): + sq, b, h, d = 8, 2, 4, 32 + sk = 8 + q = paddle.randn([sq, b, h, d]) + weights = paddle.randn([sq, b, h]) + k = paddle.randn([sk, b, d]) + out = _compute_index_scores_fused(q, weights, k) + self.assertEqual(out.shape, [b, sq, sk]) + + def test_nonnegative_after_relu(self): + sq, b, h, d = 8, 2, 4, 32 + q = paddle.randn([sq, b, h, d]) + # Use positive weights so that relu * positive_weights >= 0 + weights = paddle.abs(paddle.randn([sq, b, h])) + 0.1 + k = paddle.randn([sq, b, d]) + out = _compute_index_scores_fused(q, weights, k) + self.assertTrue((out >= -1e-6).all().item()) + + +# =========================================================================== +# Layer 2: Indexer module tests +# =========================================================================== +class TestIndexer(unittest.TestCase): + def setUp(self): + self.config = _create_dsa_config() + self.indexer = Indexer(self.config, layer_number=1) + self.b = 2 + self.s = 16 + + def _prepare_indexer_bf16(self): + """Convert wq_b/wk to bf16 for rotate_activation, keep weights_proj fp32.""" + self.indexer.wq_b = self.indexer.wq_b.to(dtype="bfloat16") + self.indexer.wk = self.indexer.wk.to(dtype="bfloat16") + self.indexer.k_norm = self.indexer.k_norm.to(dtype="bfloat16") + # weights_proj stays fp32 (code does hidden.cast("float32") before calling it) + + def test_forward_before_topk_shapes(self): + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + + q, k, weights = self.indexer.forward_before_topk( + hidden, q_latent, freqs=None, mscale=1.0 + ) + self.assertEqual( + list(q.shape), + [ + self.b, + self.s, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, + ], + ) + self.assertEqual( + list(k.shape), + [self.b, self.s, self.config.dsa_index_head_dim], + ) + self.assertEqual( + list(weights.shape), + [self.b, self.s, self.config.dsa_index_n_heads], + ) + + def test_compute_index_scores_shapes(self): + q = paddle.randn( + [ + self.b, + self.s, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, + ] + ) + k = paddle.randn([self.b, self.s, self.config.dsa_index_head_dim]) + weights = paddle.randn([self.b, self.s, self.config.dsa_index_n_heads]) + index_scores, topk_indices = self.indexer.compute_index_scores( + q, k, weights, mask=None + ) + self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) + self.assertEqual( + list(topk_indices.shape), + [self.b, self.s, self.config.dsa_index_topk], + ) + + def test_topk_in_range(self): + q = paddle.randn( + [ + self.b, + self.s, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, + ] + ) + k = paddle.randn([self.b, self.s, self.config.dsa_index_head_dim]) + weights = paddle.randn([self.b, self.s, self.config.dsa_index_n_heads]) + _, topk_indices = self.indexer.compute_index_scores( + q, k, weights, mask=None + ) + self.assertTrue((topk_indices >= 0).all().item()) + self.assertTrue((topk_indices < self.s).all().item()) + + def test_backward(self): + """Indexer parameters receive gradients.""" + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + + q, k, weights = self.indexer.forward_before_topk( + hidden, q_latent, freqs=None, mscale=1.0 + ) + # rotate_activation requires bf16, so skip it in this unit test + # and just use the raw outputs for gradient checking. + loss = q.cast("float32").sum() + k.cast("float32").sum() + weights.sum() + loss.backward() + + for name, param in self.indexer.named_parameters(): + self.assertIsNotNone( + param.grad, f"Parameter {name} has no gradient" + ) + + +# =========================================================================== +# Layer 3: Loss tests +# =========================================================================== +class TestComputeDSAIndexerLoss(unittest.TestCase): + def setUp(self): + self.sq, self.sk = 8, 8 + self.b, self.np, self.hn = 2, 4, 32 + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + self.loss_coeff = 1.0 + + def _make_inputs(self, sparse=False): + index_scores = paddle.randn([self.b, self.sq, self.sk], dtype="float32") + if sparse: + topk_indices = _make_causal_topk_indices( + self.b, self.sq, self.sk, self.topk + ) + else: + topk_indices = paddle.randint( + 0, self.sk, [self.b, self.sq, self.topk] + ).cast("int64") + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + return index_scores, topk_indices, query, key + + def test_loss_is_scalar(self): + index_scores, topk_indices, query, key = self._make_inputs() + loss = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + self.loss_coeff, + False, + None, + ) + self.assertEqual(loss.shape, []) + + def test_loss_with_sparse(self): + index_scores, topk_indices, query, key = self._make_inputs(sparse=True) + loss = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + self.loss_coeff, + True, + None, + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + def test_loss_coeff_scaling(self): + index_scores, topk_indices, query, key = self._make_inputs() + loss1 = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + 1.0, + False, + None, + ) + loss2 = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + self.softmax_scale, + 2.0, + False, + None, + ) + self.assertTrue( + paddle.allclose(loss2, loss1 * 2.0, atol=1e-4), + f"loss2={loss2.item():.6f} != 2*loss1={2 * loss1.item():.6f}", + ) + + +class TestFusedDSAIndexerLoss(unittest.TestCase): + def setUp(self): + self.sq, self.sk = 8, 8 + self.b = 2 + self.h, self.d = 4, 32 # indexer heads/dim + self.np, self.hn = 4, 64 # MLA heads/dim + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + + def _make_inputs(self, with_mask=False): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + if with_mask: + causal = paddle.triu( + paddle.full([self.sq, self.sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) # [1, 1, sq, sk] + else: + mask = None + return q, weights, k, query, key, mask + + def test_forward_returns_scalar(self): + q, weights, k, query, key, mask = self._make_inputs(with_mask=True) + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + False, + None, + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + def test_topk_indices_stored(self): + FusedDSAIndexerLoss._last_topk_indices = None + q, weights, k, query, key, _ = self._make_inputs() + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + None, + False, + None, + ) + self.assertIsNotNone(FusedDSAIndexerLoss._last_topk_indices) + self.assertEqual( + list(FusedDSAIndexerLoss._last_topk_indices.shape), + [self.b, self.sq, self.topk], + ) + + def test_backward_gradients(self): + # Pass a mask tensor so PyLayer sees 6 tensor inputs (q, weights, k, + # query, key, mask) matching the 6 return values in backward. + q, weights, k, query, key, mask = self._make_inputs(with_mask=True) + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + False, + None, + ) + loss.backward() + + self.assertIsNotNone(q.grad) + self.assertIsNotNone(weights.grad) + self.assertIsNotNone(k.grad) + self.assertEqual(list(q.grad.shape), [self.sq, self.b, self.h, self.d]) + self.assertEqual(list(weights.grad.shape), [self.sq, self.b, self.h]) + self.assertEqual(list(k.grad.shape), [self.sk, self.b, self.d]) + self.assertTrue(paddle.isfinite(q.grad).all().item()) + self.assertTrue(paddle.isfinite(weights.grad).all().item()) + self.assertTrue(paddle.isfinite(k.grad).all().item()) + + +class TestDSAIndexerLossAutoScaler(unittest.TestCase): + def _make_non_leaf_output(self, shape): + """Create a non-leaf tensor (required by PyLayer inplace check).""" + x = paddle.randn(shape) + x.stop_gradient = False + return x + 0 # Adding 0 makes it non-leaf + + def test_forward_passthrough(self): + output = self._make_non_leaf_output([2, 8, 64]) + indexer_loss = self._make_non_leaf_output([]) + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + self.assertEqual(list(result.shape), [2, 8, 64]) + + def test_backward_grad_output(self): + output = self._make_non_leaf_output([2, 8, 64]) + indexer_loss = self._make_non_leaf_output([]) + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + loss = result.sum() + loss.backward() + # output is non-leaf (x + 0), so its grad may not be retained, + # but the computation should not error out. + self.assertTrue(True) # Just verify no error + + def test_loss_scale(self): + DSAIndexerLossAutoScaler.set_loss_scale( + paddle.to_tensor(2.0, dtype="float32") + ) + output = self._make_non_leaf_output([2, 4]) + indexer_loss = paddle.to_tensor(1.0, dtype="float32") + indexer_loss.stop_gradient = False + indexer_loss = indexer_loss * 1.0 # Make non-leaf + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + loss = result.sum() + loss.backward() + # Verify no errors + self.assertTrue(True) + # Reset + DSAIndexerLossAutoScaler._main_loss_backward_scale = None + + +# =========================================================================== +# Layer 4: MLASelfAttentionWithDSA integration tests +# =========================================================================== +class TestMLASelfAttentionWithDSA(unittest.TestCase): + def setUp(self): + self.config = _create_dsa_config() + self.micro_batch_size = 2 + self.sequence_length = 32 + + def _build_model(self, config=None): + cfg = config or self.config + model = MLASelfAttentionWithDSA( + cfg, + _create_sublayers_spec(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + # Convert model to bf16 because rotate_activation requires bf16 input. + # But weights_proj does hidden.cast("float32") internally and expects + # fp32 weights, so convert it back to fp32 after the global bf16 cast. + model = model.to(dtype="bfloat16") + model.indexer.weights_proj = model.indexer.weights_proj.to( + dtype="float32" + ) + return model + + def _make_hidden(self, dtype="bfloat16"): + return paddle.randn( + [ + self.micro_batch_size, + self.sequence_length, + self.config.hidden_size, + ], + ).cast(dtype) + + def test_forward_shape(self): + model = self._build_model() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + self.assertEqual(output.shape[2], self.config.hidden_size) + self.assertEqual(bias.shape[0], self.config.hidden_size) + + def test_forward_with_attention_mask(self): + model = self._build_model() + hidden = self._make_hidden() + causal = paddle.triu( + paddle.full( + [self.sequence_length, self.sequence_length], + float("-inf"), + dtype="float32", + ), + diagonal=1, + ) + mask = ( + causal.unsqueeze(0) + .unsqueeze(0) + .expand( + [ + self.micro_batch_size, + 1, + self.sequence_length, + self.sequence_length, + ] + ) + ) + output, bias = model(hidden, attention_mask=mask) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + self.assertEqual(output.shape[2], self.config.hidden_size) + + def test_forward_training_with_loss(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + self.assertEqual(output.shape[2], self.config.hidden_size) + + def test_forward_eval_mode(self): + config = _create_dsa_config(indexer_loss_coeff=None) + model = self._build_model(config) + model.eval() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + + def test_backward_gradients(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + hidden.stop_gradient = False + output, bias = model(hidden, attention_mask=None) + loss = output.cast("float32").sum() + loss.backward() + + self.assertIsNotNone(hidden.grad) + for name, param in model.named_parameters(): + if not param.stop_gradient: + self.assertIsNotNone( + param.grad, f"Parameter {name} has no gradient" + ) + self.assertTrue( + paddle.isfinite(param.grad).all().item(), + f"Parameter {name} has non-finite gradient", + ) + + def test_indexer_params_have_grad(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + hidden.stop_gradient = False + output, bias = model(hidden, attention_mask=None) + loss = output.cast("float32").sum() + loss.backward() + + indexer_param_names = [ + "indexer.wq_b", + "indexer.wk", + "indexer.weights_proj", + ] + for name, param in model.named_parameters(): + for iname in indexer_param_names: + if iname in name: + self.assertIsNotNone( + param.grad, + f"Indexer parameter {name} has no gradient", + ) + + +# =========================================================================== +# Layer 5: DSAIndexerLossLoggingHelper tests +# =========================================================================== +class TestDSAIndexerLossLoggingHelperSaveLoss(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.save_loss_to_tracker.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def test_save_loss_initializes_values(self): + """First call should create the 'values' tensor with correct size.""" + loss = paddle.to_tensor(0.5, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=4 + ) + self.assertIn("values", DSAIndexerLossLoggingHelper.tracker) + self.assertEqual( + list(DSAIndexerLossLoggingHelper.tracker["values"].shape), [4] + ) + + def test_save_loss_accumulates(self): + """Multiple saves to the same layer should accumulate.""" + loss1 = paddle.to_tensor(0.5, dtype="float32") + loss2 = paddle.to_tensor(0.3, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss1, layer_number=1, num_layers=4 + ) + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss2, layer_number=1, num_layers=4 + ) + self.assertTrue( + paddle.allclose( + DSAIndexerLossLoggingHelper.tracker["values"][0], + paddle.to_tensor(0.8, dtype="float32"), + atol=1e-5, + ) + ) + + def test_save_loss_different_layers(self): + """Saving to different layers puts values in correct positions.""" + loss1 = paddle.to_tensor(1.0, dtype="float32") + loss2 = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss1, layer_number=1, num_layers=3 + ) + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss2, layer_number=3, num_layers=3 + ) + values = DSAIndexerLossLoggingHelper.tracker["values"] + self.assertAlmostEqual(values[0].item(), 1.0, places=5) + self.assertAlmostEqual(values[1].item(), 0.0, places=5) + self.assertAlmostEqual(values[2].item(), 2.0, places=5) + + def test_save_loss_none_layer_number_noop(self): + """layer_number=None should be a no-op.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=None, num_layers=4 + ) + self.assertNotIn("values", DSAIndexerLossLoggingHelper.tracker) + + def test_save_loss_stores_groups(self): + """reduce_group and avg_group should be stored in the tracker.""" + loss = paddle.to_tensor(0.1, dtype="float32") + mock_reduce = MagicMock() + mock_avg = MagicMock() + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + reduce_group=mock_reduce, + avg_group=mock_avg, + ) + self.assertIs( + DSAIndexerLossLoggingHelper.tracker["reduce_group"], mock_reduce + ) + self.assertIs( + DSAIndexerLossLoggingHelper.tracker["avg_group"], mock_avg + ) + + +class TestDSAIndexerLossLoggingHelperClean(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.clean_loss_in_tracker.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def test_clean_zeros_values(self): + """Clean should zero out the values tensor.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=3 + ) + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + values = DSAIndexerLossLoggingHelper.tracker["values"] + self.assertTrue(paddle.allclose(values, paddle.zeros([3]))) + + def test_clean_resets_groups(self): + """Clean should set reduce_group and avg_group to None.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + reduce_group=MagicMock(), + avg_group=MagicMock(), + ) + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + self.assertIsNone(DSAIndexerLossLoggingHelper.tracker["reduce_group"]) + self.assertIsNone(DSAIndexerLossLoggingHelper.tracker["avg_group"]) + + def test_clean_empty_tracker_noop(self): + """Clean on empty tracker should not raise.""" + DSAIndexerLossLoggingHelper.clean_loss_in_tracker() + self.assertIsNone( + DSAIndexerLossLoggingHelper.tracker.get("reduce_group") + ) + + +class TestDSAIndexerLossLoggingHelperReduce(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.reduce_loss_in_tracker.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def test_reduce_empty_tracker_noop(self): + """Reduce with no 'values' should be a no-op.""" + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + # Should not raise + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_no_distributed_groups(self, mock_ps): + """Reduce with no distributed groups should keep values unchanged.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + original_values = DSAIndexerLossLoggingHelper.tracker["values"].clone() + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + self.assertTrue( + paddle.allclose( + DSAIndexerLossLoggingHelper.tracker["values"], + original_values, + ) + ) + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_pp_group(self, mock_ps): + """Reduce with PP group should call all_reduce.""" + pp_group = MagicMock() + pp_group.nranks = 2 + mock_ps.get_pipeline_model_parallel_group.return_value = pp_group + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_called_once() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_dp_group(self, mock_ps): + """Reduce with DP group should call all_reduce and divide by nranks.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + dp_group = MagicMock() + dp_group.nranks = 4 + mock_ps.get_data_parallel_group.return_value = dp_group + + loss = paddle.to_tensor(4.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce"): + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + # After DP reduce, values should be divided by nranks + # (all_reduce is mocked, so actual value won't change, but the + # division path is exercised) + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_reduce_group(self, mock_ps): + """Reduce with TP reduce_group should call all_reduce.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(1.0, dtype="float32") + reduce_group = MagicMock() + reduce_group.nranks = 2 + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + reduce_group=reduce_group, + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_called_once() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_with_avg_group(self, mock_ps): + """Reduce with avg_group should call all_reduce and divide by nranks.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + mock_ps.get_data_parallel_group.return_value = None + + avg_group = MagicMock() + avg_group.nranks = 3 + loss = paddle.to_tensor(3.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, + layer_number=1, + num_layers=2, + avg_group=avg_group, + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_called_once() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_pp_group_single_rank_skipped(self, mock_ps): + """PP group with nranks=1 should not trigger all_reduce.""" + pp_group = MagicMock() + pp_group.nranks = 1 + mock_ps.get_pipeline_model_parallel_group.return_value = pp_group + mock_ps.get_data_parallel_group.return_value = None + + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_not_called() + + @patch("paddlefleet.transformer.dsa_attention.parallel_state") + def test_reduce_dp_group_single_rank_skipped(self, mock_ps): + """DP group with nranks=1 should not trigger all_reduce.""" + mock_ps.get_pipeline_model_parallel_group.return_value = None + dp_group = MagicMock() + dp_group.nranks = 1 + mock_ps.get_data_parallel_group.return_value = dp_group + + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with patch("paddle.distributed.all_reduce") as mock_all_reduce: + DSAIndexerLossLoggingHelper.reduce_loss_in_tracker() + mock_all_reduce.assert_not_called() + + +class TestDSAIndexerLossLoggingHelperTrackMetrics(unittest.TestCase): + """Tests for DSAIndexerLossLoggingHelper.track_indexer_metrics.""" + + def setUp(self): + DSAIndexerLossLoggingHelper.tracker = {} + + def tearDown(self): + DSAIndexerLossLoggingHelper.tracker = {} + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_empty_tracker_noop(self, mock_reduce): + """With no values, track_indexer_metrics should be a no-op after reduce.""" + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=10 + ) + mock_reduce.assert_called_once() + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_logs_loss(self, mock_reduce): + """track_indexer_metrics should log the averaged indexer loss.""" + loss = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + with self.assertLogs( + "paddlefleet.transformer.dsa_attention", level="INFO" + ) as cm: + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=42 + ) + log_output = "\n".join(cm.output) + self.assertIn("42", log_output) + self.assertIn("indexer loss", log_output) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_updates_total_loss_dict(self, mock_reduce): + """track_indexer_metrics should add 'indexer loss' to total_loss_dict.""" + loss = paddle.to_tensor(4.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + total_loss_dict = {} + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=0.5, iteration=1, total_loss_dict=total_loss_dict + ) + self.assertIn("indexer loss", total_loss_dict) + # loss=4.0 at layer 1 only, num_layers=2 + # avg = (4.0 * 0.5) / 2 = 1.0 + expected = 1.0 + self.assertAlmostEqual( + total_loss_dict["indexer loss"].item(), expected, places=4 + ) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_accumulates_total_loss_dict(self, mock_reduce): + """Calling track_indexer_metrics twice should accumulate in total_loss_dict.""" + loss1 = paddle.to_tensor(2.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss1, layer_number=1, num_layers=1 + ) + total_loss_dict = {} + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=1, total_loss_dict=total_loss_dict + ) + first_value = total_loss_dict["indexer loss"].item() + + # Save and track again + loss2 = paddle.to_tensor(3.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss2, layer_number=1, num_layers=1 + ) + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=2, total_loss_dict=total_loss_dict + ) + self.assertAlmostEqual( + total_loss_dict["indexer loss"].item(), + first_value + 3.0, + places=4, + ) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_with_writer(self, mock_reduce): + """track_indexer_metrics should call writer.add_scalar when provided.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=1 + ) + mock_writer = MagicMock() + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=100, writer=mock_writer + ) + mock_writer.add_scalar.assert_called_once() + args = mock_writer.add_scalar.call_args + self.assertEqual(args[0][0], "indexer loss") + self.assertEqual(args[0][2], 100) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_cleans_tracker(self, mock_reduce): + """track_indexer_metrics should clean the tracker after logging.""" + loss = paddle.to_tensor(1.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=2 + ) + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=1.0, iteration=1 + ) + # After track_indexer_metrics, values should be zeroed + values = DSAIndexerLossLoggingHelper.tracker["values"] + self.assertTrue(paddle.allclose(values, paddle.zeros([2]))) + + @patch.object(DSAIndexerLossLoggingHelper, "reduce_loss_in_tracker") + def test_track_metrics_loss_scale_applied(self, mock_reduce): + """loss_scale should be multiplied into the loss values.""" + loss = paddle.to_tensor(6.0, dtype="float32") + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss, layer_number=1, num_layers=1 + ) + total_loss_dict = {} + DSAIndexerLossLoggingHelper.track_indexer_metrics( + loss_scale=0.25, iteration=1, total_loss_dict=total_loss_dict + ) + # 6.0 * 0.25 / 1 = 1.5 + self.assertAlmostEqual( + total_loss_dict["indexer loss"].item(), 1.5, places=4 + ) + + +# =========================================================================== +# Layer 6: Additional coverage for Indexer and helper functions +# =========================================================================== +class TestIndexerForward(unittest.TestCase): + """Test Indexer.forward (the combined forward_before_topk + compute_index_scores).""" + + def setUp(self): + self.config = _create_dsa_config() + self.indexer = Indexer(self.config, layer_number=1) + self.b = 2 + self.s = 16 + + def _prepare_indexer_bf16(self): + self.indexer.wq_b = self.indexer.wq_b.to(dtype="bfloat16") + self.indexer.wk = self.indexer.wk.to(dtype="bfloat16") + self.indexer.k_norm = self.indexer.k_norm.to(dtype="bfloat16") + + def test_forward_returns_scores_and_indices(self): + """Indexer.forward should return (index_scores, topk_indices).""" + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + causal = paddle.triu( + paddle.full([self.s, self.s], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + index_scores, topk_indices = self.indexer.forward( + hidden, q_latent, freqs=None, attention_mask=mask, mscale=1.0 + ) + self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) + self.assertEqual( + list(topk_indices.shape), + [self.b, self.s, self.config.dsa_index_topk], + ) + + def test_forward_no_mask(self): + """Indexer.forward should work with mask=None.""" + self._prepare_indexer_bf16() + hidden = paddle.randn([self.b, self.s, self.config.hidden_size]).cast( + "bfloat16" + ) + q_latent = paddle.randn([self.b, self.s, self.config.q_lora_rank]).cast( + "bfloat16" + ) + index_scores, topk_indices = self.indexer.forward( + hidden, q_latent, freqs=None, attention_mask=None, mscale=1.0 + ) + self.assertEqual(list(index_scores.shape), [self.b, self.s, self.s]) + + +class TestIndexerComputeScoresWithMask(unittest.TestCase): + """Test mask handling in Indexer.compute_index_scores.""" + + def setUp(self): + self.config = _create_dsa_config(index_topk=4) + self.indexer = Indexer(self.config, layer_number=1) + self.b = 2 + self.s = 8 + + def test_mask_zeros_future_positions(self): + """With causal mask, topk indices should not exceed current position.""" + q = paddle.randn( + [ + self.b, + self.s, + self.config.dsa_index_n_heads, + self.config.dsa_index_head_dim, + ] + ) + k = paddle.randn([self.b, self.s, self.config.dsa_index_head_dim]) + weights = paddle.abs( + paddle.randn([self.b, self.s, self.config.dsa_index_n_heads]) + ) + causal = paddle.triu( + paddle.full([self.s, self.s], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + index_scores, topk_indices = self.indexer.compute_index_scores( + q, k, weights, mask=mask + ) + # For a later position (e.g., position 5), all topk indices should be <= 5 + pos = 5 + later_token_indices = topk_indices[:, pos, :] + self.assertTrue( + (later_token_indices <= pos).all().item(), + f"topk indices at position {pos} exceed causal bound", + ) + + +class TestComputeIndexScoresFusedAdditional(unittest.TestCase): + """Additional tests for _compute_index_scores_fused.""" + + def test_matches_unfused(self): + """Fused scores should match unfused Indexer.compute_index_scores logic.""" + sq, b, h, d = 4, 2, 2, 16 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + + fused_scores = _compute_index_scores_fused(q, weights, k) # [b, sq, sk] + + # Manual unfused computation + scores = paddle.einsum("sbhd,tbd->sbht", q, k) + relu_scores = paddle.nn.functional.relu(scores) + weighted = relu_scores * weights.unsqueeze(-1) + summed = weighted.sum(axis=2) # [sq, b, sk] + unfused_scores = summed.transpose([1, 0, 2]) # [b, sq, sk] + + self.assertTrue( + paddle.allclose(fused_scores, unfused_scores, atol=1e-5), + "Fused and unfused scores do not match", + ) + + +class TestComputeDSAIndexerLossWithMask(unittest.TestCase): + """Additional edge cases for _compute_dsa_indexer_loss.""" + + def test_loss_is_non_negative(self): + """KL divergence should be non-negative.""" + sq, sk, b, np, hn = 4, 4, 2, 2, 16 + topk = 2 + index_scores = paddle.randn([b, sq, sk], dtype="float32") + topk_indices = paddle.randint(0, sk, [b, sq, topk]).cast("int64") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sk, b, np, hn], dtype="float32") + + loss = _compute_dsa_indexer_loss( + index_scores, + topk_indices, + query, + key, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=False, + tp_group=None, + ) + # KL divergence can be slightly negative due to numerical noise, + # but should be very close to non-negative + self.assertTrue( + loss.item() > -0.1, f"Loss is too negative: {loss.item()}" + ) + self.assertTrue(paddle.isfinite(loss).item()) + + +class TestBwdFusedIndexerLoss(unittest.TestCase): + """Tests for _bwd_fused_indexer_loss manual backward.""" + + def test_backward_shapes(self): + """Manual backward should return gradients with correct shapes.""" + sq, b, h, d = 4, 2, 2, 16 + np, hn = 4, 32 + topk = 2 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sq, b, np, hn], dtype="float32") + topk_indices = paddle.randint(0, sq, [b, sq, topk]).cast("int64") + grad_loss = paddle.to_tensor(1.0, dtype="float32") + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=False, + grad_loss=grad_loss, + tp_group=None, + ) + self.assertEqual(list(grad_q.shape), [sq, b, h, d]) + self.assertEqual(list(grad_weights.shape), [sq, b, h]) + self.assertEqual(list(grad_k.shape), [sq, b, d]) + + def test_backward_finite(self): + """All gradients from manual backward should be finite.""" + sq, b, h, d = 4, 2, 2, 16 + np, hn = 4, 32 + topk = 2 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sq, b, np, hn], dtype="float32") + topk_indices = _make_causal_topk_indices(b, sq, sq, topk) + grad_loss = paddle.to_tensor(1.0, dtype="float32") + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=False, + grad_loss=grad_loss, + tp_group=None, + ) + self.assertTrue(paddle.isfinite(grad_q).all().item()) + self.assertTrue(paddle.isfinite(grad_weights).all().item()) + self.assertTrue(paddle.isfinite(grad_k).all().item()) + + def test_backward_with_sparse_loss(self): + """Manual backward should work with sparse_loss=True.""" + sq, b, h, d = 4, 2, 2, 16 + np, hn = 4, 32 + topk = 2 + q = paddle.randn([sq, b, h, d], dtype="float32") + weights = paddle.randn([sq, b, h], dtype="float32") + k = paddle.randn([sq, b, d], dtype="float32") + query = paddle.randn([sq, b, np, hn], dtype="float32") + key = paddle.randn([sq, b, np, hn], dtype="float32") + topk_indices = _make_causal_topk_indices(b, sq, sq, topk) + grad_loss = paddle.to_tensor(1.0, dtype="float32") + + grad_q, grad_weights, grad_k = _bwd_fused_indexer_loss( + q, + weights, + k, + query, + key, + topk_indices, + softmax_scale=hn**-0.5, + loss_coeff=1.0, + sparse_loss=True, + grad_loss=grad_loss, + tp_group=None, + ) + self.assertTrue(paddle.isfinite(grad_q).all().item()) + self.assertTrue(paddle.isfinite(grad_weights).all().item()) + self.assertTrue(paddle.isfinite(grad_k).all().item()) + + +class TestFusedDSAIndexerLossNoMask(unittest.TestCase): + """Test FusedDSAIndexerLoss with no mask (mask=None path).""" + + def setUp(self): + self.sq, self.sk = 8, 8 + self.b = 2 + self.h, self.d = 4, 32 + self.np, self.hn = 4, 64 + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + + def test_forward_no_mask(self): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + None, # no mask + False, + None, + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + +class TestFusedDSAIndexerLossSparseLoss(unittest.TestCase): + """Test FusedDSAIndexerLoss with sparse_loss=True.""" + + def setUp(self): + self.sq, self.sk = 8, 8 + self.b = 2 + self.h, self.d = 4, 32 + self.np, self.hn = 4, 64 + self.topk = 4 + self.softmax_scale = self.hn**-0.5 + + def test_forward_sparse_loss(self): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + causal = paddle.triu( + paddle.full([self.sq, self.sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + True, + None, # sparse_loss=True + ) + self.assertEqual(loss.shape, []) + self.assertTrue(paddle.isfinite(loss).item()) + + def test_backward_sparse_loss(self): + q = paddle.randn([self.sq, self.b, self.h, self.d], dtype="float32") + q.stop_gradient = False + weights = paddle.randn([self.sq, self.b, self.h], dtype="float32") + weights.stop_gradient = False + k = paddle.randn([self.sk, self.b, self.d], dtype="float32") + k.stop_gradient = False + query = paddle.randn( + [self.sq, self.b, self.np, self.hn], dtype="float32" + ) + key = paddle.randn([self.sk, self.b, self.np, self.hn], dtype="float32") + causal = paddle.triu( + paddle.full([self.sq, self.sk], float("-inf"), dtype="float32"), + diagonal=1, + ) + mask = causal.unsqueeze(0).unsqueeze(0) + + loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query, + key, + self.softmax_scale, + self.topk, + 1.0, + mask, + True, + None, + ) + loss.backward() + self.assertIsNotNone(q.grad) + self.assertIsNotNone(weights.grad) + self.assertIsNotNone(k.grad) + self.assertTrue(paddle.isfinite(q.grad).all().item()) + self.assertTrue(paddle.isfinite(weights.grad).all().item()) + self.assertTrue(paddle.isfinite(k.grad).all().item()) + + +class TestDSAIndexerLossAutoScalerAdditional(unittest.TestCase): + """Additional tests for DSAIndexerLossAutoScaler edge cases.""" + + def _make_non_leaf_output(self, shape): + x = paddle.randn(shape) + x.stop_gradient = False + return x + 0 + + def test_backward_without_loss_scale(self): + """When _main_loss_backward_scale is None, backward should use ones.""" + DSAIndexerLossAutoScaler._main_loss_backward_scale = None + output = self._make_non_leaf_output([2, 4]) + indexer_loss = paddle.to_tensor(1.0, dtype="float32") + indexer_loss.stop_gradient = False + indexer_loss = indexer_loss * 1.0 + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + loss = result.sum() + loss.backward() + # Should not error; the None-scale path creates ones + self.assertTrue(True) + + def test_set_loss_scale_stores_value(self): + """set_loss_scale should store the scale tensor.""" + scale = paddle.to_tensor(3.14, dtype="float32") + DSAIndexerLossAutoScaler.set_loss_scale(scale) + stored = DSAIndexerLossAutoScaler._main_loss_backward_scale + self.assertIsNotNone(stored) + self.assertAlmostEqual(stored.item(), 3.14, places=4) + DSAIndexerLossAutoScaler._main_loss_backward_scale = None + + def test_forward_preserves_value(self): + """Forward should return output with the same values (passthrough).""" + x = paddle.randn([3, 5]) + x.stop_gradient = False + output = x + 0 + indexer_loss = paddle.to_tensor(0.0, dtype="float32") + indexer_loss.stop_gradient = False + indexer_loss = indexer_loss + 0 + + result = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + self.assertTrue(paddle.allclose(result, output, atol=1e-7)) + + +class TestMLASelfAttentionWithDSASparseLoss(unittest.TestCase): + """Integration test for MLASelfAttentionWithDSA with sparse_loss enabled.""" + + def setUp(self): + self.config = _create_dsa_config(indexer_use_sparse_loss=True) + self.micro_batch_size = 2 + self.sequence_length = 32 + + def _build_model(self, config=None): + cfg = config or self.config + model = MLASelfAttentionWithDSA( + cfg, + _create_sublayers_spec(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + model = model.to(dtype="bfloat16") + model.indexer.weights_proj = model.indexer.weights_proj.to( + dtype="float32" + ) + return model + + def _make_hidden(self, dtype="bfloat16"): + return paddle.randn( + [ + self.micro_batch_size, + self.sequence_length, + self.config.hidden_size, + ] + ).cast(dtype) + + def test_forward_with_sparse_loss(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + output, bias = model(hidden, attention_mask=None) + self.assertEqual(output.shape[0], self.micro_batch_size) + self.assertEqual(output.shape[1], self.sequence_length) + + def test_backward_with_sparse_loss(self): + model = self._build_model() + model.train() + hidden = self._make_hidden() + hidden.stop_gradient = False + output, bias = model(hidden, attention_mask=None) + loss = output.cast("float32").sum() + loss.backward() + self.assertIsNotNone(hidden.grad) + + +class TestMLASelfAttentionWithDSAZeroLossCoeff(unittest.TestCase): + """Test MLASelfAttentionWithDSA with indexer_loss_coeff=0.""" + + def setUp(self): + self.config = _create_dsa_config(indexer_loss_coeff=0.0) + self.micro_batch_size = 2 + self.sequence_length = 32 + + def _build_model(self): + model = MLASelfAttentionWithDSA( + self.config, + _create_sublayers_spec(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + model = model.to(dtype="bfloat16") + model.indexer.weights_proj = model.indexer.weights_proj.to( + dtype="float32" + ) + return model + + def test_forward_zero_loss_coeff(self): + """With loss_coeff=0, the model should still compute loss but skip logging.""" + model = self._build_model() + model.train() + hidden = paddle.randn( + [ + self.micro_batch_size, + self.sequence_length, + self.config.hidden_size, + ] + ).cast("bfloat16") + DSAIndexerLossLoggingHelper.tracker = {} + output, bias = model(hidden, attention_mask=None) + self.assertEqual(output.shape[0], self.micro_batch_size) + # loss_coeff=0 means save_loss_to_tracker is NOT called (coeff <= 0 check) + # Actually the code checks `if self.dsa_indexer_loss_coeff > 0` + self.assertNotIn("values", DSAIndexerLossLoggingHelper.tracker) + DSAIndexerLossLoggingHelper.tracker = {} + + +if __name__ == "__main__": + unittest.main()