diff --git a/Cargo.toml b/Cargo.toml index 197a6fa..9b9c690 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [features] default = ["burn/default", "std"] -std = ["burn/std"] +std = ["burn/std", "regex", "thiserror", "burn-store", "safetensors", "memmap2"] # Backend accelerate = ["burn/accelerate"] @@ -17,11 +17,26 @@ wgpu = ["burn/wgpu"] [dependencies] burn = { version = "0.20.1", default-features = false } +burn-store = { version = "0.20.1", optional = true, features = ["safetensors"] } num-traits = { version = "0.2.18", default-features = false } +regex = { version = "1.10", optional = true } +safetensors = { version = "0.4", optional = true } +memmap2 = { version = "0.9", optional = true } serde = { version = "1.0.197", default-features = false, features = [ "derive", "alloc", ] } +thiserror = { version = "1.0", optional = true } -[patch.crates-io] -#burn = { git = "https://github.com/tracel-ai/burn" } +[[example]] +name = "stable-diffusion" +required-features = ["std"] + +[dev-dependencies] +anyhow = "1.0" +clap = { version = "4.4", features = ["derive", "env"] } +dirs = "5.0" +flate2 = "1.0" +hf-hub = "0.3" +image = "0.25" +ureq = "2.9" diff --git a/README.md b/README.md index b78b31b..018bab7 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# diffusers-burn: A diffusers API in Rust/Burn +# diffusers-burn -> **⚠️ This is still in development - contributors welcome!** +Stable Diffusion in Rust using [Burn](https://github.com/burn-rs/burn). Supports SD 1.5 and 2.1. -The `diffusers-burn` crate is a conversion of [diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using [burn](https://github.com/burn-rs/burn) rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1, as well as Stable Diffusion XL 1.0. +Based on [diffusers-rs](https://github.com/LaurentMazare/diffusers-rs).
@@ -18,29 +18,39 @@ _[Blaze](https://runblaze.dev) supports this project by providing ultra-fast App
-## Feature Flags +## Quick Start -This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling -the default `std` feature. +```bash +# Using wgpu backend (default, works on most GPUs) +cargo run --release --features wgpu --example stable-diffusion -- \ + --prompt "A photo of a rusty robot on a beach" -* `std` - enables the standard library. Enabled by default. -* `wgpu` - uses ndarray as the backend. Enabled by default when none specified and `std`. -* `ndarray` - uses ndarray as the backend. -* `ndarray-no-std` - uses ndarray-no-std as the backend. Enabled by default when none and `#![no_std]`. -* `ndarray-blas-accelerate` - uses ndarray with Accelerate framework (macOS only). -* `torch` - uses torch as the backend. +# Using torch backend +cargo run --release --features torch --example stable-diffusion -- \ + --prompt "A photo of a rusty robot on a beach" -## Community +# SD 2.1 at 768x768 +cargo run --release --features torch --example stable-diffusion -- \ + --sd-version v2-1 \ + --prompt "A majestic lion on a cliff at sunset" +``` + +## Backends + +| Feature | Backend | Notes | +|---------|---------|-------| +| `wgpu` | WebGPU | Cross-platform GPU support | +| `torch` | LibTorch | Requires libtorch | +| `ndarray` | ndarray | CPU only, pure Rust | -If you are excited about the project or want to contribute, don't hesitate to join our [Discord](https://discord.gg/UHtSgF6j5J)! -We try to be as welcoming as possible to everybody from any background. We're still building this out, but you can ask your questions there! +## no_std Support -## Status +This crate supports `#![no_std]` with `alloc` by disabling the default `std` feature. + +## Community -diffusers-burn is currently in active development, and is not yet complete. +Join our [Discord](https://discord.gg/UHtSgF6j5J) if you want to contribute or have questions! ## License -diffusers-burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). -See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details. Opening a pull -request is assumed to signal agreement with these licensing terms. +MIT or Apache-2.0, at your option. diff --git a/examples/stable-diffusion/Cargo.toml b/examples/stable-diffusion/Cargo.toml new file mode 100644 index 0000000..5e4d89f --- /dev/null +++ b/examples/stable-diffusion/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "stable-diffusion-example" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "stable-diffusion" +path = "main.rs" + +[features] +default = ["wgpu"] +torch = ["diffusers-burn/torch", "burn/tch"] +ndarray = ["diffusers-burn/ndarray", "burn/ndarray"] +wgpu = ["diffusers-burn/wgpu", "burn/wgpu"] + +[dependencies] +diffusers-burn = { path = "../..", features = ["std"], default-features = false } +burn = { version = "0.20.1", default-features = false } +clap = { version = "4", features = ["derive"] } +image = "0.25" +anyhow = "1.0" +hf-hub = "0.4" +ureq = "2" +flate2 = "1" +dirs = "6" diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs new file mode 100644 index 0000000..50a5dbe --- /dev/null +++ b/examples/stable-diffusion/main.rs @@ -0,0 +1,367 @@ +#![recursion_limit = "256"] +// Stable Diffusion Example +// +// This example generates images from text prompts using Stable Diffusion. +// Weights are automatically downloaded from Hugging Face Hub on first run. +// +// # Usage +// +// cargo run --release -- --prompt "a photo of a cat" +// +// The BPE vocabulary file will be downloaded automatically if not present. + +use std::fs; +use std::io::Read; +use std::path::PathBuf; + +use clap::{Parser, ValueEnum}; + +use burn::tensor::Tensor; +use hf_hub::api::sync::Api; + +use diffusers_burn::pipelines::stable_diffusion::{ + generate_image_ddim, StableDiffusion, StableDiffusionConfig, +}; +use diffusers_burn::pipelines::weights::{ + load_clip_safetensors, load_unet_safetensors, load_vae_safetensors, +}; +use diffusers_burn::transformers::{SimpleTokenizer, SimpleTokenizerConfig}; + +const GUIDANCE_SCALE: f64 = 7.5; + +#[cfg(feature = "wgpu")] +type Backend = burn::backend::Wgpu; + +#[cfg(feature = "torch")] +type Backend = burn::backend::LibTorch; + +#[cfg(feature = "ndarray")] +type Backend = burn::backend::NdArray; + +#[derive(Debug, Clone, Copy, ValueEnum)] +enum StableDiffusionVersion { + #[value(name = "v1-5")] + V1_5, + #[value(name = "v2-1")] + V2_1, +} + +impl StableDiffusionVersion { + fn repo_id(&self) -> &'static str { + match self { + StableDiffusionVersion::V1_5 => "runwayml/stable-diffusion-v1-5", + // Use community repo (ungated) instead of stabilityai/stable-diffusion-2-1 (gated) + StableDiffusionVersion::V2_1 => "sd2-community/stable-diffusion-2-1", + } + } + + fn clip_repo_id(&self) -> &'static str { + match self { + // SD 1.5 uses OpenAI's CLIP + StableDiffusionVersion::V1_5 => "openai/clip-vit-large-patch14", + // SD 2.1 community repo has CLIP in text_encoder subdirectory + StableDiffusionVersion::V2_1 => "sd2-community/stable-diffusion-2-1", + } + } + + fn clip_weights_file(&self) -> &'static str { + match self { + // SD 1.5 uses standalone CLIP model + StableDiffusionVersion::V1_5 => "model.safetensors", + // SD 2.1 has CLIP in text_encoder subdirectory + StableDiffusionVersion::V2_1 => "text_encoder/model.safetensors", + } + } + + fn tokenizer_config(&self) -> SimpleTokenizerConfig { + match self { + StableDiffusionVersion::V1_5 => SimpleTokenizerConfig::v1_5(), + StableDiffusionVersion::V2_1 => SimpleTokenizerConfig::v2_1(), + } + } +} + +#[derive(Parser)] +#[command(author, version, about = "Generate images with Stable Diffusion", long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A very realistic photo of a rusty robot walking on a sandy beach" + )] + prompt: String, + + /// The negative prompt (what to avoid in the image). + #[arg(long, default_value = "")] + negative_prompt: String, + + /// The height in pixels of the generated image. + #[arg(long)] + height: Option, + + /// The width in pixels of the generated image. + #[arg(long)] + width: Option, + + /// The UNet weight file, in .safetensors format (auto-downloaded if not specified). + #[arg(long, value_name = "FILE")] + unet_weights: Option, + + /// The CLIP weight file, in .safetensors format (auto-downloaded if not specified). + #[arg(long, value_name = "FILE")] + clip_weights: Option, + + /// The VAE weight file, in .safetensors format (auto-downloaded if not specified). + #[arg(long, value_name = "FILE")] + vae_weights: Option, + + /// The file specifying the vocabulary to use for tokenization (auto-downloaded if not specified). + #[arg(long, value_name = "FILE")] + vocab_file: Option, + + /// The number of steps to run the diffusion for. + #[arg(long, default_value_t = 30)] + n_steps: usize, + + /// The random seed to be used for the generation. + #[arg(long, default_value_t = 32)] + seed: u64, + + /// The name of the final image to generate. + #[arg(long, value_name = "FILE", default_value = "output.png")] + output: String, + + /// The Stable Diffusion version to use. + #[arg(long, value_enum, default_value = "v1-5")] + sd_version: StableDiffusionVersion, + + /// Hugging Face API token for gated models (e.g., SD 2.1). + /// Can also be set via HF_TOKEN environment variable. + #[arg(long, env = "HF_TOKEN")] + hf_token: Option, +} + +/// Downloads a file from Hugging Face Hub if not already cached. +fn download_hf_file(repo_id: &str, filename: &str, token: Option<&str>) -> anyhow::Result { + println!(" Downloading {} from {}...", filename, repo_id); + + let api = match token { + Some(t) => hf_hub::api::sync::ApiBuilder::new() + .with_token(Some(t.to_string())) + .build()?, + None => Api::new()?, + }; + + let repo = api.model(repo_id.to_string()); + match repo.get(filename) { + Ok(path) => { + println!(" Cached at: {}", path.display()); + Ok(path) + } + Err(e) => { + // Provide helpful error message for 401 errors + let err_str = e.to_string(); + if err_str.contains("401") { + anyhow::bail!( + "Authentication required for {}.\n\ + This model requires accepting the license at https://huggingface.co/{}\n\ + Then provide your HF token via --hf-token or HF_TOKEN environment variable.\n\ + Get your token at: https://huggingface.co/settings/tokens", + repo_id, + repo_id + ); + } + Err(e.into()) + } + } +} + +/// Downloads the BPE vocabulary file from OpenAI's GitHub. +fn download_bpe_vocab() -> anyhow::Result { + // Cache in HF cache directory for consistency + let cache_dir = dirs::cache_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join("huggingface") + .join("clip"); + fs::create_dir_all(&cache_dir)?; + + let vocab_path = cache_dir.join("bpe_simple_vocab_16e6.txt"); + + if vocab_path.exists() { + println!(" Using cached vocabulary at: {}", vocab_path.display()); + return Ok(vocab_path); + } + + println!(" Downloading BPE vocabulary from OpenAI GitHub..."); + let url = "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz"; + + // Download gzipped file + let response = ureq::get(url).call()?; + let mut gz_data = Vec::new(); + response.into_reader().read_to_end(&mut gz_data)?; + + // Decompress + let mut decoder = flate2::read::GzDecoder::new(&gz_data[..]); + let mut content = String::new(); + decoder.read_to_string(&mut content)?; + + // Save to cache + fs::write(&vocab_path, &content)?; + println!(" Cached at: {}", vocab_path.display()); + + Ok(vocab_path) +} + +fn tensor_to_image(tensor: Tensor) -> image::RgbImage { + // tensor shape: [1, 3, height, width], values in [0, 1] + let [_, _, height, width] = tensor.dims(); + + // Convert to [0, 255] range + let tensor = tensor * 255.0; + let data: Vec = tensor.into_data().to_vec().unwrap(); + + // Create image buffer + let mut img = image::RgbImage::new(width as u32, height as u32); + + for y in 0..height { + for x in 0..width { + // Round and clamp to [0, 255] for proper u8 conversion (matches PyTorch's to_kind(Uint8)) + let r = data[0 * height * width + y * width + x] + .round() + .clamp(0.0, 255.0) as u8; + let g = data[1 * height * width + y * width + x] + .round() + .clamp(0.0, 255.0) as u8; + let b = data[2 * height * width + y * width + x] + .round() + .clamp(0.0, 255.0) as u8; + img.put_pixel(x as u32, y as u32, image::Rgb([r, g, b])); + } + } + + img +} + +fn run(args: Args) -> anyhow::Result<()> { + let device = Default::default(); + + // Build configuration + println!("Building configuration for {:?}...", args.sd_version); + let sd_config = match args.sd_version { + StableDiffusionVersion::V1_5 => StableDiffusionConfig::v1_5(None, args.height, args.width), + StableDiffusionVersion::V2_1 => StableDiffusionConfig::v2_1(None, args.height, args.width), + }; + + // Download or use provided vocab file + // The tokenizer expects the OpenAI CLIP BPE vocabulary format + println!("\nPreparing vocabulary file..."); + let vocab_path = match &args.vocab_file { + Some(path) => PathBuf::from(path), + None => download_bpe_vocab()?, + }; + + // Build tokenizer + println!("\nLoading tokenizer from {}...", vocab_path.display()); + let tokenizer = SimpleTokenizer::new(&vocab_path, args.sd_version.tokenizer_config())?; + + // Tokenize prompts + println!("Tokenizing prompt: \"{}\"", args.prompt); + let tokens = tokenizer.encode(&args.prompt)?; + let uncond_tokens = tokenizer.encode(&args.negative_prompt)?; + println!(" Prompt tokens: {} tokens", tokens.len()); + println!(" Negative prompt tokens: {} tokens", uncond_tokens.len()); + + // Download weights + println!("\nPreparing model weights..."); + let hf_token = args.hf_token.as_deref(); + + let clip_weights = match &args.clip_weights { + Some(path) => PathBuf::from(path), + None => download_hf_file( + args.sd_version.clip_repo_id(), + args.sd_version.clip_weights_file(), + hf_token, + )?, + }; + + let vae_weights = match &args.vae_weights { + Some(path) => PathBuf::from(path), + None => download_hf_file( + args.sd_version.repo_id(), + "vae/diffusion_pytorch_model.safetensors", + hf_token, + )?, + }; + + let unet_weights = match &args.unet_weights { + Some(path) => PathBuf::from(path), + None => download_hf_file( + args.sd_version.repo_id(), + "unet/diffusion_pytorch_model.safetensors", + hf_token, + )?, + }; + + // Build scheduler + println!("\nBuilding DDIM scheduler with {} steps...", args.n_steps); + let scheduler = sd_config.build_ddim_scheduler::(args.n_steps, &device); + + // Build models + println!("Building CLIP text encoder..."); + let clip = sd_config.build_clip_transformer::(&device); + + println!("Building VAE..."); + let vae = sd_config.build_vae::(&device); + + println!("Building UNet..."); + let unet = sd_config.build_unet::(&device, 4); + + // Load weights + println!("\nLoading CLIP weights..."); + let clip = load_clip_safetensors::(clip, &clip_weights, &device)?; + + println!("Loading VAE weights..."); + let vae = load_vae_safetensors::(vae, &vae_weights, &device)?; + + println!("Loading UNet weights..."); + let unet = load_unet_safetensors::(unet, &unet_weights, &device)?; + + // Assemble pipeline + let pipeline = StableDiffusion { + clip, + vae, + unet, + width: sd_config.width, + height: sd_config.height, + }; + + // Generate image + println!("\nGenerating image..."); + println!(" Size: {}x{}", sd_config.width, sd_config.height); + println!(" Steps: {}", args.n_steps); + println!(" Guidance scale: {}", GUIDANCE_SCALE); + println!(" Seed: {}", args.seed); + + let image_tensor = generate_image_ddim( + &pipeline, + &scheduler, + &tokens, + &uncond_tokens, + GUIDANCE_SCALE, + args.seed, + &device, + ); + + // Save image + println!("\nSaving image to {}...", args.output); + let img = tensor_to_image(image_tensor); + img.save(&args.output)?; + + println!("Done!"); + Ok(()) +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + run(args) +} diff --git a/src/lib.rs b/src/lib.rs index 1e2bea9..8e12c68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub mod models; pub mod pipelines; +pub mod schedulers; pub mod transformers; pub mod utils; diff --git a/src/models/attention.rs b/src/models/attention.rs index 2326fb5..254a753 100644 --- a/src/models/attention.rs +++ b/src/models/attention.rs @@ -154,11 +154,9 @@ impl CrossAttention { fn reshape_batch_dim_to_heads(&self, xs: Tensor) -> Tensor { let [batch_size, seq_len, dim] = xs.dims(); - let output = xs - .reshape([batch_size / self.n_heads, self.n_heads, seq_len, dim]) + xs.reshape([batch_size / self.n_heads, self.n_heads, seq_len, dim]) .swap_dims(1, 2) - .reshape([batch_size / self.n_heads, seq_len, dim * self.n_heads]); - output + .reshape([batch_size / self.n_heads, seq_len, dim * self.n_heads]) } fn sliced_attention( @@ -314,7 +312,13 @@ impl Proj { fn forward(&self, xs: Tensor) -> Tensor { match self { Proj::Conv2d(conv) => conv.forward(xs), - Proj::Linear(linear) => linear.forward(xs), + Proj::Linear(linear) => { + // For linear projection, we need to permute from [batch, channels, h, w] + // to [batch, h, w, channels], apply linear, then permute back + let xs = xs.swap_dims(1, 2).swap_dims(2, 3); // [batch, h, w, channels] + let xs = linear.forward(xs); + xs.swap_dims(2, 3).swap_dims(1, 2) // [batch, channels, h, w] + } } } } @@ -325,7 +329,7 @@ pub struct SpatialTransformer { norm: GroupNorm, proj_in: Proj, transformer_blocks: Vec>, - proj_out: nn::conv::Conv2d, + proj_out: Proj, } impl SpatialTransformerConfig { @@ -352,8 +356,13 @@ impl SpatialTransformerConfig { transformer_blocks.push(tb) } - let proj_out = - nn::conv::Conv2dConfig::new([d_inner, self.in_channels], [1, 1]).init(device); + let proj_out = if self.use_linear_projection { + Proj::Linear(nn::LinearConfig::new(d_inner, self.in_channels).init(device)) + } else { + Proj::Conv2d( + nn::conv::Conv2dConfig::new([d_inner, self.in_channels], [1, 1]).init(device), + ) + }; SpatialTransformer { norm, @@ -655,4 +664,32 @@ mod tests { Tolerance::rel_abs(1e-3, 1e-3), ) } + + /// Test GELU activation matches diffusers-rs (tch gelu("none")) + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_gelu_matches_diffusers_rs() { + let device = Default::default(); + let xs: Tensor = Tensor::from_data( + TensorData::from([-2.0f32, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]), + &device, + ); + + let result = gelu(xs); + + // Reference values from diffusers-rs: gelu("none") + // [-0.04550028, -0.15865526, -0.15426877, 0.0, 0.34573123, 0.8413447, 1.9544997] + result.into_data().assert_approx_eq::( + &TensorData::from([ + -0.04550028, + -0.15865526, + -0.15426877, + 0.0, + 0.34573123, + 0.8413447, + 1.9544997, + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } } diff --git a/src/models/controlnet.rs b/src/models/controlnet.rs new file mode 100644 index 0000000..c0023a0 --- /dev/null +++ b/src/models/controlnet.rs @@ -0,0 +1,582 @@ +//! ControlNet Model +//! +//! ControlNet is a neural network structure to control diffusion models by adding extra conditions. +//! https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet.py + +use alloc::vec; +use alloc::vec::Vec; +use burn::config::Config; +use burn::module::Module; +use burn::nn::conv::{Conv2d, Conv2dConfig}; +use burn::nn::PaddingConfig2d; +use burn::tensor::activation::silu; +use burn::tensor::backend::Backend; +use burn::tensor::Tensor; + +use super::embeddings::{get_timestep_embedding, TimestepEmbedding, TimestepEmbeddingConfig}; +use super::unet_2d::{BlockConfig, UNetDownBlock}; +use super::unet_2d_blocks::{ + CrossAttnDownBlock2DConfig, DownBlock2DConfig, UNetMidBlock2DCrossAttn, + UNetMidBlock2DCrossAttnConfig, +}; + +/// ControlNet conditioning embedding module. +/// +/// Processes the conditioning image (e.g., canny edges, depth map) into +/// an embedding that can be added to the UNet features. +#[derive(Module, Debug)] +pub struct ControlNetConditioningEmbedding { + conv_in: Conv2d, + conv_out: Conv2d, + blocks: Vec<(Conv2d, Conv2d)>, +} + +/// Configuration for ControlNetConditioningEmbedding. +#[derive(Config, Debug)] +pub struct ControlNetConditioningEmbeddingConfig { + /// Output channels for the conditioning embedding. + conditioning_embedding_channels: usize, + /// Input channels from the conditioning image. + #[config(default = 3)] + conditioning_channels: usize, + /// Channel progression for the embedding blocks. + block_out_channels: Vec, +} + +impl ControlNetConditioningEmbeddingConfig { + /// Initialize the ControlNetConditioningEmbedding module. + pub fn init(&self, device: &B::Device) -> ControlNetConditioningEmbedding { + let b_channels = self.block_out_channels[0]; + let bl_channels = *self.block_out_channels.last().unwrap(); + + let conv_in = Conv2dConfig::new([self.conditioning_channels, b_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + let conv_out = + Conv2dConfig::new([bl_channels, self.conditioning_embedding_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + let blocks = (0..self.block_out_channels.len() - 1) + .map(|i| { + let channel_in = self.block_out_channels[i]; + let channel_out = self.block_out_channels[i + 1]; + + let c1 = Conv2dConfig::new([channel_in, channel_in], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + let c2 = Conv2dConfig::new([channel_in, channel_out], [3, 3]) + .with_stride([2, 2]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + (c1, c2) + }) + .collect(); + + ControlNetConditioningEmbedding { + conv_in, + conv_out, + blocks, + } + } +} + +impl ControlNetConditioningEmbedding { + /// Forward pass through the conditioning embedding. + pub fn forward(&self, xs: Tensor) -> Tensor { + let mut xs = silu(self.conv_in.forward(xs)); + + for (c1, c2) in &self.blocks { + xs = silu(c1.forward(xs)); + xs = silu(c2.forward(xs)); + } + + self.conv_out.forward(xs) + } +} + +/// Configuration for the ControlNet model. +#[derive(Config, Debug)] +pub struct ControlNetConfig { + /// Whether to flip sin to cos in timestep embedding. + #[config(default = true)] + pub flip_sin_to_cos: bool, + /// Frequency shift for timestep embedding. + #[config(default = 0.0)] + pub freq_shift: f64, + /// Configuration for each block. + pub blocks: Vec, + /// Output channels for conditioning embedding blocks. + pub conditioning_embedding_out_channels: Vec, + /// Number of ResNet layers per block. + #[config(default = 2)] + pub layers_per_block: usize, + /// Padding for downsampling convolutions. + #[config(default = 1)] + pub downsample_padding: usize, + /// Scale factor for mid block. + #[config(default = 1.0)] + pub mid_block_scale_factor: f64, + /// Number of groups for group normalization. + #[config(default = 32)] + pub norm_num_groups: usize, + /// Epsilon for normalization layers. + #[config(default = 1e-5)] + pub norm_eps: f64, + /// Dimension of cross-attention context. + #[config(default = 768)] + pub cross_attention_dim: usize, + /// Whether to use linear projection in attention. + #[config(default = false)] + pub use_linear_projection: bool, +} + +impl Default for ControlNetConfig { + fn default() -> Self { + Self { + flip_sin_to_cos: true, + freq_shift: 0.0, + blocks: vec![ + BlockConfig::new(320) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(640) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(1280) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(1280) + .with_use_cross_attn(false) + .with_attention_head_dim(8), + ], + conditioning_embedding_out_channels: vec![16, 32, 96, 256], + layers_per_block: 2, + downsample_padding: 1, + mid_block_scale_factor: 1.0, + norm_num_groups: 32, + norm_eps: 1e-5, + cross_attention_dim: 768, + use_linear_projection: false, + } + } +} + +/// ControlNet model for adding spatial conditioning to diffusion models. +/// +/// ControlNet copies the weights of a pretrained UNet and adds zero convolution +/// layers to inject conditioning information. The output consists of residuals +/// that are added to the corresponding layers of the UNet. +#[derive(Module, Debug)] +pub struct ControlNet { + conv_in: Conv2d, + controlnet_mid_block: Conv2d, + controlnet_cond_embedding: ControlNetConditioningEmbedding, + time_embedding: TimestepEmbedding, + down_blocks: Vec>, + controlnet_down_blocks: Vec>, + mid_block: UNetMidBlock2DCrossAttn, + #[module(skip)] + time_proj_channels: usize, + #[module(skip)] + flip_sin_to_cos: bool, + #[module(skip)] + freq_shift: f64, +} + +impl ControlNetConfig { + /// Initialize the ControlNet model. + pub fn init(&self, in_channels: usize, device: &B::Device) -> ControlNet { + let n_blocks = self.blocks.len(); + let b_channels = self.blocks[0].out_channels; + let bl_channels = self.blocks.last().unwrap().out_channels; + let bl_attention_head_dim = self.blocks.last().unwrap().attention_head_dim; + let time_embed_dim = b_channels * 4; + + // Time embeddings + let time_embedding = TimestepEmbeddingConfig::new(b_channels, time_embed_dim).init(device); + + // Input convolution + let conv_in = Conv2dConfig::new([in_channels, b_channels], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + // ControlNet mid block (1x1 conv, zero initialized in practice) + let controlnet_mid_block = + Conv2dConfig::new([bl_channels, bl_channels], [1, 1]).init(device); + + // Conditioning embedding + let controlnet_cond_embedding = ControlNetConditioningEmbeddingConfig::new( + b_channels, + self.conditioning_embedding_out_channels.clone(), + ) + .init(device); + + // Down blocks + let down_blocks: Vec> = (0..n_blocks) + .map(|i| { + let block_config = &self.blocks[i]; + let out_channels = block_config.out_channels; + let attention_head_dim = block_config.attention_head_dim; + + let in_ch = if i > 0 { + self.blocks[i - 1].out_channels + } else { + b_channels + }; + + let db_config = DownBlock2DConfig::new(in_ch, out_channels) + .with_temb_channels(Some(time_embed_dim)) + .with_n_layers(self.layers_per_block) + .with_resnet_eps(self.norm_eps) + .with_resnet_groups(self.norm_num_groups) + .with_add_downsample(i < n_blocks - 1) + .with_downsample_padding(self.downsample_padding); + + if block_config.use_cross_attn { + let config = CrossAttnDownBlock2DConfig::new(in_ch, out_channels, db_config) + .with_temb_channels(Some(time_embed_dim)) + .with_attn_num_head_channels(attention_head_dim) + .with_cross_attention_dim(self.cross_attention_dim) + .with_use_linear_projection(self.use_linear_projection); + UNetDownBlock::CrossAttn(config.init(device)) + } else { + UNetDownBlock::Basic(db_config.init(device)) + } + }) + .collect(); + + // Mid block + let mid_config = UNetMidBlock2DCrossAttnConfig::new(bl_channels) + .with_temb_channels(Some(time_embed_dim)) + .with_resnet_eps(self.norm_eps) + .with_output_scale_factor(self.mid_block_scale_factor) + .with_cross_attn_dim(self.cross_attention_dim) + .with_attn_num_head_channels(bl_attention_head_dim) + .with_resnet_groups(Some(self.norm_num_groups)) + .with_use_linear_projection(self.use_linear_projection); + let mid_block = mid_config.init(device); + + // ControlNet down blocks (1x1 convs, zero initialized in practice) + let mut controlnet_down_blocks = + vec![Conv2dConfig::new([b_channels, b_channels], [1, 1]).init(device)]; + + for (i, block) in self.blocks.iter().enumerate() { + let out_channels = block.out_channels; + for _ in 0..self.layers_per_block { + controlnet_down_blocks + .push(Conv2dConfig::new([out_channels, out_channels], [1, 1]).init(device)); + } + if i + 1 != self.blocks.len() { + controlnet_down_blocks + .push(Conv2dConfig::new([out_channels, out_channels], [1, 1]).init(device)); + } + } + + ControlNet { + conv_in, + controlnet_mid_block, + controlnet_cond_embedding, + time_embedding, + down_blocks, + controlnet_down_blocks, + mid_block, + time_proj_channels: b_channels, + flip_sin_to_cos: self.flip_sin_to_cos, + freq_shift: self.freq_shift, + } + } +} + +impl ControlNet { + /// Forward pass through ControlNet. + /// + /// # Arguments + /// * `xs` - Noisy input tensor [batch, channels, height, width] + /// * `timestep` - Current diffusion timestep + /// * `encoder_hidden_states` - Encoder hidden states for cross-attention [batch, seq_len, dim] + /// * `controlnet_cond` - Conditioning image [batch, 3, height, width] + /// * `conditioning_scale` - Scale factor for the conditioning (typically 1.0) + /// + /// # Returns + /// A tuple of (down_block_residuals, mid_block_residual) to be added to the UNet. + pub fn forward( + &self, + xs: Tensor, + timestep: f64, + encoder_hidden_states: Tensor, + controlnet_cond: Tensor, + conditioning_scale: f64, + ) -> (Vec>, Tensor) { + let [bsize, _channels, _height, _width] = xs.dims(); + let device = xs.device(); + + // 1. Time embedding + let timesteps: Tensor = Tensor::full([bsize], timestep as f32, &device); + let emb = get_timestep_embedding( + timesteps, + self.time_proj_channels, + self.flip_sin_to_cos, + self.freq_shift, + ); + let emb = self.time_embedding.forward(emb); + + // 2. Pre-process + let xs = self.conv_in.forward(xs); + let controlnet_cond = self.controlnet_cond_embedding.forward(controlnet_cond); + let xs = xs + controlnet_cond; + + // 3. Down blocks + let mut down_block_res_xs = vec![xs.clone()]; + let mut xs = xs; + for down_block in self.down_blocks.iter() { + let (new_xs, res_xs) = match down_block { + UNetDownBlock::Basic(b) => b.forward(xs, Some(emb.clone())), + UNetDownBlock::CrossAttn(b) => { + b.forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone())) + } + }; + down_block_res_xs.extend(res_xs); + xs = new_xs; + } + + // 4. Mid block + let xs = self + .mid_block + .forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone())); + + // 5. ControlNet blocks - apply 1x1 convs and scale + let controlnet_down_block_res_xs: Vec> = self + .controlnet_down_blocks + .iter() + .enumerate() + .map(|(i, block)| block.forward(down_block_res_xs[i].clone()) * conditioning_scale) + .collect(); + + let mid_block_res = self.controlnet_mid_block.forward(xs) * conditioning_scale; + + (controlnet_down_block_res_xs, mid_block_res) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use alloc::string::{String, ToString}; + use burn::module::{Module, ModuleMapper, Param}; + use burn::tensor::{Shape, TensorData}; + + #[test] + fn test_controlnet_conditioning_embedding_shape() { + let device = Default::default(); + + let config = ControlNetConditioningEmbeddingConfig::new(320, vec![16, 32, 96, 256]); + let embedding = config.init::(&device); + + // Input: [batch=1, channels=3, height=64, width=64] + let xs: Tensor = Tensor::zeros([1, 3, 64, 64], &device); + let output = embedding.forward(xs); + + // Output should have 320 channels and be downsampled by 2^3 = 8 + // 64 / 8 = 8 + assert_eq!(output.shape(), Shape::from([1, 320, 8, 8])); + } + + #[test] + fn test_controlnet_output_shape() { + let device = Default::default(); + + // Create a small ControlNet for testing + // Note: conditioning_embedding_out_channels length determines downsampling factor + // With [16] (1 element), there's no downsampling in the conditioning embedding + let config = ControlNetConfig { + blocks: vec![ + BlockConfig::new(32) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(64) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + ], + conditioning_embedding_out_channels: vec![16], // Single element = no downsampling + layers_per_block: 1, + norm_num_groups: 32, + cross_attention_dim: 64, + ..Default::default() + }; + + let controlnet = config.init::(4, &device); + + // Input: [batch=1, channels=4, height=32, width=32] + let xs: Tensor = Tensor::zeros([1, 4, 32, 32], &device); + // Encoder hidden states: [batch=1, seq_len=8, dim=64] + let encoder_hidden_states: Tensor = Tensor::zeros([1, 8, 64], &device); + // Conditioning image: [batch=1, channels=3, height=32, width=32] + let controlnet_cond: Tensor = Tensor::zeros([1, 3, 32, 32], &device); + + let (down_residuals, mid_residual) = + controlnet.forward(xs, 1.0, encoder_hidden_states, controlnet_cond, 1.0); + + // Should have residuals for each down block output + // With 2 blocks and 1 layer each, plus initial conv_in: + // - 1 from conv_in (32 channels) + // - 1 from block 0 resnet (32 channels) + // - 1 from block 0 downsample (32 channels) + // - 1 from block 1 resnet (64 channels) + // Total: 4 residuals + assert_eq!(down_residuals.len(), 4); + + // Mid block residual should match mid block channels + assert_eq!(mid_residual.dims()[1], 64); + } + + /// A ModuleMapper that sets weights to one value and biases to another. + struct WeightBiasMapper<'a, B: Backend> { + weight_val: f32, + bias_val: f32, + device: &'a B::Device, + current_field: String, + } + + impl<'a, B: Backend> ModuleMapper for WeightBiasMapper<'a, B> { + fn enter_module(&mut self, name: &str, _container_type: &str) { + self.current_field = name.to_string(); + } + + fn map_float( + &mut self, + tensor: Param>, + ) -> Param> { + let shape = tensor.shape(); + let dims: [usize; D] = shape.dims(); + + let is_bias = self.current_field.contains("bias") || self.current_field == "beta"; + let val = if is_bias { + self.bias_val + } else { + self.weight_val + }; + + Param::initialized(tensor.id, Tensor::full(dims, val, self.device)) + } + } + + fn set_weights_and_biases( + controlnet: ControlNet, + weight_val: f32, + bias_val: f32, + device: &B::Device, + ) -> ControlNet { + let mut mapper = WeightBiasMapper:: { + weight_val, + bias_val, + device, + current_field: String::new(), + }; + controlnet.map(&mut mapper) + } + + /// Test ControlNet with fixed weights matches diffusers-rs + /// Note: Skipped for wgpu due to GPU floating-point precision differences + #[test] + #[cfg(not(feature = "wgpu"))] + fn test_controlnet_fixed_weights_matches_diffusers_rs() { + let device = Default::default(); + + // Create a small ControlNet matching the diffusers-rs test config + let config = ControlNetConfig { + blocks: vec![ + BlockConfig::new(32) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(64) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + ], + conditioning_embedding_out_channels: vec![16], + layers_per_block: 1, + norm_num_groups: 32, + cross_attention_dim: 64, + ..Default::default() + }; + + let controlnet = config.init::(4, &device); + + // Set weights to 0.1, biases to 0.0 + let controlnet = set_weights_and_biases(controlnet, 0.1, 0.0, &device); + + // Input tensor [batch=1, channels=4, height=32, width=32] + let input_data: Vec = (0..(4 * 32 * 32)) + .map(|i| i as f32 / (4.0 * 32.0 * 32.0)) + .collect(); + let xs: Tensor = + Tensor::from_data(TensorData::from(input_data.as_slice()), &device); + let xs: Tensor = xs.reshape([1, 4, 32, 32]); + + // Encoder hidden states [batch=1, seq_len=8, dim=64] + let enc_data: Vec = (0..(8 * 64)).map(|i| i as f32 / (8.0 * 64.0)).collect(); + let encoder_hidden_states: Tensor = + Tensor::from_data(TensorData::from(enc_data.as_slice()), &device); + let encoder_hidden_states: Tensor = + encoder_hidden_states.reshape([1, 8, 64]); + + // Conditioning image [batch=1, channels=3, height=32, width=32] + let cond_data: Vec = (0..(3 * 32 * 32)) + .map(|i| i as f32 / (3.0 * 32.0 * 32.0)) + .collect(); + let controlnet_cond: Tensor = + Tensor::from_data(TensorData::from(cond_data.as_slice()), &device); + let controlnet_cond: Tensor = controlnet_cond.reshape([1, 3, 32, 32]); + + let (down_residuals, mid_residual) = + controlnet.forward(xs, 1.0, encoder_hidden_states, controlnet_cond, 1.0); + + // Reference values from diffusers-rs + assert_eq!(down_residuals.len(), 4); + + // Down residual 0: shape=[1, 32, 32, 32], mean=51.52621078491211 + let mean0: f32 = down_residuals[0].clone().mean().into_scalar(); + assert!( + (mean0 - 51.526210).abs() < 0.1, + "Down residual 0 mean mismatch: actual={}, expected=51.526210", + mean0 + ); + + // Down residual 1: shape=[1, 32, 32, 32], mean=156.5426025390625 + let mean1: f32 = down_residuals[1].clone().mean().into_scalar(); + assert!( + (mean1 - 156.54260).abs() < 0.1, + "Down residual 1 mean mismatch: actual={}, expected=156.54260", + mean1 + ); + + // Down residual 2: shape=[1, 32, 16, 16], mean=4349.87353515625 + let mean2: f32 = down_residuals[2].clone().mean().into_scalar(); + assert!( + (mean2 - 4349.8735).abs() < 1.0, + "Down residual 2 mean mismatch: actual={}, expected=4349.8735", + mean2 + ); + + // Down residual 3: shape=[1, 64, 16, 16], mean=28677.99609375 + let mean3: f32 = down_residuals[3].clone().mean().into_scalar(); + assert!( + (mean3 - 28677.996).abs() < 10.0, + "Down residual 3 mean mismatch: actual={}, expected=28677.996", + mean3 + ); + + // Mid residual: shape=[1, 64, 16, 16], mean=29518.357421875 + let mid_mean: f32 = mid_residual.clone().mean().into_scalar(); + assert!( + (mid_mean - 29518.357).abs() < 10.0, + "Mid residual mean mismatch: actual={}, expected=29518.357", + mid_mean + ); + } +} diff --git a/src/models/embeddings.rs b/src/models/embeddings.rs index 83b663e..071770e 100644 --- a/src/models/embeddings.rs +++ b/src/models/embeddings.rs @@ -6,7 +6,6 @@ use burn::nn::{Linear, LinearConfig}; use burn::tensor::activation::silu; use burn::tensor::backend::Backend; use burn::tensor::Tensor; -use core::marker::PhantomData; #[cfg(not(feature = "std"))] #[allow(unused_imports)] @@ -35,49 +34,49 @@ impl TimestepEmbeddingConfig { } impl TimestepEmbedding { - fn forward(&self, xs: Tensor) -> Tensor { + pub fn forward(&self, xs: Tensor) -> Tensor { let xs = silu(self.linear_1.forward(xs)); self.linear_2.forward(xs) } } -#[derive(Module, Debug)] -pub struct Timesteps { +/// Computes sinusoidal timestep embeddings. +/// +/// This is a pure function with no learnable parameters. +/// It generates positional embeddings for diffusion timesteps using sinusoidal encoding. +/// +/// # Arguments +/// * `timesteps` - 1D tensor of timestep values +/// * `num_channels` - Number of embedding channels to produce +/// * `flip_sin_to_cos` - If true, output is [cos, sin]; if false, output is [sin, cos] +/// * `downscale_freq_shift` - Frequency shift applied to the exponent denominator +/// +/// # Returns +/// A 2D tensor of shape [batch_size, num_channels] containing the timestep embeddings. +pub fn get_timestep_embedding( + timesteps: Tensor, num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64, - _backend: PhantomData, -} - -impl Timesteps { - pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self { - Self { - num_channels, - flip_sin_to_cos, - - downscale_freq_shift, - _backend: PhantomData, - } - } - - pub fn forward(&self, xs: Tensor) -> Tensor { - let half_dim = self.num_channels / 2; - let exponent = Tensor::arange(0..half_dim as i64, &xs.device()).float() * -f64::ln(10000.); - let exponent = exponent / (half_dim as f64 - self.downscale_freq_shift); - let emb = exponent.exp(); - // emb = timesteps[:, None].float() * emb[None, :] - let emb: Tensor = xs.unsqueeze_dim(D1) * emb.unsqueeze(); - let emb: Tensor = if self.flip_sin_to_cos { - Tensor::cat(vec![emb.clone().cos(), emb.clone().sin()], D1) - } else { - Tensor::cat(vec![emb.clone().sin(), emb.clone().cos()], D1) - }; - - if self.num_channels % 2 == 1 { - pad_with_zeros(emb, D1 - 1, 0, 1) - } else { - emb - } +) -> Tensor { + let half_dim = num_channels / 2; + let exponent = + Tensor::arange(0..half_dim as i64, ×teps.device()).float() * -f64::ln(10000.); + let exponent = exponent / (half_dim as f64 - downscale_freq_shift); + let emb = exponent.exp(); + // emb = timesteps[:, None].float() * emb[None, :] + let emb: Tensor = timesteps.unsqueeze_dim(1) * emb.unsqueeze(); + // Concatenate along the last dimension (-1 in PyTorch, which is dim 1 for 2D tensor) + let emb: Tensor = if flip_sin_to_cos { + Tensor::cat(vec![emb.clone().cos(), emb.clone().sin()], 1) + } else { + Tensor::cat(vec![emb.clone().sin(), emb.clone().cos()], 1) + }; + + if num_channels % 2 == 1 { + pad_with_zeros(emb, 1, 0, 1) + } else { + emb } } @@ -87,49 +86,93 @@ mod tests { use crate::TestBackend; use burn::tensor::{Shape, TensorData, Tolerance}; + /// Test get_timestep_embedding with even channels - validated against diffusers-rs v0.3.1 #[test] #[cfg(not(feature = "torch"))] fn test_timesteps_even_channels() { let device = Default::default(); - let timesteps = Timesteps::::new(4, true, 0.); let xs: Tensor = Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device); - let emb: Tensor = timesteps.forward(xs); + let emb: Tensor = get_timestep_embedding(xs, 4, true, 0.); assert_eq!(emb.shape(), Shape::from([4, 4])); + // Reference values from diffusers-rs v0.3.1: emb.into_data().assert_approx_eq::( &TensorData::from([ - [0.5403, 1.0000, 0.8415, 0.0100], - [-0.4161, 0.9998, 0.9093, 0.0200], - [-0.9900, 0.9996, 0.1411, 0.0300], - [-0.6536, 0.9992, -0.7568, 0.0400], + [0.5403023, 0.99995, 0.84147096, 0.009999833], + [-0.41614684, 0.9998, 0.9092974, 0.019998666], + [-0.9899925, 0.99955004, 0.14112, 0.0299955], + [-0.6536436, 0.9992001, -0.7568025, 0.039989334], ]), - Tolerance::rel_abs(1e-3, 1e-3), + Tolerance::rel_abs(1e-4, 1e-4), ); } + /// Test get_timestep_embedding with odd channels (padding) - validated against diffusers-rs v0.3.1 #[test] #[cfg(not(feature = "torch"))] fn test_timesteps_odd_channels() { let device = Default::default(); - let timesteps = Timesteps::::new(5, true, 0.); + let xs: Tensor = Tensor::from_data(TensorData::from([1., 2., 3.]), &device); + + let emb: Tensor = get_timestep_embedding(xs, 5, true, 0.); + + assert_eq!(emb.shape(), Shape::from([3, 5])); + // Reference values from diffusers-rs v0.3.1: + emb.into_data().assert_approx_eq::( + &TensorData::from([ + [0.5403023, 0.99995, 0.84147096, 0.009999833, 0.0], + [-0.41614684, 0.9998, 0.9092974, 0.019998666, 0.0], + [-0.9899925, 0.99955004, 0.14112, 0.0299955, 0.0], + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } + + /// Test get_timestep_embedding with flip_sin_to_cos=false - validated against diffusers-rs v0.3.1 + #[test] + #[cfg(not(feature = "torch"))] + fn test_timesteps_no_flip() { + let device = Default::default(); let xs: Tensor = - Tensor::from_data(TensorData::from([1., 2., 3., 4., 5.]), &device); + Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device); + + let emb: Tensor = get_timestep_embedding(xs, 4, false, 0.); + + assert_eq!(emb.shape(), Shape::from([4, 4])); + // Reference values from diffusers-rs v0.3.1 with flip_sin_to_cos=false: + emb.into_data().assert_approx_eq::( + &TensorData::from([ + [0.84147096, 0.009999833, 0.5403023, 0.99995], + [0.9092974, 0.019998666, -0.41614684, 0.9998], + [0.14112, 0.0299955, -0.9899925, 0.99955004], + [-0.7568025, 0.039989334, -0.6536436, 0.9992001], + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } + + /// Test get_timestep_embedding with downscale_freq_shift - validated against diffusers-rs v0.3.1 + #[test] + #[cfg(not(feature = "torch"))] + fn test_timesteps_with_downscale() { + let device = Default::default(); + let xs: Tensor = + Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device); - let emb: Tensor = timesteps.forward(xs); + let emb: Tensor = get_timestep_embedding(xs, 4, true, 1.0); - assert_eq!(emb.shape(), Shape::from([6, 4])); + assert_eq!(emb.shape(), Shape::from([4, 4])); + // Reference values from diffusers-rs v0.3.1 with downscale_freq_shift=1.0: emb.into_data().assert_approx_eq::( &TensorData::from([ - [0.5403, 1.0000, 0.8415, 0.0100], - [-0.4161, 0.9998, 0.9093, 0.0200], - [-0.9900, 0.9996, 0.1411, 0.0300], - [-0.6536, 0.9992, -0.7568, 0.0400], - [0.2837, 0.9988, -0.9589, 0.0500], - [0.0000, 0.0000, 0.0000, 0.0000], + [0.5403023, 1.0, 0.84147096, 9.999999e-5], + [-0.41614684, 1.0, 0.9092974, 0.00019999998], + [-0.9899925, 0.99999994, 0.14112, 0.00029999996], + [-0.6536436, 0.99999994, -0.7568025, 0.00039999996], ]), - Tolerance::rel_abs(1e-3, 1e-3), + Tolerance::rel_abs(1e-4, 1e-4), ); } } diff --git a/src/models/mod.rs b/src/models/mod.rs index 5d7234d..743d60f 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -3,6 +3,9 @@ //! A collection of models to be used in a diffusion loop. pub mod attention; +pub mod controlnet; pub mod embeddings; pub mod resnet; +pub mod unet_2d; pub mod unet_2d_blocks; +pub mod vae; diff --git a/src/models/resnet.rs b/src/models/resnet.rs index 57c1e31..2c7e47c 100644 --- a/src/models/resnet.rs +++ b/src/models/resnet.rs @@ -120,7 +120,8 @@ impl ResnetBlock2D { mod tests { use super::*; use crate::TestBackend; - use burn::tensor::{Distribution, Shape}; + use alloc::vec::Vec; + use burn::tensor::{Distribution, Shape, TensorData, Tolerance}; #[test] fn test_resnet_block_2d_no_temb() { @@ -142,4 +143,327 @@ mod tests { assert_eq!(output.shape(), Shape::from([2, 128, 64, 64])); } + + /// Test SiLU activation matches diffusers-rs (tch silu) + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_silu_matches_diffusers_rs() { + let device = Default::default(); + let xs: Tensor = Tensor::from_data( + TensorData::from([-2.0f32, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]), + &device, + ); + + let result = silu(xs); + + // Reference values from diffusers-rs: tensor.silu() + result.into_data().assert_approx_eq::( + &TensorData::from([ + -0.23840584, + -0.26894143, + -0.18877034, + 0.0, + 0.31122968, + 0.7310586, + 1.761594, + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } + + /// Test GroupNorm matches diffusers-rs + /// Reference values from diffusers-rs v0.3.1 with weight=1, bias=0 + #[test] + fn test_group_norm_matches_diffusers_rs() { + let device = Default::default(); + + // Create GroupNorm: 2 groups, 4 channels + let norm = GroupNormConfig::new(2, 4) + .with_epsilon(1e-6) + .init::(&device); + + // Set weight to 1 and bias to 0 (default initialization) + // GroupNorm in Burn initializes gamma=1, beta=0 by default + + // Input: [batch=1, channels=4, height=2, width=2] with sequential values 1-16 + let xs: Tensor = Tensor::from_data( + TensorData::from([ + 1.0f32, 2.0, 3.0, 4.0, // channel 0 + 5.0, 6.0, 7.0, 8.0, // channel 1 + 9.0, 10.0, 11.0, 12.0, // channel 2 + 13.0, 14.0, 15.0, 16.0, // channel 3 + ]), + &device, + ); + let xs: Tensor = xs.reshape([1, 4, 2, 2]); + + let result = norm.forward(xs); + + // Reference values from diffusers-rs GroupNorm + let result_flat = result.flatten::<1>(0, 3); + result_flat.into_data().assert_approx_eq::( + &TensorData::from([ + -1.5275251, + -1.0910892, + -0.65465355, + -0.21821785, + 0.21821797, + 0.65465367, + 1.0910894, + 1.5275251, + -1.5275252, + -1.0910892, + -0.65465355, + -0.21821785, + 0.21821785, + 0.65465355, + 1.0910892, + 1.527525, + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } + + /// Helper function to set all weights in a Conv2d to a constant value + fn set_conv2d_weights( + conv: Conv2d, + weight_val: f32, + bias_val: f32, + device: &B::Device, + ) -> Conv2d { + let weight_shape = conv.weight.shape(); + let [out_ch, in_ch, kh, kw] = weight_shape.dims(); + + // Use Param::map to transform the weight tensor + let new_weight = conv + .weight + .map(|_| Tensor::full([out_ch, in_ch, kh, kw], weight_val, device)); + + let new_bias = conv + .bias + .map(|b| b.map(|_| Tensor::full([out_ch], bias_val, device))); + + Conv2d { + weight: new_weight, + bias: new_bias, + stride: conv.stride, + kernel_size: conv.kernel_size, + dilation: conv.dilation, + groups: conv.groups, + padding: conv.padding, + } + } + + /// Helper function to set GroupNorm weights + fn set_group_norm_weights( + norm: GroupNorm, + gamma_val: f32, + beta_val: f32, + device: &B::Device, + ) -> GroupNorm { + let num_channels = norm.num_channels; + + let new_gamma = norm + .gamma + .map(|g| g.map(|_| Tensor::full([num_channels], gamma_val, device))); + + let new_beta = norm + .beta + .map(|b| b.map(|_| Tensor::full([num_channels], beta_val, device))); + + GroupNorm { + gamma: new_gamma, + beta: new_beta, + num_groups: norm.num_groups, + num_channels: norm.num_channels, + epsilon: norm.epsilon, + affine: norm.affine, + } + } + + /// Helper function to set Linear weights + fn set_linear_weights( + linear: Linear, + weight_val: f32, + bias_val: f32, + device: &B::Device, + ) -> Linear { + let weight_shape = linear.weight.shape(); + let [d_input, d_output] = weight_shape.dims(); + + let new_weight = linear + .weight + .map(|_| Tensor::full([d_input, d_output], weight_val, device)); + + let new_bias = linear + .bias + .map(|b| b.map(|_| Tensor::full([d_output], bias_val, device))); + + Linear { + weight: new_weight, + bias: new_bias, + } + } + + /// Test ResnetBlock2D with fixed weights matches diffusers-rs + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_resnet_block_2d_fixed_weights_matches_diffusers_rs() { + let device = Default::default(); + + // Create ResnetBlock2D: in_channels=4, out_channels=4, groups=2 + let config = ResnetBlock2DConfig::new(4) + .with_out_channels(Some(4)) + .with_groups(2) + .with_groups_out(Some(2)) + .with_eps(1e-6) + .with_use_in_shortcut(Some(false)); + + let mut block = config.init::(&device); + + // Set all weights to 0.1 and biases to 0.0 to match diffusers-rs test + // NOTE: diffusers-rs sets ALL "weight" params to 0.1, including GroupNorm gamma + block.norm1 = set_group_norm_weights(block.norm1, 0.1, 0.0, &device); + block.norm2 = set_group_norm_weights(block.norm2, 0.1, 0.0, &device); + block.conv1 = set_conv2d_weights(block.conv1, 0.1, 0.0, &device); + block.conv2 = set_conv2d_weights(block.conv2, 0.1, 0.0, &device); + + // Input: arange(64).reshape([1, 4, 4, 4]) / 64.0 + let input_data: Vec = (0..64).map(|i| i as f32 / 64.0).collect(); + let xs: Tensor = + Tensor::from_data(TensorData::from(input_data.as_slice()), &device); + let xs: Tensor = xs.reshape([1, 4, 4, 4]); + + let result = block.forward(xs, None); + + // Reference values from diffusers-rs + let result_flat = result.clone().flatten::<1>(0, 3); + let result_data = result_flat.to_data(); + let result_vec: Vec = result_data.to_vec().unwrap(); + + // Check first 16 values match diffusers-rs reference + let expected_first_16 = [ + -0.087926514, + -0.106717095, + -0.07181215, + -0.0077848956, + -0.0029995441, + 0.005315691, + 0.054237127, + 0.10205648, + 0.1431311, + 0.20416035, + 0.25529763, + 0.25261912, + 0.2456759, + 0.31968236, + 0.35902193, + 0.33475745, + ]; + + for (i, (actual, expected)) in result_vec + .iter() + .take(16) + .zip(expected_first_16.iter()) + .enumerate() + { + assert!( + (actual - expected).abs() < 1e-4, + "Mismatch at index {}: actual={}, expected={}", + i, + actual, + expected + ); + } + + // Check overall mean matches + let mean = result.mean().into_scalar(); + assert!( + (mean - 0.4999196529388428).abs() < 1e-4, + "Mean mismatch: actual={}, expected=0.4999196529388428", + mean + ); + } + + /// Test ResnetBlock2D with time embedding and fixed weights matches diffusers-rs + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_resnet_block_2d_with_temb_fixed_weights_matches_diffusers_rs() { + let device = Default::default(); + + // Create ResnetBlock2D with time embedding: in_channels=4, out_channels=4, temb_channels=8 + let config = ResnetBlock2DConfig::new(4) + .with_out_channels(Some(4)) + .with_temb_channels(Some(8)) + .with_groups(2) + .with_groups_out(Some(2)) + .with_eps(1e-6) + .with_use_in_shortcut(Some(false)); + + let mut block = config.init::(&device); + + // Set all weights to 0.1 and biases to 0.0 to match diffusers-rs test + // NOTE: diffusers-rs sets ALL "weight" params to 0.1, including GroupNorm gamma + block.norm1 = set_group_norm_weights(block.norm1, 0.1, 0.0, &device); + block.norm2 = set_group_norm_weights(block.norm2, 0.1, 0.0, &device); + block.conv1 = set_conv2d_weights(block.conv1, 0.1, 0.0, &device); + block.conv2 = set_conv2d_weights(block.conv2, 0.1, 0.0, &device); + block.time_emb_proj = block + .time_emb_proj + .map(|proj| set_linear_weights(proj, 0.1, 0.0, &device)); + + // Input: arange(64).reshape([1, 4, 4, 4]) / 64.0 + let input_data: Vec = (0..64).map(|i| i as f32 / 64.0).collect(); + let xs: Tensor = + Tensor::from_data(TensorData::from(input_data.as_slice()), &device); + let xs: Tensor = xs.reshape([1, 4, 4, 4]); + + // Time embedding: arange(8).reshape([1, 8]) / 8.0 + let temb_data: Vec = (0..8).map(|i| i as f32 / 8.0).collect(); + let temb: Tensor = + Tensor::from_data(TensorData::from(temb_data.as_slice()), &device); + let temb: Tensor = temb.reshape([1, 8]); + + let result = block.forward(xs, Some(temb)); + + // Reference values from diffusers-rs (same as without temb due to the specific input values) + let result_flat = result.flatten::<1>(0, 3); + let result_data = result_flat.to_data(); + let result_vec: Vec = result_data.to_vec().unwrap(); + + // Check first 16 values match diffusers-rs reference + let expected_first_16 = [ + -0.08792652, + -0.10671712, + -0.07181221, + -0.007784918, + -0.0029995516, + 0.005315639, + 0.0542371, + 0.10205645, + 0.14313108, + 0.2041603, + 0.2552976, + 0.2526191, + 0.24567588, + 0.3196823, + 0.3590219, + 0.33475742, + ]; + + for (i, (actual, expected)) in result_vec + .iter() + .take(16) + .zip(expected_first_16.iter()) + .enumerate() + { + assert!( + (actual - expected).abs() < 1e-4, + "Mismatch at index {}: actual={}, expected={}", + i, + actual, + expected + ); + } + } } diff --git a/src/models/unet_2d.rs b/src/models/unet_2d.rs new file mode 100644 index 0000000..7689c75 --- /dev/null +++ b/src/models/unet_2d.rs @@ -0,0 +1,634 @@ +//! 2D UNet Denoising Models +//! +//! The 2D UNet models take as input a noisy sample and the current diffusion +//! timestep and return a denoised version of the input. + +use burn::config::Config; +use burn::module::Module; +use burn::nn::conv::{Conv2d, Conv2dConfig}; +use burn::nn::{GroupNorm, GroupNormConfig, PaddingConfig2d}; +use burn::tensor::activation::silu; +use burn::tensor::backend::Backend; +use burn::tensor::Tensor; + +use alloc::vec; +use alloc::vec::Vec; + +use super::embeddings::{get_timestep_embedding, TimestepEmbedding, TimestepEmbeddingConfig}; +use super::unet_2d_blocks::{ + CrossAttnDownBlock2D, CrossAttnDownBlock2DConfig, CrossAttnUpBlock2D, CrossAttnUpBlock2DConfig, + DownBlock2D, DownBlock2DConfig, UNetMidBlock2DCrossAttn, UNetMidBlock2DCrossAttnConfig, + UpBlock2D, UpBlock2DConfig, +}; + +/// Configuration for a single UNet block. +#[derive(Debug, Clone, burn::serde::Serialize, burn::serde::Deserialize)] +pub struct BlockConfig { + /// Output channels for this block. + pub out_channels: usize, + /// Whether to use cross-attention in this block. + pub use_cross_attn: bool, + /// Number of attention heads. + pub attention_head_dim: usize, +} + +impl BlockConfig { + /// Create a new block configuration. + pub fn new(out_channels: usize) -> Self { + Self { + out_channels, + use_cross_attn: true, + attention_head_dim: 8, + } + } + + /// Set whether to use cross-attention. + pub fn with_use_cross_attn(mut self, use_cross_attn: bool) -> Self { + self.use_cross_attn = use_cross_attn; + self + } + + /// Set the attention head dimension. + pub fn with_attention_head_dim(mut self, attention_head_dim: usize) -> Self { + self.attention_head_dim = attention_head_dim; + self + } +} + +/// Configuration for the UNet2DConditionModel. +#[derive(Config, Debug)] +pub struct UNet2DConditionModelConfig { + /// Whether to center the input sample. + #[config(default = false)] + pub center_input_sample: bool, + /// Whether to flip sin to cos in timestep embedding. + #[config(default = true)] + pub flip_sin_to_cos: bool, + /// Frequency shift for timestep embedding. + #[config(default = 0.0)] + pub freq_shift: f64, + /// Configuration for each block. + pub blocks: Vec, + /// Number of ResNet layers per block. + #[config(default = 2)] + pub layers_per_block: usize, + /// Padding for downsampling convolutions. + #[config(default = 1)] + pub downsample_padding: usize, + /// Scale factor for mid block. + #[config(default = 1.0)] + pub mid_block_scale_factor: f64, + /// Number of groups for group normalization. + #[config(default = 32)] + pub norm_num_groups: usize, + /// Epsilon for normalization layers. + #[config(default = 1e-5)] + pub norm_eps: f64, + /// Dimension of cross-attention context. + #[config(default = 1280)] + pub cross_attention_dim: usize, + /// Size for sliced attention (None for full attention). + pub sliced_attention_size: Option, + /// Whether to use linear projection in attention. + #[config(default = false)] + pub use_linear_projection: bool, +} + +impl Default for UNet2DConditionModelConfig { + fn default() -> Self { + Self { + center_input_sample: false, + flip_sin_to_cos: true, + freq_shift: 0.0, + blocks: vec![ + BlockConfig::new(320) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(640) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(1280) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(1280) + .with_use_cross_attn(false) + .with_attention_head_dim(8), + ], + layers_per_block: 2, + downsample_padding: 1, + mid_block_scale_factor: 1.0, + norm_num_groups: 32, + norm_eps: 1e-5, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +/// Down block types for UNet. +#[derive(Module, Debug)] +pub enum UNetDownBlock { + Basic(DownBlock2D), + CrossAttn(CrossAttnDownBlock2D), +} + +/// Up block types for UNet. +#[derive(Module, Debug)] +pub enum UNetUpBlock { + Basic(UpBlock2D), + CrossAttn(CrossAttnUpBlock2D), +} + +/// UNet2D Conditional Model for denoising diffusion. +/// +/// This model takes a noisy sample, timestep, and encoder hidden states +/// (from text conditioning) and predicts the noise to be removed. +#[derive(Module, Debug)] +pub struct UNet2DConditionModel { + conv_in: Conv2d, + time_embedding: TimestepEmbedding, + down_blocks: Vec>, + mid_block: UNetMidBlock2DCrossAttn, + up_blocks: Vec>, + conv_norm_out: GroupNorm, + conv_out: Conv2d, + #[module(skip)] + time_proj_channels: usize, + #[module(skip)] + flip_sin_to_cos: bool, + #[module(skip)] + freq_shift: f64, + #[module(skip)] + center_input_sample: bool, +} + +impl UNet2DConditionModelConfig { + /// Initialize the UNet2DConditionModel. + pub fn init( + &self, + in_channels: usize, + out_channels: usize, + device: &B::Device, + ) -> UNet2DConditionModel { + let n_blocks = self.blocks.len(); + let b_channels = self.blocks[0].out_channels; + let bl_channels = self.blocks.last().unwrap().out_channels; + let bl_attention_head_dim = self.blocks.last().unwrap().attention_head_dim; + let time_embed_dim = b_channels * 4; + + // Input convolution + let conv_in = Conv2dConfig::new([in_channels, b_channels], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + // Time embeddings + let time_embedding = TimestepEmbeddingConfig::new(b_channels, time_embed_dim).init(device); + + // Down blocks + let down_blocks = (0..n_blocks) + .map(|i| { + let block_config = &self.blocks[i]; + let out_channels = block_config.out_channels; + let attention_head_dim = block_config.attention_head_dim; + + // Enable automatic attention slicing if config sliced_attention_size is 0 + let sliced_attention_size = match self.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + other => other, + }; + + let in_ch = if i > 0 { + self.blocks[i - 1].out_channels + } else { + b_channels + }; + + let db_config = DownBlock2DConfig::new(in_ch, out_channels) + .with_temb_channels(Some(time_embed_dim)) + .with_n_layers(self.layers_per_block) + .with_resnet_eps(self.norm_eps) + .with_resnet_groups(self.norm_num_groups) + .with_add_downsample(i < n_blocks - 1) + .with_downsample_padding(self.downsample_padding); + + if block_config.use_cross_attn { + let config = CrossAttnDownBlock2DConfig::new(in_ch, out_channels, db_config) + .with_temb_channels(Some(time_embed_dim)) + .with_attn_num_head_channels(attention_head_dim) + .with_cross_attention_dim(self.cross_attention_dim) + .with_sliced_attention_size(sliced_attention_size) + .with_use_linear_projection(self.use_linear_projection); + UNetDownBlock::CrossAttn(config.init(device)) + } else { + UNetDownBlock::Basic(db_config.init(device)) + } + }) + .collect(); + + // Mid block + let mid_config = UNetMidBlock2DCrossAttnConfig::new(bl_channels) + .with_temb_channels(Some(time_embed_dim)) + .with_resnet_eps(self.norm_eps) + .with_output_scale_factor(self.mid_block_scale_factor) + .with_cross_attn_dim(self.cross_attention_dim) + .with_attn_num_head_channels(bl_attention_head_dim) + .with_resnet_groups(Some(self.norm_num_groups)) + .with_use_linear_projection(self.use_linear_projection); + let mid_block = mid_config.init(device); + + // Up blocks + let up_blocks = (0..n_blocks) + .map(|i| { + let block_config = &self.blocks[n_blocks - 1 - i]; + let out_channels = block_config.out_channels; + let attention_head_dim = block_config.attention_head_dim; + + // Enable automatic attention slicing if config sliced_attention_size is 0 + let sliced_attention_size = match self.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + other => other, + }; + + let prev_out_channels = if i > 0 { + self.blocks[n_blocks - i].out_channels + } else { + bl_channels + }; + + let in_ch = { + let index = if i == n_blocks - 1 { + 0 + } else { + n_blocks - i - 2 + }; + self.blocks[index].out_channels + }; + + let ub_config = UpBlock2DConfig::new(in_ch, prev_out_channels, out_channels) + .with_temb_channels(Some(time_embed_dim)) + .with_n_layers(self.layers_per_block + 1) + .with_resnet_eps(self.norm_eps) + .with_resnet_groups(self.norm_num_groups) + .with_add_upsample(i < n_blocks - 1); + + if block_config.use_cross_attn { + let config = CrossAttnUpBlock2DConfig::new( + in_ch, + prev_out_channels, + out_channels, + ub_config, + ) + .with_temb_channels(Some(time_embed_dim)) + .with_attn_num_head_channels(attention_head_dim) + .with_cross_attention_dim(self.cross_attention_dim) + .with_sliced_attention_size(sliced_attention_size) + .with_use_linear_projection(self.use_linear_projection); + UNetUpBlock::CrossAttn(config.init(device)) + } else { + UNetUpBlock::Basic(ub_config.init(device)) + } + }) + .collect(); + + // Output layers + let conv_norm_out = GroupNormConfig::new(self.norm_num_groups, b_channels) + .with_epsilon(self.norm_eps) + .init(device); + + let conv_out = Conv2dConfig::new([b_channels, out_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + UNet2DConditionModel { + conv_in, + time_embedding, + down_blocks, + mid_block, + time_proj_channels: b_channels, + flip_sin_to_cos: self.flip_sin_to_cos, + freq_shift: self.freq_shift, + up_blocks, + conv_norm_out, + conv_out, + center_input_sample: self.center_input_sample, + } + } +} + +impl UNet2DConditionModel { + /// Forward pass through the UNet. + /// + /// # Arguments + /// * `xs` - Noisy input tensor [batch, channels, height, width] + /// * `timestep` - Current diffusion timestep + /// * `encoder_hidden_states` - Encoder hidden states for cross-attention [batch, seq_len, dim] + pub fn forward( + &self, + xs: Tensor, + timestep: f64, + encoder_hidden_states: Tensor, + ) -> Tensor { + self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None) + } + + /// Forward pass with additional residuals (for ControlNet support). + pub fn forward_with_additional_residuals( + &self, + xs: Tensor, + timestep: f64, + encoder_hidden_states: Tensor, + down_block_additional_residuals: Option<&[Tensor]>, + mid_block_additional_residual: Option<&Tensor>, + ) -> Tensor { + let [bsize, _channels, height, width] = xs.dims(); + let device = xs.device(); + let n_blocks = self.down_blocks.len(); + let num_upsamplers = n_blocks - 1; + let default_overall_up_factor = 2usize.pow(num_upsamplers as u32); + let forward_upsample_size = + height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0; + + // 0. Center input if necessary + let xs = if self.center_input_sample { + xs * 2.0 - 1.0 + } else { + xs + }; + + // 1. Time embedding + let timesteps: Tensor = Tensor::full([bsize], timestep as f32, &device); + let emb = get_timestep_embedding( + timesteps, + self.time_proj_channels, + self.flip_sin_to_cos, + self.freq_shift, + ); + let emb = self.time_embedding.forward(emb); + + // 2. Pre-process + let xs = self.conv_in.forward(xs); + + // 3. Down blocks + let mut down_block_res_xs = vec![xs.clone()]; + let mut xs = xs; + for down_block in self.down_blocks.iter() { + let (new_xs, res_xs) = match down_block { + UNetDownBlock::Basic(b) => b.forward(xs, Some(emb.clone())), + UNetDownBlock::CrossAttn(b) => { + b.forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone())) + } + }; + down_block_res_xs.extend(res_xs); + xs = new_xs; + } + + // Add additional residuals if provided (for ControlNet) + let mut down_block_res_xs = if let Some(additional) = down_block_additional_residuals { + down_block_res_xs + .iter() + .zip(additional.iter()) + .map(|(r, a)| r.clone() + a.clone()) + .collect() + } else { + down_block_res_xs + }; + + // 4. Mid block + let xs = self + .mid_block + .forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone())); + let xs = match mid_block_additional_residual { + Some(m) => xs + m.clone(), + None => xs, + }; + + // 5. Up blocks + let mut xs = xs; + let mut upsample_size = None; + for (i, up_block) in self.up_blocks.iter().enumerate() { + let n_resnets = match up_block { + UNetUpBlock::Basic(b) => b.resnets.len(), + UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(), + }; + let res_xs: Vec<_> = down_block_res_xs + .drain(down_block_res_xs.len() - n_resnets..) + .collect(); + + if i < n_blocks - 1 && forward_upsample_size { + let last = down_block_res_xs.last().unwrap(); + let [_, _, h, w] = last.dims(); + upsample_size = Some((h, w)); + } + + xs = match up_block { + UNetUpBlock::Basic(b) => b.forward(xs, &res_xs, Some(emb.clone()), upsample_size), + UNetUpBlock::CrossAttn(b) => b.forward( + xs, + &res_xs, + Some(emb.clone()), + upsample_size, + Some(encoder_hidden_states.clone()), + ), + }; + } + + // 6. Post-process + let xs = self.conv_norm_out.forward(xs); + let xs = silu(xs); + self.conv_out.forward(xs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use alloc::string::{String, ToString}; + use burn::module::{Module, ModuleMapper, Param}; + use burn::tensor::{Shape, TensorData}; + + /// A ModuleMapper that sets weights to one value and biases to another. + /// Uses the field name from enter_module to distinguish weight vs bias. + struct WeightBiasMapper<'a, B: Backend> { + weight_val: f32, + bias_val: f32, + device: &'a B::Device, + current_field: String, + } + + impl<'a, B: Backend> ModuleMapper for WeightBiasMapper<'a, B> { + fn enter_module(&mut self, name: &str, _container_type: &str) { + self.current_field = name.to_string(); + } + + fn map_float( + &mut self, + tensor: Param>, + ) -> Param> { + let shape = tensor.shape(); + let dims: [usize; D] = shape.dims(); + + // Use field name to distinguish: "bias" and "beta" get bias_val, + // "weight" and "gamma" get weight_val + let is_bias = self.current_field.contains("bias") || self.current_field == "beta"; + let val = if is_bias { + self.bias_val + } else { + self.weight_val + }; + + Param::initialized(tensor.id, Tensor::full(dims, val, self.device)) + } + } + + /// Set weights and biases to different constant values. + fn set_weights_and_biases( + unet: UNet2DConditionModel, + weight_val: f32, + bias_val: f32, + device: &B::Device, + ) -> UNet2DConditionModel { + let mut mapper = WeightBiasMapper:: { + weight_val, + bias_val, + device, + current_field: String::new(), + }; + unet.map(&mut mapper) + } + + #[test] + fn test_unet2d_output_shape() { + let device = Default::default(); + + // Create a small UNet for testing + let config = UNet2DConditionModelConfig { + blocks: vec![ + BlockConfig::new(32) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(64) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + ], + layers_per_block: 1, + norm_num_groups: 32, + cross_attention_dim: 64, + ..Default::default() + }; + + let unet = config.init::(4, 4, &device); + + // Input: [batch=1, channels=4, height=32, width=32] + let xs: Tensor = Tensor::zeros([1, 4, 32, 32], &device); + // Encoder hidden states: [batch=1, seq_len=8, dim=64] + let encoder_hidden_states: Tensor = Tensor::zeros([1, 8, 64], &device); + + let output = unet.forward(xs, 1.0, encoder_hidden_states); + + // Output should have same spatial dimensions as input + assert_eq!(output.shape(), Shape::from([1, 4, 32, 32])); + } + + /// Test UNet2D forward with fixed weights matches diffusers-rs + /// Reference values from diffusers-rs v0.3.1 + /// Note: Skipped for wgpu due to GPU floating-point precision differences + #[test] + #[cfg(not(feature = "wgpu"))] + fn test_unet2d_fixed_weights_matches_diffusers_rs() { + let device = Default::default(); + + // Create a small UNet with same config as diffusers-rs test + let config = UNet2DConditionModelConfig { + blocks: vec![ + BlockConfig::new(32) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(64) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + ], + layers_per_block: 1, + norm_num_groups: 32, + cross_attention_dim: 64, + ..Default::default() + }; + + let unet = config.init::(4, 4, &device); + + // Set weights to 0.1, biases to 0.0 (matching diffusers-rs test) + let unet = set_weights_and_biases(unet, 0.1, 0.0, &device); + + // Input: normalized arange values for reproducibility + let input_data: Vec = (0..(4 * 32 * 32)) + .map(|i| i as f32 / (4.0 * 32.0 * 32.0)) + .collect(); + let xs: Tensor = + Tensor::from_data(TensorData::from(input_data.as_slice()), &device); + let xs: Tensor = xs.reshape([1, 4, 32, 32]); + + // Encoder hidden states: normalized values + let enc_data: Vec = (0..(8 * 64)).map(|i| i as f32 / (8.0 * 64.0)).collect(); + let encoder_hidden_states: Tensor = + Tensor::from_data(TensorData::from(enc_data.as_slice()), &device); + let encoder_hidden_states: Tensor = + encoder_hidden_states.reshape([1, 8, 64]); + + let output = unet.forward(xs, 1.0, encoder_hidden_states); + + // Verify output shape + assert_eq!(output.shape(), Shape::from([1, 4, 32, 32])); + + // Get result values + let result_flat = output.clone().flatten::<1>(0, 3); + let result_data = result_flat.to_data(); + let result_vec: Vec = result_data.to_vec().unwrap(); + + // Reference values from diffusers-rs v0.3.1 + let expected_first_16 = [ + -1.7083509_f32, + -2.3742673, + -1.9932699, + -1.836023, + -1.7530675, + -1.7501539, + -1.7482271, + -1.7476634, + -1.7470942, + -1.7467214, + -1.7462531, + -1.7458805, + -1.7454054, + -1.7450254, + -1.7445545, + -1.7441819, + ]; + + for (i, (actual, expected)) in result_vec + .iter() + .take(16) + .zip(expected_first_16.iter()) + .enumerate() + { + assert!( + (actual - expected).abs() < 1e-3, + "Mismatch at index {}: actual={}, expected={}", + i, + actual, + expected + ); + } + + // Check overall mean (reference: 0.1854093372821808) + let mean = output.clone().mean().into_scalar(); + let expected_mean = 0.18540934_f32; + assert!( + (mean - expected_mean).abs() < 1e-4, + "Mean mismatch: actual={}, expected={}", + mean, + expected_mean + ); + } +} diff --git a/src/models/unet_2d_blocks.rs b/src/models/unet_2d_blocks.rs index 9526e7b..22841c6 100644 --- a/src/models/unet_2d_blocks.rs +++ b/src/models/unet_2d_blocks.rs @@ -66,7 +66,7 @@ impl Downsample2D { return pad_with_zeros(xs, 4 - 2, 0, 1); } - return xs; + xs } fn forward(&self, xs: Tensor) -> Tensor { @@ -187,7 +187,7 @@ impl DownEncoderBlock2DConfig { } impl DownEncoderBlock2D { - fn forward(&self, xs: Tensor) -> Tensor { + pub fn forward(&self, xs: Tensor) -> Tensor { let mut xs = xs.clone(); for resnet in self.resnets.iter() { xs = resnet.forward(xs, None) @@ -259,7 +259,7 @@ impl UpDecoderBlock2DConfig { } impl UpDecoderBlock2D { - fn forward(&self, xs: Tensor) -> Tensor { + pub fn forward(&self, xs: Tensor) -> Tensor { let mut xs = xs.clone(); for resnet in self.resnets.iter() { xs = resnet.forward(xs, None) @@ -477,15 +477,21 @@ pub struct DownBlock2DConfig { #[derive(Module, Debug)] pub struct DownBlock2D { - resnets: Vec>, + pub resnets: Vec>, downsampler: Option>, } impl DownBlock2DConfig { pub fn init(&self, device: &B::Device) -> DownBlock2D { let resnets = (0..self.n_layers) - .map(|_| { - ResnetBlock2DConfig::new(self.out_channels) + .map(|i| { + let in_channels = if i == 0 { + self.in_channels + } else { + self.out_channels + }; + ResnetBlock2DConfig::new(in_channels) + .with_out_channels(Some(self.out_channels)) .with_eps(self.resnet_eps) .with_groups(self.resnet_groups) .with_output_scale_factor(self.output_scale_factor) @@ -636,7 +642,7 @@ pub struct UpBlock2DConfig { #[derive(Module, Debug)] pub struct UpBlock2D { - resnets: Vec>, + pub resnets: Vec>, upsampler: Option>, } @@ -922,4 +928,40 @@ mod tests { assert_eq!(output.shape(), Shape::new([4, 32, 64, 64])); } + + /// Test Downsample2D avg_pool matches diffusers-rs + /// Reference values from diffusers-rs v0.3.1: avg_pool2d([2,2], [2,2], [0,0], false, true, None) + #[test] + fn test_downsample_2d_avg_pool_matches_diffusers_rs() { + let device = Default::default(); + let tensor: Tensor = Tensor::from_data( + TensorData::from([ + [ + [[0.0351f32, 0.4179], [0.0137, 0.6947]], + [[0.9526, 0.5386], [0.2856, 0.1839]], + [[0.3215, 0.4595], [0.6777, 0.3946]], + [[0.5221, 0.4230], [0.2774, 0.1069]], + ], + [ + [[0.8941, 0.8696], [0.5735, 0.8750]], + [[0.6718, 0.4144], [0.1038, 0.2629]], + [[0.7467, 0.9415], [0.5005, 0.6309]], + [[0.6534, 0.2019], [0.3670, 0.8074]], + ], + ]), + &device, + ); + + let downsample_2d = Downsample2DConfig::new(4, false, 4, 0).init(&device); + let output = downsample_2d.forward(tensor); + + // Reference values from diffusers-rs: [0.29035002, 0.49017498, 0.463325, 0.33235, 0.80305, 0.363225, 0.7049, 0.507425] + output.into_data().assert_approx_eq::( + &TensorData::from([ + [[[0.29035002f32]], [[0.49017498]], [[0.463325]], [[0.33235]]], + [[[0.80305]], [[0.363225]], [[0.7049]], [[0.507425]]], + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } } diff --git a/src/models/vae.rs b/src/models/vae.rs new file mode 100644 index 0000000..60e24c7 --- /dev/null +++ b/src/models/vae.rs @@ -0,0 +1,605 @@ +//! # Variational Auto-Encoder (VAE) Models. +//! +//! Auto-encoder models compress their input to a usually smaller latent space +//! before expanding it back to its original shape. This results in the latent values +//! compressing the original information. + +use burn::config::Config; +use burn::module::Module; +use burn::nn::conv::{Conv2d, Conv2dConfig}; +use burn::nn::{GroupNorm, GroupNormConfig, PaddingConfig2d}; +use burn::tensor::activation::silu; +use burn::tensor::backend::Backend; +use burn::tensor::{Distribution, Tensor}; + +use alloc::vec; +use alloc::vec::Vec; + +use super::unet_2d_blocks::{ + DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig, + UpDecoderBlock2D, UpDecoderBlock2DConfig, +}; + +/// Configuration for the VAE Encoder. +#[derive(Config, Debug)] +pub struct EncoderConfig { + /// Number of input channels (e.g., 3 for RGB images). + in_channels: usize, + /// Number of output channels (latent channels). + out_channels: usize, + /// Output channels for each block. + block_out_channels: Vec, + /// Number of resnet layers per block. + #[config(default = 2)] + layers_per_block: usize, + /// Number of groups for group normalization. + #[config(default = 32)] + norm_num_groups: usize, + /// Whether to output double channels for mean and logvar. + #[config(default = true)] + double_z: bool, +} + +/// VAE Encoder - compresses images to latent space. +#[derive(Module, Debug)] +pub struct Encoder { + conv_in: Conv2d, + down_blocks: Vec>, + mid_block: UNetMidBlock2D, + conv_norm_out: GroupNorm, + conv_out: Conv2d, +} + +impl EncoderConfig { + /// Initialize the Encoder. + pub fn init(&self, device: &B::Device) -> Encoder { + let conv_in = Conv2dConfig::new([self.in_channels, self.block_out_channels[0]], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + let mut down_blocks = vec![]; + for index in 0..self.block_out_channels.len() { + let out_channels = self.block_out_channels[index]; + let in_channels = if index > 0 { + self.block_out_channels[index - 1] + } else { + self.block_out_channels[0] + }; + let is_final = index + 1 == self.block_out_channels.len(); + + let down_block = DownEncoderBlock2DConfig::new(in_channels, out_channels) + .with_n_layers(self.layers_per_block) + .with_resnet_eps(1e-6) + .with_resnet_groups(self.norm_num_groups) + .with_add_downsample(!is_final) + .with_downsample_padding(0) + .init(device); + + down_blocks.push(down_block); + } + + let last_block_out_channels = *self.block_out_channels.last().unwrap(); + + let mid_block = UNetMidBlock2DConfig::new(last_block_out_channels) + .with_resnet_eps(1e-6) + .with_output_scale_factor(1.0) + .with_attn_num_head_channels(None) + .with_resnet_groups(Some(self.norm_num_groups)) + .with_n_layers(1) + .init(device); + + let conv_norm_out = GroupNormConfig::new(self.norm_num_groups, last_block_out_channels) + .with_epsilon(1e-6) + .init(device); + + let conv_out_channels = if self.double_z { + 2 * self.out_channels + } else { + self.out_channels + }; + + let conv_out = Conv2dConfig::new([last_block_out_channels, conv_out_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + Encoder { + conv_in, + down_blocks, + mid_block, + conv_norm_out, + conv_out, + } + } +} + +impl Encoder { + /// Forward pass through the encoder. + pub fn forward(&self, xs: Tensor) -> Tensor { + let mut xs = self.conv_in.forward(xs); + + for down_block in self.down_blocks.iter() { + xs = down_block.forward(xs); + } + + let xs = self.mid_block.forward(xs, None); + let xs = self.conv_norm_out.forward(xs); + let xs = silu(xs); + self.conv_out.forward(xs) + } +} + +/// Configuration for the VAE Decoder. +#[derive(Config, Debug)] +pub struct DecoderConfig { + /// Number of input channels (latent channels). + in_channels: usize, + /// Number of output channels (e.g., 3 for RGB images). + out_channels: usize, + /// Output channels for each block. + block_out_channels: Vec, + /// Number of resnet layers per block. + #[config(default = 2)] + layers_per_block: usize, + /// Number of groups for group normalization. + #[config(default = 32)] + norm_num_groups: usize, +} + +/// VAE Decoder - expands latent space back to images. +#[derive(Module, Debug)] +pub struct Decoder { + conv_in: Conv2d, + up_blocks: Vec>, + mid_block: UNetMidBlock2D, + conv_norm_out: GroupNorm, + conv_out: Conv2d, +} + +impl DecoderConfig { + /// Initialize the Decoder. + pub fn init(&self, device: &B::Device) -> Decoder { + let n_block_out_channels = self.block_out_channels.len(); + let last_block_out_channels = *self.block_out_channels.last().unwrap(); + + let conv_in = Conv2dConfig::new([self.in_channels, last_block_out_channels], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + let mid_block = UNetMidBlock2DConfig::new(last_block_out_channels) + .with_resnet_eps(1e-6) + .with_output_scale_factor(1.0) + .with_attn_num_head_channels(None) + .with_resnet_groups(Some(self.norm_num_groups)) + .with_n_layers(1) + .init(device); + + let mut up_blocks = vec![]; + let reversed_block_out_channels: Vec<_> = + self.block_out_channels.iter().copied().rev().collect(); + + for index in 0..n_block_out_channels { + let out_channels = reversed_block_out_channels[index]; + let in_channels = if index > 0 { + reversed_block_out_channels[index - 1] + } else { + reversed_block_out_channels[0] + }; + let is_final = index + 1 == n_block_out_channels; + + let up_block = UpDecoderBlock2DConfig::new(in_channels, out_channels) + .with_n_layers(self.layers_per_block + 1) + .with_resnet_eps(1e-6) + .with_resnet_groups(self.norm_num_groups) + .with_add_upsample(!is_final) + .init(device); + + up_blocks.push(up_block); + } + + let conv_norm_out = GroupNormConfig::new(self.norm_num_groups, self.block_out_channels[0]) + .with_epsilon(1e-6) + .init(device); + + let conv_out = Conv2dConfig::new([self.block_out_channels[0], self.out_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(device); + + Decoder { + conv_in, + up_blocks, + mid_block, + conv_norm_out, + conv_out, + } + } +} + +impl Decoder { + /// Forward pass through the decoder. + pub fn forward(&self, xs: Tensor) -> Tensor { + let xs = self.conv_in.forward(xs); + let mut xs = self.mid_block.forward(xs, None); + + for up_block in self.up_blocks.iter() { + xs = up_block.forward(xs); + } + + let xs = self.conv_norm_out.forward(xs); + let xs = silu(xs); + self.conv_out.forward(xs) + } +} + +/// Diagonal Gaussian Distribution for VAE latent space. +/// +/// Represents the posterior distribution q(z|x) as a diagonal Gaussian +/// with learned mean and variance. +pub struct DiagonalGaussianDistribution { + mean: Tensor, + std: Tensor, +} + +impl DiagonalGaussianDistribution { + /// Create a new distribution from the encoder output parameters. + /// + /// The parameters tensor is expected to have shape [batch, 2*latent_channels, height, width] + /// where the first half contains the mean and the second half contains the log variance. + pub fn new(parameters: Tensor) -> Self { + // Split along channel dimension + let [batch, channels, height, width] = parameters.dims(); + let half_channels = channels / 2; + + // Get mean (first half of channels) + let mean = parameters + .clone() + .slice([0..batch, 0..half_channels, 0..height, 0..width]); + + // Get logvar (second half of channels) + let logvar = parameters.slice([0..batch, half_channels..channels, 0..height, 0..width]); + + // std = exp(0.5 * logvar) + let std = (logvar * 0.5).exp(); + + DiagonalGaussianDistribution { mean, std } + } + + /// Sample from the distribution using the reparameterization trick. + /// + /// z = mean + std * epsilon, where epsilon ~ N(0, 1) + pub fn sample(&self) -> Tensor { + let noise: Tensor = Tensor::random( + self.mean.shape(), + Distribution::Normal(0.0, 1.0), + &self.mean.device(), + ); + self.mean.clone() + self.std.clone() * noise + } + + /// Get the mean of the distribution. + pub fn mean(&self) -> Tensor { + self.mean.clone() + } + + /// Get the mode of the distribution (same as mean for Gaussian). + pub fn mode(&self) -> Tensor { + self.mean.clone() + } +} + +/// Configuration for the AutoEncoder KL model. +#[derive(Config, Debug)] +pub struct AutoEncoderKLConfig { + /// Output channels for each block. + #[config(default = "vec![128, 256, 512, 512]")] + pub block_out_channels: Vec, + /// Number of resnet layers per block. + #[config(default = 2)] + pub layers_per_block: usize, + /// Number of latent channels. + #[config(default = 4)] + pub latent_channels: usize, + /// Number of groups for group normalization. + #[config(default = 32)] + pub norm_num_groups: usize, +} + +/// AutoEncoder with KL divergence loss (VAE). +/// +/// This model compresses images to a latent space and can reconstruct them. +/// It's used in Stable Diffusion to work in a compressed latent space. +#[derive(Module, Debug)] +pub struct AutoEncoderKL { + encoder: Encoder, + decoder: Decoder, + quant_conv: Conv2d, + post_quant_conv: Conv2d, +} + +impl AutoEncoderKLConfig { + /// Initialize the AutoEncoderKL model. + /// + /// # Arguments + /// * `in_channels` - Number of input image channels (e.g., 3 for RGB) + /// * `out_channels` - Number of output image channels (e.g., 3 for RGB) + /// * `device` - The device to create the model on + pub fn init( + &self, + in_channels: usize, + out_channels: usize, + device: &B::Device, + ) -> AutoEncoderKL { + let encoder = EncoderConfig::new( + in_channels, + self.latent_channels, + self.block_out_channels.clone(), + ) + .with_layers_per_block(self.layers_per_block) + .with_norm_num_groups(self.norm_num_groups) + .with_double_z(true) + .init(device); + + let decoder = DecoderConfig::new( + self.latent_channels, + out_channels, + self.block_out_channels.clone(), + ) + .with_layers_per_block(self.layers_per_block) + .with_norm_num_groups(self.norm_num_groups) + .init(device); + + // 1x1 convolutions for quantization + let quant_conv = + Conv2dConfig::new([2 * self.latent_channels, 2 * self.latent_channels], [1, 1]) + .init(device); + + let post_quant_conv = + Conv2dConfig::new([self.latent_channels, self.latent_channels], [1, 1]).init(device); + + AutoEncoderKL { + encoder, + decoder, + quant_conv, + post_quant_conv, + } + } +} + +impl AutoEncoderKL { + /// Encode an image to the latent space distribution. + /// + /// Returns a DiagonalGaussianDistribution that can be sampled from. + pub fn encode(&self, xs: Tensor) -> DiagonalGaussianDistribution { + let parameters = self.encoder.forward(xs); + let parameters = self.quant_conv.forward(parameters); + DiagonalGaussianDistribution::new(parameters) + } + + /// Decode latent vectors back to images. + pub fn decode(&self, xs: Tensor) -> Tensor { + let xs = self.post_quant_conv.forward(xs); + self.decoder.forward(xs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use alloc::string::{String, ToString}; + use burn::module::Module; + use burn::tensor::{Shape, TensorData}; + + use burn::module::{ModuleMapper, Param}; + + /// A ModuleMapper that sets weights to one value and biases to another. + /// Uses the field name from enter_module to distinguish weight vs bias. + struct WeightBiasMapper<'a, B: Backend> { + weight_val: f32, + bias_val: f32, + device: &'a B::Device, + current_field: String, + } + + impl<'a, B: Backend> ModuleMapper for WeightBiasMapper<'a, B> { + fn enter_module(&mut self, name: &str, _container_type: &str) { + self.current_field = name.to_string(); + } + + fn map_float( + &mut self, + tensor: Param>, + ) -> Param> { + let shape = tensor.shape(); + let dims: [usize; D] = shape.dims(); + + // Use field name to distinguish: "bias" and "beta" get bias_val, + // "weight" and "gamma" get weight_val + let is_bias = self.current_field.contains("bias") || self.current_field == "beta"; + let val = if is_bias { + self.bias_val + } else { + self.weight_val + }; + + Param::initialized(tensor.id, Tensor::full(dims, val, self.device)) + } + } + + /// Set weights and biases to different constant values. + fn set_weights_and_biases( + vae: AutoEncoderKL, + weight_val: f32, + bias_val: f32, + device: &B::Device, + ) -> AutoEncoderKL { + let mut mapper = WeightBiasMapper:: { + weight_val, + bias_val, + device, + current_field: String::new(), + }; + vae.map(&mut mapper) + } + + #[test] + fn test_encoder_output_shape() { + let device = Default::default(); + + let encoder = EncoderConfig::new(3, 4, vec![32, 64]) + .with_layers_per_block(1) + .with_norm_num_groups(32) + .with_double_z(true) + .init::(&device); + + // Input: [batch=1, channels=3, height=64, width=64] + let xs: Tensor = Tensor::zeros([1, 3, 64, 64], &device); + let output = encoder.forward(xs); + + // Output should have 2*latent_channels due to double_z + // Spatial dimensions reduced by factor of 2 for each downsample (1 downsample here) + // 64 -> 32 + assert_eq!(output.shape(), Shape::from([1, 8, 32, 32])); + } + + #[test] + fn test_decoder_output_shape() { + let device = Default::default(); + + let decoder = DecoderConfig::new(4, 3, vec![32, 64]) + .with_layers_per_block(1) + .with_norm_num_groups(32) + .init::(&device); + + // Input: [batch=1, channels=4, height=32, width=32] + let xs: Tensor = Tensor::zeros([1, 4, 32, 32], &device); + let output = decoder.forward(xs); + + // Output should upsample back + // 32 -> 64 + assert_eq!(output.shape(), Shape::from([1, 3, 64, 64])); + } + + #[test] + fn test_diagonal_gaussian_distribution() { + let device = Default::default(); + + // Parameters with 8 channels (4 for mean, 4 for logvar) + let parameters: Tensor = Tensor::zeros([1, 8, 4, 4], &device); + let dist = DiagonalGaussianDistribution::new(parameters); + + let sample = dist.sample(); + assert_eq!(sample.shape(), Shape::from([1, 4, 4, 4])); + + let mean = dist.mean(); + assert_eq!(mean.shape(), Shape::from([1, 4, 4, 4])); + } + + #[test] + fn test_autoencoder_kl_shapes() { + let device = Default::default(); + + let vae = AutoEncoderKLConfig::new() + .with_block_out_channels(vec![32, 64]) + .with_layers_per_block(1) + .with_latent_channels(4) + .with_norm_num_groups(32) + .init::(3, 3, &device); + + // Input image: [batch=1, channels=3, height=64, width=64] + let xs: Tensor = Tensor::zeros([1, 3, 64, 64], &device); + + // Encode + let dist = vae.encode(xs); + let latent = dist.sample(); + + // Latent should be [1, 4, 32, 32] (4 channels, 2x downsampled spatially) + assert_eq!(latent.shape(), Shape::from([1, 4, 32, 32])); + + // Decode + let reconstructed = vae.decode(latent); + + // Reconstructed should match input shape + assert_eq!(reconstructed.shape(), Shape::from([1, 3, 64, 64])); + } + + /// Test VAE decode with fixed weights matches diffusers-rs + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_vae_decode_fixed_weights_matches_diffusers_rs() { + let device = Default::default(); + + // Create VAE with same config as diffusers-rs test + let vae = AutoEncoderKLConfig::new() + .with_block_out_channels(vec![32, 64]) + .with_layers_per_block(1) + .with_latent_channels(4) + .with_norm_num_groups(32) + .init::(3, 3, &device); + + // Set weights to 0.1, biases to 0.0 (matching diffusers-rs test) + let vae = set_weights_and_biases(vae, 0.1, 0.0, &device); + + // Input: arange(4*16*16).reshape([1, 4, 16, 16]) / (4*16*16) + let input_data: Vec = (0..(4 * 16 * 16)) + .map(|i| i as f32 / (4.0 * 16.0 * 16.0)) + .collect(); + let xs: Tensor = + Tensor::from_data(TensorData::from(input_data.as_slice()), &device); + let xs: Tensor = xs.reshape([1, 4, 16, 16]); + + let decoded = vae.decode(xs); + + // Verify output shape + assert_eq!(decoded.shape(), Shape::from([1, 3, 32, 32])); + + // Get result values + let result_flat = decoded.clone().flatten::<1>(0, 3); + let result_data = result_flat.to_data(); + let result_vec: Vec = result_data.to_vec().unwrap(); + + // Reference values from diffusers-rs (all weights=0.1, biases=0.0) + let expected_first_16 = [ + -0.42073298_f32, + -0.689827, + -0.8111322, + -0.85588384, + -0.884616, + -0.8894976, + -0.8884052, + -0.8844389, + -0.88121814, + -0.8796178, + -0.8778493, + -0.8766395, + -0.87531906, + -0.8743709, + -0.87319934, + -0.8722713, + ]; + + for (i, (actual, expected)) in result_vec + .iter() + .take(16) + .zip(expected_first_16.iter()) + .enumerate() + { + assert!( + (actual - expected).abs() < 1e-4, + "Mismatch at index {}: actual={}, expected={}", + i, + actual, + expected + ); + } + + // Check overall mean matches + let mean = decoded.mean().into_scalar(); + assert!( + (mean - 0.09089245647192001).abs() < 1e-4, + "Mean mismatch: actual={}, expected=0.09089245647192001", + mean + ); + } +} diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index 628b03b..f11c47f 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -1 +1,10 @@ pub mod stable_diffusion; + +#[cfg(feature = "std")] +pub mod weights; + +#[cfg(feature = "std")] +pub use weights::{ + download_instructions, load_clip_safetensors, load_unet_safetensors, load_vae_safetensors, + WeightLoadError, +}; diff --git a/src/pipelines/stable_diffusion.rs b/src/pipelines/stable_diffusion.rs index 09a6f4a..c6aa973 100644 --- a/src/pipelines/stable_diffusion.rs +++ b/src/pipelines/stable_diffusion.rs @@ -1,35 +1,796 @@ -use crate::transformers::clip; -use crate::transformers::clip::ClipConfig; -use burn::config::Config; +//! Stable Diffusion Pipeline +//! +//! This module provides the Stable Diffusion pipeline for text-to-image generation. +//! The pipeline combines a CLIP text encoder, a VAE, a UNet, and a noise scheduler +//! to generate images from text prompts. +//! +//! # Example (pseudocode) +//! +//! ```ignore +//! // Create configuration +//! let config = StableDiffusionConfig::v1_5(None, None, None); +//! +//! // Build pipeline components +//! let pipeline = config.init::(&device); +//! +//! // Build scheduler +//! let scheduler = config.build_ddim_scheduler::(30, &device); +//! +//! // Tokenize prompt (requires std feature) +//! let tokenizer = SimpleTokenizer::new("bpe_simple_vocab_16e6.txt", SimpleTokenizerConfig::v1_5())?; +//! let tokens = tokenizer.encode("a photo of a cat")?; +//! let uncond_tokens = tokenizer.encode("")?; +//! +//! // Generate image +//! let image = generate_image_ddim( +//! &pipeline, +//! &scheduler, +//! tokens, +//! uncond_tokens, +//! 7.5, // guidance_scale +//! 42, // seed +//! &device, +//! ); +//! ``` + +use alloc::vec; +use alloc::vec::Vec; + use burn::module::Module; use burn::tensor::backend::Backend; +use burn::tensor::{Distribution, Int, Tensor}; + +use crate::models::unet_2d::{BlockConfig, UNet2DConditionModel, UNet2DConditionModelConfig}; +use crate::models::vae::{AutoEncoderKL, AutoEncoderKLConfig}; +use crate::schedulers::{ + BetaSchedule, DDIMScheduler, DDIMSchedulerConfig, DDPMScheduler, DDPMSchedulerConfig, + DPMSolverMultistepScheduler, DPMSolverMultistepSchedulerConfig, + EulerAncestralDiscreteScheduler, EulerAncestralDiscreteSchedulerConfig, EulerDiscreteScheduler, + EulerDiscreteSchedulerConfig, PNDMScheduler, PNDMSchedulerConfig, PredictionType, +}; +use crate::transformers::clip::{ClipConfig, ClipTextTransformer}; + +/// The guidance scale for classifier-free guidance. +/// Higher values give more weight to the text prompt. +pub const GUIDANCE_SCALE: f64 = 7.5; + +/// The scaling factor for the VAE latent space. +/// Latents are scaled by 1/0.18215 when encoding and 0.18215 when decoding. +pub const VAE_SCALE: f64 = 0.18215; -#[derive(Config, Debug)] +/// Configuration for the Stable Diffusion pipeline. +#[derive(Debug, Clone)] pub struct StableDiffusionConfig { - width: i64, - height: i64, + /// Width of the generated image in pixels. + pub width: usize, + /// Height of the generated image in pixels. + pub height: usize, + /// CLIP text encoder configuration. + pub clip: ClipConfig, + /// VAE configuration. + pub vae: AutoEncoderKLConfig, + /// UNet configuration. + pub unet: UNet2DConditionModelConfig, + /// Beta schedule start value. + pub beta_start: f64, + /// Beta schedule end value. + pub beta_end: f64, + /// Beta schedule type. + pub beta_schedule: BetaSchedule, + /// Prediction type for the scheduler. + pub prediction_type: PredictionType, + /// Number of training timesteps. + pub train_timesteps: usize, } impl StableDiffusionConfig { - pub fn init( + /// Create a configuration for Stable Diffusion v1.5. + /// + /// Reference: https://huggingface.co/runwayml/stable-diffusion-v1-5 + pub fn v1_5( + sliced_attention_size: Option, + height: Option, + width: Option, + ) -> Self { + let height = height.unwrap_or(512); + let width = width.unwrap_or(512); + assert!(height.is_multiple_of(8), "height must be divisible by 8"); + assert!(width.is_multiple_of(8), "width must be divisible by 8"); + + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json + let unet = UNet2DConditionModelConfig { + blocks: vec![ + BlockConfig::new(320) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(640) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(1280) + .with_use_cross_attn(true) + .with_attention_head_dim(8), + BlockConfig::new(1280) + .with_use_cross_attn(false) + .with_attention_head_dim(8), + ], + center_input_sample: false, + cross_attention_dim: 768, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0.0, + layers_per_block: 2, + mid_block_scale_factor: 1.0, + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: false, + }; + + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json + let vae = AutoEncoderKLConfig::new() + .with_block_out_channels(vec![128, 256, 512, 512]) + .with_layers_per_block(2) + .with_latent_channels(4) + .with_norm_num_groups(32); + + Self { + width, + height, + clip: ClipConfig::v1_5(), + vae, + unet, + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } + + /// Create a configuration for Stable Diffusion v2.1. + /// + /// Reference: https://huggingface.co/stabilityai/stable-diffusion-2-1 + pub fn v2_1( + sliced_attention_size: Option, + height: Option, + width: Option, + ) -> Self { + let height = height.unwrap_or(768); + let width = width.unwrap_or(768); + assert!(height.is_multiple_of(8), "height must be divisible by 8"); + assert!(width.is_multiple_of(8), "width must be divisible by 8"); + + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json + let unet = UNet2DConditionModelConfig { + blocks: vec![ + BlockConfig::new(320) + .with_use_cross_attn(true) + .with_attention_head_dim(5), + BlockConfig::new(640) + .with_use_cross_attn(true) + .with_attention_head_dim(10), + BlockConfig::new(1280) + .with_use_cross_attn(true) + .with_attention_head_dim(20), + BlockConfig::new(1280) + .with_use_cross_attn(false) + .with_attention_head_dim(20), + ], + center_input_sample: false, + cross_attention_dim: 1024, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0.0, + layers_per_block: 2, + mid_block_scale_factor: 1.0, + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json + let vae = AutoEncoderKLConfig::new() + .with_block_out_channels(vec![128, 256, 512, 512]) + .with_layers_per_block(2) + .with_latent_channels(4) + .with_norm_num_groups(32); + + Self { + width, + height, + clip: ClipConfig::v2_1(), + vae, + unet, + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + prediction_type: PredictionType::VPrediction, + train_timesteps: 1000, + } + } + + /// Initialize the CLIP text transformer. + pub fn build_clip_transformer(&self, device: &B::Device) -> ClipTextTransformer { + self.clip.init_text_transformer(device) + } + + /// Initialize the VAE. + pub fn build_vae(&self, device: &B::Device) -> AutoEncoderKL { + self.vae.init(3, 3, device) + } + + /// Initialize the UNet. + pub fn build_unet( &self, - clip_config: ClipConfig, device: &B::Device, - ) -> StableDiffusion { + in_channels: usize, + ) -> UNet2DConditionModel { + self.unet.init(in_channels, 4, device) + } + + /// Build a DDIM scheduler. + pub fn build_ddim_scheduler( + &self, + n_steps: usize, + device: &B::Device, + ) -> DDIMScheduler { + let config = DDIMSchedulerConfig { + beta_start: self.beta_start, + beta_end: self.beta_end, + beta_schedule: self.beta_schedule, + prediction_type: self.prediction_type, + train_timesteps: self.train_timesteps, + ..DDIMSchedulerConfig::default() + }; + DDIMScheduler::new::(n_steps, config, device) + } + + /// Build a DDPM scheduler. + pub fn build_ddpm_scheduler( + &self, + n_steps: usize, + device: &B::Device, + ) -> DDPMScheduler { + let config = DDPMSchedulerConfig { + beta_start: self.beta_start, + beta_end: self.beta_end, + beta_schedule: self.beta_schedule, + prediction_type: self.prediction_type, + train_timesteps: self.train_timesteps, + ..DDPMSchedulerConfig::default() + }; + DDPMScheduler::new::(n_steps, config, device) + } + + /// Build a DPM-Solver++ Multistep scheduler. + pub fn build_dpm_solver_scheduler( + &self, + n_steps: usize, + device: &B::Device, + ) -> DPMSolverMultistepScheduler { + let config = DPMSolverMultistepSchedulerConfig { + beta_start: self.beta_start, + beta_end: self.beta_end, + beta_schedule: self.beta_schedule, + prediction_type: self.prediction_type, + train_timesteps: self.train_timesteps, + ..DPMSolverMultistepSchedulerConfig::default() + }; + DPMSolverMultistepScheduler::new(n_steps, config, device) + } + + /// Build an Euler Discrete scheduler. + pub fn build_euler_discrete_scheduler(&self, n_steps: usize) -> EulerDiscreteScheduler { + let config = EulerDiscreteSchedulerConfig { + beta_start: self.beta_start, + beta_end: self.beta_end, + beta_schedule: self.beta_schedule, + prediction_type: self.prediction_type, + train_timesteps: self.train_timesteps, + }; + EulerDiscreteScheduler::new(n_steps, config) + } + + /// Build an Euler Ancestral Discrete scheduler. + pub fn build_euler_ancestral_scheduler( + &self, + n_steps: usize, + ) -> EulerAncestralDiscreteScheduler { + let config = EulerAncestralDiscreteSchedulerConfig { + beta_start: self.beta_start, + beta_end: self.beta_end, + beta_schedule: self.beta_schedule, + prediction_type: self.prediction_type, + train_timesteps: self.train_timesteps, + }; + EulerAncestralDiscreteScheduler::new(n_steps, config) + } + + /// Build a PNDM scheduler. + pub fn build_pndm_scheduler( + &self, + n_steps: usize, + device: &B::Device, + ) -> PNDMScheduler { + let config = PNDMSchedulerConfig { + beta_start: self.beta_start, + beta_end: self.beta_end, + beta_schedule: self.beta_schedule, + prediction_type: self.prediction_type, + train_timesteps: self.train_timesteps, + ..PNDMSchedulerConfig::default() + }; + PNDMScheduler::new(n_steps, config, device) + } + + /// Initialize the complete Stable Diffusion pipeline. + pub fn init(&self, device: &B::Device) -> StableDiffusion { StableDiffusion { + clip: self.build_clip_transformer(device), + vae: self.build_vae(device), + unet: self.build_unet(device, 4), width: self.width, height: self.height, - clip: clip_config.init_text_transformer(device), } } } +/// The Stable Diffusion pipeline. +/// +/// This struct holds all the models needed for text-to-image generation. #[derive(Module, Debug)] pub struct StableDiffusion { - width: i64, - height: i64, - clip: clip::ClipTextTransformer, - // autoencoder: vae::AutoEncoderKLConfig, - // unet: unet_2d::UNet2DConditionModelConfig, - // scheduler: ddim::DDIMSchedulerConfig, + /// The CLIP text encoder. + pub clip: ClipTextTransformer, + /// The VAE for encoding/decoding images. + pub vae: AutoEncoderKL, + /// The UNet for denoising. + pub unet: UNet2DConditionModel, + /// Width of the generated image. + pub width: usize, + /// Height of the generated image. + pub height: usize, +} + +impl StableDiffusion { + /// Encode text tokens to embeddings using the CLIP model. + /// + /// # Arguments + /// * `tokens` - Token IDs from the tokenizer [batch_size, seq_len] + /// + /// # Returns + /// Text embeddings [batch_size, seq_len, embed_dim] + pub fn encode_text(&self, tokens: Tensor) -> Tensor { + self.clip.forward(tokens) + } + + /// Encode an image to latent space. + /// + /// # Arguments + /// * `image` - Input image tensor [batch_size, 3, height, width] with values in [0, 1] + /// + /// # Returns + /// Latent tensor [batch_size, 4, height/8, width/8] + pub fn encode_image(&self, image: Tensor) -> Tensor { + // Scale image to [-1, 1] + let image = image * 2.0 - 1.0; + // Encode and sample from the distribution + let dist = self.vae.encode(image); + // Scale latent + dist.sample() * VAE_SCALE + } + + /// Decode latent vectors to images. + /// + /// # Arguments + /// * `latents` - Latent tensor [batch_size, 4, height/8, width/8] + /// + /// # Returns + /// Image tensor [batch_size, 3, height, width] with values in [0, 1] + pub fn decode_latents(&self, latents: Tensor) -> Tensor { + // Scale latent + let latents = latents / VAE_SCALE; + // Decode + let image = self.vae.decode(latents); + // Scale back to [0, 1] + (image / 2.0 + 0.5).clamp(0.0, 1.0) + } + + /// Predict the noise for a given noisy latent and timestep. + /// + /// # Arguments + /// * `latents` - Noisy latent tensor [batch_size, 4, height/8, width/8] + /// * `timestep` - Current diffusion timestep + /// * `encoder_hidden_states` - Text embeddings [batch_size, seq_len, embed_dim] + /// + /// # Returns + /// Predicted noise tensor [batch_size, 4, height/8, width/8] + pub fn predict_noise( + &self, + latents: Tensor, + timestep: f64, + encoder_hidden_states: Tensor, + ) -> Tensor { + self.unet.forward(latents, timestep, encoder_hidden_states) + } + + /// Get text embeddings with classifier-free guidance. + /// + /// Concatenates unconditional (empty prompt) and conditional embeddings. + /// + /// # Arguments + /// * `prompt_tokens` - Token IDs for the prompt [1, seq_len] + /// * `uncond_tokens` - Token IDs for empty/negative prompt [1, seq_len] + /// + /// # Returns + /// Combined embeddings [2, seq_len, embed_dim] (uncond first, then cond) + pub fn encode_prompt_with_guidance( + &self, + prompt_tokens: Tensor, + uncond_tokens: Tensor, + ) -> Tensor { + let text_embeddings = self.encode_text(prompt_tokens); + let uncond_embeddings = self.encode_text(uncond_tokens); + Tensor::cat(vec![uncond_embeddings, text_embeddings], 0) + } +} + +/// Generate an image using the DDIM scheduler. +/// +/// This function implements the full diffusion loop with classifier-free guidance. +/// +/// # Arguments +/// * `pipeline` - The Stable Diffusion pipeline with loaded models +/// * `scheduler` - The DDIM scheduler configured with the number of steps +/// * `prompt_tokens` - Tokenized prompt as a vector of token IDs +/// * `uncond_tokens` - Tokenized empty/negative prompt as a vector of token IDs +/// * `guidance_scale` - Classifier-free guidance scale (typically 7.5) +/// * `seed` - Random seed for reproducibility +/// * `device` - Device to run inference on +/// +/// # Returns +/// Generated image tensor [1, 3, height, width] with values in [0, 1] +pub fn generate_image_ddim( + pipeline: &StableDiffusion, + scheduler: &DDIMScheduler, + prompt_tokens: &[usize], + uncond_tokens: &[usize], + guidance_scale: f64, + seed: u64, + device: &B::Device, +) -> Tensor { + // Seed the random number generator for reproducibility + B::seed(device, seed); + + // Convert tokens to tensors + let prompt_tokens: Vec = prompt_tokens.iter().map(|&x| x as i64).collect(); + let uncond_tokens: Vec = uncond_tokens.iter().map(|&x| x as i64).collect(); + + let prompt_tensor: Tensor = Tensor::from_ints(&prompt_tokens[..], device); + let prompt_tensor: Tensor = prompt_tensor.unsqueeze_dim(0); + let uncond_tensor: Tensor = Tensor::from_ints(&uncond_tokens[..], device); + let uncond_tensor: Tensor = uncond_tensor.unsqueeze_dim(0); + + // Get text embeddings with guidance + let text_embeddings = pipeline.encode_prompt_with_guidance(prompt_tensor, uncond_tensor); + + // Initialize random latents + let latent_height = pipeline.height / 8; + let latent_width = pipeline.width / 8; + let mut latents: Tensor = Tensor::random( + [1, 4, latent_height, latent_width], + Distribution::Normal(0.0, 1.0), + device, + ); + + // Scale initial noise by scheduler's init_noise_sigma + latents = latents * scheduler.init_noise_sigma(); + + // Diffusion loop + for ×tep in scheduler.timesteps().iter() { + // Duplicate latents for classifier-free guidance (uncond + cond) + let latent_model_input = Tensor::cat(vec![latents.clone(), latents.clone()], 0); + + // Scale model input + let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep); + + // Predict noise + let noise_pred = + pipeline.predict_noise(latent_model_input, timestep as f64, text_embeddings.clone()); + + // Split predictions for guidance + let [noise_pred_uncond, noise_pred_text] = noise_pred.chunk(2, 0).try_into().unwrap(); + + // Apply classifier-free guidance + let noise_pred = + noise_pred_uncond.clone() + (noise_pred_text - noise_pred_uncond) * guidance_scale; + + // Scheduler step + latents = scheduler.step(&noise_pred, timestep, &latents); + } + + // Decode latents to image + pipeline.decode_latents(latents) +} + +/// Generate an image using the Euler Discrete scheduler. +/// +/// # Arguments +/// * `pipeline` - The Stable Diffusion pipeline with loaded models +/// * `scheduler` - The Euler Discrete scheduler configured with the number of steps +/// * `prompt_tokens` - Tokenized prompt as a vector of token IDs +/// * `uncond_tokens` - Tokenized empty/negative prompt as a vector of token IDs +/// * `guidance_scale` - Classifier-free guidance scale (typically 7.5) +/// * `seed` - Random seed for reproducibility +/// * `device` - Device to run inference on +/// +/// # Returns +/// Generated image tensor [1, 3, height, width] with values in [0, 1] +pub fn generate_image_euler( + pipeline: &StableDiffusion, + scheduler: &EulerDiscreteScheduler, + prompt_tokens: &[usize], + uncond_tokens: &[usize], + guidance_scale: f64, + seed: u64, + device: &B::Device, +) -> Tensor { + // Seed the random number generator for reproducibility + B::seed(device, seed); + + // Convert tokens to tensors + let prompt_tokens: Vec = prompt_tokens.iter().map(|&x| x as i64).collect(); + let uncond_tokens: Vec = uncond_tokens.iter().map(|&x| x as i64).collect(); + + let prompt_tensor: Tensor = Tensor::from_ints(&prompt_tokens[..], device); + let prompt_tensor: Tensor = prompt_tensor.unsqueeze_dim(0); + let uncond_tensor: Tensor = Tensor::from_ints(&uncond_tokens[..], device); + let uncond_tensor: Tensor = uncond_tensor.unsqueeze_dim(0); + + // Get text embeddings with guidance + let text_embeddings = pipeline.encode_prompt_with_guidance(prompt_tensor, uncond_tensor); + + // Initialize random latents + let latent_height = pipeline.height / 8; + let latent_width = pipeline.width / 8; + let mut latents: Tensor = Tensor::random( + [1, 4, latent_height, latent_width], + Distribution::Normal(0.0, 1.0), + device, + ); + + // Scale initial noise by scheduler's init_noise_sigma + latents = latents * scheduler.init_noise_sigma(); + + // Diffusion loop + for (i, ×tep) in scheduler.timesteps().iter().enumerate() { + // Duplicate latents for classifier-free guidance (uncond + cond) + let latent_model_input = Tensor::cat(vec![latents.clone(), latents.clone()], 0); + + // Scale model input + let latent_model_input = scheduler.scale_model_input(latent_model_input, i as f64); + + // Predict noise + let noise_pred = + pipeline.predict_noise(latent_model_input, timestep, text_embeddings.clone()); + + // Split predictions for guidance + let [noise_pred_uncond, noise_pred_text] = noise_pred.chunk(2, 0).try_into().unwrap(); + + // Apply classifier-free guidance + let noise_pred = + noise_pred_uncond.clone() + (noise_pred_text - noise_pred_uncond) * guidance_scale; + + // Scheduler step + latents = scheduler.step(&noise_pred, i as f64, &latents); + } + + // Decode latents to image + pipeline.decode_latents(latents) +} + +/// Generate an image using the DPM-Solver++ Multistep scheduler. +/// +/// # Arguments +/// * `pipeline` - The Stable Diffusion pipeline with loaded models +/// * `scheduler` - The DPM-Solver++ scheduler configured with the number of steps +/// * `prompt_tokens` - Tokenized prompt as a vector of token IDs +/// * `uncond_tokens` - Tokenized empty/negative prompt as a vector of token IDs +/// * `guidance_scale` - Classifier-free guidance scale (typically 7.5) +/// * `seed` - Random seed for reproducibility +/// * `device` - Device to run inference on +/// +/// # Returns +/// Generated image tensor [1, 3, height, width] with values in [0, 1] +pub fn generate_image_dpm( + pipeline: &StableDiffusion, + scheduler: &mut DPMSolverMultistepScheduler, + prompt_tokens: &[usize], + uncond_tokens: &[usize], + guidance_scale: f64, + seed: u64, + device: &B::Device, +) -> Tensor { + // Seed the random number generator for reproducibility + B::seed(device, seed); + + // Convert tokens to tensors + let prompt_tokens: Vec = prompt_tokens.iter().map(|&x| x as i64).collect(); + let uncond_tokens: Vec = uncond_tokens.iter().map(|&x| x as i64).collect(); + + let prompt_tensor: Tensor = Tensor::from_ints(&prompt_tokens[..], device); + let prompt_tensor: Tensor = prompt_tensor.unsqueeze_dim(0); + let uncond_tensor: Tensor = Tensor::from_ints(&uncond_tokens[..], device); + let uncond_tensor: Tensor = uncond_tensor.unsqueeze_dim(0); + + // Get text embeddings with guidance + let text_embeddings = pipeline.encode_prompt_with_guidance(prompt_tensor, uncond_tensor); + + // Initialize random latents + let latent_height = pipeline.height / 8; + let latent_width = pipeline.width / 8; + let mut latents: Tensor = Tensor::random( + [1, 4, latent_height, latent_width], + Distribution::Normal(0.0, 1.0), + device, + ); + + // Scale initial noise by scheduler's init_noise_sigma + latents = latents * scheduler.init_noise_sigma(); + + // Get timesteps (need to clone since we iterate while mutating scheduler) + let timesteps: alloc::vec::Vec = scheduler.timesteps().to_vec(); + + // Diffusion loop + for (i, ×tep) in timesteps.iter().enumerate() { + // Duplicate latents for classifier-free guidance (uncond + cond) + let latent_model_input = Tensor::cat(vec![latents.clone(), latents.clone()], 0); + + // Scale model input + let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep); + + // Predict noise + let noise_pred = + pipeline.predict_noise(latent_model_input, timestep as f64, text_embeddings.clone()); + + // Split predictions for guidance + let [noise_pred_uncond, noise_pred_text] = noise_pred.chunk(2, 0).try_into().unwrap(); + + // Apply classifier-free guidance + let noise_pred = + noise_pred_uncond.clone() + (noise_pred_text - noise_pred_uncond) * guidance_scale; + + // Scheduler step + latents = scheduler.step(&noise_pred, i, &latents); + } + + // Decode latents to image + pipeline.decode_latents(latents) +} + +/// Generate an image using the PNDM scheduler. +/// +/// # Arguments +/// * `pipeline` - The Stable Diffusion pipeline with loaded models +/// * `scheduler` - The PNDM scheduler configured with the number of steps +/// * `prompt_tokens` - Tokenized prompt as a vector of token IDs +/// * `uncond_tokens` - Tokenized empty/negative prompt as a vector of token IDs +/// * `guidance_scale` - Classifier-free guidance scale (typically 7.5) +/// * `seed` - Random seed for reproducibility +/// * `device` - Device to run inference on +/// +/// # Returns +/// Generated image tensor [1, 3, height, width] with values in [0, 1] +pub fn generate_image_pndm( + pipeline: &StableDiffusion, + scheduler: &mut PNDMScheduler, + prompt_tokens: &[usize], + uncond_tokens: &[usize], + guidance_scale: f64, + seed: u64, + device: &B::Device, +) -> Tensor { + // Seed the random number generator for reproducibility + B::seed(device, seed); + + // Convert tokens to tensors + let prompt_tokens: Vec = prompt_tokens.iter().map(|&x| x as i64).collect(); + let uncond_tokens: Vec = uncond_tokens.iter().map(|&x| x as i64).collect(); + + let prompt_tensor: Tensor = Tensor::from_ints(&prompt_tokens[..], device); + let prompt_tensor: Tensor = prompt_tensor.unsqueeze_dim(0); + let uncond_tensor: Tensor = Tensor::from_ints(&uncond_tokens[..], device); + let uncond_tensor: Tensor = uncond_tensor.unsqueeze_dim(0); + + // Get text embeddings with guidance + let text_embeddings = pipeline.encode_prompt_with_guidance(prompt_tensor, uncond_tensor); + + // Initialize random latents + let latent_height = pipeline.height / 8; + let latent_width = pipeline.width / 8; + let mut latents: Tensor = Tensor::random( + [1, 4, latent_height, latent_width], + Distribution::Normal(0.0, 1.0), + device, + ); + + // Scale initial noise by scheduler's init_noise_sigma + latents = latents * scheduler.init_noise_sigma(); + + // Get timesteps (need to clone since we iterate while mutating scheduler) + let timesteps: alloc::vec::Vec = scheduler.timesteps().to_vec(); + + // Diffusion loop + for ×tep in timesteps.iter() { + // Duplicate latents for classifier-free guidance (uncond + cond) + let latent_model_input = Tensor::cat(vec![latents.clone(), latents.clone()], 0); + + // Scale model input + let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep); + + // Predict noise + let noise_pred = + pipeline.predict_noise(latent_model_input, timestep as f64, text_embeddings.clone()); + + // Split predictions for guidance + let [noise_pred_uncond, noise_pred_text] = noise_pred.chunk(2, 0).try_into().unwrap(); + + // Apply classifier-free guidance + let noise_pred = + noise_pred_uncond.clone() + (noise_pred_text - noise_pred_uncond) * guidance_scale; + + // Scheduler step + latents = scheduler.step(&noise_pred, timestep, &latents); + } + + // Decode latents to image + pipeline.decode_latents(latents) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_v1_5_config() { + let config = StableDiffusionConfig::v1_5(None, None, None); + + assert_eq!(config.width, 512); + assert_eq!(config.height, 512); + assert_eq!(config.unet.cross_attention_dim, 768); + assert_eq!(config.unet.blocks.len(), 4); + assert!(!config.unet.use_linear_projection); + } + + #[test] + fn test_v2_1_config() { + let config = StableDiffusionConfig::v2_1(None, None, None); + + assert_eq!(config.width, 768); + assert_eq!(config.height, 768); + assert_eq!(config.unet.cross_attention_dim, 1024); + assert_eq!(config.unet.blocks.len(), 4); + assert!(config.unet.use_linear_projection); + } + + #[test] + fn test_custom_dimensions() { + let config = StableDiffusionConfig::v1_5(None, Some(768), Some(1024)); + + assert_eq!(config.width, 1024); + assert_eq!(config.height, 768); + } + + #[test] + #[should_panic(expected = "height must be divisible by 8")] + fn test_invalid_height() { + let _ = StableDiffusionConfig::v1_5(None, Some(513), None); + } + + #[test] + #[should_panic(expected = "width must be divisible by 8")] + fn test_invalid_width() { + let _ = StableDiffusionConfig::v1_5(None, None, Some(513)); + } } diff --git a/src/pipelines/weights.rs b/src/pipelines/weights.rs new file mode 100644 index 0000000..f18851e --- /dev/null +++ b/src/pipelines/weights.rs @@ -0,0 +1,409 @@ +//! Weight Loading Utilities +//! +//! This module provides utilities for loading pre-trained weights from +//! SafeTensors format into Burn models using `burn-store`. +//! +//! # Overview +//! +//! Stable Diffusion models typically come in SafeTensors format from Hugging Face. +//! This module provides helper functions to load these weights into the +//! diffusers-burn model structures. +//! +//! # Weight File Structure +//! +//! For Stable Diffusion, you typically need: +//! - `text_encoder/model.safetensors` - CLIP text encoder weights +//! - `vae/diffusion_pytorch_model.safetensors` - VAE encoder/decoder weights +//! - `unet/diffusion_pytorch_model.safetensors` - UNet denoising model weights +//! +//! # Example Usage +//! +//! ```ignore +//! use diffusers_burn::pipelines::weights::{load_clip_safetensors, load_vae_safetensors, load_unet_safetensors}; +//! use diffusers_burn::pipelines::stable_diffusion::StableDiffusionConfig; +//! +//! let config = StableDiffusionConfig::v1_5(None, None, None); +//! let device = Default::default(); +//! +//! // Initialize models +//! let clip = config.build_clip_transformer::(&device); +//! let vae = config.build_vae::(&device); +//! let unet = config.build_unet::(&device, 4); +//! +//! // Load weights +//! let clip = load_clip_safetensors::(clip, "path/to/model.safetensors", &device)?; +//! let vae = load_vae_safetensors::(vae, "path/to/vae.safetensors", &device)?; +//! let unet = load_unet_safetensors::(unet, "path/to/unet.safetensors", &device)?; +//! ``` + +use std::path::Path; + +use burn::tensor::backend::Backend; +use burn_store::{KeyRemapper, ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore}; + +/// Errors that can occur during weight loading. +#[derive(Debug, thiserror::Error)] +pub enum WeightLoadError { + #[error("Failed to load safetensors file: {0}")] + SafetensorsLoad(String), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +/// Load CLIP text encoder weights from HuggingFace format. +/// +/// This handles the key remapping from HuggingFace's naming convention +/// (e.g., `text_model.embeddings.token_embedding`) to Burn's convention +/// (e.g., `embeddings.token_embedding`). +/// +/// Uses `PyTorchToBurnAdapter` to automatically handle: +/// - Transposing linear layer weights +/// - Renaming normalization parameters (weight->gamma, bias->beta) +pub fn load_clip_safetensors( + mut module: M, + path: P, + _device: &B::Device, +) -> Result +where + B: Backend, + M: ModuleSnapshot, + P: AsRef, +{ + // Remap HuggingFace CLIP keys to our model structure + let key_mappings: Vec<(&str, &str)> = vec![ + // Remove "text_model." prefix + ("^text_model\\.", ""), + ]; + + let remapper = KeyRemapper::from_patterns(key_mappings) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + + let checkpoint_path = path.as_ref().to_path_buf(); + let mut store = SafetensorsStore::from_file(checkpoint_path) + .with_from_adapter(PyTorchToBurnAdapter) + .remap(remapper); + + module + .load_from(&mut store) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + + Ok(module) +} + +/// Load VAE weights from HuggingFace format. +/// +/// This handles the key remapping from HuggingFace's naming convention to Burn's. +/// Main differences: +/// - HF uses `mid_block.resnets.0/1` but Burn uses `mid_block.resnet` (first) and +/// `mid_block.attn_resnets.0.resnet_block` (second) +/// - HF uses `mid_block.attentions.0` but Burn uses `mid_block.attn_resnets.0.attention_block` +/// - HF uses `downsamplers.0` but Burn uses `downsampler` +/// - HF uses `upsamplers.0` but Burn uses `upsampler` +/// +/// Uses `PyTorchToBurnAdapter` to automatically handle: +/// - Transposing linear layer weights +/// - Renaming normalization parameters (weight->gamma, bias->beta) +pub fn load_vae_safetensors( + mut module: M, + path: P, + _device: &B::Device, +) -> Result +where + B: Backend, + M: ModuleSnapshot, + P: AsRef, +{ + // Remap HuggingFace VAE keys to our model structure + // Order matters: more specific patterns should come first + // Note: VAE has encoder/decoder prefixes, so patterns should not use ^ anchor + let key_mappings: Vec<(&str, &str)> = vec![ + // Mid block: first resnet (index 0) maps to standalone resnet field + ("\\.mid_block\\.resnets\\.0\\.", ".mid_block.resnet."), + // Mid block: second resnet (index 1) maps to attn_resnets.0.resnet_block + ( + "\\.mid_block\\.resnets\\.1\\.", + ".mid_block.attn_resnets.0.resnet_block.", + ), + // Mid block: attention maps to attn_resnets.X.attention_block + ( + "\\.mid_block\\.attentions\\.(\\d+)\\.", + ".mid_block.attn_resnets.$1.attention_block.", + ), + // Downsamplers: downsamplers.0 -> downsampler + ("\\.downsamplers\\.0\\.", ".downsampler."), + // Upsamplers: upsamplers.0 -> upsampler + ("\\.upsamplers\\.0\\.", ".upsampler."), + ]; + + let remapper = KeyRemapper::from_patterns(key_mappings) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + + let checkpoint_path = path.as_ref().to_path_buf(); + let mut store = SafetensorsStore::from_file(checkpoint_path) + .with_from_adapter(PyTorchToBurnAdapter) + .remap(remapper); + + module + .load_from(&mut store) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + + Ok(module) +} + +/// Inspect a safetensors file to determine which down/up blocks have attention. +/// +/// Returns two vectors of booleans indicating which blocks have attention: +/// - First vector: down_blocks (true if block has attention) +/// - Second vector: up_blocks (true if block has attention) +fn inspect_unet_block_types>( + path: P, +) -> Result<(Vec, Vec), WeightLoadError> { + use std::collections::HashSet; + + let file = std::fs::File::open(path.as_ref()) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + let buffer = unsafe { memmap2::MmapOptions::new().map(&file) } + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + let tensors = safetensors::SafeTensors::deserialize(&buffer) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + + let keys: Vec = tensors.names().into_iter().cloned().collect(); + + // Find max block indices + let mut max_down_block = 0usize; + let mut max_up_block = 0usize; + let mut down_blocks_with_attn = HashSet::new(); + let mut up_blocks_with_attn = HashSet::new(); + + for key in keys.iter() { + // Check for down_blocks.X pattern + if let Some(rest) = key.strip_prefix("down_blocks.") { + if let Some(dot_pos) = rest.find('.') { + if let Ok(idx) = rest[..dot_pos].parse::() { + max_down_block = max_down_block.max(idx); + // Check if this block has attentions + if rest[dot_pos..].starts_with(".attentions.") { + down_blocks_with_attn.insert(idx); + } + } + } + } + // Check for up_blocks.X pattern + if let Some(rest) = key.strip_prefix("up_blocks.") { + if let Some(dot_pos) = rest.find('.') { + if let Ok(idx) = rest[..dot_pos].parse::() { + max_up_block = max_up_block.max(idx); + // Check if this block has attentions + if rest[dot_pos..].starts_with(".attentions.") { + up_blocks_with_attn.insert(idx); + } + } + } + } + } + + // Build boolean vectors + let down_has_attn: Vec = (0..=max_down_block) + .map(|i| down_blocks_with_attn.contains(&i)) + .collect(); + let up_has_attn: Vec = (0..=max_up_block) + .map(|i| up_blocks_with_attn.contains(&i)) + .collect(); + + Ok((down_has_attn, up_has_attn)) +} + +/// Load UNet weights from HuggingFace format with smart block detection. +/// +/// This function uses burn-store with `skip_enum_variants` to handle the enum-based +/// block types (UNetDownBlock, UNetUpBlock) without needing variant names in the +/// weight file keys. +/// +/// Main differences between HuggingFace and Burn naming: +/// - Mid block resnets: `mid_block.resnets.0/1` → `mid_block.resnet` / `mid_block.attn_resnets.0.resnet_block` +/// - Mid block attention: `mid_block.attentions.0` → `mid_block.attn_resnets.0.spatial_transformer` +/// - Downsampler: `downsamplers.0` → `downsampler` +/// - Upsampler: `upsamplers.0` → `upsampler` +/// - Cross-attention keys: `to_k/to_q/to_v` → `key/query/value` +/// - Cross-attention output: `to_out.0` → `output` +/// - FeedForward: `ff.net.0.proj` → `ff.geglu.proj`, `ff.net.2` → `ff.linear_outer` +/// +/// For blocks with cross-attention (CrossAttnDownBlock2D, CrossAttnUpBlock2D): +/// - resnets/downsamplers/upsamplers are nested under `downblock`/`upblock` +/// +/// For basic blocks (DownBlock2D, UpBlock2D): +/// - resnets/downsamplers/upsamplers are at the top level +/// +/// Uses `PyTorchToBurnAdapter` to automatically handle: +/// - Transposing linear layer weights +/// - Renaming normalization parameters (weight->gamma, bias->beta) +pub fn load_unet_safetensors( + mut module: M, + path: P, + _device: &B::Device, +) -> Result +where + B: Backend, + M: ModuleSnapshot, + P: AsRef, +{ + // Inspect the file to determine block types + let (down_has_attn, up_has_attn) = inspect_unet_block_types(path.as_ref())?; + + // Build key mappings for HuggingFace -> Burn structure + // Order matters: more specific patterns should come first + let mut key_mappings: Vec<(&str, &str)> = vec![ + // Mid block remappings + ("^mid_block\\.resnets\\.0\\.", "mid_block.resnet."), + ( + "^mid_block\\.resnets\\.1\\.", + "mid_block.attn_resnets.0.resnet_block.", + ), + ( + "^mid_block\\.attentions\\.([0-9]+)\\.", + "mid_block.attn_resnets.$1.spatial_transformer.", + ), + // Cross-attention key remapping + ("\\.to_k\\.", ".key."), + ("\\.to_q\\.", ".query."), + ("\\.to_v\\.", ".value."), + ("\\.to_out\\.0\\.", ".output."), + // FeedForward remapping + ("\\.ff\\.net\\.0\\.proj\\.", ".ff.geglu.proj."), + ("\\.ff\\.net\\.2\\.", ".ff.linear_outer."), + ]; + + // Down block remappings - depends on whether block has attention + // We need to use owned strings for dynamic patterns + let mut dynamic_mappings: Vec<(String, String)> = Vec::new(); + + for (i, has_attn) in down_has_attn.iter().enumerate() { + if *has_attn { + // CrossAttnDownBlock2D: resnets/downsamplers nested under downblock + dynamic_mappings.push(( + format!("^down_blocks\\.{}\\.resnets\\.", i), + format!("down_blocks.{}.downblock.resnets.", i), + )); + dynamic_mappings.push(( + format!("^down_blocks\\.{}\\.downsamplers\\.0\\.", i), + format!("down_blocks.{}.downblock.downsampler.", i), + )); + // attentions stay at block level + } else { + // DownBlock2D: flat structure, just remap downsamplers.0 -> downsampler + dynamic_mappings.push(( + format!("^down_blocks\\.{}\\.downsamplers\\.0\\.", i), + format!("down_blocks.{}.downsampler.", i), + )); + } + } + + for (i, has_attn) in up_has_attn.iter().enumerate() { + if *has_attn { + // CrossAttnUpBlock2D: resnets/upsamplers nested under upblock + dynamic_mappings.push(( + format!("^up_blocks\\.{}\\.resnets\\.", i), + format!("up_blocks.{}.upblock.resnets.", i), + )); + dynamic_mappings.push(( + format!("^up_blocks\\.{}\\.upsamplers\\.0\\.", i), + format!("up_blocks.{}.upblock.upsampler.", i), + )); + // attentions stay at block level + } else { + // UpBlock2D: flat structure, just remap upsamplers.0 -> upsampler + dynamic_mappings.push(( + format!("^up_blocks\\.{}\\.upsamplers\\.0\\.", i), + format!("up_blocks.{}.upsampler.", i), + )); + } + } + + // Combine static and dynamic mappings + let dynamic_refs: Vec<(&str, &str)> = dynamic_mappings + .iter() + .map(|(a, b)| (a.as_str(), b.as_str())) + .collect(); + key_mappings.extend(dynamic_refs); + + let remapper = KeyRemapper::from_patterns(key_mappings) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + + let checkpoint_path = path.as_ref().to_path_buf(); + let mut store = SafetensorsStore::from_file(checkpoint_path) + .with_from_adapter(PyTorchToBurnAdapter) + .remap(remapper) + .skip_enum_variants(true); // This is the key: skip enum variant names when matching paths + + module + .load_from(&mut store) + .map_err(|e| WeightLoadError::SafetensorsLoad(e.to_string()))?; + + Ok(module) +} + +/// Instructions for downloading Stable Diffusion weights. +/// +/// Returns a string with instructions for obtaining the necessary weight files. +pub fn download_instructions() -> &'static str { + r#" +# Downloading Stable Diffusion Weights + +## Option 1: From Hugging Face Hub (Recommended) + +1. Install huggingface-cli: + pip install huggingface_hub + +2. Download SD 1.5 weights: + huggingface-cli download runwayml/stable-diffusion-v1-5 \ + text_encoder/model.safetensors \ + vae/diffusion_pytorch_model.safetensors \ + unet/diffusion_pytorch_model.safetensors \ + --local-dir ./sd-v1-5 + +3. Or download SD 2.1 weights: + huggingface-cli download stabilityai/stable-diffusion-2-1 \ + text_encoder/model.safetensors \ + vae/diffusion_pytorch_model.safetensors \ + unet/diffusion_pytorch_model.safetensors \ + --local-dir ./sd-v2-1 + +## Option 2: Manual Download + +Visit the model pages on Hugging Face: +- SD 1.5: https://huggingface.co/runwayml/stable-diffusion-v1-5 +- SD 2.1: https://huggingface.co/stabilityai/stable-diffusion-2-1 + +Download the .safetensors files from the "Files" tab. + +## BPE Vocabulary File + +You also need the BPE vocabulary file for the tokenizer: + wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz + gunzip bpe_simple_vocab_16e6.txt.gz + +## Directory Structure + +Your weights directory should look like: + model_dir/ + ├── text_encoder/ + │ └── model.safetensors + ├── vae/ + │ └── diffusion_pytorch_model.safetensors + ├── unet/ + │ └── diffusion_pytorch_model.safetensors + └── bpe_simple_vocab_16e6.txt +"# +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_download_instructions() { + let instructions = download_instructions(); + assert!(instructions.contains("Hugging Face")); + assert!(instructions.contains("safetensors")); + } +} diff --git a/src/schedulers/ddim.rs b/src/schedulers/ddim.rs new file mode 100644 index 0000000..3d78d5f --- /dev/null +++ b/src/schedulers/ddim.rs @@ -0,0 +1,569 @@ +//! # Denoising Diffusion Implicit Models +//! +//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler +//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM +//! generative process is the reverse of a Markovian process, DDIM generalizes +//! this to non-Markovian guidance. +//! +//! Denoising Diffusion Implicit Models, J. Song et al, 2020. +//! + +use alloc::vec::Vec; +use burn::tensor::{backend::Backend, Tensor}; + +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +/// The configuration for the DDIM scheduler. +#[derive(Debug, Clone, Copy)] +pub struct DDIMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// The amount of noise to be added at each step. + pub eta: f64, + /// Adjust the indexes of the inference schedule by this value. + pub steps_offset: usize, + /// Prediction type of the scheduler function, one of `epsilon` (predicting + /// the noise of the diffusion process), `sample` (directly predicting the noisy sample) + /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + pub prediction_type: PredictionType, + /// Number of diffusion steps used to train the model. + pub train_timesteps: usize, +} + +impl Default for DDIMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + eta: 0.0, + steps_offset: 1, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +/// The DDIM scheduler. +#[derive(Debug, Clone)] +pub struct DDIMScheduler { + timesteps: Vec, + alphas_cumprod: Vec, + step_ratio: usize, + init_noise_sigma: f64, + /// The configuration used to create this scheduler. + pub config: DDIMSchedulerConfig, +} + +impl DDIMScheduler { + /// Creates a new DDIM scheduler given the number of steps to be + /// used for inference as well as the number of steps that was used + /// during training. + pub fn new( + inference_steps: usize, + config: DDIMSchedulerConfig, + device: &B::Device, + ) -> Self { + let step_ratio = config.train_timesteps / inference_steps; + + // Generate timesteps in reverse order + let timesteps: Vec = (0..inference_steps) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(); + + // Compute betas based on schedule type + let betas: Tensor = match config.beta_schedule { + BetaSchedule::ScaledLinear => { + // linspace of sqrt(beta_start) to sqrt(beta_end), then squared + let start = config.beta_start.sqrt(); + let end = config.beta_end.sqrt(); + Tensor::from_floats( + linspace(start, end, config.train_timesteps).as_slice(), + device, + ) + .powf_scalar(2.0) + } + BetaSchedule::Linear => Tensor::from_floats( + linspace(config.beta_start, config.beta_end, config.train_timesteps).as_slice(), + device, + ), + BetaSchedule::SquaredcosCapV2 => { + betas_for_alpha_bar(config.train_timesteps, 0.999, device) + } + }; + + // alphas = 1 - betas + let alphas = betas.neg().add_scalar(1.0); + + // Compute cumulative product of alphas + let alphas_cumprod = cumprod_vec::(alphas); + + Self { + alphas_cumprod, + timesteps, + step_ratio, + init_noise_sigma: 1.0, + config, + } + } + + /// Returns the timesteps for the scheduler. + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input( + &self, + sample: Tensor, + _timestep: usize, + ) -> Tensor { + sample + } + + /// Performs a backward step during inference. + pub fn step( + &self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor { + // Clamp timestep if needed + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + + // Calculate previous timestep + let prev_timestep = timestep.saturating_sub(self.step_ratio); + + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]; + let beta_prod_t = 1.0 - alpha_prod_t; + let beta_prod_t_prev = 1.0 - alpha_prod_t_prev; + + // Compute predicted original sample and epsilon based on prediction type + let (pred_original_sample, pred_epsilon) = match self.config.prediction_type { + PredictionType::Epsilon => { + // pred_original_sample = (sample - sqrt(beta_prod_t) * model_output) / sqrt(alpha_prod_t) + let pred_original_sample = sample + .clone() + .sub(model_output.clone().mul_scalar(beta_prod_t.sqrt())) + .div_scalar(alpha_prod_t.sqrt()); + (pred_original_sample, model_output.clone()) + } + PredictionType::VPrediction => { + // pred_original_sample = sqrt(alpha_prod_t) * sample - sqrt(beta_prod_t) * model_output + let pred_original_sample = sample + .clone() + .mul_scalar(alpha_prod_t.sqrt()) + .sub(model_output.clone().mul_scalar(beta_prod_t.sqrt())); + // pred_epsilon = sqrt(alpha_prod_t) * model_output + sqrt(beta_prod_t) * sample + let pred_epsilon = model_output + .clone() + .mul_scalar(alpha_prod_t.sqrt()) + .add(sample.clone().mul_scalar(beta_prod_t.sqrt())); + (pred_original_sample, pred_epsilon) + } + PredictionType::Sample => { + let pred_original_sample = model_output.clone(); + // pred_epsilon = (sample - sqrt(alpha_prod_t) * pred_original_sample) / sqrt(beta_prod_t) + let pred_epsilon = sample + .clone() + .sub(pred_original_sample.clone().mul_scalar(alpha_prod_t.sqrt())) + .div_scalar(beta_prod_t.sqrt()); + (pred_original_sample, pred_epsilon) + } + }; + + // Compute variance + let variance = (beta_prod_t_prev / beta_prod_t) * (1.0 - alpha_prod_t / alpha_prod_t_prev); + let std_dev_t = self.config.eta * variance.sqrt(); + + // pred_sample_direction = sqrt(1 - alpha_prod_t_prev - std_dev_t^2) * pred_epsilon + let pred_sample_direction = + pred_epsilon.mul_scalar((1.0 - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt()); + + // prev_sample = sqrt(alpha_prod_t_prev) * pred_original_sample + pred_sample_direction + let prev_sample = pred_original_sample + .mul_scalar(alpha_prod_t_prev.sqrt()) + .add(pred_sample_direction); + + // Add noise if eta > 0 + if self.config.eta > 0.0 { + let noise = + Tensor::random_like(&prev_sample, burn::tensor::Distribution::Normal(0.0, 1.0)); + prev_sample.add(noise.mul_scalar(std_dev_t)) + } else { + prev_sample + } + } + + /// Adds noise to original samples. + pub fn add_noise( + &self, + original: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Tensor { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + + // sqrt_alpha_prod * original + sqrt_one_minus_alpha_prod * noise + original + .clone() + .mul_scalar(sqrt_alpha_prod) + .add(noise.mul_scalar(sqrt_one_minus_alpha_prod)) + } + + /// Returns the initial noise sigma. + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} + +/// Creates a vector of linearly spaced values. +fn linspace(start: f64, end: f64, steps: usize) -> Vec { + if steps == 0 { + return Vec::new(); + } + if steps == 1 { + return alloc::vec![start]; + } + let step_size = (end - start) / (steps - 1) as f64; + (0..steps).map(|i| start + step_size * i as f64).collect() +} + +/// Computes cumulative product and returns as Vec. +fn cumprod_vec(tensor: Tensor) -> Vec { + let data = tensor.into_data(); + let values: Vec = data.to_vec().unwrap(); + + let mut result = Vec::with_capacity(values.len()); + let mut acc = 1.0f64; + for v in values { + acc *= v as f64; + result.push(acc); + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::prelude::ElementConversion; + + #[test] + fn test_linspace() { + let result = linspace(0.0, 1.0, 5); + assert_eq!(result.len(), 5); + assert!((result[0] - 0.0).abs() < 1e-10); + assert!((result[4] - 1.0).abs() < 1e-10); + assert!((result[2] - 0.5).abs() < 1e-10); + } + + #[test] + fn test_ddim_scheduler_creation() { + let device = Default::default(); + let config = DDIMSchedulerConfig::default(); + let scheduler = DDIMScheduler::new::(50, config, &device); + + // Check timesteps are in descending order + let timesteps = scheduler.timesteps(); + assert_eq!(timesteps.len(), 50); + assert!(timesteps[0] > timesteps[timesteps.len() - 1]); + + // First timestep should be around 981 (49 * 20 + 1 with step_ratio=20, steps_offset=1) + assert_eq!(timesteps[0], 981); + + // Last timestep should be 1 (0 * 20 + 1) + assert_eq!(timesteps[timesteps.len() - 1], 1); + } + + #[test] + fn test_ddim_alphas_cumprod() { + let device = Default::default(); + let config = DDIMSchedulerConfig::default(); + let scheduler = DDIMScheduler::new::(50, config, &device); + + // alphas_cumprod should have train_timesteps entries + assert_eq!(scheduler.alphas_cumprod.len(), 1000); + + // First alpha_cumprod should be close to 1 (since beta_start is small) + assert!(scheduler.alphas_cumprod[0] > 0.99); + + // Last alpha_cumprod should be small (accumulated product decreases) + assert!(scheduler.alphas_cumprod[999] < 0.1); + + // Should be monotonically decreasing + for i in 1..scheduler.alphas_cumprod.len() { + assert!(scheduler.alphas_cumprod[i] < scheduler.alphas_cumprod[i - 1]); + } + } + + #[test] + fn test_ddim_add_noise() { + let device = Default::default(); + let config = DDIMSchedulerConfig::default(); + let scheduler = DDIMScheduler::new::(50, config, &device); + + let original: Tensor = Tensor::ones([1, 4, 64, 64], &device); + let noise: Tensor = Tensor::zeros([1, 4, 64, 64], &device); + + // With zero noise at timestep 0, result should be close to original * sqrt(alpha_cumprod[0]) + let result = scheduler.add_noise(&original, noise, 0); + let expected_scale = scheduler.alphas_cumprod[0].sqrt(); + + // Check that the result is scaled correctly + let result_mean: f32 = result.mean().into_scalar().elem(); + assert!((result_mean as f64 - expected_scale).abs() < 1e-4); + } + + #[test] + fn test_ddim_step_epsilon_prediction() { + let device = Default::default(); + let config = DDIMSchedulerConfig { + prediction_type: PredictionType::Epsilon, + ..Default::default() + }; + let scheduler = DDIMScheduler::new::(50, config, &device); + + // Simple test: zero model output (no predicted noise) should return scaled sample + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + + let timestep = scheduler.timesteps()[0]; // First timestep + let result = scheduler.step(&model_output, timestep, &sample); + + // Result should not be NaN or Inf + let result_data = result.into_data(); + let values: Vec = result_data.to_vec().unwrap(); + for v in &values { + assert!(v.is_finite(), "Result contains non-finite values"); + } + } + + #[test] + fn test_linear_beta_schedule() { + let device = ::Device::default(); + let config = DDIMSchedulerConfig { + beta_schedule: BetaSchedule::Linear, + ..Default::default() + }; + let scheduler = DDIMScheduler::new::(50, config, &device); + + // With linear schedule, alphas_cumprod should still be monotonically decreasing + for i in 1..scheduler.alphas_cumprod.len() { + assert!(scheduler.alphas_cumprod[i] < scheduler.alphas_cumprod[i - 1]); + } + } + + #[test] + fn test_init_noise_sigma() { + let device = Default::default(); + let config = DDIMSchedulerConfig::default(); + let scheduler = DDIMScheduler::new::(50, config, &device); + + assert_eq!(scheduler.init_noise_sigma(), 1.0); + } + + /// Test that alphas_cumprod values match diffusers-rs within acceptable tolerance. + /// Reference values derived from diffusers-rs v0.3.1 using add_noise with zero noise. + /// Note: Small differences (~1e-4) are expected due to f32 precision in tensor ops. + #[test] + fn test_alphas_cumprod_matches_diffusers_rs() { + let device = Default::default(); + let config = DDIMSchedulerConfig::default(); + let scheduler = DDIMScheduler::new::(50, config, &device); + + // Expected values from diffusers-rs v0.3.1 with ScaledLinear beta schedule + // beta_start=0.00085, beta_end=0.012, train_timesteps=1000 + // Values derived by running: scheduler.add_noise(ones, zeros, t) which gives sqrt(alpha_cumprod[t]) + let expected_first = 0.999149980057211; // alphas_cumprod[0] + let expected_at_500 = 0.276332449694738; // alphas_cumprod[500] + let expected_last = 0.004660095778424; // alphas_cumprod[999] + + assert!( + (scheduler.alphas_cumprod[0] - expected_first).abs() < 1e-4, + "alphas_cumprod[0]: expected {}, got {}", + expected_first, + scheduler.alphas_cumprod[0] + ); + assert!( + (scheduler.alphas_cumprod[500] - expected_at_500).abs() < 1e-4, + "alphas_cumprod[500]: expected {}, got {}", + expected_at_500, + scheduler.alphas_cumprod[500] + ); + assert!( + (scheduler.alphas_cumprod[999] - expected_last).abs() < 1e-4, + "alphas_cumprod[999]: expected {}, got {}", + expected_last, + scheduler.alphas_cumprod[999] + ); + } + + /// Test step() produces correct output matching diffusers-rs. + /// Reference values from diffusers-rs v0.3.1. + #[test] + fn test_step_matches_diffusers_rs() { + let device = Default::default(); + let config = DDIMSchedulerConfig { + eta: 0.0, // Deterministic (no noise added) + ..Default::default() + }; + let scheduler = DDIMScheduler::new::(50, config, &device); + + // Use a simple known input + let sample: Tensor = Tensor::ones([1, 1, 2, 2], &device); + let model_output: Tensor = + Tensor::ones([1, 1, 2, 2], &device).mul_scalar(0.5); + + // Step at timestep 981 (first timestep with 50 inference steps) + // diffusers-rs returns: 1.061225771903992 + let result = scheduler.step(&model_output, 981, &sample); + let result_data = result.into_data(); + let values: Vec = result_data.to_vec().unwrap(); + let first_val = values[0]; + + let expected_981 = 1.061225771903992; + assert!( + (first_val as f64 - expected_981).abs() < 1e-3, + "Step at 981: expected {}, got {}", + expected_981, + first_val + ); + + // Step at timestep 500 + // diffusers-rs returns: 1.019660353660583 + let result2 = scheduler.step(&model_output, 500, &sample); + let result2_data = result2.into_data(); + let values2: Vec = result2_data.to_vec().unwrap(); + + let expected_500 = 1.019660353660583; + assert!( + (values2[0] as f64 - expected_500).abs() < 1e-3, + "Step at 500: expected {}, got {}", + expected_500, + values2[0] + ); + } + + /// Test add_noise produces correct output matching diffusers-rs. + /// Reference values from diffusers-rs v0.3.1. + #[test] + fn test_add_noise_matches_diffusers_rs() { + let device = Default::default(); + let config = DDIMSchedulerConfig::default(); + let scheduler = DDIMScheduler::new::(50, config, &device); + + let original: Tensor = Tensor::ones([1, 1, 2, 2], &device); + + // Test at timestep 0 + // diffusers-rs returns: 1.028730034828186 + let noise0: Tensor = Tensor::ones([1, 1, 2, 2], &device); + let result0 = scheduler.add_noise(&original, noise0, 0); + let val0: f32 = result0.into_data().to_vec::().unwrap()[0]; + let expected_0 = 1.028730034828186; + assert!( + (val0 as f64 - expected_0).abs() < 1e-3, + "add_noise at 0: expected {}, got {}", + expected_0, + val0 + ); + + // Test at timestep 500 + // diffusers-rs returns: 1.376359820365906 + let noise500: Tensor = Tensor::ones([1, 1, 2, 2], &device); + let result500 = scheduler.add_noise(&original, noise500, 500); + let val500: f32 = result500.into_data().to_vec::().unwrap()[0]; + let expected_500 = 1.376359820365906; + assert!( + (val500 as f64 - expected_500).abs() < 1e-3, + "add_noise at 500: expected {}, got {}", + expected_500, + val500 + ); + + // Test at timestep 999 + // diffusers-rs returns: 1.065932154655457 + let noise999: Tensor = Tensor::ones([1, 1, 2, 2], &device); + let result999 = scheduler.add_noise(&original, noise999, 999); + let val999: f32 = result999.into_data().to_vec::().unwrap()[0]; + let expected_999 = 1.065932154655457; + assert!( + (val999 as f64 - expected_999).abs() < 1e-3, + "add_noise at 999: expected {}, got {}", + expected_999, + val999 + ); + } + + /// Test V-prediction mode produces valid output. + #[test] + fn test_v_prediction_step() { + let device = Default::default(); + let config = DDIMSchedulerConfig { + prediction_type: PredictionType::VPrediction, + eta: 0.0, + ..Default::default() + }; + let scheduler = DDIMScheduler::new::(50, config, &device); + + let sample: Tensor = Tensor::ones([1, 1, 2, 2], &device); + let model_output: Tensor = + Tensor::ones([1, 1, 2, 2], &device).mul_scalar(0.5); + + let result = scheduler.step(&model_output, 981, &sample); + let result_data = result.into_data(); + let values: Vec = result_data.to_vec().unwrap(); + + // Result should be finite and uniform + for v in &values { + assert!(v.is_finite(), "V-prediction result should be finite"); + } + } + + /// Test Sample prediction mode produces valid output. + #[test] + fn test_sample_prediction_step() { + let device = Default::default(); + let config = DDIMSchedulerConfig { + prediction_type: PredictionType::Sample, + eta: 0.0, + ..Default::default() + }; + let scheduler = DDIMScheduler::new::(50, config, &device); + + let sample: Tensor = Tensor::ones([1, 1, 2, 2], &device); + let model_output: Tensor = + Tensor::ones([1, 1, 2, 2], &device).mul_scalar(0.5); + + let result = scheduler.step(&model_output, 981, &sample); + let result_data = result.into_data(); + let values: Vec = result_data.to_vec().unwrap(); + + // Result should be finite and uniform + for v in &values { + assert!(v.is_finite(), "Sample prediction result should be finite"); + } + } +} diff --git a/src/schedulers/ddpm.rs b/src/schedulers/ddpm.rs new file mode 100644 index 0000000..c83a2e4 --- /dev/null +++ b/src/schedulers/ddpm.rs @@ -0,0 +1,406 @@ +//! DDPM Scheduler +//! +//! Denoising Diffusion Probabilistic Models (DDPM) scheduler. +//! Based on the paper: https://arxiv.org/abs/2006.11239 + +use alloc::vec::Vec; +use burn::tensor::{backend::Backend, Distribution, Tensor}; + +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +/// Variance type for DDPM scheduler. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DDPMVarianceType { + /// Fixed small variance. + FixedSmall, + /// Fixed small variance (log). + FixedSmallLog, + /// Fixed large variance. + FixedLarge, + /// Fixed large variance (log). + FixedLargeLog, + /// Learned variance. + Learned, +} + +impl Default for DDPMVarianceType { + fn default() -> Self { + Self::FixedSmall + } +} + +/// Configuration for the DDPM Scheduler. +#[derive(Debug, Clone)] +pub struct DDPMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Option to clip the predicted sample between -1 and 1 for numerical stability. + pub clip_sample: bool, + /// Option to clip the variance used when adding noise to the denoised sample. + pub variance_type: DDPMVarianceType, + /// Prediction type of the scheduler function. + pub prediction_type: PredictionType, + /// Number of diffusion steps used to train the model. + pub train_timesteps: usize, +} + +impl Default for DDPMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + clip_sample: false, + variance_type: DDPMVarianceType::FixedSmall, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +/// DDPM Scheduler for diffusion models. +/// +/// This scheduler implements the DDPM algorithm for denoising diffusion models. +pub struct DDPMScheduler { + alphas_cumprod: Vec, + init_noise_sigma: f64, + timesteps: Vec, + step_ratio: usize, + /// The scheduler configuration. + pub config: DDPMSchedulerConfig, +} + +impl DDPMScheduler { + /// Create a new DDPM Scheduler. + /// + /// # Arguments + /// * `inference_steps` - Number of inference steps + /// * `config` - Scheduler configuration + /// * `device` - The device to create tensors on + pub fn new( + inference_steps: usize, + config: DDPMSchedulerConfig, + device: &B::Device, + ) -> Self { + let betas: Vec = match config.beta_schedule { + BetaSchedule::ScaledLinear => { + let start = config.beta_start.sqrt(); + let end = config.beta_end.sqrt(); + let step = (end - start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| { + let v = start + step * i as f64; + v * v + }) + .collect() + } + BetaSchedule::Linear => { + let step = + (config.beta_end - config.beta_start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| config.beta_start + step * i as f64) + .collect() + } + BetaSchedule::SquaredcosCapV2 => { + let betas_tensor: Tensor = + betas_for_alpha_bar(config.train_timesteps, 0.999, device); + let data = betas_tensor.into_data(); + data.to_vec::() + .unwrap() + .into_iter() + .map(|x| x as f64) + .collect() + } + }; + + // alphas = 1 - betas + let alphas: Vec = betas.iter().map(|b| 1.0 - b).collect(); + + // alphas_cumprod = cumprod(alphas) + let mut alphas_cumprod: Vec = Vec::with_capacity(config.train_timesteps); + let mut cumprod = 1.0; + for alpha in &alphas { + cumprod *= alpha; + alphas_cumprod.push(cumprod); + } + + // min(train_timesteps, inference_steps) + let inference_steps = inference_steps.min(config.train_timesteps); + + // Timesteps: arange(0, inference_steps) * step_ratio, reversed + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec = (0..inference_steps).map(|s| s * step_ratio).rev().collect(); + + Self { + alphas_cumprod, + init_noise_sigma: 1.0, + timesteps, + step_ratio, + config, + } + } + + /// Compute the variance for a given timestep. + fn get_variance(&self, timestep: usize) -> f64 { + let prev_t = timestep as isize - self.step_ratio as isize; + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let current_beta_t = 1.0 - alpha_prod_t / alpha_prod_t_prev; + + // For t > 0, compute predicted variance βt (see formula (6) and (7) from paper) + let variance = (1.0 - alpha_prod_t_prev) / (1.0 - alpha_prod_t) * current_beta_t; + + // Retrieve variance based on type + match self.config.variance_type { + DDPMVarianceType::FixedSmall => variance.max(1e-20), + DDPMVarianceType::FixedSmallLog => { + let variance = variance.max(1e-20).ln(); + (variance * 0.5).exp() + } + DDPMVarianceType::FixedLarge => current_beta_t, + DDPMVarianceType::FixedLargeLog => current_beta_t.ln(), + DDPMVarianceType::Learned => variance, + } + } + + /// Get the timesteps for the scheduler. + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Get the initial noise sigma value. + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } + + /// Scale the model input (identity for DDPM). + pub fn scale_model_input( + &self, + sample: Tensor, + _timestep: usize, + ) -> Tensor { + sample + } + + /// Perform one step of the DDPM. + pub fn step( + &self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor { + let prev_t = timestep as isize - self.step_ratio as isize; + + // 1. Compute alphas, betas + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let beta_prod_t = 1.0 - alpha_prod_t; + let beta_prod_t_prev = 1.0 - alpha_prod_t_prev; + let current_alpha_t = alpha_prod_t / alpha_prod_t_prev; + let current_beta_t = 1.0 - current_alpha_t; + + // 2. Compute predicted original sample from predicted noise (formula (15)) + let mut pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => { + (sample.clone() - model_output.clone() * beta_prod_t.sqrt()) / alpha_prod_t.sqrt() + } + PredictionType::Sample => model_output.clone(), + PredictionType::VPrediction => { + sample.clone() * alpha_prod_t.sqrt() - model_output.clone() * beta_prod_t.sqrt() + } + }; + + // 3. Clip predicted x_0 + if self.config.clip_sample { + pred_original_sample = pred_original_sample.clamp(-1.0, 1.0); + } + + // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + // See formula (7) from paper + let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t; + let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t; + + // 5. Compute predicted previous sample µ_t (formula (7)) + let pred_prev_sample = pred_original_sample * pred_original_sample_coeff + + sample.clone() * current_sample_coeff; + + // 6. Add noise + if timestep > 0 { + let device = sample.device(); + let variance_noise: Tensor = Tensor::random( + model_output.shape(), + Distribution::Normal(0.0, 1.0), + &device, + ); + + let variance = if self.config.variance_type == DDPMVarianceType::FixedSmallLog { + self.get_variance(timestep) * variance_noise + } else { + self.get_variance(timestep).sqrt() * variance_noise + }; + pred_prev_sample + variance + } else { + pred_prev_sample + } + } + + /// Add noise to original samples. + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Tensor { + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + + original_samples.clone() * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::tensor::Shape; + + #[test] + fn test_ddpm_scheduler_creation() { + let device = Default::default(); + let config = DDPMSchedulerConfig::default(); + let scheduler = DDPMScheduler::new::(20, config, &device); + + assert_eq!(scheduler.timesteps().len(), 20); + assert_eq!(scheduler.init_noise_sigma(), 1.0); + } + + #[test] + fn test_ddpm_timesteps() { + let device = Default::default(); + let config = DDPMSchedulerConfig::default(); + let scheduler = DDPMScheduler::new::(20, config, &device); + + let timesteps = scheduler.timesteps(); + // First timestep should be high, last should be 0 + assert!(timesteps[0] > timesteps[timesteps.len() - 1]); + assert_eq!(timesteps[timesteps.len() - 1], 0); + + // Should be monotonically decreasing + for i in 1..timesteps.len() { + assert!(timesteps[i] < timesteps[i - 1]); + } + } + + #[test] + fn test_ddpm_scale_model_input() { + let device = Default::default(); + let config = DDPMSchedulerConfig::default(); + let scheduler = DDPMScheduler::new::(20, config, &device); + + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + // DDPM doesn't scale input + let scaled = scheduler.scale_model_input(sample.clone(), timestep); + let diff: f32 = (scaled - sample).abs().mean().into_scalar(); + assert!(diff < 1e-6); + } + + #[test] + fn test_ddpm_step() { + let device = Default::default(); + let config = DDPMSchedulerConfig::default(); + let scheduler = DDPMScheduler::new::(20, config, &device); + + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + + // Step at timestep 0 (no noise added) + let result = scheduler.step(&model_output, 0, &sample); + assert_eq!(result.shape(), Shape::from([1, 4, 8, 8])); + + // Result should be finite + let result_data = result.into_data(); + let values: Vec = result_data.to_vec().unwrap(); + for v in &values { + assert!(v.is_finite(), "Result contains non-finite values"); + } + } + + /// Test DDPM scheduler values match diffusers-rs + #[test] + fn test_ddpm_matches_diffusers_rs() { + let device = Default::default(); + let config = DDPMSchedulerConfig::default(); + let scheduler = DDPMScheduler::new::(20, config, &device); + + // Reference values from diffusers-rs + // Timesteps: [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0] + let expected_timesteps = [ + 950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, + 100, 50, 0, + ]; + + let timesteps = scheduler.timesteps(); + assert_eq!(timesteps.len(), expected_timesteps.len()); + for (i, (actual, expected)) in timesteps.iter().zip(expected_timesteps.iter()).enumerate() { + assert_eq!( + *actual, *expected, + "Timestep mismatch at {}: actual={}, expected={}", + i, actual, expected + ); + } + + // Check init_noise_sigma (reference: 1.0) + assert_eq!(scheduler.init_noise_sigma(), 1.0); + + // Check scale_model_input (reference mean: 1.0 - identity) + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + let scaled = scheduler.scale_model_input(sample, timestep); + let scaled_mean: f32 = scaled.mean().into_scalar(); + assert!( + (scaled_mean - 1.0).abs() < 1e-6, + "scale_model_input mean mismatch: actual={}, expected=1.0", + scaled_mean + ); + + // Check step at timestep 0 (reference mean: 1.0004253387451172) + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let result = scheduler.step(&model_output, 0, &sample); + let result_mean: f32 = result.mean().into_scalar(); + assert!( + (result_mean as f64 - 1.0004253387451172).abs() < 1e-4, + "step mean mismatch: actual={}, expected=1.0004253387451172", + result_mean + ); + + // Check add_noise (reference mean: 1.0866814851760864 at timestep 950) + let original: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noise: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noisy = scheduler.add_noise(&original, noise, 950); + let noisy_mean: f32 = noisy.mean().into_scalar(); + assert!( + (noisy_mean as f64 - 1.0866814851760864).abs() < 1e-4, + "add_noise mean mismatch: actual={}, expected=1.0866814851760864", + noisy_mean + ); + } +} diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs new file mode 100644 index 0000000..5cb7066 --- /dev/null +++ b/src/schedulers/dpmsolver_multistep.rs @@ -0,0 +1,701 @@ +//! DPM-Solver++ Multistep Scheduler +//! +//! DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver +//! for diffusion ODEs with the convergence order guarantee. +//! +//! Based on: +//! - DPM-Solver: +//! - DPM-Solver++: + +use alloc::vec; +use alloc::vec::Vec; +use burn::tensor::{backend::Backend, Tensor}; + +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +/// The algorithm type for the solver. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +pub enum DPMSolverAlgorithmType { + /// Implements the algorithms defined in . + #[default] + DPMSolverPlusPlus, + /// Implements the algorithms defined in . + DPMSolver, +} + +/// The solver type for the second-order solver. +/// The solver type slightly affects the sample quality, especially for +/// small number of steps. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +pub enum DPMSolverType { + #[default] + Midpoint, + Heun, +} + +/// Configuration for the DPM-Solver++ Multistep Scheduler. +#[derive(Debug, Clone)] +pub struct DPMSolverMultistepSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Number of diffusion steps used to train the model. + pub train_timesteps: usize, + /// The order of DPM-Solver; can be 1, 2, or 3. We recommend solver_order=2 for guided + /// sampling, and solver_order=3 for unconditional sampling. + pub solver_order: usize, + /// Prediction type of the scheduler function. + pub prediction_type: PredictionType, + /// Whether to use the "dynamic thresholding" method (introduced by Imagen). + /// For pixel-space diffusion models, you can set both `algorithm_type=DPMSolverPlusPlus` + /// and `thresholding=true` to use dynamic thresholding. Note that thresholding is + /// unsuitable for latent-space diffusion models (such as stable-diffusion). + pub thresholding: bool, + /// The ratio for the dynamic thresholding method. Default is 0.995, same as Imagen. + pub dynamic_thresholding_ratio: f64, + /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` + /// and `algorithm_type: DPMSolverPlusPlus`. + pub sample_max_value: f64, + /// The algorithm type for the solver. + pub algorithm_type: DPMSolverAlgorithmType, + /// The solver type for the second-order solver. + pub solver_type: DPMSolverType, + /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference + /// steps. This can stabilize the sampling of DPM-Solver for steps < 15. + pub lower_order_final: bool, +} + +impl Default for DPMSolverMultistepSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + train_timesteps: 1000, + solver_order: 2, + prediction_type: PredictionType::Epsilon, + thresholding: false, + dynamic_thresholding_ratio: 0.995, + sample_max_value: 1.0, + algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus, + solver_type: DPMSolverType::Midpoint, + lower_order_final: true, + } + } +} + +/// DPM-Solver++ Multistep Scheduler for diffusion models. +/// +/// This scheduler implements DPM-Solver and DPM-Solver++ algorithms for fast +/// sampling in diffusion models. +pub struct DPMSolverMultistepScheduler { + alphas_cumprod: Vec, + alpha_t: Vec, + sigma_t: Vec, + lambda_t: Vec, + init_noise_sigma: f64, + lower_order_nums: usize, + model_outputs: Vec>>, + timesteps: Vec, + /// The scheduler configuration. + pub config: DPMSolverMultistepSchedulerConfig, +} + +impl DPMSolverMultistepScheduler { + /// Create a new DPM-Solver++ Multistep Scheduler. + /// + /// # Arguments + /// * `inference_steps` - Number of inference steps + /// * `config` - Scheduler configuration + /// * `device` - The device to create tensors on + pub fn new( + inference_steps: usize, + config: DPMSolverMultistepSchedulerConfig, + device: &B::Device, + ) -> Self { + let betas: Vec = match config.beta_schedule { + BetaSchedule::ScaledLinear => { + let start = config.beta_start.sqrt(); + let end = config.beta_end.sqrt(); + let step = (end - start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| { + let v = start + step * i as f64; + v * v + }) + .collect() + } + BetaSchedule::Linear => { + let step = + (config.beta_end - config.beta_start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| config.beta_start + step * i as f64) + .collect() + } + BetaSchedule::SquaredcosCapV2 => { + let betas_tensor: Tensor = + betas_for_alpha_bar(config.train_timesteps, 0.999, device); + let data = betas_tensor.into_data(); + data.to_vec::() + .unwrap() + .into_iter() + .map(|x| x as f64) + .collect() + } + }; + + // alphas = 1 - betas + let alphas: Vec = betas.iter().map(|b| 1.0 - b).collect(); + + // alphas_cumprod = cumprod(alphas) + let mut alphas_cumprod: Vec = Vec::with_capacity(config.train_timesteps); + let mut cumprod = 1.0; + for alpha in &alphas { + cumprod *= alpha; + alphas_cumprod.push(cumprod); + } + + // alpha_t = sqrt(alphas_cumprod) + let alpha_t: Vec = alphas_cumprod.iter().map(|&acp| acp.sqrt()).collect(); + + // sigma_t = sqrt(1 - alphas_cumprod) + let sigma_t: Vec = alphas_cumprod + .iter() + .map(|&acp| (1.0 - acp).sqrt()) + .collect(); + + // lambda_t = log(alpha_t) - log(sigma_t) + let lambda_t: Vec = alpha_t + .iter() + .zip(sigma_t.iter()) + .map(|(&a, &s)| a.ln() - s.ln()) + .collect(); + + // timesteps = linspace(train_timesteps - 1, 0, inference_steps + 1), skip first, reverse + let step = (config.train_timesteps - 1) as f64 / inference_steps as f64; + let mut timesteps: Vec = (0..=inference_steps) + .map(|i| (i as f64 * step).round() as usize) + .skip(1) + .collect(); + timesteps.reverse(); + + // Create a vector of solver_order None tensors for model outputs + let model_outputs: Vec>> = vec![None; config.solver_order]; + + Self { + alphas_cumprod, + alpha_t, + sigma_t, + lambda_t, + init_noise_sigma: 1.0, + lower_order_nums: 0, + model_outputs, + timesteps, + config, + } + } + + /// Convert the model output to the corresponding type that the algorithm needs. + /// + /// DPM-Solver is designed to discretize an integral of the noise prediction model, + /// and DPM-Solver++ is designed to discretize an integral of the data prediction model. + fn convert_model_output( + &self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor { + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + let x0_pred = match self.config.prediction_type { + PredictionType::Epsilon => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + // (sample - sigma_t * model_output) / alpha_t + (sample.clone() - model_output.clone() * sigma_t) / alpha_t + } + PredictionType::Sample => model_output.clone(), + PredictionType::VPrediction => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + // alpha_t * sample - sigma_t * model_output + sample.clone() * alpha_t - model_output.clone() * sigma_t + } + }; + + // Note: thresholding is not implemented for burn tensors + // as it requires quantile operations not available in burn + if self.config.thresholding { + // For now, just return x0_pred without thresholding + // In a full implementation, you would need to add quantile support + x0_pred + } else { + x0_pred + } + } + DPMSolverAlgorithmType::DPMSolver => match self.config.prediction_type { + PredictionType::Epsilon => model_output.clone(), + PredictionType::Sample => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + // (sample - alpha_t * model_output) / sigma_t + (sample.clone() - model_output.clone() * alpha_t) / sigma_t + } + PredictionType::VPrediction => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + // alpha_t * model_output + sigma_t * sample + model_output.clone() * alpha_t + sample.clone() * sigma_t + } + }, + } + } + + /// One step for the first-order DPM-Solver (equivalent to DDIM). + fn dpm_solver_first_order_update( + &self, + model_output: Tensor, + timestep: usize, + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor { + let (lambda_t, lambda_s) = (self.lambda_t[prev_timestep], self.lambda_t[timestep]); + let (alpha_t, _alpha_s) = (self.alpha_t[prev_timestep], self.alpha_t[timestep]); + let (sigma_t, sigma_s) = (self.sigma_t[prev_timestep], self.sigma_t[timestep]); + let h = lambda_t - lambda_s; + + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + // (sigma_t / sigma_s) * sample - (alpha_t * (exp(-h) - 1)) * model_output + sample.clone() * (sigma_t / sigma_s) - model_output * (alpha_t * ((-h).exp() - 1.0)) + } + DPMSolverAlgorithmType::DPMSolver => { + let alpha_s = self.alpha_t[timestep]; + // (alpha_t / alpha_s) * sample - (sigma_t * (exp(h) - 1)) * model_output + sample.clone() * (alpha_t / alpha_s) - model_output * (sigma_t * (h.exp() - 1.0)) + } + } + } + + /// One step for the second-order multistep DPM-Solver. + fn multistep_dpm_solver_second_order_update( + &self, + model_output_list: &[Option>], + timestep_list: [usize; 2], + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor { + let (t, s0, s1) = ( + prev_timestep, + timestep_list[timestep_list.len() - 1], + timestep_list[timestep_list.len() - 2], + ); + + let m0 = model_output_list[model_output_list.len() - 1] + .as_ref() + .unwrap(); + let m1 = model_output_list[model_output_list.len() - 2] + .as_ref() + .unwrap(); + + let (lambda_t, lambda_s0, lambda_s1) = + (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]); + let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); + let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); + let (h, h_0) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1); + let r0 = h_0 / h; + let d0 = m0; + let d1 = (m0.clone() - m1.clone()) * (1.0 / r0); + + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => match self.config.solver_type { + DPMSolverType::Midpoint => { + // (sigma_t / sigma_s0) * sample + // - (alpha_t * (exp(-h) - 1)) * d0 + // - 0.5 * (alpha_t * (exp(-h) - 1)) * d1 + let coeff = alpha_t * ((-h).exp() - 1.0); + sample.clone() * (sigma_t / sigma_s0) - d0.clone() * coeff - d1 * (0.5 * coeff) + } + DPMSolverType::Heun => { + // (sigma_t / sigma_s0) * sample + // - (alpha_t * (exp(-h) - 1)) * d0 + // + (alpha_t * ((exp(-h) - 1) / h + 1)) * d1 + let exp_neg_h = (-h).exp(); + sample.clone() * (sigma_t / sigma_s0) + - d0.clone() * (alpha_t * (exp_neg_h - 1.0)) + + d1 * (alpha_t * ((exp_neg_h - 1.0) / h + 1.0)) + } + }, + DPMSolverAlgorithmType::DPMSolver => match self.config.solver_type { + DPMSolverType::Midpoint => { + let coeff = sigma_t * (h.exp() - 1.0); + sample.clone() * (alpha_t / alpha_s0) - d0.clone() * coeff - d1 * (0.5 * coeff) + } + DPMSolverType::Heun => { + let exp_h = h.exp(); + sample.clone() * (alpha_t / alpha_s0) + - d0.clone() * (sigma_t * (exp_h - 1.0)) + - d1 * (sigma_t * ((exp_h - 1.0) / h - 1.0)) + } + }, + } + } + + /// One step for the third-order multistep DPM-Solver. + fn multistep_dpm_solver_third_order_update( + &self, + model_output_list: &[Option>], + timestep_list: [usize; 3], + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor { + let (t, s0, s1, s2) = ( + prev_timestep, + timestep_list[timestep_list.len() - 1], + timestep_list[timestep_list.len() - 2], + timestep_list[timestep_list.len() - 3], + ); + + let m0 = model_output_list[model_output_list.len() - 1] + .as_ref() + .unwrap(); + let m1 = model_output_list[model_output_list.len() - 2] + .as_ref() + .unwrap(); + let m2 = model_output_list[model_output_list.len() - 3] + .as_ref() + .unwrap(); + + let (lambda_t, lambda_s0, lambda_s1, lambda_s2) = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ); + let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); + let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); + let (h, h_0, h_1) = ( + lambda_t - lambda_s0, + lambda_s0 - lambda_s1, + lambda_s1 - lambda_s2, + ); + let (r0, r1) = (h_0 / h, h_1 / h); + + let d0 = m0; + let d1_0 = (m0.clone() - m1.clone()) * (1.0 / r0); + let d1_1 = (m1.clone() - m2.clone()) * (1.0 / r1); + let d1 = d1_0.clone() + (d1_0.clone() - d1_1) * (r0 / (r0 + r1)); + let d2 = (d1_0.clone() - (m1.clone() - m2.clone()) * (1.0 / r1)) * (1.0 / (r0 + r1)); + + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + let exp_neg_h = (-h).exp(); + sample.clone() * (sigma_t / sigma_s0) - d0.clone() * (alpha_t * (exp_neg_h - 1.0)) + + d1 * (alpha_t * ((exp_neg_h - 1.0) / h + 1.0)) + - d2 * (alpha_t * ((exp_neg_h - 1.0 + h) / h.powi(2) - 0.5)) + } + DPMSolverAlgorithmType::DPMSolver => { + let exp_h = h.exp(); + sample.clone() * (alpha_t / alpha_s0) + - d0.clone() * (sigma_t * (exp_h - 1.0)) + - d1 * (sigma_t * ((exp_h - 1.0) / h - 1.0)) + - d2 * (sigma_t * ((exp_h - 1.0 - h) / h.powi(2) - 0.5)) + } + } + } + + /// Get the timesteps for the scheduler. + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Get the initial noise sigma value. + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } + + /// Scale the model input (no scaling needed for DPM-Solver). + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + /// Perform one step of the DPM-Solver. + pub fn step( + &mut self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor { + let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); + + let prev_timestep = if step_index == self.timesteps.len() - 1 { + 0 + } else { + self.timesteps[step_index + 1] + }; + + let lower_order_final = (step_index == self.timesteps.len() - 1) + && self.config.lower_order_final + && self.timesteps.len() < 15; + let lower_order_second = (step_index == self.timesteps.len() - 2) + && self.config.lower_order_final + && self.timesteps.len() < 15; + + let model_output = self.convert_model_output(model_output, timestep, sample); + + // Shift model outputs + for i in 0..self.config.solver_order - 1 { + self.model_outputs[i] = self.model_outputs[i + 1].take(); + } + // Store the latest model output + let m = self.model_outputs.len(); + self.model_outputs[m - 1] = Some(model_output.clone()); + + let prev_sample = if self.config.solver_order == 1 + || self.lower_order_nums < 1 + || lower_order_final + { + self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) + } else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second { + let timestep_list = [self.timesteps[step_index - 1], timestep]; + self.multistep_dpm_solver_second_order_update( + &self.model_outputs, + timestep_list, + prev_timestep, + sample, + ) + } else { + let timestep_list = [ + self.timesteps[step_index - 2], + self.timesteps[step_index - 1], + timestep, + ]; + self.multistep_dpm_solver_third_order_update( + &self.model_outputs, + timestep_list, + prev_timestep, + sample, + ) + }; + + if self.lower_order_nums < self.config.solver_order { + self.lower_order_nums += 1; + } + + prev_sample + } + + /// Add noise to original samples. + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Tensor { + let sqrt_alpha_cumprod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_cumprod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + + original_samples.clone() * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::tensor::Shape; + + #[test] + fn test_dpmsolver_multistep_scheduler_creation() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + assert_eq!(scheduler.timesteps().len(), 20); + assert_eq!(scheduler.init_noise_sigma(), 1.0); + } + + #[test] + fn test_dpmsolver_timesteps() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + let timesteps = scheduler.timesteps(); + // First timestep should be close to train_timesteps - 1 + assert!(timesteps[0] > 900); + // Last timestep should be small + assert!(timesteps[timesteps.len() - 1] < 100); + + // Should be monotonically decreasing + for i in 1..timesteps.len() { + assert!(timesteps[i] < timesteps[i - 1]); + } + } + + #[test] + fn test_dpmsolver_scale_model_input() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + // DPM-Solver doesn't scale input + let scaled = scheduler.scale_model_input(sample.clone(), timestep); + let diff: f32 = (scaled - sample).abs().mean().into_scalar(); + assert!(diff < 1e-6); + } + + #[test] + fn test_dpmsolver_step() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let mut scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + let result = scheduler.step(&model_output, timestep, &sample); + assert_eq!(result.shape(), Shape::from([1, 4, 8, 8])); + } + + #[test] + fn test_dpmsolver_add_noise() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + let original: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noise: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + let noisy = scheduler.add_noise(&original, noise, timestep); + assert_eq!(noisy.shape(), Shape::from([1, 4, 8, 8])); + } + + #[test] + fn test_dpmsolver_multiple_steps() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig { + solver_order: 2, + ..Default::default() + }; + let mut scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + let mut sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + + // Run multiple steps to test the multistep logic + for i in 0..5 { + let timestep = scheduler.timesteps()[i]; + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + sample = scheduler.step(&model_output, timestep, &sample); + } + + // After multiple steps with zero model output, sample should still be finite + let sample_data = sample.into_data(); + let values: Vec = sample_data.to_vec().unwrap(); + for v in &values { + assert!(v.is_finite(), "Sample contains non-finite values"); + } + } + + /// Test DPM-Solver++ values match diffusers-rs + #[test] + fn test_dpmsolver_matches_diffusers_rs() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + // Reference values from diffusers-rs + // step = 999/20 = 49.95 + // Timesteps computed as: (i * step).round() for i in 1..=20, then reversed + let expected_timesteps = [ + 999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 500, 450, 400, 350, 300, 250, 200, + 150, 100, 50, + ]; + + let timesteps = scheduler.timesteps(); + assert_eq!(timesteps.len(), expected_timesteps.len()); + for (i, (actual, expected)) in timesteps.iter().zip(expected_timesteps.iter()).enumerate() { + assert_eq!( + *actual, *expected, + "Timestep mismatch at {}: actual={}, expected={}", + i, actual, expected + ); + } + + // Check init_noise_sigma + assert_eq!(scheduler.init_noise_sigma(), 1.0); + + // Check alphas_cumprod at key positions + // Reference: diffusers-rs alphas_cumprod[0] ≈ 0.9991 + assert!( + (scheduler.alphas_cumprod[0] - 0.9991499800572107).abs() < 1e-4, + "alphas_cumprod[0] mismatch: {}", + scheduler.alphas_cumprod[0] + ); + + // Reference: diffusers-rs alphas_cumprod[999] ≈ 0.0047 + assert!( + (scheduler.alphas_cumprod[999] - 0.004660095977824908).abs() < 1e-4, + "alphas_cumprod[999] mismatch: {}", + scheduler.alphas_cumprod[999] + ); + } + + /// Test DPM-Solver++ step matches diffusers-rs reference values + #[test] + fn test_dpmsolver_step_matches_diffusers_rs() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let mut scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + // Use simple known inputs + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + + // First step at timestep 999 + let timestep = scheduler.timesteps()[0]; + let result = scheduler.step(&model_output, timestep, &sample); + let result_mean: f32 = result.mean().into_scalar(); + + // Reference from diffusers-rs: step(zeros, timestep=999, ones): mean=1.3377922773361206 + assert!( + (result_mean as f64 - 1.3377922773361206).abs() < 1e-4, + "Step mean mismatch: actual={}, expected=1.3377922773361206", + result_mean + ); + } + + /// Test DPM-Solver++ add_noise matches diffusers-rs reference values + #[test] + fn test_dpmsolver_add_noise_matches_diffusers_rs() { + let device = Default::default(); + let config = DPMSolverMultistepSchedulerConfig::default(); + let scheduler = DPMSolverMultistepScheduler::::new(20, config, &device); + + let original: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noise: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; // 999 + + let noisy = scheduler.add_noise(&original, noise, timestep); + let noisy_mean: f32 = noisy.mean().into_scalar(); + + // Reference from diffusers-rs: add_noise(ones, ones, timestep=999): mean=1.0659321546554565 + assert!( + (noisy_mean as f64 - 1.0659321546554565).abs() < 1e-4, + "add_noise mean mismatch: actual={}, expected=1.0659321546554565", + noisy_mean + ); + } +} diff --git a/src/schedulers/euler_ancestral_discrete.rs b/src/schedulers/euler_ancestral_discrete.rs new file mode 100644 index 0000000..e8bcab9 --- /dev/null +++ b/src/schedulers/euler_ancestral_discrete.rs @@ -0,0 +1,432 @@ +//! Euler Ancestral Discrete Scheduler +//! +//! Ancestral sampling with Euler method steps. +//! Based on the original k-diffusion implementation by Katherine Crowson: +//! https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + +use alloc::vec; +use alloc::vec::Vec; +use burn::tensor::{backend::Backend, Distribution, Tensor}; + +use super::{BetaSchedule, PredictionType}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +/// Configuration for the Euler Ancestral Discrete Scheduler. +#[derive(Debug, Clone)] +pub struct EulerAncestralDiscreteSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Number of diffusion steps used to train the model. + pub train_timesteps: usize, + /// Prediction type of the scheduler function. + pub prediction_type: PredictionType, +} + +impl Default for EulerAncestralDiscreteSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + train_timesteps: 1000, + prediction_type: PredictionType::Epsilon, + } + } +} + +/// Euler Ancestral Discrete Scheduler for diffusion models. +/// +/// This scheduler implements ancestral sampling with Euler method steps, +/// adding noise at each step for stochastic sampling. +#[derive(Debug, Clone)] +pub struct EulerAncestralDiscreteScheduler { + timesteps: Vec, + sigmas: Vec, + init_noise_sigma: f64, + /// The scheduler configuration. + pub config: EulerAncestralDiscreteSchedulerConfig, +} + +impl EulerAncestralDiscreteScheduler { + /// Create a new Euler Ancestral Discrete Scheduler. + /// + /// # Arguments + /// * `inference_steps` - Number of inference steps + /// * `config` - Scheduler configuration + pub fn new(inference_steps: usize, config: EulerAncestralDiscreteSchedulerConfig) -> Self { + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => { + let start = config.beta_start.sqrt(); + let end = config.beta_end.sqrt(); + let step = (end - start) / (config.train_timesteps - 1) as f64; + let betas: Vec = (0..config.train_timesteps) + .map(|i| { + let v = start + step * i as f64; + v * v + }) + .collect(); + betas + } + BetaSchedule::Linear => { + let step = + (config.beta_end - config.beta_start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| config.beta_start + step * i as f64) + .collect() + } + BetaSchedule::SquaredcosCapV2 => { + unimplemented!( + "EulerAncestralDiscreteScheduler only implements linear and scaled_linear betas." + ) + } + }; + + // alphas = 1 - betas + let alphas: Vec = betas.iter().map(|b| 1.0 - b).collect(); + + // alphas_cumprod = cumprod(alphas) + let mut alphas_cumprod: Vec = Vec::with_capacity(config.train_timesteps); + let mut cumprod = 1.0; + for alpha in &alphas { + cumprod *= alpha; + alphas_cumprod.push(cumprod); + } + + // timesteps = linspace(train_timesteps - 1, 0, inference_steps) + let timesteps: Vec = + linspace((config.train_timesteps - 1) as f64, 0.0, inference_steps); + + // sigmas = sqrt((1 - alphas_cumprod) / alphas_cumprod) + let sigmas_full: Vec = alphas_cumprod + .iter() + .map(|&acp| ((1.0 - acp) / acp).sqrt()) + .collect(); + + // Interpolate sigmas at timestep positions + let xp: Vec = (0..sigmas_full.len()).map(|i| i as f64).collect(); + let sigmas = interp(×teps, &xp, &sigmas_full); + + // Append 0.0 to sigmas + let mut sigmas = sigmas; + sigmas.push(0.0); + + // init_noise_sigma = max(sigmas) + let init_noise_sigma = sigmas.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + + Self { + timesteps, + sigmas, + init_noise_sigma, + config, + } + } + + /// Get the timesteps for the scheduler. + pub fn timesteps(&self) -> &[f64] { + self.timesteps.as_slice() + } + + /// Get the initial noise sigma value. + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } + + /// Scale the model input by the appropriate sigma value. + /// + /// # Arguments + /// * `sample` - The input sample tensor + /// * `timestep` - The current timestep + pub fn scale_model_input( + &self, + sample: Tensor, + timestep: f64, + ) -> Tensor { + let step_index = self + .timesteps + .iter() + .position(|&t| t == timestep) + .expect("Timestep not found in scheduler timesteps"); + let sigma = self.sigmas[step_index]; + + // sample / sqrt(sigma^2 + 1) + let scale = (sigma.powi(2) + 1.0).sqrt(); + sample / scale + } + + /// Perform one step of the Euler Ancestral method. + /// + /// # Arguments + /// * `model_output` - The model's predicted noise + /// * `timestep` - The current timestep + /// * `sample` - The current noisy sample + pub fn step( + &self, + model_output: &Tensor, + timestep: f64, + sample: &Tensor, + ) -> Tensor { + let step_index = self + .timesteps + .iter() + .position(|&t| t == timestep) + .expect("Timestep not found in scheduler timesteps"); + let sigma = self.sigmas[step_index]; + + // 1. Compute predicted original sample (x_0) from sigma-scaled predicted noise + let pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => sample.clone() - model_output.clone() * sigma, + PredictionType::VPrediction => { + let sigma_sq_plus_1 = sigma.powi(2) + 1.0; + model_output.clone() * (-sigma / sigma_sq_plus_1.sqrt()) + + sample.clone() / sigma_sq_plus_1 + } + PredictionType::Sample => { + unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`") + } + }; + + let sigma_from = self.sigmas[step_index]; + let sigma_to = self.sigmas[step_index + 1]; + + // Compute sigma_up and sigma_down for ancestral sampling + let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2)) + / sigma_from.powi(2)) + .sqrt(); + let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt(); + + // 2. Convert to an ODE derivative + let derivative = (sample.clone() - pred_original_sample) / sigma; + let dt = sigma_down - sigma; + + let prev_sample = sample.clone() + derivative * dt; + + // Add noise for ancestral sampling + let device = sample.device(); + let noise: Tensor = Tensor::random( + model_output.shape(), + Distribution::Normal(0.0, 1.0), + &device, + ); + + prev_sample + noise * sigma_up + } + + /// Add noise to original samples. + /// + /// # Arguments + /// * `original_samples` - The original clean samples + /// * `noise` - The noise to add + /// * `timestep` - The timestep at which to add noise + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: f64, + ) -> Tensor { + let step_index = self + .timesteps + .iter() + .position(|&t| t == timestep) + .expect("Timestep not found in scheduler timesteps"); + let sigma = self.sigmas[step_index]; + + original_samples.clone() + noise * sigma + } +} + +/// Create a linearly spaced vector from start to end with n points. +fn linspace(start: f64, end: f64, n: usize) -> Vec { + if n == 0 { + return vec![]; + } + if n == 1 { + return vec![start]; + } + let step = (end - start) / (n - 1) as f64; + (0..n).map(|i| start + step * i as f64).collect() +} + +/// One-dimensional linear interpolation for monotonically increasing sample points. +/// Mimics numpy's interp() function. +fn interp(x: &[f64], xp: &[f64], yp: &[f64]) -> Vec { + assert_eq!(xp.len(), yp.len()); + let sz = xp.len(); + + // Compute slopes: m = (yp[1:] - yp[:-1]) / (xp[1:] - xp[:-1]) + let m: Vec = (0..sz - 1) + .map(|i| (yp[i + 1] - yp[i]) / (xp[i + 1] - xp[i])) + .collect(); + + // Compute intercepts: b = yp[:-1] - m * xp[:-1] + let b: Vec = (0..sz - 1).map(|i| yp[i] - m[i] * xp[i]).collect(); + + // For each x value, find the appropriate segment and interpolate + x.iter() + .map(|&xi| { + // Find index: sum(x >= xp) - 1, clamped to valid range + let mut idx = 0; + for (i, &xp_val) in xp.iter().enumerate() { + if xi >= xp_val { + idx = i; + } + } + // Clamp to valid range for m and b + let idx = idx.min(m.len() - 1); + m[idx] * xi + b[idx] + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::tensor::Shape; + + #[test] + fn test_euler_ancestral_scheduler_creation() { + let config = EulerAncestralDiscreteSchedulerConfig::default(); + let scheduler = EulerAncestralDiscreteScheduler::new(20, config); + + assert_eq!(scheduler.timesteps().len(), 20); + assert_eq!(scheduler.sigmas.len(), 21); // 20 + 1 (appended 0) + assert!(scheduler.init_noise_sigma() > 0.0); + } + + #[test] + fn test_euler_ancestral_timesteps() { + let config = EulerAncestralDiscreteSchedulerConfig::default(); + let scheduler = EulerAncestralDiscreteScheduler::new(10, config); + + let timesteps = scheduler.timesteps(); + // Should be decreasing from train_timesteps-1 to 0 + assert!((timesteps[0] - 999.0).abs() < 0.1); + assert!((timesteps[9] - 0.0).abs() < 0.1); + + // Should be monotonically decreasing + for i in 1..timesteps.len() { + assert!(timesteps[i] < timesteps[i - 1]); + } + } + + #[test] + fn test_euler_ancestral_scale_model_input() { + let device = Default::default(); + let config = EulerAncestralDiscreteSchedulerConfig::default(); + let scheduler = EulerAncestralDiscreteScheduler::new(20, config); + + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + let scaled = scheduler.scale_model_input(sample, timestep); + assert_eq!(scaled.shape(), Shape::from([1, 4, 8, 8])); + + // Scaled values should be less than original (division by sqrt(sigma^2 + 1)) + let scaled_mean: f32 = scaled.mean().into_scalar(); + assert!(scaled_mean < 1.0); + assert!(scaled_mean > 0.0); + } + + #[test] + fn test_euler_ancestral_step() { + let device = Default::default(); + let config = EulerAncestralDiscreteSchedulerConfig::default(); + let scheduler = EulerAncestralDiscreteScheduler::new(20, config); + + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + let result = scheduler.step(&model_output, timestep, &sample); + assert_eq!(result.shape(), Shape::from([1, 4, 8, 8])); + + // Result should be finite + let result_data = result.into_data(); + let values: Vec = result_data.to_vec().unwrap(); + for v in &values { + assert!(v.is_finite(), "Result contains non-finite values"); + } + } + + /// Test Euler Ancestral scheduler values match diffusers-rs + #[test] + fn test_euler_ancestral_matches_diffusers_rs() { + let device = Default::default(); + let config = EulerAncestralDiscreteSchedulerConfig::default(); + let scheduler = EulerAncestralDiscreteScheduler::new(20, config); + + // Reference values from diffusers-rs + let expected_timesteps = [ + 999.0, + 946.4210205078125, + 893.8421020507813, + 841.26318359375, + 788.6842041015625, + 736.105224609375, + 683.5263061523438, + 630.9473876953125, + 578.368408203125, + 525.7894287109375, + 473.2105407714844, + 420.631591796875, + 368.0526428222656, + 315.47369384765625, + 262.8947448730469, + 210.3157958984375, + 157.73684692382813, + 105.15789794921875, + 52.578948974609375, + 0.0, + ]; + + // Check timesteps + let timesteps = scheduler.timesteps(); + assert_eq!(timesteps.len(), expected_timesteps.len()); + for (i, (actual, expected)) in timesteps.iter().zip(expected_timesteps.iter()).enumerate() { + assert!( + (actual - expected).abs() < 1e-3, + "Timestep mismatch at {}: actual={}, expected={}", + i, + actual, + expected + ); + } + + // Check init_noise_sigma (reference: 14.614646291831562) + let init_sigma = scheduler.init_noise_sigma(); + assert!( + (init_sigma - 14.614646291831562).abs() < 1e-4, + "init_noise_sigma mismatch: actual={}, expected=14.614646291831562", + init_sigma + ); + + // Check scale_model_input (reference mean: 0.06826489418745041) + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + let scaled = scheduler.scale_model_input(sample, timestep); + let scaled_mean: f32 = scaled.mean().into_scalar(); + assert!( + (scaled_mean as f64 - 0.06826489418745041).abs() < 1e-4, + "scale_model_input mean mismatch: actual={}, expected=0.06826489418745041", + scaled_mean + ); + + // Check add_noise (reference mean: 15.614645957946777) + let original: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noise: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noisy = scheduler.add_noise(&original, noise, timestep); + let noisy_mean: f32 = noisy.mean().into_scalar(); + assert!( + (noisy_mean as f64 - 15.614645957946777).abs() < 1e-3, + "add_noise mean mismatch: actual={}, expected=15.614645957946777", + noisy_mean + ); + } +} diff --git a/src/schedulers/euler_discrete.rs b/src/schedulers/euler_discrete.rs new file mode 100644 index 0000000..6d8f792 --- /dev/null +++ b/src/schedulers/euler_discrete.rs @@ -0,0 +1,454 @@ +//! Euler Discrete Scheduler +//! +//! Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. +//! Based on the original k-diffusion implementation by Katherine Crowson: +//! https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 + +use alloc::vec; +use alloc::vec::Vec; +use burn::tensor::{backend::Backend, Distribution, Tensor}; + +use super::{BetaSchedule, PredictionType}; + +/// Configuration for the Euler Discrete Scheduler. +#[derive(Debug, Clone)] +pub struct EulerDiscreteSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Number of diffusion steps used to train the model. + pub train_timesteps: usize, + /// Prediction type of the scheduler function. + pub prediction_type: PredictionType, +} + +impl Default for EulerDiscreteSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + train_timesteps: 1000, + prediction_type: PredictionType::Epsilon, + } + } +} + +/// Euler Discrete Scheduler for diffusion models. +/// +/// This scheduler implements the Euler method for solving the probability flow ODE +/// in diffusion models, as described in Karras et al. (2022). +#[derive(Debug, Clone)] +pub struct EulerDiscreteScheduler { + timesteps: Vec, + sigmas: Vec, + init_noise_sigma: f64, + /// The scheduler configuration. + pub config: EulerDiscreteSchedulerConfig, +} + +impl EulerDiscreteScheduler { + /// Create a new Euler Discrete Scheduler. + /// + /// # Arguments + /// * `inference_steps` - Number of inference steps + /// * `config` - Scheduler configuration + pub fn new(inference_steps: usize, config: EulerDiscreteSchedulerConfig) -> Self { + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => { + // linspace(beta_start.sqrt(), beta_end.sqrt(), train_timesteps).square() + let start = config.beta_start.sqrt(); + let end = config.beta_end.sqrt(); + let step = (end - start) / (config.train_timesteps - 1) as f64; + let betas: Vec = (0..config.train_timesteps) + .map(|i| { + let v = start + step * i as f64; + v * v + }) + .collect(); + betas + } + BetaSchedule::Linear => { + let step = + (config.beta_end - config.beta_start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| config.beta_start + step * i as f64) + .collect() + } + BetaSchedule::SquaredcosCapV2 => { + unimplemented!( + "EulerDiscreteScheduler only implements linear and scaled_linear betas." + ) + } + }; + + // alphas = 1 - betas + let alphas: Vec = betas.iter().map(|b| 1.0 - b).collect(); + + // alphas_cumprod = cumprod(alphas) + let mut alphas_cumprod: Vec = Vec::with_capacity(config.train_timesteps); + let mut cumprod = 1.0; + for alpha in &alphas { + cumprod *= alpha; + alphas_cumprod.push(cumprod); + } + + // timesteps = linspace(train_timesteps - 1, 0, inference_steps) + let timesteps: Vec = + linspace((config.train_timesteps - 1) as f64, 0.0, inference_steps); + + // sigmas = sqrt((1 - alphas_cumprod) / alphas_cumprod) + let sigmas_full: Vec = alphas_cumprod + .iter() + .map(|&acp| ((1.0 - acp) / acp).sqrt()) + .collect(); + + // Interpolate sigmas at timestep positions + let xp: Vec = (0..sigmas_full.len()).map(|i| i as f64).collect(); + let sigmas = interp(×teps, &xp, &sigmas_full); + + // Append 0.0 to sigmas + let mut sigmas = sigmas; + sigmas.push(0.0); + + // init_noise_sigma = max(sigmas) + let init_noise_sigma = sigmas.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + + Self { + timesteps, + sigmas, + init_noise_sigma, + config, + } + } + + /// Get the timesteps for the scheduler. + pub fn timesteps(&self) -> &[f64] { + self.timesteps.as_slice() + } + + /// Get the initial noise sigma value. + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } + + /// Scale the model input by the appropriate sigma value. + /// + /// # Arguments + /// * `sample` - The input sample tensor + /// * `timestep` - The current timestep + pub fn scale_model_input( + &self, + sample: Tensor, + timestep: f64, + ) -> Tensor { + let step_index = self + .timesteps + .iter() + .position(|&t| t == timestep) + .expect("Timestep not found in scheduler timesteps"); + let sigma = self.sigmas[step_index]; + + // sample / sqrt(sigma^2 + 1) + let scale = (sigma.powi(2) + 1.0).sqrt(); + sample / scale + } + + /// Perform one step of the Euler method. + /// + /// # Arguments + /// * `model_output` - The model's predicted noise + /// * `timestep` - The current timestep + /// * `sample` - The current noisy sample + pub fn step( + &self, + model_output: &Tensor, + timestep: f64, + sample: &Tensor, + ) -> Tensor { + // Euler method parameters (no stochasticity by default) + let (s_churn, s_tmin, s_tmax, s_noise) = (0.0, 0.0, f64::INFINITY, 1.0); + + let step_index = self + .timesteps + .iter() + .position(|&t| t == timestep) + .expect("Timestep not found in scheduler timesteps"); + let sigma = self.sigmas[step_index]; + + let gamma = if s_tmin <= sigma && sigma <= s_tmax { + (s_churn / (self.sigmas.len() as f64 - 1.0)).min(2.0_f64.sqrt() - 1.0) + } else { + 0.0 + }; + + let device = sample.device(); + let noise: Tensor = Tensor::random( + model_output.shape(), + Distribution::Normal(0.0, 1.0), + &device, + ); + let eps = noise * s_noise; + let sigma_hat = sigma * (gamma + 1.0); + + let sample = if gamma > 0.0 { + sample.clone() + eps * (sigma_hat.powi(2) - sigma.powi(2)).sqrt() + } else { + sample.clone() + }; + + // 1. Compute predicted original sample (x_0) from sigma-scaled predicted noise + let pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => sample.clone() - model_output.clone() * sigma_hat, + PredictionType::VPrediction => { + let sigma_sq_plus_1 = sigma.powi(2) + 1.0; + model_output.clone() * (-sigma / sigma_sq_plus_1.sqrt()) + + sample.clone() / sigma_sq_plus_1 + } + PredictionType::Sample => { + unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`") + } + }; + + // 2. Convert to an ODE derivative + let derivative = (sample.clone() - pred_original_sample) / sigma_hat; + let dt = self.sigmas[step_index + 1] - sigma_hat; + + // Euler step: sample + derivative * dt + sample + derivative * dt + } + + /// Add noise to original samples. + /// + /// # Arguments + /// * `original_samples` - The original clean samples + /// * `noise` - The noise to add + /// * `timestep` - The timestep at which to add noise + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: f64, + ) -> Tensor { + let step_index = self + .timesteps + .iter() + .position(|&t| t == timestep) + .expect("Timestep not found in scheduler timesteps"); + let sigma = self.sigmas[step_index]; + + original_samples.clone() + noise * sigma + } +} + +/// Create a linearly spaced vector from start to end with n points. +fn linspace(start: f64, end: f64, n: usize) -> Vec { + if n == 0 { + return vec![]; + } + if n == 1 { + return vec![start]; + } + let step = (end - start) / (n - 1) as f64; + (0..n).map(|i| start + step * i as f64).collect() +} + +/// One-dimensional linear interpolation for monotonically increasing sample points. +/// Mimics numpy's interp() function. +/// +/// # Arguments +/// * `x` - x-coordinates at which to evaluate the interpolated values +/// * `xp` - x-coordinates of the data points (must be increasing) +/// * `yp` - y-coordinates of the data points +fn interp(x: &[f64], xp: &[f64], yp: &[f64]) -> Vec { + assert_eq!(xp.len(), yp.len()); + let sz = xp.len(); + + // Compute slopes: m = (yp[1:] - yp[:-1]) / (xp[1:] - xp[:-1]) + let m: Vec = (0..sz - 1) + .map(|i| (yp[i + 1] - yp[i]) / (xp[i + 1] - xp[i])) + .collect(); + + // Compute intercepts: b = yp[:-1] - m * xp[:-1] + let b: Vec = (0..sz - 1).map(|i| yp[i] - m[i] * xp[i]).collect(); + + // For each x value, find the appropriate segment and interpolate + x.iter() + .map(|&xi| { + // Find index: sum(x >= xp) - 1, clamped to valid range + let mut idx = 0; + for (i, &xp_val) in xp.iter().enumerate() { + if xi >= xp_val { + idx = i; + } + } + // Clamp to valid range for m and b + let idx = idx.min(m.len() - 1); + m[idx] * xi + b[idx] + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::tensor::Shape; + + #[test] + fn test_euler_discrete_scheduler_creation() { + let config = EulerDiscreteSchedulerConfig::default(); + let scheduler = EulerDiscreteScheduler::new(20, config); + + assert_eq!(scheduler.timesteps().len(), 20); + assert_eq!(scheduler.sigmas.len(), 21); // 20 + 1 (appended 0) + assert!(scheduler.init_noise_sigma() > 0.0); + } + + #[test] + fn test_euler_discrete_timesteps() { + let config = EulerDiscreteSchedulerConfig::default(); + let scheduler = EulerDiscreteScheduler::new(10, config); + + let timesteps = scheduler.timesteps(); + // Should be decreasing from train_timesteps-1 to 0 + assert!((timesteps[0] - 999.0).abs() < 0.1); + assert!((timesteps[9] - 0.0).abs() < 0.1); + + // Should be monotonically decreasing + for i in 1..timesteps.len() { + assert!(timesteps[i] < timesteps[i - 1]); + } + } + + #[test] + fn test_euler_discrete_scale_model_input() { + let device = Default::default(); + let config = EulerDiscreteSchedulerConfig::default(); + let scheduler = EulerDiscreteScheduler::new(20, config); + + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + let scaled = scheduler.scale_model_input(sample, timestep); + assert_eq!(scaled.shape(), Shape::from([1, 4, 8, 8])); + + // Scaled values should be less than original (division by sqrt(sigma^2 + 1)) + let scaled_mean: f32 = scaled.mean().into_scalar(); + assert!(scaled_mean < 1.0); + assert!(scaled_mean > 0.0); + } + + #[test] + fn test_euler_discrete_step() { + let device = Default::default(); + let config = EulerDiscreteSchedulerConfig::default(); + let scheduler = EulerDiscreteScheduler::new(20, config); + + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + let result = scheduler.step(&model_output, timestep, &sample); + assert_eq!(result.shape(), Shape::from([1, 4, 8, 8])); + } + + #[test] + fn test_linspace() { + let result = linspace(0.0, 10.0, 5); + assert_eq!(result.len(), 5); + assert!((result[0] - 0.0).abs() < 1e-10); + assert!((result[4] - 10.0).abs() < 1e-10); + assert!((result[2] - 5.0).abs() < 1e-10); + } + + #[test] + fn test_interp() { + let xp = vec![0.0, 1.0, 2.0, 3.0]; + let yp = vec![0.0, 2.0, 4.0, 6.0]; + let x = vec![0.5, 1.5, 2.5]; + + let result = interp(&x, &xp, &yp); + assert!((result[0] - 1.0).abs() < 1e-10); + assert!((result[1] - 3.0).abs() < 1e-10); + assert!((result[2] - 5.0).abs() < 1e-10); + } + + /// Test Euler Discrete scheduler values match diffusers-rs + #[test] + fn test_euler_discrete_matches_diffusers_rs() { + let device = Default::default(); + let config = EulerDiscreteSchedulerConfig::default(); + let scheduler = EulerDiscreteScheduler::new(20, config); + + // Reference values from diffusers-rs + let expected_timesteps = [ + 999.0, + 946.4210205078125, + 893.8421020507813, + 841.26318359375, + 788.6842041015625, + 736.105224609375, + 683.5263061523438, + 630.9473876953125, + 578.368408203125, + 525.7894287109375, + 473.2105407714844, + 420.631591796875, + 368.0526428222656, + 315.47369384765625, + 262.8947448730469, + 210.3157958984375, + 157.73684692382813, + 105.15789794921875, + 52.578948974609375, + 0.0, + ]; + + // Check timesteps + let timesteps = scheduler.timesteps(); + assert_eq!(timesteps.len(), expected_timesteps.len()); + for (i, (actual, expected)) in timesteps.iter().zip(expected_timesteps.iter()).enumerate() { + assert!( + (actual - expected).abs() < 1e-3, + "Timestep mismatch at {}: actual={}, expected={}", + i, + actual, + expected + ); + } + + // Check init_noise_sigma (reference: 14.614646291831562) + let init_sigma = scheduler.init_noise_sigma(); + assert!( + (init_sigma - 14.614646291831562).abs() < 1e-4, + "init_noise_sigma mismatch: actual={}, expected=14.614646291831562", + init_sigma + ); + + // Check scale_model_input (reference mean: 0.06826489418745041) + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + let scaled = scheduler.scale_model_input(sample, timestep); + let scaled_mean: f32 = scaled.mean().into_scalar(); + assert!( + (scaled_mean as f64 - 0.06826489418745041).abs() < 1e-4, + "scale_model_input mean mismatch: actual={}, expected=0.06826489418745041", + scaled_mean + ); + + // Check step (reference mean: 1.0 when model_output is zeros) + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let result = scheduler.step(&model_output, timestep, &sample); + let result_mean: f32 = result.mean().into_scalar(); + assert!( + (result_mean - 1.0).abs() < 1e-4, + "step mean mismatch: actual={}, expected=1.0", + result_mean + ); + } +} diff --git a/src/schedulers/mod.rs b/src/schedulers/mod.rs new file mode 100644 index 0000000..6ee8c80 --- /dev/null +++ b/src/schedulers/mod.rs @@ -0,0 +1,76 @@ +//! # Noise Schedulers +//! +//! Noise schedulers can be used to set the trade-off between +//! inference speed and quality. + +use alloc::vec::Vec; +use burn::tensor::{backend::Backend, Tensor}; +use core::f64::consts::FRAC_PI_2; + +pub mod ddim; +pub mod ddpm; +pub mod dpmsolver_multistep; +pub mod euler_ancestral_discrete; +pub mod euler_discrete; +pub mod pndm; + +pub use ddim::{DDIMScheduler, DDIMSchedulerConfig}; +pub use ddpm::{DDPMScheduler, DDPMSchedulerConfig, DDPMVarianceType}; +pub use dpmsolver_multistep::{ + DPMSolverAlgorithmType, DPMSolverMultistepScheduler, DPMSolverMultistepSchedulerConfig, + DPMSolverType, +}; +pub use euler_ancestral_discrete::{ + EulerAncestralDiscreteScheduler, EulerAncestralDiscreteSchedulerConfig, +}; +pub use euler_discrete::{EulerDiscreteScheduler, EulerDiscreteSchedulerConfig}; +pub use pndm::{PNDMScheduler, PNDMSchedulerConfig}; + +/// This represents how beta ranges from its minimum value to the maximum +/// during training. +#[derive(Debug, Clone, Copy)] +pub enum BetaSchedule { + /// Linear interpolation. + Linear, + /// Linear interpolation of the square root of beta. + ScaledLinear, + /// Glide cosine schedule + SquaredcosCapV2, +} + +/// The type of prediction the model makes. +#[derive(Debug, Clone, Copy)] +pub enum PredictionType { + /// Predicting the noise of the diffusion process. + Epsilon, + /// See section 2.4 https://imagen.research.google/video/paper.pdf + VPrediction, + /// Directly predicting the noisy sample. + Sample, +} + +/// Create a beta schedule that discretizes the given alpha_t_bar function, +/// which defines the cumulative product of `(1-beta)` over time from `t = [0,1]`. +/// +/// Contains a function `alpha_bar` that takes an argument `t` and transforms it +/// to the cumulative product of `(1-beta)` up to that part of the diffusion process. +pub fn betas_for_alpha_bar( + num_diffusion_timesteps: usize, + max_beta: f64, + device: &B::Device, +) -> Tensor { + let alpha_bar = |time_step: usize| -> f64 { + let t = (time_step as f64 + 0.008) / 1.008 * FRAC_PI_2; + t.cos().powi(2) + }; + + let mut betas: Vec = Vec::with_capacity(num_diffusion_timesteps); + for i in 0..num_diffusion_timesteps { + let t1 = i / num_diffusion_timesteps; + let t2 = (i + 1) / num_diffusion_timesteps; + let beta = (1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta); + betas.push(beta); + } + + Tensor::from_floats(betas.as_slice(), device) +} diff --git a/src/schedulers/pndm.rs b/src/schedulers/pndm.rs new file mode 100644 index 0000000..4f05288 --- /dev/null +++ b/src/schedulers/pndm.rs @@ -0,0 +1,421 @@ +//! PNDM Scheduler +//! +//! Pseudo numerical methods for diffusion models (PNDM) proposes using more +//! advanced ODE integration techniques, namely Runge-Kutta method and a +//! linear multi-step method. +//! Based on the paper: https://arxiv.org/abs/2202.09778 + +use alloc::vec; +use alloc::vec::Vec; +use burn::tensor::{backend::Backend, Tensor}; + +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +/// Configuration for the PNDM Scheduler. +#[derive(Debug, Clone)] +pub struct PNDMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Each diffusion step uses the value of alphas product at that step and + /// at the previous one. For the final step there is no previous alpha. + /// When this option is `true` the previous alpha product is fixed to `1`, + /// otherwise it uses the value of alpha at step 0. + pub set_alpha_to_one: bool, + /// Prediction type of the scheduler function. + pub prediction_type: PredictionType, + /// An offset added to the inference steps. + pub steps_offset: usize, + /// Number of diffusion steps used to train the model. + pub train_timesteps: usize, +} + +impl Default for PNDMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + set_alpha_to_one: false, + prediction_type: PredictionType::Epsilon, + steps_offset: 1, + train_timesteps: 1000, + } + } +} + +/// PNDM Scheduler for diffusion models. +/// +/// This scheduler implements the PLMS method for fast sampling in diffusion models. +pub struct PNDMScheduler { + alphas_cumprod: Vec, + final_alpha_cumprod: f64, + step_ratio: usize, + init_noise_sigma: f64, + counter: usize, + cur_sample: Option>, + ets: Vec>, + timesteps: Vec, + /// The scheduler configuration. + pub config: PNDMSchedulerConfig, +} + +impl PNDMScheduler { + /// Create a new PNDM Scheduler. + /// + /// # Arguments + /// * `inference_steps` - Number of inference steps + /// * `config` - Scheduler configuration + /// * `device` - The device to create tensors on + pub fn new(inference_steps: usize, config: PNDMSchedulerConfig, device: &B::Device) -> Self { + let betas: Vec = match config.beta_schedule { + BetaSchedule::ScaledLinear => { + let start = config.beta_start.sqrt(); + let end = config.beta_end.sqrt(); + let step = (end - start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| { + let v = start + step * i as f64; + v * v + }) + .collect() + } + BetaSchedule::Linear => { + let step = + (config.beta_end - config.beta_start) / (config.train_timesteps - 1) as f64; + (0..config.train_timesteps) + .map(|i| config.beta_start + step * i as f64) + .collect() + } + BetaSchedule::SquaredcosCapV2 => { + let betas_tensor: Tensor = + betas_for_alpha_bar(config.train_timesteps, 0.999, device); + let data = betas_tensor.into_data(); + data.to_vec::() + .unwrap() + .into_iter() + .map(|x| x as f64) + .collect() + } + }; + + // alphas = 1 - betas + let alphas: Vec = betas.iter().map(|b| 1.0 - b).collect(); + + // alphas_cumprod = cumprod(alphas) + let mut alphas_cumprod: Vec = Vec::with_capacity(config.train_timesteps); + let mut cumprod = 1.0; + for alpha in &alphas { + cumprod *= alpha; + alphas_cumprod.push(cumprod); + } + + let final_alpha_cumprod = if config.set_alpha_to_one { + 1.0 + } else { + alphas_cumprod[0] + }; + + // Create integer timesteps by multiplying by ratio + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec = (0..inference_steps) + .map(|s| s * step_ratio + config.steps_offset) + .collect(); + + // Create PLMS timesteps + // plms_timesteps = [timesteps[:-2], timesteps[-2], timesteps[-2:]] + let n_ts = timesteps.len(); + let mut plms_timesteps = Vec::new(); + // timesteps[:-2] + plms_timesteps.extend_from_slice(×teps[..n_ts - 2]); + // timesteps[-2] (duplicate) + plms_timesteps.push(timesteps[n_ts - 2]); + // timesteps[-2:] + plms_timesteps.extend_from_slice(×teps[n_ts - 2..]); + // Reverse + plms_timesteps.reverse(); + + Self { + alphas_cumprod, + final_alpha_cumprod, + step_ratio, + init_noise_sigma: 1.0, + counter: 0, + cur_sample: None, + ets: vec![], + timesteps: plms_timesteps, + config, + } + } + + /// Get the timesteps for the scheduler. + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Get the initial noise sigma value. + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } + + /// Scale the model input (identity for PNDM). + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + /// Perform one step of the PNDM (using PLMS method). + pub fn step( + &mut self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor { + self.step_plms(model_output, timestep, sample) + } + + /// Step function propagating the sample with the linear multi-step method. + fn step_plms( + &mut self, + model_output: &Tensor, + mut timestep: usize, + sample: &Tensor, + ) -> Tensor { + let mut prev_timestep = timestep as isize - self.step_ratio as isize; + + if self.counter != 1 { + // Make sure ets has at most 4 elements (keep last 3) + if self.ets.len() > 3 { + self.ets.drain(0..self.ets.len() - 3); + } + self.ets.push(model_output.clone()); + } else { + prev_timestep = timestep as isize; + timestep += self.step_ratio; + } + + let n_ets = self.ets.len(); + let (mut model_output, mut sample) = (model_output.clone(), sample.clone()); + + if n_ets == 1 && self.counter == 0 { + self.cur_sample = Some(sample.clone()); + } else if n_ets == 1 && self.counter == 1 { + sample = self.cur_sample.take().unwrap(); + model_output = (model_output + self.ets.last().unwrap().clone()) / 2.0; + } else if n_ets == 2 { + let ets_last = self.ets.last().unwrap(); + model_output = (ets_last.clone() * 3.0 - self.ets[n_ets - 2].clone()) / 2.0; + } else if n_ets == 3 { + let ets_last = self.ets.last().unwrap(); + model_output = (ets_last.clone() * 23.0 - self.ets[n_ets - 2].clone() * 16.0 + + self.ets[n_ets - 3].clone() * 5.0) + / 12.0; + } else { + let ets_last = self.ets.last().unwrap(); + model_output = (ets_last.clone() * 55.0 - self.ets[n_ets - 2].clone() * 59.0 + + self.ets[n_ets - 3].clone() * 37.0 + - self.ets[n_ets - 4].clone() * 9.0) + * (1.0 / 24.0); + } + + let prev_sample = self.get_prev_sample(sample, timestep, prev_timestep, model_output); + self.counter += 1; + + prev_sample + } + + /// Compute the previous sample using the PNDM formula. + fn get_prev_sample( + &self, + sample: Tensor, + timestep: usize, + prev_timestep: isize, + model_output: Tensor, + ) -> Tensor { + // See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_timestep >= 0 { + self.alphas_cumprod[prev_timestep as usize] + } else { + self.final_alpha_cumprod + }; + + let beta_prod_t = 1.0 - alpha_prod_t; + let beta_prod_t_prev = 1.0 - alpha_prod_t_prev; + + let model_output = match self.config.prediction_type { + PredictionType::VPrediction => { + model_output * alpha_prod_t.sqrt() + sample.clone() * beta_prod_t.sqrt() + } + PredictionType::Epsilon => model_output, + PredictionType::Sample => { + unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`") + } + }; + + // Corresponds to (α_(t−δ) - α_t) divided by + // denominator of x_t in formula (9) and plus 1 + let sample_coeff = (alpha_prod_t_prev / alpha_prod_t).sqrt(); + + // Corresponds to denominator of e_θ(x_t, t) in formula (9) + let model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev.sqrt() + + (alpha_prod_t * beta_prod_t * alpha_prod_t_prev).sqrt(); + + // Full formula (9) + sample * sample_coeff + - model_output * (alpha_prod_t_prev - alpha_prod_t) / model_output_denom_coeff + } + + /// Add noise to original samples. + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Tensor { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + + original_samples.clone() * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::tensor::Shape; + + #[test] + fn test_pndm_scheduler_creation() { + let device = Default::default(); + let config = PNDMSchedulerConfig::default(); + let scheduler = PNDMScheduler::::new(20, config, &device); + + // PNDM has 21 timesteps due to the PLMS method + assert_eq!(scheduler.timesteps().len(), 21); + assert_eq!(scheduler.init_noise_sigma(), 1.0); + } + + #[test] + fn test_pndm_timesteps() { + let device = Default::default(); + let config = PNDMSchedulerConfig::default(); + let scheduler = PNDMScheduler::::new(20, config, &device); + + let timesteps = scheduler.timesteps(); + // First timestep should be high + assert!(timesteps[0] > 900); + // Last timestep should be small + assert!(timesteps[timesteps.len() - 1] < 10); + } + + #[test] + fn test_pndm_scale_model_input() { + let device = Default::default(); + let config = PNDMSchedulerConfig::default(); + let scheduler = PNDMScheduler::::new(20, config, &device); + + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + // PNDM doesn't scale input + let scaled = scheduler.scale_model_input(sample.clone(), timestep); + let diff: f32 = (scaled - sample).abs().mean().into_scalar(); + assert!(diff < 1e-6); + } + + #[test] + fn test_pndm_step() { + let device = Default::default(); + let config = PNDMSchedulerConfig::default(); + let mut scheduler = PNDMScheduler::::new(20, config, &device); + + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + + let result = scheduler.step(&model_output, timestep, &sample); + assert_eq!(result.shape(), Shape::from([1, 4, 8, 8])); + + // Result should be finite + let result_data = result.into_data(); + let values: Vec = result_data.to_vec().unwrap(); + for v in &values { + assert!(v.is_finite(), "Result contains non-finite values"); + } + } + + /// Test PNDM scheduler values match diffusers-rs + #[test] + fn test_pndm_matches_diffusers_rs() { + let device = Default::default(); + let config = PNDMSchedulerConfig::default(); + let mut scheduler = PNDMScheduler::::new(20, config, &device); + + // Reference values from diffusers-rs + // Timesteps: [951, 901, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, 151, 101, 51, 1] + let expected_timesteps = [ + 951, 901, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, + 151, 101, 51, 1, + ]; + + let timesteps = scheduler.timesteps(); + assert_eq!(timesteps.len(), expected_timesteps.len()); + for (i, (actual, expected)) in timesteps.iter().zip(expected_timesteps.iter()).enumerate() { + assert_eq!( + *actual, *expected, + "Timestep mismatch at {}: actual={}, expected={}", + i, actual, expected + ); + } + + // Check init_noise_sigma (reference: 1.0) + assert_eq!(scheduler.init_noise_sigma(), 1.0); + + // Check scale_model_input (reference mean: 1.0 - identity) + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let timestep = scheduler.timesteps()[0]; + let scaled = scheduler.scale_model_input(sample, timestep); + let scaled_mean: f32 = scaled.mean().into_scalar(); + assert!( + (scaled_mean - 1.0).abs() < 1e-6, + "scale_model_input mean mismatch: actual={}, expected=1.0", + scaled_mean + ); + + // Check step (reference mean: 1.3104724884033203) + let model_output: Tensor = Tensor::zeros([1, 4, 8, 8], &device); + let sample: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let result = scheduler.step(&model_output, timestep, &sample); + let result_mean: f32 = result.mean().into_scalar(); + assert!( + (result_mean as f64 - 1.3104724884033203).abs() < 1e-4, + "step mean mismatch: actual={}, expected=1.3104724884033203", + result_mean + ); + + // Check add_noise (reference mean: 1.0862191915512085 at timestep 951) + let original: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noise: Tensor = Tensor::ones([1, 4, 8, 8], &device); + let noisy = scheduler.add_noise(&original, noise, 951); + let noisy_mean: f32 = noisy.mean().into_scalar(); + assert!( + (noisy_mean as f64 - 1.0862191915512085).abs() < 1e-4, + "add_noise mean mismatch: actual={}, expected=1.0862191915512085", + noisy_mean + ); + } +} diff --git a/src/transformers/clip.rs b/src/transformers/clip.rs index ef02b89..3332e35 100644 --- a/src/transformers/clip.rs +++ b/src/transformers/clip.rs @@ -145,15 +145,10 @@ impl ClipConfig { let embed_dim = self.embed_dim; let num_attention_heads = self.num_attention_heads; - let k_proj = nn::LinearConfig::new(embed_dim, embed_dim) - .with_bias(false) - .init(device); - let v_proj = nn::LinearConfig::new(embed_dim, embed_dim) - .with_bias(false) - .init(device); - let q_proj = nn::LinearConfig::new(embed_dim, embed_dim) - .with_bias(false) - .init(device); + // CLIP attention layers have bias (unlike UNet cross-attention which doesn't) + let k_proj = nn::LinearConfig::new(embed_dim, embed_dim).init(device); + let v_proj = nn::LinearConfig::new(embed_dim, embed_dim).init(device); + let q_proj = nn::LinearConfig::new(embed_dim, embed_dim).init(device); let out_proj = nn::LinearConfig::new(embed_dim, embed_dim).init(device); let head_dim = embed_dim / num_attention_heads; let scale = (head_dim as f64).powf(-0.5); @@ -261,7 +256,7 @@ impl ClipAttention { .reshape([bsz, self.num_attention_heads, seq_len, src_len]) .add(causal_attention_mask); let attn_weights = attn_weights.reshape([bsz * self.num_attention_heads, seq_len, src_len]); - let attn_weights = softmax(attn_weights, 3); + let attn_weights = softmax(attn_weights, 2); let attn_output = attn_weights.matmul(value_states); let attn_output = attn_output @@ -300,10 +295,10 @@ impl ClipEncoderLayer { let residual = xs; let xs = self.layer_norm1.forward(residual.clone()); let xs = self.self_attn.forward(xs, causal_attention_mask); - let xs2 = xs.clone() + residual; + let xs = xs + residual; - let residual = xs2; - let xs = self.layer_norm2.forward(xs.clone()); + let residual = xs.clone(); + let xs = self.layer_norm2.forward(xs); let xs = self.mlp.forward(xs); xs + residual } @@ -343,7 +338,14 @@ impl ClipTextTransformer { mask.triu(1).unsqueeze_dim(1) } - fn forward(&self, xs: Tensor) -> Tensor { + /// Forward pass through the CLIP text transformer. + /// + /// # Arguments + /// * `xs` - Token IDs [batch_size, seq_len] + /// + /// # Returns + /// Text embeddings [batch_size, seq_len, embed_dim] + pub fn forward(&self, xs: Tensor) -> Tensor { let [bsz, seq_len] = xs.dims(); let xs = self.embeddings.forward(xs); let causal_attention_mask = @@ -419,4 +421,88 @@ mod tests { Tolerance::rel_abs(1e-3, 1e-3), ); } + + /// Test QuickGelu activation matches diffusers-rs + /// QuickGelu: x * sigmoid(1.702 * x) + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_quick_gelu_matches_diffusers_rs() { + let device = Default::default(); + let xs: Tensor = Tensor::from_data( + TensorData::from([-2.0f32, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]), + &device, + ); + + let result = Activation::QuickGelu.forward(xs); + + // Reference values from diffusers-rs: xs * (xs * 1.702).sigmoid() + result.into_data().assert_approx_eq::( + &TensorData::from([ + -0.064341374f32, + -0.15420423, + -0.14961156, + 0.0, + 0.35038844, + 0.84579575, + 1.9356586, + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } + + /// Test GeluErf activation matches diffusers-rs + /// GeluErf: 0.5 * x * (1 + erf(x / sqrt(2))) + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_gelu_erf_matches_diffusers_rs() { + let device = Default::default(); + let xs: Tensor = Tensor::from_data( + TensorData::from([-2.0f32, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]), + &device, + ); + + let result = Activation::GeluErf.forward(xs); + + // Reference values from diffusers-rs: (xs * (xs / sqrt(2)).erf() + 1) / 2 + result.into_data().assert_approx_eq::( + &TensorData::from([ + -0.04550028f32, + -0.15865526, + -0.15426877, + 0.0, + 0.34573123, + 0.8413447, + 1.9544997, + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } + + /// Test Gelu activation matches diffusers-rs (gelu("none")) + /// Reference values from diffusers-rs v0.3.1 + #[test] + fn test_gelu_matches_diffusers_rs() { + let device = Default::default(); + let xs: Tensor = Tensor::from_data( + TensorData::from([-2.0f32, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]), + &device, + ); + + let result = Activation::Gelu.forward(xs); + + // Reference values from diffusers-rs: gelu("none") + // Note: Burn's gelu() uses the same approximation as PyTorch's gelu("none") + result.into_data().assert_approx_eq::( + &TensorData::from([ + -0.04550028f32, + -0.15865526, + -0.15426877, + 0.0, + 0.34573123, + 0.8413447, + 1.9544997, + ]), + Tolerance::rel_abs(1e-4, 1e-4), + ); + } } diff --git a/src/transformers/clip_tokenizer.rs b/src/transformers/clip_tokenizer.rs new file mode 100644 index 0000000..17eb182 --- /dev/null +++ b/src/transformers/clip_tokenizer.rs @@ -0,0 +1,607 @@ +//! CLIP Tokenizer +//! +//! BPE tokenizer for CLIP text encoding. This module is only available +//! with the `std` feature enabled. +//! +//! The tokenizer uses byte-pair encoding (BPE) to convert text into tokens +//! that can be processed by the CLIP text encoder. + +use std::collections::{HashMap, HashSet}; +use std::io::BufRead; +use std::path::Path; + +/// Errors that can occur during tokenization. +#[derive(Debug, thiserror::Error)] +pub enum TokenizerError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Regex error: {0}")] + Regex(#[from] regex::Error), + #[error("Invalid BPE file format: {0}")] + InvalidFormat(String), + #[error("Unknown padding character: {0}")] + UnknownPadChar(String), +} + +/// Mapping from bytes to unicode characters used by CLIP's BPE. +/// This allows representing any byte sequence as valid unicode. +const BYTES_TO_UNICODE: [(u8, char); 256] = [ + (33, '!'), + (34, '"'), + (35, '#'), + (36, '$'), + (37, '%'), + (38, '&'), + (39, '\''), + (40, '('), + (41, ')'), + (42, '*'), + (43, '+'), + (44, ','), + (45, '-'), + (46, '.'), + (47, '/'), + (48, '0'), + (49, '1'), + (50, '2'), + (51, '3'), + (52, '4'), + (53, '5'), + (54, '6'), + (55, '7'), + (56, '8'), + (57, '9'), + (58, ':'), + (59, ';'), + (60, '<'), + (61, '='), + (62, '>'), + (63, '?'), + (64, '@'), + (65, 'A'), + (66, 'B'), + (67, 'C'), + (68, 'D'), + (69, 'E'), + (70, 'F'), + (71, 'G'), + (72, 'H'), + (73, 'I'), + (74, 'J'), + (75, 'K'), + (76, 'L'), + (77, 'M'), + (78, 'N'), + (79, 'O'), + (80, 'P'), + (81, 'Q'), + (82, 'R'), + (83, 'S'), + (84, 'T'), + (85, 'U'), + (86, 'V'), + (87, 'W'), + (88, 'X'), + (89, 'Y'), + (90, 'Z'), + (91, '['), + (92, '\\'), + (93, ']'), + (94, '^'), + (95, '_'), + (96, '`'), + (97, 'a'), + (98, 'b'), + (99, 'c'), + (100, 'd'), + (101, 'e'), + (102, 'f'), + (103, 'g'), + (104, 'h'), + (105, 'i'), + (106, 'j'), + (107, 'k'), + (108, 'l'), + (109, 'm'), + (110, 'n'), + (111, 'o'), + (112, 'p'), + (113, 'q'), + (114, 'r'), + (115, 's'), + (116, 't'), + (117, 'u'), + (118, 'v'), + (119, 'w'), + (120, 'x'), + (121, 'y'), + (122, 'z'), + (123, '{'), + (124, '|'), + (125, '}'), + (126, '~'), + (161, '\u{00A1}'), // Inverted exclamation mark + (162, '\u{00A2}'), // Cent sign + (163, '\u{00A3}'), // Pound sign + (164, '\u{00A4}'), // Currency sign + (165, '\u{00A5}'), // Yen sign + (166, '\u{00A6}'), // Broken bar + (167, '\u{00A7}'), // Section sign + (168, '\u{00A8}'), // Diaeresis + (169, '\u{00A9}'), // Copyright sign + (170, '\u{00AA}'), // Feminine ordinal indicator + (171, '\u{00AB}'), // Left-pointing double angle quotation mark + (172, '\u{00AC}'), // Not sign + (174, '\u{00AE}'), // Registered sign + (175, '\u{00AF}'), // Macron + (176, '\u{00B0}'), // Degree sign + (177, '\u{00B1}'), // Plus-minus sign + (178, '\u{00B2}'), // Superscript two + (179, '\u{00B3}'), // Superscript three + (180, '\u{00B4}'), // Acute accent + (181, '\u{00B5}'), // Micro sign + (182, '\u{00B6}'), // Pilcrow sign + (183, '\u{00B7}'), // Middle dot + (184, '\u{00B8}'), // Cedilla + (185, '\u{00B9}'), // Superscript one + (186, '\u{00BA}'), // Masculine ordinal indicator + (187, '\u{00BB}'), // Right-pointing double angle quotation mark + (188, '\u{00BC}'), // Vulgar fraction one quarter + (189, '\u{00BD}'), // Vulgar fraction one half + (190, '\u{00BE}'), // Vulgar fraction three quarters + (191, '\u{00BF}'), // Inverted question mark + (192, '\u{00C0}'), // Latin capital letter A with grave + (193, '\u{00C1}'), // Latin capital letter A with acute + (194, '\u{00C2}'), // Latin capital letter A with circumflex + (195, '\u{00C3}'), // Latin capital letter A with tilde + (196, '\u{00C4}'), // Latin capital letter A with diaeresis + (197, '\u{00C5}'), // Latin capital letter A with ring above + (198, '\u{00C6}'), // Latin capital letter AE + (199, '\u{00C7}'), // Latin capital letter C with cedilla + (200, '\u{00C8}'), // Latin capital letter E with grave + (201, '\u{00C9}'), // Latin capital letter E with acute + (202, '\u{00CA}'), // Latin capital letter E with circumflex + (203, '\u{00CB}'), // Latin capital letter E with diaeresis + (204, '\u{00CC}'), // Latin capital letter I with grave + (205, '\u{00CD}'), // Latin capital letter I with acute + (206, '\u{00CE}'), // Latin capital letter I with circumflex + (207, '\u{00CF}'), // Latin capital letter I with diaeresis + (208, '\u{00D0}'), // Latin capital letter Eth + (209, '\u{00D1}'), // Latin capital letter N with tilde + (210, '\u{00D2}'), // Latin capital letter O with grave + (211, '\u{00D3}'), // Latin capital letter O with acute + (212, '\u{00D4}'), // Latin capital letter O with circumflex + (213, '\u{00D5}'), // Latin capital letter O with tilde + (214, '\u{00D6}'), // Latin capital letter O with diaeresis + (215, '\u{00D7}'), // Multiplication sign + (216, '\u{00D8}'), // Latin capital letter O with stroke + (217, '\u{00D9}'), // Latin capital letter U with grave + (218, '\u{00DA}'), // Latin capital letter U with acute + (219, '\u{00DB}'), // Latin capital letter U with circumflex + (220, '\u{00DC}'), // Latin capital letter U with diaeresis + (221, '\u{00DD}'), // Latin capital letter Y with acute + (222, '\u{00DE}'), // Latin capital letter Thorn + (223, '\u{00DF}'), // Latin small letter sharp s + (224, '\u{00E0}'), // Latin small letter a with grave + (225, '\u{00E1}'), // Latin small letter a with acute + (226, '\u{00E2}'), // Latin small letter a with circumflex + (227, '\u{00E3}'), // Latin small letter a with tilde + (228, '\u{00E4}'), // Latin small letter a with diaeresis + (229, '\u{00E5}'), // Latin small letter a with ring above + (230, '\u{00E6}'), // Latin small letter ae + (231, '\u{00E7}'), // Latin small letter c with cedilla + (232, '\u{00E8}'), // Latin small letter e with grave + (233, '\u{00E9}'), // Latin small letter e with acute + (234, '\u{00EA}'), // Latin small letter e with circumflex + (235, '\u{00EB}'), // Latin small letter e with diaeresis + (236, '\u{00EC}'), // Latin small letter i with grave + (237, '\u{00ED}'), // Latin small letter i with acute + (238, '\u{00EE}'), // Latin small letter i with circumflex + (239, '\u{00EF}'), // Latin small letter i with diaeresis + (240, '\u{00F0}'), // Latin small letter eth + (241, '\u{00F1}'), // Latin small letter n with tilde + (242, '\u{00F2}'), // Latin small letter o with grave + (243, '\u{00F3}'), // Latin small letter o with acute + (244, '\u{00F4}'), // Latin small letter o with circumflex + (245, '\u{00F5}'), // Latin small letter o with tilde + (246, '\u{00F6}'), // Latin small letter o with diaeresis + (247, '\u{00F7}'), // Division sign + (248, '\u{00F8}'), // Latin small letter o with stroke + (249, '\u{00F9}'), // Latin small letter u with grave + (250, '\u{00FA}'), // Latin small letter u with acute + (251, '\u{00FB}'), // Latin small letter u with circumflex + (252, '\u{00FC}'), // Latin small letter u with diaeresis + (253, '\u{00FD}'), // Latin small letter y with acute + (254, '\u{00FE}'), // Latin small letter thorn + (255, '\u{00FF}'), // Latin small letter y with diaeresis + // Extended characters for bytes 0-32 and 127-160 + (0, '\u{0100}'), // Latin capital letter A with macron + (1, '\u{0101}'), // Latin small letter a with macron + (2, '\u{0102}'), // Latin capital letter A with breve + (3, '\u{0103}'), // Latin small letter a with breve + (4, '\u{0104}'), // Latin capital letter A with ogonek + (5, '\u{0105}'), // Latin small letter a with ogonek + (6, '\u{0106}'), // Latin capital letter C with acute + (7, '\u{0107}'), // Latin small letter c with acute + (8, '\u{0108}'), // Latin capital letter C with circumflex + (9, '\u{0109}'), // Latin small letter c with circumflex + (10, '\u{010A}'), // Latin capital letter C with dot above + (11, '\u{010B}'), // Latin small letter c with dot above + (12, '\u{010C}'), // Latin capital letter C with caron + (13, '\u{010D}'), // Latin small letter c with caron + (14, '\u{010E}'), // Latin capital letter D with caron + (15, '\u{010F}'), // Latin small letter d with caron + (16, '\u{0110}'), // Latin capital letter D with stroke + (17, '\u{0111}'), // Latin small letter d with stroke + (18, '\u{0112}'), // Latin capital letter E with macron + (19, '\u{0113}'), // Latin small letter e with macron + (20, '\u{0114}'), // Latin capital letter E with breve + (21, '\u{0115}'), // Latin small letter e with breve + (22, '\u{0116}'), // Latin capital letter E with dot above + (23, '\u{0117}'), // Latin small letter e with dot above + (24, '\u{0118}'), // Latin capital letter E with ogonek + (25, '\u{0119}'), // Latin small letter e with ogonek + (26, '\u{011A}'), // Latin capital letter E with caron + (27, '\u{011B}'), // Latin small letter e with caron + (28, '\u{011C}'), // Latin capital letter G with circumflex + (29, '\u{011D}'), // Latin small letter g with circumflex + (30, '\u{011E}'), // Latin capital letter G with breve + (31, '\u{011F}'), // Latin small letter g with breve + (32, '\u{0120}'), // Latin capital letter G with dot above + (127, '\u{0121}'), // Latin small letter g with dot above + (128, '\u{0122}'), // Latin capital letter G with cedilla + (129, '\u{0123}'), // Latin small letter g with cedilla + (130, '\u{0124}'), // Latin capital letter H with circumflex + (131, '\u{0125}'), // Latin small letter h with circumflex + (132, '\u{0126}'), // Latin capital letter H with stroke + (133, '\u{0127}'), // Latin small letter h with stroke + (134, '\u{0128}'), // Latin capital letter I with tilde + (135, '\u{0129}'), // Latin small letter i with tilde + (136, '\u{012A}'), // Latin capital letter I with macron + (137, '\u{012B}'), // Latin small letter i with macron + (138, '\u{012C}'), // Latin capital letter I with breve + (139, '\u{012D}'), // Latin small letter i with breve + (140, '\u{012E}'), // Latin capital letter I with ogonek + (141, '\u{012F}'), // Latin small letter i with ogonek + (142, '\u{0130}'), // Latin capital letter I with dot above + (143, '\u{0131}'), // Latin small letter dotless i + (144, '\u{0132}'), // Latin capital ligature IJ + (145, '\u{0133}'), // Latin small ligature ij + (146, '\u{0134}'), // Latin capital letter J with circumflex + (147, '\u{0135}'), // Latin small letter j with circumflex + (148, '\u{0136}'), // Latin capital letter K with cedilla + (149, '\u{0137}'), // Latin small letter k with cedilla + (150, '\u{0138}'), // Latin small letter kra + (151, '\u{0139}'), // Latin capital letter L with acute + (152, '\u{013A}'), // Latin small letter l with acute + (153, '\u{013B}'), // Latin capital letter L with cedilla + (154, '\u{013C}'), // Latin small letter l with cedilla + (155, '\u{013D}'), // Latin capital letter L with caron + (156, '\u{013E}'), // Latin small letter l with caron + (157, '\u{013F}'), // Latin capital letter L with middle dot + (158, '\u{0140}'), // Latin small letter l with middle dot + (159, '\u{0141}'), // Latin capital letter L with stroke + (160, '\u{0142}'), // Latin small letter l with stroke + (173, '\u{0143}'), // Latin capital letter N with acute +]; + +/// Regex pattern for tokenizing text. +/// Matches special tokens, contractions, letters, numbers, and other characters. +const TOKENIZER_PATTERN: &str = + r"<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"; + +/// Configuration for the CLIP tokenizer. +#[derive(Debug, Clone)] +pub struct SimpleTokenizerConfig { + /// Maximum sequence length (default: 77 for CLIP). + pub max_position_embeddings: usize, + /// Character to use for padding. If None, uses end-of-text token. + pub pad_with: Option, +} + +impl Default for SimpleTokenizerConfig { + fn default() -> Self { + Self { + max_position_embeddings: 77, + pad_with: None, + } + } +} + +impl SimpleTokenizerConfig { + /// Create config for Stable Diffusion v1.5. + pub fn v1_5() -> Self { + Self { + max_position_embeddings: 77, + pad_with: None, + } + } + + /// Create config for Stable Diffusion v2.1. + pub fn v2_1() -> Self { + Self { + max_position_embeddings: 77, + pad_with: Some("!".to_string()), + } + } +} + +/// A BPE tokenizer for CLIP text encoding. +/// +/// This tokenizer converts text into token IDs that can be processed +/// by the CLIP text encoder. It uses byte-pair encoding (BPE) with +/// a vocabulary derived from the original CLIP model. +pub struct SimpleTokenizer { + regex: regex::Regex, + encoder: HashMap, + decoder: HashMap, + bpe_ranks: HashMap<(String, String), usize>, + start_of_text_token: usize, + end_of_text_token: usize, + config: SimpleTokenizerConfig, +} + +impl SimpleTokenizer { + /// Create a new tokenizer from a BPE vocabulary file. + /// + /// # Arguments + /// * `bpe_path` - Path to the BPE vocabulary file (e.g., bpe_simple_vocab_16e6.txt) + /// * `config` - Tokenizer configuration + /// + /// # Returns + /// A new tokenizer instance or an error if the file cannot be read. + pub fn new>( + bpe_path: P, + config: SimpleTokenizerConfig, + ) -> Result { + let file = std::fs::File::open(bpe_path)?; + let reader = std::io::BufReader::new(file); + + let bpe_lines: Result, _> = reader.lines().collect(); + let bpe_lines = bpe_lines?; + + // Parse BPE merges (skip header, take 49152 - 256 - 2 merges) + let merge_count = 49152 - 256 - 2; + let bpe_merges: Result, TokenizerError> = bpe_lines[1..=merge_count] + .iter() + .map(|line| { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() != 2 { + return Err(TokenizerError::InvalidFormat(format!( + "Expected 2 tokens, got {}: '{}'", + parts.len(), + line + ))); + } + Ok((parts[0].to_string(), parts[1].to_string())) + }) + .collect(); + let bpe_merges = bpe_merges?; + + // Build vocabulary + let mut vocab: Vec = Vec::new(); + + // Add base characters + for (_, c) in BYTES_TO_UNICODE.iter() { + vocab.push(c.to_string()); + } + + // Add base characters with end-of-word marker + for (_, c) in BYTES_TO_UNICODE.iter() { + vocab.push(format!("{}", c)); + } + + // Add BPE merges + for (first, second) in bpe_merges.iter() { + vocab.push(format!("{}{}", first, second)); + } + + // Add special tokens + let start_of_text_token = vocab.len(); + vocab.push("<|startoftext|>".to_string()); + let end_of_text_token = vocab.len(); + vocab.push("<|endoftext|>".to_string()); + + // Build encoder/decoder mappings + let encoder: HashMap = vocab + .iter() + .enumerate() + .map(|(i, v)| (v.clone(), i)) + .collect(); + let decoder: HashMap = + encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + + // Build BPE ranks + let bpe_ranks: HashMap<(String, String), usize> = bpe_merges + .into_iter() + .enumerate() + .map(|(i, v)| (v, i)) + .collect(); + + let regex = regex::Regex::new(TOKENIZER_PATTERN)?; + + Ok(Self { + regex, + encoder, + decoder, + bpe_ranks, + start_of_text_token, + end_of_text_token, + config, + }) + } + + /// Get pairs of adjacent tokens in a word. + fn get_pairs(word: &[String]) -> HashSet<(String, String)> { + let mut pairs = HashSet::new(); + for i in 1..word.len() { + pairs.insert((word[i - 1].clone(), word[i].clone())); + } + pairs + } + + /// Apply BPE encoding to a single token. + fn bpe(&self, token: &str) -> Vec { + let mut word: Vec = token.chars().map(|c| c.to_string()).collect(); + + if word.is_empty() { + return Vec::new(); + } + + // Add end-of-word marker to last character + let last_idx = word.len() - 1; + word[last_idx] = format!("{}", word[last_idx]); + + // Iteratively merge pairs with lowest BPE rank + while word.len() > 1 { + let pairs = Self::get_pairs(&word); + + // Find pair with lowest rank + let best_pair = pairs + .iter() + .filter_map(|p| self.bpe_ranks.get(p).map(|rank| (rank, p))) + .min_by_key(|(rank, _)| *rank) + .map(|(_, p)| p.clone()); + + let (first, second) = match best_pair { + Some(p) => p, + None => break, + }; + + // Merge the pair + let mut new_word = Vec::new(); + let mut i = 0; + while i < word.len() { + if i + 1 < word.len() && word[i] == first && word[i + 1] == second { + new_word.push(format!("{}{}", first, second)); + i += 2; + } else { + new_word.push(word[i].clone()); + i += 1; + } + } + word = new_word; + } + + // Convert to token IDs + word.iter() + .filter_map(|w| self.encoder.get(w).copied()) + .collect() + } + + /// Encode text to token IDs with optional padding. + /// + /// # Arguments + /// * `text` - The text to encode + /// * `pad_to` - If Some, pad the result to this length + /// + /// # Returns + /// A vector of token IDs. + pub fn encode_with_padding( + &self, + text: &str, + pad_to: Option, + ) -> Result, TokenizerError> { + let text = text.to_lowercase(); + let mut tokens = vec![self.start_of_text_token]; + + // Tokenize each match + for cap in self.regex.captures_iter(&text) { + if let Some(m) = cap.get(0) { + tokens.extend(self.bpe(m.as_str())); + } + } + + // Add end token + tokens.push(self.end_of_text_token); + + // Apply padding if requested + if let Some(target_len) = pad_to { + // Truncate if necessary (keep room for end token) + if tokens.len() > target_len { + tokens.truncate(target_len - 1); + tokens.push(self.end_of_text_token); + } + + // Pad to target length + let pad_token = match &self.config.pad_with { + None => self.end_of_text_token, + Some(pad_char) => self + .encoder + .get(pad_char) + .copied() + .ok_or_else(|| TokenizerError::UnknownPadChar(pad_char.clone()))?, + }; + + while tokens.len() < target_len { + tokens.push(pad_token); + } + } + + Ok(tokens) + } + + /// Encode text to token IDs, padding to max_position_embeddings. + /// + /// This is the main entry point for tokenization. + /// + /// # Arguments + /// * `text` - The text to encode + /// + /// # Returns + /// A vector of token IDs padded to the configured max length. + pub fn encode(&self, text: &str) -> Result, TokenizerError> { + self.encode_with_padding(text, Some(self.config.max_position_embeddings)) + } + + /// Decode token IDs back to text. + /// + /// # Arguments + /// * `tokens` - The token IDs to decode + /// + /// # Returns + /// The decoded text string. + pub fn decode(&self, tokens: &[usize]) -> String { + let text: String = tokens + .iter() + .filter_map(|t| self.decoder.get(t)) + .cloned() + .collect(); + text.replace("", " ") + } + + /// Get the start-of-text token ID. + pub fn start_of_text_token(&self) -> usize { + self.start_of_text_token + } + + /// Get the end-of-text token ID. + pub fn end_of_text_token(&self) -> usize { + self.end_of_text_token + } + + /// Get the maximum sequence length. + pub fn max_length(&self) -> usize { + self.config.max_position_embeddings + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Note: These tests require the BPE vocabulary file to be present. + // The file can be downloaded from: + // https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz + + #[test] + fn test_tokenizer_config() { + let config = SimpleTokenizerConfig::v1_5(); + assert_eq!(config.max_position_embeddings, 77); + assert!(config.pad_with.is_none()); + + let config = SimpleTokenizerConfig::v2_1(); + assert_eq!(config.max_position_embeddings, 77); + assert_eq!(config.pad_with, Some("!".to_string())); + } +} diff --git a/src/transformers/mod.rs b/src/transformers/mod.rs index c044cac..a873d50 100644 --- a/src/transformers/mod.rs +++ b/src/transformers/mod.rs @@ -6,3 +6,9 @@ //! simple tokenization. pub mod clip; + +#[cfg(feature = "std")] +pub mod clip_tokenizer; + +#[cfg(feature = "std")] +pub use clip_tokenizer::{SimpleTokenizer, SimpleTokenizerConfig, TokenizerError};