Skip to content

thibmonsel/vqvae

Repository files navigation

Vector Quantized VAE

Implementation of VQ-VAE model from https://arxiv.org/abs/1711.00937.

A nice visual representation of a VQ-GAN (instead of using a VAE, you use a GAN...)

VQ-VAE utilizes discrete latent variables and a novel training approach inspired by vector quantization (VQ). Unlike traditional VAEs where prior and posterior distributions are typically Gaussian, VQ-VAE employs categorical distributions. Samples drawn from these categorical distributions serve as indices into an embedding table. These retrieved embeddings are then fed into the decoder network.

After successfully training the VQ-VAE, an autoregressive model (e.g., a Transformer) can be trained on the discrete latent variable distribution $p(z)$. This allows for the generation of higher-quality samples.

Usage Instructions

1. Setup

Please install repository to get all required dependencies :

pip install .

2. Training Models

The following commands assume you are training on the MNIST dataset with default arguments. You may need to modify the arguments for other datasets or custom configurations.

A. Train the VQ-VAE

To train the VQ-VAE model:

python train_vqvae.py --save-plots-dir YOUR_VQVAE_SAVE_DIRECTORY

Replace YOUR_VQVAE_SAVE_DIRECTORY with your desired path to save plots and model checkpoints.

B. Train an Unconditional Transformer Prior

To train a Transformer model as an unconditional prior over the learned latent variables $z$ :

python train_prior.py --trained-vqvae-path PATH_TO_SAVED_VQVAE --save-plots-dir YOUR_PRIOR_SAVE_DIRECTORY

Replace PATH_TO_SAVED_VQVAE with the path to your trained VQ-VAE model and YOUR_PRIOR_SAVE_DIRECTORY for the prior's outputs.

C. Train a Class-Conditioned Transformer Prior

To train a Transformer model as a class-conditioned prior over the latent variables $z$:

python train_conditioned_prior.py --trained-vqvae-path PATH_TO_SAVED_VQVAE --save-plots-dir YOUR_CONDITIONED_PRIOR_SAVE_DIRECTORY

Replace PATH_TO_SAVED_VQVAE accordingly and YOUR_CONDITIONED_PRIOR_SAVE_DIRECTORY for the conditioned prior's outputs.

3. Generating Samples

To generate samples using a trained VQ-VAE and a (conditionned) Transformer prior:

python generate_examples.py --trained-vqvae-path SAVED_VQVAE_PATH --trained-prior-path SAVED_PRIOR_PATH

Ensure PATH_TO_SAVED_VQVAE and PATH_TO_SAVED_PRIOR point to your trained models. You need to add --conditionned to use the train conditioned prior model.

4. Using Different Datasets (Fashion MNIST, CIFAR10)

To switch to Fashion MNIST or CIFAR10, modify the --dataset argument in the training scripts. You will also need to update the --trained-vqvae-path, --trained-prior-path (if applicable), and --save-plots-dir arguments to reflect the paths relevant to the new dataset.

5. Automated Training Script

For convenience, the launch.sh script automates the training of a VQ-VAE, an unconditional Transformer prior, and a class-conditioned Transformer prior sequentially. You will have to manually change the bash variable DATASET in the file.

bash launch.sh

Note: This script is computationally intensive, and a GPU is highly recommended for efficient execution.

Results

Given the default arguments in launch.sh, let's see how did the training do !

MNIST

VQ-VAE training

VQ-VAE training curves

VQ-VAE reconstruction samples

VQ-VAE Histogram codebook usage

VQ-VAE sample generation from random latents

Training the (un)-conditioned prior models

Transformer Prior loss curves (Unconditional)

Transformer Prior loss curves (Class-Conditioned)

Sampling with trained prior model (Unconditional)

Generated samples from conditioned prior

Sampling with trained prior model (Class-Conditioned)

Fashion MNIST

VQ-VAE training

VQ-VAE training curves

VQ-VAE reconstruction samples

VQ-VAE Histogram codebook usage

VQ-VAE sample generation from random latents

Training the (un)-conditioned prior models

Transformer Prior loss curves (Unconditional)

Transformer Prior loss curves (Class-Conditioned)

Sampling with trained prior model (Unconditional)

Generated samples from conditioned prior

Sampling with trained prior model (Class-Conditioned)

CIFAR10

VQ-VAE training

VQ-VAE training curves

VQ-VAE reconstruction samples

VQ-VAE Histogram codebook usage

VQ-VAE sample generation from random latents

Training the (un)-conditioned prior models

Transformer Prior loss curves (Unconditional)

Transformer Prior loss curves (Class-Conditioned)

Unconditional samples from prior model

Sampling with trained prior model (Unconditional)

Generated samples from conditioned prior

Sampling with trained prior model (Class-Conditioned)

About

Reproducing VQ-VAE paper https://arxiv.org/pdf/2012.09841

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published