Skip to content

feat(nn): Add YaRN / NTK-aware RoPE for context length extension #169

@m96-chan

Description

@m96-chan

Summary

Implement YaRN (Yet another RoPE extensioN) and NTK-aware RoPE interpolation methods for extending context length beyond training length.

Background

When LLMs need to handle sequences longer than their training context, naive RoPE fails. Several interpolation methods have been proposed:

1. Linear Interpolation (Position Interpolation)

Simple scaling of positions: pos' = pos / scale

  • Works but degrades quality at high scales

2. NTK-aware Interpolation

Scales the base frequency instead of positions:

base' = base * scale^(dim / (dim - 2))
  • Better preserves high-frequency components
  • Used in Code Llama, Mistral

3. YaRN (Yet another RoPE extensioN)

Combines NTK with attention scaling and dimension-wise interpolation:

# Different scaling per frequency band
low_freq: no interpolation (preserves local attention)
high_freq: full interpolation
mid_freq: gradual transition
  • State-of-the-art for context extension
  • Used in Qwen2, Yi, etc.

Proposed Implementation

Native Kernels

Extend existing RoPE with interpolation parameters:

native/ops/nn/rope/
├── rope_inplace.inl          # Existing
├── rope_yarn.inl             # NEW: YaRN interpolation
└── rope_ntk.inl              # NEW: NTK-aware interpolation

Python API

# NTK-aware RoPE frequency computation
cos, sin = rope_init_ntk_aware(
    max_seq_len=8192,
    head_dim=128,
    base=10000.0,
    scale=2.0,  # 2x context extension
)

# YaRN RoPE with dimension-wise scaling
cos, sin = rope_init_yarn(
    max_seq_len=32768,
    head_dim=128,
    base=10000.0,
    scale=4.0,
    original_max_len=4096,
    beta_fast=32,
    beta_slow=1,
    mscale=0.1,  # attention scale factor
)

# Apply (same as regular RoPE)
rope_inplace(q, k, cos, sin)

Key Parameters

Method Parameters
Linear scale
NTK scale, base
YaRN scale, base, beta_fast, beta_slow, mscale, original_max_len

Tasks

  • Implement rope_init_ntk_aware() - NTK frequency scaling
  • Implement rope_init_yarn() - YaRN dimension-wise interpolation
  • Add attention scaling support for YaRN (mscale)
  • Add Python bindings
  • Add Python API in ops/nn.py
  • Add tests comparing against HuggingFace
  • Update LLM model configs to support extended context

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions