This model is a DiT (diffusion transformer) trained on Wikiart dataset dataset from scratch. It is designed to generate art images given art genre and art style. Demo on Huggingface Space: demo
The model largely mirrors classic DiT architecture described in the paper Scalable Diffusion Models with Transformers with slight modifications:
- Replaced ImageNet classes embeddings with Wikiart genres and styles embeddings;
- Used post-norm instead of pre-norm;
- Omitted final linear layer;
- Replaced sin-cos-2d positional embedding with learned positional embedding;
- Models only predict noise and don't learn sigma;
- Setting patch_size=2 for all model variants;
- Models have different size settings. Please check modeling_dit_wikiart.py in this repository for more details if you are interested.
The model has three variants:
- S: small, num_blocks=8, hidden_size=384, num_heads=6, total_params=20M;
- B: base, num_blocks=12, hidden_size=640, num_heads=10, total_params=90M;
- L: large, num_blocks=16, hidden_size=896, num_heads=14, total_params=234M.
- dataset: all model variants were trained on 103K Wikiart dataset with data augmentation by horizontal flipping.
- optimizer: AdamW with default settings.
- learning rate: linear warmup for first 1% steps where learning rate reached a maximum of 3e-4, then cosine decay to zero in following steps.
- epochs and batch size:
- S: 96 epochs with batch size of 176,
- B: 120 epochs with batch size of 192,
- L: 144 epochs with batch size of 192
- device:
- S: single RTX 4060ti 16G for 24 hrs,
- B: single RTX 4060ti 16G for 90 hrs,
- L: single RTX 4090D 24G for 48 hrs, followed by single RTX 4060ti 16G for 100 hrs.
- loss curve: all variants witnessed a dramatic loss in the first epoch from above 1.0000 to around 0.2000, followed by a much slower decrease to finally reach loss=0.1600 at 20th epoch. DiT-S finally reached 0.1590; DiT-B finally reached 0.1525; DiT-L finally reached 0.1510. Training is stable without loss spike.
- The models demonstrates basic abilities to understand genres and styles and produce visually-appealing paintings (at first glance).
- Limitations include:
- Failure to understand complex structures like human faces, buildings, etc.
- Occassional modal collapse when asked to generate genres or styles rarely seen in the dataset. style like minimalism and genre like uroshi-e for example.
- Resolution limited to 256x256
- Trained on Wikiart dataset, therefore unable to generate out-of-scope images
To use the model, install the "huggingface_hub" library and download modeling_dit_wikiart.py in "Files and versions" for model definition. After that you can use the model using the following code:
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)The model is paired with stabilityai/sd-vae-ft-ema.