Skip to content
13 changes: 12 additions & 1 deletion src/paddlefleet/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
SelfAttention,
SelfAttentionSublayersSpec,
)
from paddlefleet.transformer.dsa_attention import (
MLASelfAttentionWithDSA,
)
from paddlefleet.transformer.enums import AttnMaskType
from paddlefleet.transformer.identity_op import IdentityOp
from paddlefleet.transformer.mlp import MLP, MLPSublayersSpec
Expand Down Expand Up @@ -124,12 +127,20 @@ def get_gpt_layer_local_spec(

if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."

# Decide attention class: DSA variant if index_n_heads is configured
use_dsa = (
config is not None
and getattr(config, "index_n_heads", None) is not None
)
attn_cls = MLASelfAttentionWithDSA if use_dsa else MLASelfAttention

return LayerSpec(
layer=transformer_cls,
sublayers_spec=TransformerLayerSublayersSpec(
input_layernorm=layer_norm,
self_attn=LayerSpec(
layer=MLASelfAttention,
layer=attn_cls,
extra_kwargs={"attn_mask_type": AttnMaskType.causal},
sublayers_spec=MLASelfAttentionSublayersSpec(
q_proj=backend.column_parallel_linear(),
Expand Down
Loading
Loading