Summary
Implement ALiBi (Attention with Linear Biases) as an alternative to RoPE for positional encoding.
ALiBi adds a linear bias to attention scores based on query-key distance, enabling better length extrapolation than absolute or rotary embeddings.
Background
ALiBi was introduced in "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (Press et al., 2022).
Key advantages:
- No learned positional embeddings
- Better length extrapolation than RoPE
- Simple implementation (bias added to attention scores)
- Used in BLOOM, MPT, and other models
Formula
attention_scores[i, j] = Q[i] @ K[j]^T - m * |i - j|
Where m is a head-specific slope computed as:
m_h = 2^(-8 * h / num_heads) # for h in [1, num_heads]
Proposed Implementation
Native Kernels
native/ops/nn/
└── alibi/
├── alibi_bias.inl # Bias computation
└── alibi_kernels.cuh # CUDA kernels
Python API
# Initialize ALiBi slopes (precomputed per model)
alibi_slopes = alibi_init_slopes(num_heads) # -> [num_heads]
# Apply ALiBi bias to attention scores (fused with SDPA)
sdpa_causal_alibi(Q, K, V, alibi_slopes)
# Or standalone bias computation
alibi_bias = alibi_compute_bias(seq_len, num_heads, alibi_slopes)
Integration with SDPA
Option A: Fused kernel sdpa_causal_alibi()
Option B: Separate bias tensor added before softmax
Tasks
References
Summary
Implement ALiBi (Attention with Linear Biases) as an alternative to RoPE for positional encoding.
ALiBi adds a linear bias to attention scores based on query-key distance, enabling better length extrapolation than absolute or rotary embeddings.
Background
ALiBi was introduced in "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (Press et al., 2022).
Key advantages:
Formula
Where
mis a head-specific slope computed as:Proposed Implementation
Native Kernels
Python API
Integration with SDPA
Option A: Fused kernel
sdpa_causal_alibi()Option B: Separate bias tensor added before softmax
Tasks
alibi_init_slopes()- compute head-specific slopesalibi_compute_bias()- compute bias matrix [num_heads, seq_len, seq_len]sdpa_causal_alibi()kernelops/nn.pyReferences