Implementation of VQ-VAE model from https://arxiv.org/abs/1711.00937.
A nice visual representation of a VQ-GAN (instead of using a VAE, you use a GAN...)
VQ-VAE utilizes discrete latent variables and a novel training approach inspired by vector quantization (VQ). Unlike traditional VAEs where prior and posterior distributions are typically Gaussian, VQ-VAE employs categorical distributions. Samples drawn from these categorical distributions serve as indices into an embedding table. These retrieved embeddings are then fed into the decoder network.
After successfully training the VQ-VAE, an autoregressive model (e.g., a Transformer) can be trained on the discrete latent variable distribution
Please install repository to get all required dependencies :
pip install .The following commands assume you are training on the MNIST dataset with default arguments. You may need to modify the arguments for other datasets or custom configurations.
To train the VQ-VAE model:
python train_vqvae.py --save-plots-dir YOUR_VQVAE_SAVE_DIRECTORYReplace YOUR_VQVAE_SAVE_DIRECTORY with your desired path to save plots and model checkpoints.
To train a Transformer model as an unconditional prior over the learned latent variables
python train_prior.py --trained-vqvae-path PATH_TO_SAVED_VQVAE --save-plots-dir YOUR_PRIOR_SAVE_DIRECTORYReplace PATH_TO_SAVED_VQVAE with the path to your trained VQ-VAE model and YOUR_PRIOR_SAVE_DIRECTORY for the prior's outputs.
To train a Transformer model as a class-conditioned prior over the latent variables
python train_conditioned_prior.py --trained-vqvae-path PATH_TO_SAVED_VQVAE --save-plots-dir YOUR_CONDITIONED_PRIOR_SAVE_DIRECTORYReplace PATH_TO_SAVED_VQVAE accordingly and YOUR_CONDITIONED_PRIOR_SAVE_DIRECTORY for the conditioned prior's outputs.
To generate samples using a trained VQ-VAE and a (conditionned) Transformer prior:
python generate_examples.py --trained-vqvae-path SAVED_VQVAE_PATH --trained-prior-path SAVED_PRIOR_PATHEnsure PATH_TO_SAVED_VQVAE and PATH_TO_SAVED_PRIOR point to your trained models. You need to add --conditionned to use the train conditioned prior model.
To switch to Fashion MNIST or CIFAR10, modify the --dataset argument in the training scripts. You will also need to update the --trained-vqvae-path, --trained-prior-path (if applicable), and --save-plots-dir arguments to reflect the paths relevant to the new dataset.
For convenience, the launch.sh script automates the training of a VQ-VAE, an unconditional Transformer prior, and a class-conditioned Transformer prior sequentially. You will have to manually change the bash variable DATASET in the file.
bash launch.shNote: This script is computationally intensive, and a GPU is highly recommended for efficient execution.
Given the default arguments in launch.sh, let's see how did the training do !
VQ-VAE training curves
VQ-VAE reconstruction samples
VQ-VAE Histogram codebook usage
VQ-VAE sample generation from random latents
Sampling with trained prior model (Unconditional) |
Sampling with trained prior model (Class-Conditioned) |
VQ-VAE training curves
VQ-VAE reconstruction samples
VQ-VAE Histogram codebook usage
VQ-VAE sample generation from random latents
Sampling with trained prior model (Unconditional) |
Sampling with trained prior model (Class-Conditioned) |
VQ-VAE training curves
VQ-VAE reconstruction samples
VQ-VAE Histogram codebook usage
VQ-VAE sample generation from random latents
























