Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 85 additions & 14 deletions examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use burn::tensor::Tensor;
use hf_hub::api::sync::Api;

use diffusers_burn::pipelines::stable_diffusion::{
generate_image_ddim, StableDiffusion, StableDiffusionConfig,
generate_image_ddim, generate_image_heun, generate_image_kdpm2, generate_image_kdpm2_ancestral,
generate_image_lms, StableDiffusion, StableDiffusionConfig,
};
use diffusers_burn::pipelines::weights::{
load_clip_safetensors, load_unet_safetensors, load_vae_safetensors,
Expand All @@ -46,6 +47,21 @@ enum StableDiffusionVersion {
V2_1,
}

#[derive(Debug, Clone, Copy, ValueEnum, Default)]
enum SchedulerType {
/// DDIM - Denoising Diffusion Implicit Models (default, deterministic)
#[default]
Ddim,
/// Heun - Second-order Runge-Kutta method (more accurate)
Heun,
/// LMS - Linear Multi-Step method
Lms,
/// K-DPM2 - DPM-Solver-2 variant by @crowsonkb
Kdpm2,
/// K-DPM2 Ancestral - Stochastic K-DPM2 (adds noise each step)
Kdpm2Ancestral,
}

impl StableDiffusionVersion {
fn repo_id(&self) -> &'static str {
match self {
Expand Down Expand Up @@ -135,6 +151,11 @@ struct Args {
#[arg(long, value_enum, default_value = "v1-5")]
sd_version: StableDiffusionVersion,

/// The scheduler (sampler) to use for denoising.
/// Different schedulers trade off speed vs quality.
#[arg(long, value_enum, default_value = "ddim")]
scheduler: SchedulerType,

/// Hugging Face API token for gated models (e.g., SD 2.1).
/// Can also be set via HF_TOKEN environment variable.
#[arg(long, env = "HF_TOKEN")]
Expand Down Expand Up @@ -302,10 +323,6 @@ fn run(args: Args) -> anyhow::Result<()> {
)?,
};

// Build scheduler
println!("\nBuilding DDIM scheduler with {} steps...", args.n_steps);
let scheduler = sd_config.build_ddim_scheduler::<Backend>(args.n_steps, &device);

// Build models
println!("Building CLIP text encoder...");
let clip = sd_config.build_clip_transformer::<Backend>(&device);
Expand Down Expand Up @@ -339,18 +356,72 @@ fn run(args: Args) -> anyhow::Result<()> {
println!("\nGenerating image...");
println!(" Size: {}x{}", sd_config.width, sd_config.height);
println!(" Steps: {}", args.n_steps);
println!(" Scheduler: {:?}", args.scheduler);
println!(" Guidance scale: {}", GUIDANCE_SCALE);
println!(" Seed: {}", args.seed);

let image_tensor = generate_image_ddim(
&pipeline,
&scheduler,
&tokens,
&uncond_tokens,
GUIDANCE_SCALE,
args.seed,
&device,
);
let image_tensor = match args.scheduler {
SchedulerType::Ddim => {
let scheduler = sd_config.build_ddim_scheduler::<Backend>(args.n_steps, &device);
generate_image_ddim(
&pipeline,
&scheduler,
&tokens,
&uncond_tokens,
GUIDANCE_SCALE,
args.seed,
&device,
)
}
SchedulerType::Heun => {
let mut scheduler = sd_config.build_heun_scheduler::<Backend>(args.n_steps);
generate_image_heun(
&pipeline,
&mut scheduler,
&tokens,
&uncond_tokens,
GUIDANCE_SCALE,
args.seed,
&device,
)
}
SchedulerType::Lms => {
let mut scheduler = sd_config.build_lms_scheduler::<Backend>(args.n_steps);
generate_image_lms(
&pipeline,
&mut scheduler,
&tokens,
&uncond_tokens,
GUIDANCE_SCALE,
args.seed,
&device,
)
}
SchedulerType::Kdpm2 => {
let mut scheduler = sd_config.build_kdpm2_scheduler::<Backend>(args.n_steps);
generate_image_kdpm2(
&pipeline,
&mut scheduler,
&tokens,
&uncond_tokens,
GUIDANCE_SCALE,
args.seed,
&device,
)
}
SchedulerType::Kdpm2Ancestral => {
let mut scheduler = sd_config.build_kdpm2_ancestral_scheduler::<Backend>(args.n_steps);
generate_image_kdpm2_ancestral(
&pipeline,
&mut scheduler,
&tokens,
&uncond_tokens,
GUIDANCE_SCALE,
args.seed,
&device,
)
}
};

// Save image
println!("\nSaving image to {}...", args.output);
Expand Down
Loading