Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 

README.md

BA-TFD

This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels.

Requirements

Ensure you have the necessary environment setup. You can create a Conda environment using the following commands:

# prepare the environment
conda create -n batfd python=3.10 -y
conda activate batfd
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia -y
pip install avdeepfake1m toml tensorboard pytorch-lightning pandas "av<14"

Training

Train the BATFD or BATFD+ model using a TOML configuration file (e.g., batfd.toml or batfd_plus.toml).

python train.py --config ./batfd.toml --data_root /path/to/AV-Deepfake1M-PlusPlus

If you meet the NaN issue when training BA-TFD+, that might be caused by the bug in PyTorch self attention ops, upgrading or changing the PyTorch version can solve it.

Output

  • Checkpoints: Model checkpoints are saved under ./ckpt/xception/. The last checkpoint is saved as last.ckpt.
  • Logs: Training logs (including metrics like train_loss, val_loss, and learning rates) are saved by PyTorch Lightning, typically in a directory named ./lightning_logs/. You can view these logs using TensorBoard (tensorboard --logdir ./lightning_logs).

Inference

After training, generate predictions on a dataset subset (e.g., val, test) using infer.py. This script saves the predictions to a JSON file, which is required for evaluation.

python infer.py --config ./batfd.toml --checkpoint /path/to/checkpoint --data_root /path/to/AV-Deepfake1M-PlusPlus --subset val

Evaluation

python evaluate.py /path/to/prediction_file /path/to/metadata_file