A detector ensembled with Swin-Transformer and CLIP
This project implements an ensemble model for detecting AI-generated images, combining a fine-tuned Swin-Transformer and a CLIP-based feature classifier. The Swin-Transformer is fine-tuned for image classification, while CLIP extracts robust features that are classified using a custom neural network. The final prediction is an ensemble of both models' outputs.
To run this project, install the following Python packages:
pip install torch torchvision timm
pip install git+https://github.com/openai/CLIP.gitAdditional dependencies (automatically installed with the above):
numpyscikit-learnpillowtqdm
Ensure you have a CUDA-enabled GPU for optimal performance, though the code supports CPU execution as well.
-
Dataset Preparation: Modify the
dataset_pathvariable in the code to point to your dataset directory:dataset_path = "./AIGC-Detection-Dataset"
The dataset should have the following structure:
AIGC-Detection-Dataset/ ├── train/ │ ├── 0_real/ │ └── 1_fake/ └── val/ ├── 0_real/ └── 1_fake/ -
Pretrained Models: The code uses pretrained weights for Swin-Transformer (
swinv2_small_window16_256) and CLIP (ViT-L/14@336px), which are downloaded automatically viatimmandclip.
The Swin-Transformer is fine-tuned on the dataset with specific layers unfrozen for training. Key configurations include:
- Model:
swinv2_small_window16_256 - Batch Size: 64
- Epochs: 100
- Optimizer: AdamW (lr=1e-4, weight_decay=1e-4)
- Scheduler: CosineAnnealingWarmRestarts
- Loss: CrossEntropyLoss with label smoothing (0.05)
- Run the fine-tuning script (provided in the code).
- Models are saved per epoch in
Fine_Tuned_Swin_Models/. - The final model is selected based on the largest validation loss among epochs 50-100 with validation accuracy > 0.99. This is saved as
swin_model.pth.
Validation accuracy reflects intra-domain performance, which may not generalize across domains. A model with slightly lower intra-domain accuracy (and higher loss) might generalize better in cross-domain scenarios.
CLIP (ViT-L/14@336px) extracts image features, which are then classified using a custom neural network.
- Input: Training images
- Augmentations:
- Padding to 336x336
- Center crop to 336x336
- Horizontal flip
- TenCrop (four corners, center, and their flipped versions)
- Output: Augmented images saved in
train_augmented/
- Features are extracted using CLIP’s
encode_imagemethod. - Features are scaled using
StandardScaler(saved astrained_scaler.pkl).
- Architecture:
ComplexClassifier(input_dim=768, hidden_dim=512, output_dim=1) - Optimizer: SGD (lr=0.9, momentum=0.99, weight_decay=1e-4)
- Loss: BCEWithLogitsLoss
- Epochs: 100
- Selection: Model with the highest validation accuracy is saved as
model.pth.
Since CLIP is not fine-tuned for this task, validation accuracy is assumed to correlate strongly with test set performance.
The final prediction combines outputs from both models:
- CLIP Probability: Weight = 0.489
- Swin Probability: Weight = 0.511
- Threshold: Combined score > 0.5 indicates an AI-generated image.
See the testing section for implementation details.
The provided test function evaluates the ensemble model:
def test(model, swin_model, test_dataset_path):
# Load models, dataset, and compute predictions
# Returns accuracydata_loader: Custom dataset class to preprocess images and extract CLIP features.ComplexClassifier: Loadsmodel.pth.- Swin-Transformer: Loads
swin_model.pth. - Usage:
test_dataset_path = "./AIGC-Detection-Dataset/val" accuracy = test(model, swin_model, test_dataset_path) print(f"Test Accuracy: {accuracy:.4f}")
- Customize
data_loaderandmodelbased on your dataset structure. - Ensure
trained_scaler.pkl,model.pth, andswin_model.pthare in the working directory.