Skip to content

UNet-256 architecture with diffusion-based training on the MNIST dataset to generate unique handwritten digits (0-9)

Notifications You must be signed in to change notification settings

RubinInsert/DigitSynthesiserCNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 

Repository files navigation

DigitSynthesiserCNN - UNet Diffusion Generative Model

Description

Implemented a UNet-256 architecture with diffusion-based training on the MNIST dataset to generate unique handwritten digits (0-9), building the model from scratch in PyTorch with custom residual blocks and sinusoidal timestep embeddings.

Model Architecture

UNET-1024 Architecture

The model used in this project is of very similar design to the one listed in Ronneberger's 2015 paper. A distinct difference however is the size of the model with Ronneberger implementing a model whose deepest layer has 1024 connections, whilst mine only has 256. This is largely in part to the disparity between the size of our datasets. The MNIST's dataset contains significantly smaller images than the original paper's biomedical images, as well as much simpler relationships between pixel regions. Consequently, a reduced network capcity can still sufficiently capture the underlying structure of the handwritten digits with added bonuses like increased training and inference, as well as preventing overfitting

Results and Process

Example unique digits diffused from random noise

UNET_64_128_256_CORRECT_SAMPLING_Ex1.png

UNET_64_128_256_CORRECT_SAMPLING_Ex2.png

Training

Mismatched Sampling Function

I ran into issues whilst training as throughout my tests I found the model required particular sizing. As a result I resized the MNIST images 28x28 -> 32x32. Unfortunately I did not take into account this change into the noise sampling function. So whilst the model was pehaps improving. The improving results were not visible due to the mismatched noise functions as described in the above training.

Once I changed the sampling to take into account the resize to 32x32, there was a clear jump in understanding and quality output from the CNN evident in the figure below.

Adjusted Sampling Function to 32x32

Forward Noising Process on MNIST Dataset

Forward Diffuse Example

The figure above describes the DDPM (Denoising Diffusion Probabilistic Model) forward noising function, showing the gaussian noise it outputs applied recursively to the clean MNIST images. This is a vital aspect of a diffusion model as it assists in the denoise training of the model.

Resources

About

UNet-256 architecture with diffusion-based training on the MNIST dataset to generate unique handwritten digits (0-9)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published