Summary
Implement YaRN (Yet another RoPE extensioN) and NTK-aware RoPE interpolation methods for extending context length beyond training length.
Background
When LLMs need to handle sequences longer than their training context, naive RoPE fails. Several interpolation methods have been proposed:
1. Linear Interpolation (Position Interpolation)
Simple scaling of positions: pos' = pos / scale
- Works but degrades quality at high scales
2. NTK-aware Interpolation
Scales the base frequency instead of positions:
base' = base * scale^(dim / (dim - 2))
- Better preserves high-frequency components
- Used in Code Llama, Mistral
3. YaRN (Yet another RoPE extensioN)
Combines NTK with attention scaling and dimension-wise interpolation:
# Different scaling per frequency band
low_freq: no interpolation (preserves local attention)
high_freq: full interpolation
mid_freq: gradual transition
- State-of-the-art for context extension
- Used in Qwen2, Yi, etc.
Proposed Implementation
Native Kernels
Extend existing RoPE with interpolation parameters:
native/ops/nn/rope/
├── rope_inplace.inl # Existing
├── rope_yarn.inl # NEW: YaRN interpolation
└── rope_ntk.inl # NEW: NTK-aware interpolation
Python API
# NTK-aware RoPE frequency computation
cos, sin = rope_init_ntk_aware(
max_seq_len=8192,
head_dim=128,
base=10000.0,
scale=2.0, # 2x context extension
)
# YaRN RoPE with dimension-wise scaling
cos, sin = rope_init_yarn(
max_seq_len=32768,
head_dim=128,
base=10000.0,
scale=4.0,
original_max_len=4096,
beta_fast=32,
beta_slow=1,
mscale=0.1, # attention scale factor
)
# Apply (same as regular RoPE)
rope_inplace(q, k, cos, sin)
Key Parameters
| Method |
Parameters |
| Linear |
scale |
| NTK |
scale, base |
| YaRN |
scale, base, beta_fast, beta_slow, mscale, original_max_len |
Tasks
References
Summary
Implement YaRN (Yet another RoPE extensioN) and NTK-aware RoPE interpolation methods for extending context length beyond training length.
Background
When LLMs need to handle sequences longer than their training context, naive RoPE fails. Several interpolation methods have been proposed:
1. Linear Interpolation (Position Interpolation)
Simple scaling of positions:
pos' = pos / scale2. NTK-aware Interpolation
Scales the base frequency instead of positions:
3. YaRN (Yet another RoPE extensioN)
Combines NTK with attention scaling and dimension-wise interpolation:
Proposed Implementation
Native Kernels
Extend existing RoPE with interpolation parameters:
Python API
Key Parameters
scalescale,basescale,base,beta_fast,beta_slow,mscale,original_max_lenTasks
rope_init_ntk_aware()- NTK frequency scalingrope_init_yarn()- YaRN dimension-wise interpolationmscale)ops/nn.pyReferences