Skip to content

atgugu/grokking_experiments

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Grokking Experiments

Mechanistic interpretability of delayed generalization in neural networks

Python 3.11+ PyTorch Streamlit Matplotlib License: MIT

A replication and visualization suite for "Progress Measures for Grokking via Mechanistic Interpretability" (Nanda, Chan, Lieberum, Smith & Steinhardt, ICLR 2023). Trains a minimal transformer on modular arithmetic, observes the grokking phase transition, and dissects the learned trigonometric algorithm using Fourier analysis.

Grokking animation — from memorization to generalization
The grokking story in one animation: a transformer slowly replaces memorization with an elegant trigonometric algorithm.


Why Modular Arithmetic?

Modular addition is the Drosophila of grokking research. Just as fruit flies became the model organism for genetics — small, fast-reproducing, genetically tractable — the task a + b mod p has become the model problem for studying how neural networks transition from memorization to generalization. It is:

  • Simple enough to train in minutes on a single GPU
  • Structured enough that the learned algorithm has a known closed-form solution (trigonometric)
  • Rich enough to exhibit the grokking phenomenology: delayed generalization, phase transitions, and emergent circuits

This project replicates key results from Nanda et al. (2023) and extends them with 23 visualizations and an interactive Streamlit dashboard for exploring every stage of the process.


The Grokking Phenomenon

A neural network groks when it first memorizes its training data — achieving near-perfect training accuracy — then, thousands of epochs later, suddenly generalizes to unseen data (Power et al., 2022). The delay can be 10-100x the time needed to memorize.

What happens inside the network during that long plateau?

Nanda et al. (2023) showed that even while test accuracy is stuck at chance, measurable progress is occurring beneath the surface. The network is slowly replacing a memorized lookup table with an elegant trigonometric algorithm: it learns to embed inputs as points on a circle, compute cosine/sine at a handful of key frequencies, and sum the results to produce the correct output. Weight decay is the driving force — regularization pressure steadily simplifies the internal representation until the algorithmic solution becomes cheaper than brute-force memorization.

This project makes that hidden progress visible.


Key Results

Training Curves — The Classic Grokking S-Curve

Training curves

Train accuracy hits ~100% by epoch 300, but test accuracy stays at chance until ~3,000 epochs — then suddenly snaps to generalization.

Progress Measures — Hidden Progress During the Plateau

Progress measures

Nanda's 4 measures reveal that the network is quietly restructuring even when test accuracy shows no improvement.

Frequency Spectrum — Key Frequencies Emerge

Frequency spectrum

The Fourier decomposition reveals which frequencies dominate — the network learns to rely on a sparse set of trigonometric components.

Fourier Heatmap — 2D Component Norms

Fourier heatmap

A 2D view of Fourier component norms across all frequency pairs, showing the sparse structure of the learned algorithm.

Embedding Circles — Circular Geometry

Embedding circles

Learned embeddings arrange inputs as evenly-spaced points on circles — one circle per key frequency — encoding the cyclic group structure of mod p.

Attention Patterns — What Heads Learn

Attention patterns

Attention heads learn interpretable patterns: uniformly attending to both operands so the MLP can compute their sum.

Weight Decay Sweep

Weight decay sweep animation — phase transition across 7 models
Sweeping weight decay from 0.01 to 5.0 reveals a sharp phase transition: too little regularization delays grokking indefinitely, too much prevents learning altogether, and a narrow optimal range (wd~1.0) produces rapid generalization with clean Fourier sparsity.

Train Fraction Sweep

Train fraction sweep animation
Sweeping train fraction from 5% to 70% reveals a data threshold for grokking: below 30% the model never generalizes within 40K epochs, while above 30% grokking accelerates dramatically — from 8,350 epochs at 30% to just 400 at 70%.

Learning Rate Sweep

Learning rate sweep animation
Sweeping learning rate from 1e-4 to 1e-2 reveals that LR controls grokking speed across two orders of magnitude: lr=3e-3 groks in just 1,500 epochs vs. 8,350 at the default 1e-3, while lr=1e-4 never groks and lr=1e-2 groks fast but catastrophically collapses at epoch 18,400.

Operation Sweep

Operation sweep animation
Sweeping across five modular operations — addition, subtraction, multiplication, a² + b², and a³ + ab — reveals whether grokking is universal across algebraic structures and whether different operations produce different internal Fourier representations.

Depth Sweep

Depth sweep animation
Sweeping network depth from 1 to 3 layers across all five operations reveals that deeper networks are not universally better: multiplication uniquely benefits from depth (grokking 3× faster at L=3), while addition and x²+y² become unstable or fail entirely at L=3, and x³+ab never groks at any depth.

Effective Regularization Sweep

Effective regularization (wd × lr) unification animation
Testing whether grokking dynamics are controlled purely by the product eff_wd = wd × lr (the effective weight-decay per AdamW step). Each animated panel overlays three (wd, lr) decompositions that share the same eff_wd — if curves overlap, unification holds. The eff_wd=1e-3 group (baseline wd=1.0/lr=1e-3, plus wd=2.0/lr=5e-4 and wd=0.5/lr=2e-3) provides the cleanest three-way test of this mechanistic prediction.

Prime p Sweep

Prime p sweep animation
Sweeping the group size p ∈ {7, 11, 13, 17, 23, 31, 43, 59, 67, 89, 97, 113} reveals a sharp phase transition: models with p ≤ 43 (~550 training pairs) never grok despite perfect training accuracy, while p ≥ 59 (~1,000+ pairs) always grok — and counterintuitively, larger p groks faster (epoch ~ p−1.74). The number of key Fourier frequencies is universally 5 across all primes.

