This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels.
Ensure you have the necessary environment setup. You can create a Conda environment using the following commands:
# prepare the environment
conda create -n batfd python=3.10 -y
conda activate batfd
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia -y
pip install avdeepfake1m toml tensorboard pytorch-lightning pandas "av<14"Train the BATFD or BATFD+ model using a TOML configuration file (e.g., batfd.toml or batfd_plus.toml).
python train.py --config ./batfd.toml --data_root /path/to/AV-Deepfake1M-PlusPlusIf you meet the NaN issue when training BA-TFD+, that might be caused by the bug in PyTorch self attention ops, upgrading or changing the PyTorch version can solve it.
- Checkpoints: Model checkpoints are saved under
./ckpt/xception/. The last checkpoint is saved aslast.ckpt. - Logs: Training logs (including metrics like
train_loss,val_loss, and learning rates) are saved by PyTorch Lightning, typically in a directory named./lightning_logs/. You can view these logs using TensorBoard (tensorboard --logdir ./lightning_logs).
After training, generate predictions on a dataset subset (e.g., val, test) using infer.py. This script saves the predictions to a JSON file, which is required for evaluation.
python infer.py --config ./batfd.toml --checkpoint /path/to/checkpoint --data_root /path/to/AV-Deepfake1M-PlusPlus --subset valpython evaluate.py /path/to/prediction_file /path/to/metadata_file