I first updated the training part of the code, and the search will be updated soon.
The evolutionary neural architecture search for generative adversarial networks (GANs) has demonstrated promising performance for generating high-quality images. However, two challenges persist, including the long search times and unstable search results. To alleviate these problems, this paper proposes a sampling and clustering-based neural architecture search algorithm for GANs, named SCGAN, which can significantly improve searching efficiency and enhance generation quality. Two improved strategies are proposed in SCGAN. First, a constraint sampling strategy is designed to limit the parameter capacity of architectures, which calculates their architecture size and discards those exceeding a reasonable parameter threshold. Second, a clustering selection strategy is applied in each architecture iteration, which integrates a decomposition selection mechanism and a hierarchical clustering mechanism to further improve search stability. Extensive experiments on the CIFAR-10 and STL-10 datasets demonstrated that SCGAN only requires 0.4 GPU days to find a promising GAN architecture in a vast search space including approximately 10$^{15}$ networks. Our best-found GAN outperformed those obtained by other neural architecture search methods with performance metric results (IS=9.68±0.06, FID=5.54) on CIFAR-10 and (IS=12.12±0.13, FID=12.54) on STL-10.
The environment of SCGAN is more complex, training and searching are torch-based, but part of the evaluation needs to call the api of TensorFlow 2. For better reading, we provide English tutorials and 中文 tutorials.
SCGAN's search environment uses the latest version of PyTorch 2.0 and above, along with TensorFlow 2.12 and above.
- python=3.11
- pytorch=2.0.1
- tensorflow=2.12.0
- tensorflow-gan=2.1.0
Considering the difficulty of simultaneously installing and configuring Torch and TensorFlow, we have prepared pre-configured installation packages on Baidu Cloud.
Link: https://pan.baidu.com/s/1I_3zXugfGJAg6l5PEdsV_w Access Code: 83gs
After downloading, simply extract it to '/home/user/anaconda3/envs/'. The file directory is as follows.
cd /home/yangyeming/anaconda3/envs
tar -xvf torch.tar
Then, activate the relevant environment.
conda activate torch
cd SCGAN/search/
In SCGAN, we use the CIFAR-10 and STL-10 datasets for evaluation. The default datasets are stored in ./datasets/cifar10 and ./datasets/stl10. Readers can download them manually or use the data code to download them automatically.
You need to download the relevant data from EAGAN. https://github.com/marsggbo/EAGAN
cd SCGAN/search/
mkdir fid_stat
mkdir tmp
cd SCGAN/train/
mkdir fid_stat
mkdir tmp
cd SCGAN/search/
bash train_search_gen.sh
We used the training environment provided by MMD-AdversarialNAS and found that the networks trained with the MMD-loss performed well. In this step, you only need to replace the training architecture of MMD-AdversarialNAS with the one found by SCGAN.
cd SCGAN/train/
python MGPU_train_arch.py --gpu_ids 0 --num_workers 8 --gen_bs 128 --dis_bs 128 --dataset cifar10 --bottom_width 4 --img_size 32 --max_epoch_G 500 --n_critic 1 --arch arch_cifar10 --draw_arch False --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 256 --df_dim 512 --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --val_freq 5 --num_eval_imgs 50000 --exp_name arch_train_cifar10
python MGPU_train_arch.py --gpu_ids 0 --num_workers 8 --gen_bs 128 --dis_bs 128 --dataset stl10 --bottom_width 6 --img_size 48 --max_epoch_G 500 --n_critic 1 --arch arch_cifar10 --draw_arch False --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 256 --df_dim 512 --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --val_freq 5 --num_eval_imgs 50000 --exp_name arch_train_stl10
bash train_arch_stl10.sh
If you have any questions, please email me. I will respond as soon as I have time. Thank you for your attention.
Some of the codes are built by:
1.EAGAN
2.AlphaGAN
3.Inception Score code from OpenAI's Improved GAN (official).
4.FID Score code and CIFAR-10 statistics file from (official).
Thanks them for their great works!