Skip to content

Diffusion transformer trained on Wikiart dataset

Notifications You must be signed in to change notification settings

Mtrya/dit-wikiart

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Model Description

This model is a DiT (diffusion transformer) trained on Wikiart dataset dataset from scratch. It is designed to generate art images given art genre and art style. Demo on Huggingface Space: demo

Model Architecture

The model largely mirrors classic DiT architecture described in the paper Scalable Diffusion Models with Transformers with slight modifications:

  1. Replaced ImageNet classes embeddings with Wikiart genres and styles embeddings;
  2. Used post-norm instead of pre-norm;
  3. Omitted final linear layer;
  4. Replaced sin-cos-2d positional embedding with learned positional embedding;
  5. Models only predict noise and don't learn sigma;
  6. Setting patch_size=2 for all model variants;
  7. Models have different size settings. Please check modeling_dit_wikiart.py in this repository for more details if you are interested.

The model has three variants:

  • S: small, num_blocks=8, hidden_size=384, num_heads=6, total_params=20M;
  • B: base, num_blocks=12, hidden_size=640, num_heads=10, total_params=90M;
  • L: large, num_blocks=16, hidden_size=896, num_heads=14, total_params=234M.

Training Procedure

  • dataset: all model variants were trained on 103K Wikiart dataset with data augmentation by horizontal flipping.
  • optimizer: AdamW with default settings.
  • learning rate: linear warmup for first 1% steps where learning rate reached a maximum of 3e-4, then cosine decay to zero in following steps.
  • epochs and batch size:
    • S: 96 epochs with batch size of 176,
    • B: 120 epochs with batch size of 192,
    • L: 144 epochs with batch size of 192
  • device:
    • S: single RTX 4060ti 16G for 24 hrs,
    • B: single RTX 4060ti 16G for 90 hrs,
    • L: single RTX 4090D 24G for 48 hrs, followed by single RTX 4060ti 16G for 100 hrs.
  • loss curve: all variants witnessed a dramatic loss in the first epoch from above 1.0000 to around 0.2000, followed by a much slower decrease to finally reach loss=0.1600 at 20th epoch. DiT-S finally reached 0.1590; DiT-B finally reached 0.1525; DiT-L finally reached 0.1510. Training is stable without loss spike.

Performance and Limitations

  • The models demonstrates basic abilities to understand genres and styles and produce visually-appealing paintings (at first glance).
  • Limitations include:
    • Failure to understand complex structures like human faces, buildings, etc.
    • Occassional modal collapse when asked to generate genres or styles rarely seen in the dataset. style like minimalism and genre like uroshi-e for example.
    • Resolution limited to 256x256
    • Trained on Wikiart dataset, therefore unable to generate out-of-scope images

How to use it

To use the model, install the "huggingface_hub" library and download modeling_dit_wikiart.py in "Files and versions" for model definition. After that you can use the model using the following code:

from modeling_dit_wikiart import DiTWikiartModel

model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)

The model is paired with stabilityai/sd-vae-ft-ema.

About

Diffusion transformer trained on Wikiart dataset

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages