MPVAE: Disentangled Variational Autoencoder based Multi-Label Classification with Covariance-Aware Multivariate Probit Model
Junwen Bai, Shufeng Kong, Carla Gomes
IJCAI-PRICAI 2020
[paper]
In this paper, we propose Multi-variate Probit based Variational AutoEncoder (MPVAE) to 1) align the label embedding subspace and the feature embedding subspace and 2) handle the correlations between labels via classic Multi-variate Probit model. MPVAE improves both the embedding space learning and label correlation encoding. Furthermore, β-VAE brings disentanglement effects and could improve the performance compared to vanilla VAE.
- Python 3.7+
- TensorFlow 1.15.0
- numpy 1.17.3
- sklearn 0.22.1
Older versions might work as well.
A PyTorch implementation of MPVAE can be found here.
git clone this repo to your local machine.
All datasets can be downloaded from the Google drive Baidu drive.
The downloaded datasets are already in the format that can be recognized by the code.
The downloaded datasets are organized in the npy format. There are 4 npy files in total. One contains the data entries and the others are indices for train, validation and test splits. For example, mirflickr dataset has 4 npy files: mirflickr_data.npy, mirflickr_train_idx.npy, mirflickr_val_idx.npy, mirflickr_test_idx.npy.
The other 3 npy files are just the lists of indices for different splits.
We use mirflickr as the running example here. The detailed descriptions of FLAGS can be found in config.py.
To train the model, use the following script:
./run_train_mirflickr.shThe best validation checkpoint will be written into run_test_mirflickr.sh automatically, if one sets the flag write_to_test_sh to True and specifies the path to the test bash with flag test_sh_path.
To test the model, use the following script:
./run_test_mirflickr.shThe default hyper-parameters should give reasonably good results.
If you have any questions, feel free to open an issue.
One can further check the scripts under scripts folder, which contains tuned hyperparameters for most datasets.
If you find our paper interesting, or will use the datasets we collected, please cite our paper:
@inproceedings{bai2021disentangled,
title={Disentangled variational autoencoder based multi-label classification with covariance-aware multivariate probit model},
author={Bai, Junwen and Kong, Shufeng and Gomes, Carla},
booktitle={Proceedings of the Twenty-Ninth International Conference on International Joint Conferences on Artificial Intelligence},
pages={4313--4321},
year={2021}
}
