This codebase computes the entire empirical Neural Tangent Kernel matrix for classical image classification datasets, defined by
Support is available for five image classification datasets (using only training data), all of which predict one of
- CIFAR-10 (
$n = 50,000$ ) - MNIST (
$n = 60,000$ ) - KMNIST (
$n = 60,000$ ) - Fashion-MNIST (
$n = 60,000$ ) - SVHN (
$n = 73,257$ )
Note that the complete empirical NTK matrix is an enormous dense matrix, and requires significant disk space for storage. Storage requirements for each dataset (independent of model size) with each datatype are as follows:
| Dataset | float16 | float32 | float64 |
|---|---|---|---|
| CIFAR-10 | 500 GiB | 1.0 TiB | 2.0 TiB |
| MNIST | 720 GiB | 1.5 TiB | 2.9 TiB |
| KMNIST | 720 GiB | 1.5 TiB | 2.9 TiB |
| FMNIST | 720 GiB | 1.5 TiB | 2.9 TiB |
| SVHN | 1.1 TiB | 2.2 TiB | 4.3 TiB |
This code is designed to run even on desktop-class GPUs for smaller ResNets (e.g. ResNet9, ResNet18). Larger models will require significantly more VRAM.
To generate empirical NTK matrices for the float16 and float32 datatypes, simply run:
python main.py --models {#MODELNAMES} --dataset {#DATASET}
where MODELNAMES is a space (not comma) separated list of models from:
- resnet9 (4.8M parameters)
- resnet18 (11.1M parameters)
- resnet34 (21.2M parameters)
- resnet50 (23.5M parameters)
- resnet68 (41.4M parameters)
- resnet101 (42.5M parameters)
- resnet152 (58.2M parameters)
- mobilenet (3.2M parameters)
- mobilenetv2 (2.3M parameters)
- vgg11 (9.2M parameters)
- vgg13 (9.4M parameters)
- lenet (62K parameters)
- wrn-28-2 (1.5M parameters)
- wrn-28-5 (9.1M parameters)
- wrn-28-10 (36.5M parameters)
- logistic (30K parameters)
- densenet121 (7.0M parameters)
and DATASET is one of (cifar, mnist, kmnist, fmnist, svhn). Optional arguments for training include
--lrlearning rate: default0.1--bsbatch size: default64--widthResNet width: default64--num_epochsnumber of epochs: default200--repeatsnumber of independently trained models: default1
Due to the size of the empirical NTK matrix, it is often worthwhile to instead compute the matrix for a subsample of the dataset. This can be done with the optional argument:
--subsamplesubsample the dataset for NTK computation: defaultNone