Fourier Deep Dive — evolution, spectra, and embedding Fourier structure

How the Fourier representation builds up over training, the frequency spectrum at convergence, and the Fourier structure of the learned embeddings.

Fourier evolution
Fourier component norms over training
Fourier spectrum strip
Frequency spectrum at convergence
Embedding Fourier
Fourier structure of learned embeddings
Neuron Analysis — activation grids, logit maps, and frequency spectra

How individual MLP neurons respond to inputs, what they contribute to the output logits, and which frequencies each neuron encodes.

Neuron activation grids
Per-neuron activation patterns over input pairs
Neuron logit map
Per-neuron contribution to output logits
Neuron frequency spectrum heatmap
Frequency content of each neuron
Logit Analysis — full vs. restricted logits, 3D surfaces, and per-sample loss

Comparing the full logit output against the restricted (key-frequency-only) reconstruction, the 3D surface of correct-class logits, and per-sample loss over training.

Logit heatmap comparison
Full vs. restricted logit heatmaps
Correct logit surface
3D surface of correct-class logit values
Per-sample loss heatmap
Per-sample loss evolution over training
Weight Matrices — heatmaps and evolution over training

The structure of learned weight matrices and how they evolve across checkpoints during training.

Weight heatmaps
Weight matrix structure after grokking
Weight evolution
Weight matrices across training checkpoints
Training Trajectories — PCA evolution, weight-space paths, and neuron clusters

Low-dimensional projections of how embeddings, weights, and neuron representations evolve through the memorization-to-generalization transition.

Embedding PCA evolution
Embedding space via PCA over training
Weight trajectory PCA
Parameter-space trajectory via PCA
Neuron clusters
Neuron clustering by frequency preference

Visualizations

23 visualization functions organized into 8 categories, covering key stages of the grokking story:

Category Plots What you see
Training Dynamics 3 Loss/accuracy curves, progress measures, phase boundaries
Fourier Analysis 5 Frequency spectra, 2D heatmaps, temporal evolution, embedding spectra
Attention Patterns 2 Average and per-input attention maps
Embedding Geometry 2 Polar circle plots, neuron frequency clusters
Weight Matrices 2 Heatmaps and checkpoint evolution
Neuron Analysis 3 Activation grids, logit maps, frequency spectrum heatmaps
Logit Analysis 3 Full vs. restricted logit comparison, 3D surfaces, per-sample loss
Trajectories 3 PCA evolution, parameter-space paths, synchronized animation

Visualizations include interpretation guidance: what to look for, what normal results look like, and what it means when something looks wrong. See the Visualization Guide for the full reference.


Model & Task

Task a + b mod 113 (all 113^2 = 12,769 pairs)
Split 30% train / 70% test
Input 3 tokens: [a, b, =]
Architecture 1-layer transformer, d_model = 128, 4 heads, d_mlp = 512
Activation ReLU (no LayerNorm)
Optimizer AdamW (lr = 1e-3, weight decay = 1.0)
Training Full-batch, 40K epochs

The choice of p = 113 (prime) ensures a clean cyclic group structure. Weight decay of 1.0 — far higher than typical — is crucial: it provides the regularization pressure that forces the transition from memorization to the trigonometric algorithm.


Quick Start

# Install
conda activate kripke
pip install -e .

# Smoke test (100 epochs, ~30 seconds)
python scripts/train_single.py --config configs/default.yaml --max-epochs 100

# Full replication (40K epochs, ~20 min on GPU)
python scripts/train_single.py --config configs/nanda_replication.yaml

# Analysis and figures
python scripts/analyze_run.py --run-dir results/<run_id>
python scripts/generate_figures.py --run-dir results/<run_id>

# Interactive dashboard
streamlit run dashboard/app.py

See Getting Started for detailed setup instructions.


Expected Results

Milestone Typical epoch
Train accuracy reaches ~100% ~300
Gini coefficient begins rising ~500
Weight norm peaks then declines ~1,000
Test accuracy begins climbing ~3,000
Test accuracy exceeds 95% ~5,000

Key frequencies after grokking: approximately {14, 35, 41, 42, 52} for p = 113.


Project Structure

configs/              Hyperparameter configs (default, nanda_replication, sweeps)
src/
  data/               Modular arithmetic data generation
  models/             GrokkingTransformer + activation hooks
  training/           Full-batch trainer + checkpointing
  analysis/           Fourier analysis, progress measures, neuron analysis
  viz/                23 visualization functions across 9 modules
scripts/              CLI entry points (train, analyze, generate figures)
dashboard/            Streamlit interactive explorer
tests/                Unit tests
docs/                 Documentation

Documentation

Document Description
Grokking Overview The science: grokking, the trig algorithm, training phases
Visualization Guide All 23 plots with interpretation guidance
Analysis Pipeline Fourier decomposition, progress measures, neuron analysis
Getting Started Installation, training, figure generation, dashboard
References Annotated bibliography

References

  • Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022). Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets. ICLR 2022 Spotlight. arXiv:2201.02177
  • Nanda, N., Chan, L., Lieberum, T., Smith, J., & Steinhardt, J. (2023). Progress Measures for Grokking via Mechanistic Interpretability. ICLR 2023. arXiv:2301.05217
  • Zhong, Z., Liu, Z., Tegmark, M., & Andreas, J. (2023). The Clock and the Pizza: Two Stories in Mechanistic Explanation of Neural Networks. NeurIPS 2023. arXiv:2306.17844

About

Mechanistic interpretability of delayed generalization in neural networks

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages