Skip to content
/ SwinIR Public

SwinIR (Shifted windows Image Restoration) model for SISR (Single Image Super-Resolution) task using PyTorch

License

Notifications You must be signed in to change notification settings

ash1ra/SwinIR

Repository files navigation

SwinIR for SISR task using PyTorch

This project implements a SwinIR (Shifted Windows Transformer for Image Resoration) model for the SISR (Single Image Super-Resolution) task. The primary goal is to upscale low-resolution (LR) images by a given factor (2x, 4x, 8x) to produce super-resolution (SR) images with high fidelity and perceptual quality. This project focuses on the Classical Image Super-Resolution task.

This implementation is based on the paper SwinIR: Image Restoration Using Swin Transformer.

Demonstration

The following images compare standard bicubic interpolation with the output of the SwinIR model.

Baboon comparison image Butterfly comparison image Bird comparison image Man comparison image PPT3 comparison image

Key Features

  • The architecture utilizes the Swin Transformer as a deep feature extraction backbone to leverage shifted window-based self-attention.
  • Residual Swin Transformer Blocks (RSTB) are employed to facilitate deep feature extraction with residual connections for improved gradient flow.
  • The model effectively captures long-range dependencies and local context through the hierarchical nature of the Swin Transformer.
  • Shifted window mechanisms allow for cross-window connections, significantly enhancing the modeling power compared to standard partition-based methods.
  • A shallow feature extraction module and a high-quality reconstruction module complement the Transformer core to restore fine image details.

Datasets

Training

The model is trained on the DF2K (DIV2K + Flickr2K) dataset. The prepare_data.py script is used to crop HR images to ensure divisibility by the scaling factor and to generate corresponding LR images using MATLAB-like bicubic downsampling. During training, the SRDataset class in dataset.py further processes these images by extracting random patches and applying augmentations like horizontal/vertical flips and rotations.

Validation

The DIV2K_valid dataset is used for validation.

Testing

The test.py script is configured to evaluate the trained model on standard benchmark datasets: Set5, Set14, BSDS100, Urban100, and Manga109.

Project Structure

.
├── checkpoints/             # Model weights (.safetensors) and training states (.pth)
├── images/                  # Inference inputs and outputs
├── config.py                # Hyperparameters and file paths
├── dataset.py               # SRDataset class and image augmentations
├── prepare_data.py          # Script for generating HR and LR pairs from raw datasets 
├── inference.py             # Inference pipeline
├── models.py                # SwinIR model architecture definition
├── test.py                  # Testing pipeline
├── trainer.py               # Trainer class for model training
├── train.py                 # Training pipeline
└── utils.py                 # Utility functions

Configuration

All hyperparameters, paths, and training settings can be configured in the config.py file.

Explanation of some settings:

  • LOAD_CHECKPOINT: Set to True to resume training from the specified checkpoint (for train.py).
  • LOAD_BEST_CHECKPOINT: Set to True to resume training from the best checkpoint (for train.py).
  • TRAIN_DATASET_PATH: Path to the training directory containing HR and LR subfolders generated by prepare_data.py.
  • VAL_DATASET_PATH: Path to the validation directory containing HR and LR subfolders generated by prepare_data.py.
  • TEST_DATASETS_PATHS: List of paths to test datasets prepared by the prepare_data.py script.
  • DEV_MODE: Set to True to use a 10% subset of the training data for quick testing.

Setting Up and Running the Project

1. Installation

  1. Clone the repository:
git clone https://github.com/ash1ra/SwinIR.git
cd SwinIR
  1. Create a .venv and install dependencies:
uv sync
  1. Activate the virtual environment:
# On Windows
.venv\Scripts\activate
# On Unix or MacOS
source .venv/bin/activate

2. Data Preparation

  1. Download the DIV2K datasets (Train Data (HR images) and Validation Data (HR images)).

  2. Download the standard benchmark datasets (Set5, Set14, BSDS100, Urban100) and Manga109 dataset.

  3. Organize your raw data containing original high-resolution images:

    data/
    ├── DF2K/
    │   ├── 1.jpg
    │   └── ...
    ├── DIV2K_valid/
    │   ├── 1.jpg
    │   └── ...
    ├── Set5/
    │   ├── baboon.png
    │   └── ...
    ├── Set14/
    │   └── ...
    ...
    

    or

    data/
    ├── DF2K.txt
    ├── DIV2K_valid.txt
    ├── Set5.txt
    ├── Set14.txt
    ...
    
  4. Run prepare_data.py to generate the training/validation pairs. This script will create HR and LR_x{scaling factor} directories within your dataset path, which are required for the training process.

  5. Update the paths (TRAIN_DATASET_PATH, etc.) in config.py to point to these newly created directories.

3. Training

  1. Adjust parameters in config.py as needed.
  2. Weights & Biases (WandB) Integration: To track your experiments, set USE_WANDB = True in config.py. The trainer will log loss, learning rate, and visual samples automatically.
  3. Run the training script:
    python train.py
  4. Training progress will be logged to the console and to a file in the logs/ directory.
  5. Checkpoints will be saved in checkpoints/.

4. Testing

To evaluate the model's performance on the test datasets:

  1. Ensure the BEST_CHECKPOINT_DIR_PATH in config.py points to your trained model (e.g., checkpoints/best).
  2. Run the test script:
    python test.py
  3. The script will print the average PSNR and SSIM for each dataset.

5. Inference

The inference script supports command-line flags to specify parameters without editing the configuration file. To upscale a single image:

python inference.py -i images/input.png -o images/output.png --scale 4 --comparison

Available flags:

  • -i, --input: Path to the input image.
  • -o, --output: Path to save the result.
  • -s, --scale: Upscaling factor.
  • -ts, --tile_size: Specify a tile size (e.g., 512) for large images to avoid VRAM issues.
  • -to, --tile_overlap: Overlap between tiles (default: 32).
  • -c, --comparison: Generate a side-by-side comparison with the original and bicubic upscaling.
  • -v, --vertical: Stack comparison images verticaly.

Training Results

SwinIR model training metrics

The model was trained for 100,000 iterations with a batch size of 32 on an NVIDIA RTX 4060 Ti (8 GB), which took nearly 24 hours. The training dataset consisted of 3450 images from the DF2K dataset. The rest of the hyperparameters are specified in config.py file. The final model selected is the one with the highest PSNR on the validation set.

Benchmark Evaluation (4x Upscaling)

The final model (checkpoints/best) was evaluated on standard benchmark datasets. Metrics are calculated on the Y-channel after shaving 4px (the scaling factor) from the border.

PSNR (dB) / SSIM Comparison

Dataset SwinIR (this project) SwinIR (paper)
Set5 32.64/0.9019 32.92/0.9044
Set14 28.98/0.7908 29.09/0.7950
BSDS100 27.81/0.7449 27.92/0.7489
Urban100 27.01/0.8134 27.45/0.8254
Manga109 31.58/0.9207 32.03/0.9260

Note: Differences in results are primarily due to training constraints; I utilized a patch size of 48×48 instead of 64×64 and trained the model for 100,000 iterations compared to the original 500,000 iterations. Additionally, the learning rate in this implementation was decayed five times more frequently.

Visual Comparisons

The following images compare the standard bicubic interpolation with the output of the SwinIR model. I selected various images where the difference in results would be visible, including anime images, photos, etc.

Comparisson image 1 Comparisson image 2 Comparisson image 3 Comparisson image 4 Comparisson image 5

Acknowledgements

This implementation is based on the paper SwinIR: Image Restoration Using Swin Transformer

@misc{liang2021swinirimagerestorationusing,
      title={SwinIR: Image Restoration Using Swin Transformer}, 
      author={Jingyun Liang and Jiezhang Cao and Guolei Sun and Kai Zhang and Luc Van Gool and Radu Timofte},
      year={2021},
      eprint={2108.10257},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2108.10257}, 
}

and on the paper Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.

@misc{liu2021swintransformerhierarchicalvision,
      title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 
      author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
      year={2021},
      eprint={2103.14030},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2103.14030}, 
}

DIV2K dataset citation:

@InProceedings{Timofte_2018_CVPR_Workshops,
  author = {Timofte, Radu and Gu, Shuhang and Wu, Jiqing and Van Gool, Luc and Zhang, Lei and Yang, Ming-Hsuan and Haris, Muhammad and others},
  title = {NTIRE 2018 Challenge on Single Image Super-Resolution: Methods and Results},
  booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
  month = {June},
  year = {2018}
}

Manga109 dataset citation:

@article{mtap_matsui_2017,
    author={Yusuke Matsui and Kota Ito and Yuji Aramaki and Azuma Fujimoto and Toru Ogawa and Toshihiko Yamasaki and Kiyoharu Aizawa},
    title={Sketch-based Manga Retrieval using Manga109 Dataset},
    journal={Multimedia Tools and Applications},
    volume={76},
    number={20},
    pages={21811--21838},
    doi={10.1007/s11042-016-4020-z},
    year={2017}
}

@article{multimedia_aizawa_2020,
    author={Kiyoharu Aizawa and Azuma Fujimoto and Atsushi Otsubo and Toru Ogawa and Yusuke Matsui and Koki Tsubota and Hikaru Ikuta},
    title={Building a Manga Dataset ``Manga109'' with Annotations for Multimedia Applications},
    journal={IEEE MultiMedia},
    volume={27},
    number={2},
    pages={8--18},
    doi={10.1109/mmul.2020.2987895},
    year={2020}
}

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

About

SwinIR (Shifted windows Image Restoration) model for SISR (Single Image Super-Resolution) task using PyTorch

Topics

Resources

License

Stars

Watchers

Forks

Languages