Skip to content

LCZ-ctrl/cifar100-image-classification

Repository files navigation

CIFAR100-Image-Classification

A collection of modified CNN models implemented in PyTorch for the CIFAR-100 classification task. Models are adapted for $32 \times 32$ inputs while preserving their core architectural logic.
I'm using PyTorch 2.10.0+cu128 in Python 3.12.0.

Structure

├── data/
|   ├── cifar100/
|   ├── processed/
├── models/
|   ├── vgg16.py
|   ├── googlenet.py
|   ├── resnet18.py
|   └── ...
├── config.py
├── utils.py
├── prepare_data.py
├── dataset.py
├── train.py
├── predict.py
└── eval.py

Read files in order:

config.py → utils.py → prepare_data.py → dataset.py → model → train.py → eval.py → predict.py

Requirements

matplotlib==3.10.8
numpy==2.4.4
Pillow==12.2.0
torch==2.10.0+cu128
torchvision==0.25.0+cu128
tqdm==4.67.3

Dataset

The dataset comes from Kaggle website: CIFAR-100. It has a total of 100 categories.
The raw training set has 50,000 images (each category 500 images), the raw test set has 10,000 images (each category 100 images), each image is a $32 \times 32$ RGB image.



CIFAR-100

Data Preparation & Augmentation

Splitting:

To split the data, run the command -

python prepare_data.py

This will split the raw training data into 90% Training (45,000) and 10% Validation (5,000).

Augmentation:

I placed data augmentation in dataset.py, including random cropping, random horizontal flipping and AutoAugment (a reinforcement learning-based strategy that automatically searches for and applies the optimal combination of data augmentation policies). These augmentations can enrich the diversity of the training data, and improve model's robustness and generalization capability. Furthermore, data normalization is applied to stabilize the training process.

Train

To start training, run the command -

python train.py

I used a Cosine Annealing Schedule to adjust the learning rate during training -

scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)
  • T_max: Number of iterations for the learning rate to decrease to eta_min.
  • eta_min: The minimum target learning rate.

The learning rate $\eta_t$ is adjusted according to the following formula:

$$ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) $$

This strategy ensures a smooth transition from the initial learning rate to the minimum.

You can adjust hyperparameters in config.py according to your own hardware (It is recommended to train on a GPU). I used an NVIDIA GeForce RTX 2080 Ti GPU (11GB VRAM).

Prediction

To test your trained model, run the command -

python predict.py

It randomly selects an image from the test set, and displays the image with its label and the model's predicition results (green for correct, red for wrong).



Evaluation

To evaluate your trained model on the test set, run the command -

python eval.py

It will show the model's prediction accuracy on the test set.

Model Params FLOPs Epochs Top1 Accuracy Top5 Accuracy
SqueezeNet 776.74K 25.08M 200 68.03% 89.94%
VGG-13 18.95M 295.47M 200 74.55% 92.77%
VGG-16 24.26M 408.85M 200 74.73% 92.16%
VGG-19 29.58M 522.23M 200 74.49% 92.41%
GoogLeNet 6.08M 483.24M 200 78.60% 95.04%
MobileNetV2 2.35M 94.67M 200 71.41% 93.18%
MobileNetV3 4.33M 71.16M 200 73.51% 93.59%
ShuffleNetV1 999.87K 45.36M 200 73.23% 93.27%
ShuffleNetV2 1.36M 47.36M 200 73.54% 93.42%
ResNet-18 11.22M 557.94M 200 78.01% 94.27%
ResNet-34 21.33M 1.16G 200 78.04% 94.91%
ResNet-50 23.71M 1.31G 200 79.35% 95.44%
ResNet-101 42.70M 2.53G 200 79.18% 94.87%
ResNeXt-50 23.18M 1.36G 200 80.36% 95.56%
ResNeXt-101 42.33M 2.59G 200 81.91% 96.18%
DenseNet-121 7.05M 908.19M 200 80.99% 96.01%
DenseNet-169 12.64M 1.08G 200 80.87% 95.72%
DenseNet-201 18.28M 1.40G 200 80.21% 95.26%
WideResNet-40-4 8.97M 1.30G 200 80.05% 95.59%
WideResNet-28-10 36.54M 5.25G 200 81.12% 96.05%
VisionTransformer 4.78M 308.39M 200 61.92% 86.87%

About

A collection of modified CNN models implemented in PyTorch for the CIFAR-100 classification task, including ResNet, GoogLeNet, VGG, MobileNet, ResNeXt, DenseNet, ShuffleNet, etc.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages