diff --git a/Cargo.toml b/Cargo.toml
index 197a6fa..9b9c690 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -5,7 +5,7 @@ edition = "2021"
[features]
default = ["burn/default", "std"]
-std = ["burn/std"]
+std = ["burn/std", "regex", "thiserror", "burn-store", "safetensors", "memmap2"]
# Backend
accelerate = ["burn/accelerate"]
@@ -17,11 +17,26 @@ wgpu = ["burn/wgpu"]
[dependencies]
burn = { version = "0.20.1", default-features = false }
+burn-store = { version = "0.20.1", optional = true, features = ["safetensors"] }
num-traits = { version = "0.2.18", default-features = false }
+regex = { version = "1.10", optional = true }
+safetensors = { version = "0.4", optional = true }
+memmap2 = { version = "0.9", optional = true }
serde = { version = "1.0.197", default-features = false, features = [
"derive",
"alloc",
] }
+thiserror = { version = "1.0", optional = true }
-[patch.crates-io]
-#burn = { git = "https://github.com/tracel-ai/burn" }
+[[example]]
+name = "stable-diffusion"
+required-features = ["std"]
+
+[dev-dependencies]
+anyhow = "1.0"
+clap = { version = "4.4", features = ["derive", "env"] }
+dirs = "5.0"
+flate2 = "1.0"
+hf-hub = "0.3"
+image = "0.25"
+ureq = "2.9"
diff --git a/README.md b/README.md
index b78b31b..018bab7 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
-# diffusers-burn: A diffusers API in Rust/Burn
+# diffusers-burn
-> **⚠️ This is still in development - contributors welcome!**
+Stable Diffusion in Rust using [Burn](https://github.com/burn-rs/burn). Supports SD 1.5 and 2.1.
-The `diffusers-burn` crate is a conversion of [diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using [burn](https://github.com/burn-rs/burn) rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1, as well as Stable Diffusion XL 1.0.
+Based on [diffusers-rs](https://github.com/LaurentMazare/diffusers-rs).
-## Feature Flags
+## Quick Start
-This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling
-the default `std` feature.
+```bash
+# Using wgpu backend (default, works on most GPUs)
+cargo run --release --features wgpu --example stable-diffusion -- \
+ --prompt "A photo of a rusty robot on a beach"
-* `std` - enables the standard library. Enabled by default.
-* `wgpu` - uses ndarray as the backend. Enabled by default when none specified and `std`.
-* `ndarray` - uses ndarray as the backend.
-* `ndarray-no-std` - uses ndarray-no-std as the backend. Enabled by default when none and `#![no_std]`.
-* `ndarray-blas-accelerate` - uses ndarray with Accelerate framework (macOS only).
-* `torch` - uses torch as the backend.
+# Using torch backend
+cargo run --release --features torch --example stable-diffusion -- \
+ --prompt "A photo of a rusty robot on a beach"
-## Community
+# SD 2.1 at 768x768
+cargo run --release --features torch --example stable-diffusion -- \
+ --sd-version v2-1 \
+ --prompt "A majestic lion on a cliff at sunset"
+```
+
+## Backends
+
+| Feature | Backend | Notes |
+|---------|---------|-------|
+| `wgpu` | WebGPU | Cross-platform GPU support |
+| `torch` | LibTorch | Requires libtorch |
+| `ndarray` | ndarray | CPU only, pure Rust |
-If you are excited about the project or want to contribute, don't hesitate to join our [Discord](https://discord.gg/UHtSgF6j5J)!
-We try to be as welcoming as possible to everybody from any background. We're still building this out, but you can ask your questions there!
+## no_std Support
-## Status
+This crate supports `#![no_std]` with `alloc` by disabling the default `std` feature.
+
+## Community
-diffusers-burn is currently in active development, and is not yet complete.
+Join our [Discord](https://discord.gg/UHtSgF6j5J) if you want to contribute or have questions!
## License
-diffusers-burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).
-See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details. Opening a pull
-request is assumed to signal agreement with these licensing terms.
+MIT or Apache-2.0, at your option.
diff --git a/examples/stable-diffusion/Cargo.toml b/examples/stable-diffusion/Cargo.toml
new file mode 100644
index 0000000..5e4d89f
--- /dev/null
+++ b/examples/stable-diffusion/Cargo.toml
@@ -0,0 +1,25 @@
+[package]
+name = "stable-diffusion-example"
+version = "0.1.0"
+edition = "2021"
+
+[[bin]]
+name = "stable-diffusion"
+path = "main.rs"
+
+[features]
+default = ["wgpu"]
+torch = ["diffusers-burn/torch", "burn/tch"]
+ndarray = ["diffusers-burn/ndarray", "burn/ndarray"]
+wgpu = ["diffusers-burn/wgpu", "burn/wgpu"]
+
+[dependencies]
+diffusers-burn = { path = "../..", features = ["std"], default-features = false }
+burn = { version = "0.20.1", default-features = false }
+clap = { version = "4", features = ["derive"] }
+image = "0.25"
+anyhow = "1.0"
+hf-hub = "0.4"
+ureq = "2"
+flate2 = "1"
+dirs = "6"
diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs
new file mode 100644
index 0000000..50a5dbe
--- /dev/null
+++ b/examples/stable-diffusion/main.rs
@@ -0,0 +1,367 @@
+#![recursion_limit = "256"]
+// Stable Diffusion Example
+//
+// This example generates images from text prompts using Stable Diffusion.
+// Weights are automatically downloaded from Hugging Face Hub on first run.
+//
+// # Usage
+//
+// cargo run --release -- --prompt "a photo of a cat"
+//
+// The BPE vocabulary file will be downloaded automatically if not present.
+
+use std::fs;
+use std::io::Read;
+use std::path::PathBuf;
+
+use clap::{Parser, ValueEnum};
+
+use burn::tensor::Tensor;
+use hf_hub::api::sync::Api;
+
+use diffusers_burn::pipelines::stable_diffusion::{
+ generate_image_ddim, StableDiffusion, StableDiffusionConfig,
+};
+use diffusers_burn::pipelines::weights::{
+ load_clip_safetensors, load_unet_safetensors, load_vae_safetensors,
+};
+use diffusers_burn::transformers::{SimpleTokenizer, SimpleTokenizerConfig};
+
+const GUIDANCE_SCALE: f64 = 7.5;
+
+#[cfg(feature = "wgpu")]
+type Backend = burn::backend::Wgpu;
+
+#[cfg(feature = "torch")]
+type Backend = burn::backend::LibTorch;
+
+#[cfg(feature = "ndarray")]
+type Backend = burn::backend::NdArray;
+
+#[derive(Debug, Clone, Copy, ValueEnum)]
+enum StableDiffusionVersion {
+ #[value(name = "v1-5")]
+ V1_5,
+ #[value(name = "v2-1")]
+ V2_1,
+}
+
+impl StableDiffusionVersion {
+ fn repo_id(&self) -> &'static str {
+ match self {
+ StableDiffusionVersion::V1_5 => "runwayml/stable-diffusion-v1-5",
+ // Use community repo (ungated) instead of stabilityai/stable-diffusion-2-1 (gated)
+ StableDiffusionVersion::V2_1 => "sd2-community/stable-diffusion-2-1",
+ }
+ }
+
+ fn clip_repo_id(&self) -> &'static str {
+ match self {
+ // SD 1.5 uses OpenAI's CLIP
+ StableDiffusionVersion::V1_5 => "openai/clip-vit-large-patch14",
+ // SD 2.1 community repo has CLIP in text_encoder subdirectory
+ StableDiffusionVersion::V2_1 => "sd2-community/stable-diffusion-2-1",
+ }
+ }
+
+ fn clip_weights_file(&self) -> &'static str {
+ match self {
+ // SD 1.5 uses standalone CLIP model
+ StableDiffusionVersion::V1_5 => "model.safetensors",
+ // SD 2.1 has CLIP in text_encoder subdirectory
+ StableDiffusionVersion::V2_1 => "text_encoder/model.safetensors",
+ }
+ }
+
+ fn tokenizer_config(&self) -> SimpleTokenizerConfig {
+ match self {
+ StableDiffusionVersion::V1_5 => SimpleTokenizerConfig::v1_5(),
+ StableDiffusionVersion::V2_1 => SimpleTokenizerConfig::v2_1(),
+ }
+ }
+}
+
+#[derive(Parser)]
+#[command(author, version, about = "Generate images with Stable Diffusion", long_about = None)]
+struct Args {
+ /// The prompt to be used for image generation.
+ #[arg(
+ long,
+ default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
+ )]
+ prompt: String,
+
+ /// The negative prompt (what to avoid in the image).
+ #[arg(long, default_value = "")]
+ negative_prompt: String,
+
+ /// The height in pixels of the generated image.
+ #[arg(long)]
+ height: Option,
+
+ /// The width in pixels of the generated image.
+ #[arg(long)]
+ width: Option,
+
+ /// The UNet weight file, in .safetensors format (auto-downloaded if not specified).
+ #[arg(long, value_name = "FILE")]
+ unet_weights: Option,
+
+ /// The CLIP weight file, in .safetensors format (auto-downloaded if not specified).
+ #[arg(long, value_name = "FILE")]
+ clip_weights: Option,
+
+ /// The VAE weight file, in .safetensors format (auto-downloaded if not specified).
+ #[arg(long, value_name = "FILE")]
+ vae_weights: Option,
+
+ /// The file specifying the vocabulary to use for tokenization (auto-downloaded if not specified).
+ #[arg(long, value_name = "FILE")]
+ vocab_file: Option,
+
+ /// The number of steps to run the diffusion for.
+ #[arg(long, default_value_t = 30)]
+ n_steps: usize,
+
+ /// The random seed to be used for the generation.
+ #[arg(long, default_value_t = 32)]
+ seed: u64,
+
+ /// The name of the final image to generate.
+ #[arg(long, value_name = "FILE", default_value = "output.png")]
+ output: String,
+
+ /// The Stable Diffusion version to use.
+ #[arg(long, value_enum, default_value = "v1-5")]
+ sd_version: StableDiffusionVersion,
+
+ /// Hugging Face API token for gated models (e.g., SD 2.1).
+ /// Can also be set via HF_TOKEN environment variable.
+ #[arg(long, env = "HF_TOKEN")]
+ hf_token: Option,
+}
+
+/// Downloads a file from Hugging Face Hub if not already cached.
+fn download_hf_file(repo_id: &str, filename: &str, token: Option<&str>) -> anyhow::Result {
+ println!(" Downloading {} from {}...", filename, repo_id);
+
+ let api = match token {
+ Some(t) => hf_hub::api::sync::ApiBuilder::new()
+ .with_token(Some(t.to_string()))
+ .build()?,
+ None => Api::new()?,
+ };
+
+ let repo = api.model(repo_id.to_string());
+ match repo.get(filename) {
+ Ok(path) => {
+ println!(" Cached at: {}", path.display());
+ Ok(path)
+ }
+ Err(e) => {
+ // Provide helpful error message for 401 errors
+ let err_str = e.to_string();
+ if err_str.contains("401") {
+ anyhow::bail!(
+ "Authentication required for {}.\n\
+ This model requires accepting the license at https://huggingface.co/{}\n\
+ Then provide your HF token via --hf-token or HF_TOKEN environment variable.\n\
+ Get your token at: https://huggingface.co/settings/tokens",
+ repo_id,
+ repo_id
+ );
+ }
+ Err(e.into())
+ }
+ }
+}
+
+/// Downloads the BPE vocabulary file from OpenAI's GitHub.
+fn download_bpe_vocab() -> anyhow::Result {
+ // Cache in HF cache directory for consistency
+ let cache_dir = dirs::cache_dir()
+ .unwrap_or_else(|| PathBuf::from("."))
+ .join("huggingface")
+ .join("clip");
+ fs::create_dir_all(&cache_dir)?;
+
+ let vocab_path = cache_dir.join("bpe_simple_vocab_16e6.txt");
+
+ if vocab_path.exists() {
+ println!(" Using cached vocabulary at: {}", vocab_path.display());
+ return Ok(vocab_path);
+ }
+
+ println!(" Downloading BPE vocabulary from OpenAI GitHub...");
+ let url = "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz";
+
+ // Download gzipped file
+ let response = ureq::get(url).call()?;
+ let mut gz_data = Vec::new();
+ response.into_reader().read_to_end(&mut gz_data)?;
+
+ // Decompress
+ let mut decoder = flate2::read::GzDecoder::new(&gz_data[..]);
+ let mut content = String::new();
+ decoder.read_to_string(&mut content)?;
+
+ // Save to cache
+ fs::write(&vocab_path, &content)?;
+ println!(" Cached at: {}", vocab_path.display());
+
+ Ok(vocab_path)
+}
+
+fn tensor_to_image(tensor: Tensor) -> image::RgbImage {
+ // tensor shape: [1, 3, height, width], values in [0, 1]
+ let [_, _, height, width] = tensor.dims();
+
+ // Convert to [0, 255] range
+ let tensor = tensor * 255.0;
+ let data: Vec = tensor.into_data().to_vec().unwrap();
+
+ // Create image buffer
+ let mut img = image::RgbImage::new(width as u32, height as u32);
+
+ for y in 0..height {
+ for x in 0..width {
+ // Round and clamp to [0, 255] for proper u8 conversion (matches PyTorch's to_kind(Uint8))
+ let r = data[0 * height * width + y * width + x]
+ .round()
+ .clamp(0.0, 255.0) as u8;
+ let g = data[1 * height * width + y * width + x]
+ .round()
+ .clamp(0.0, 255.0) as u8;
+ let b = data[2 * height * width + y * width + x]
+ .round()
+ .clamp(0.0, 255.0) as u8;
+ img.put_pixel(x as u32, y as u32, image::Rgb([r, g, b]));
+ }
+ }
+
+ img
+}
+
+fn run(args: Args) -> anyhow::Result<()> {
+ let device = Default::default();
+
+ // Build configuration
+ println!("Building configuration for {:?}...", args.sd_version);
+ let sd_config = match args.sd_version {
+ StableDiffusionVersion::V1_5 => StableDiffusionConfig::v1_5(None, args.height, args.width),
+ StableDiffusionVersion::V2_1 => StableDiffusionConfig::v2_1(None, args.height, args.width),
+ };
+
+ // Download or use provided vocab file
+ // The tokenizer expects the OpenAI CLIP BPE vocabulary format
+ println!("\nPreparing vocabulary file...");
+ let vocab_path = match &args.vocab_file {
+ Some(path) => PathBuf::from(path),
+ None => download_bpe_vocab()?,
+ };
+
+ // Build tokenizer
+ println!("\nLoading tokenizer from {}...", vocab_path.display());
+ let tokenizer = SimpleTokenizer::new(&vocab_path, args.sd_version.tokenizer_config())?;
+
+ // Tokenize prompts
+ println!("Tokenizing prompt: \"{}\"", args.prompt);
+ let tokens = tokenizer.encode(&args.prompt)?;
+ let uncond_tokens = tokenizer.encode(&args.negative_prompt)?;
+ println!(" Prompt tokens: {} tokens", tokens.len());
+ println!(" Negative prompt tokens: {} tokens", uncond_tokens.len());
+
+ // Download weights
+ println!("\nPreparing model weights...");
+ let hf_token = args.hf_token.as_deref();
+
+ let clip_weights = match &args.clip_weights {
+ Some(path) => PathBuf::from(path),
+ None => download_hf_file(
+ args.sd_version.clip_repo_id(),
+ args.sd_version.clip_weights_file(),
+ hf_token,
+ )?,
+ };
+
+ let vae_weights = match &args.vae_weights {
+ Some(path) => PathBuf::from(path),
+ None => download_hf_file(
+ args.sd_version.repo_id(),
+ "vae/diffusion_pytorch_model.safetensors",
+ hf_token,
+ )?,
+ };
+
+ let unet_weights = match &args.unet_weights {
+ Some(path) => PathBuf::from(path),
+ None => download_hf_file(
+ args.sd_version.repo_id(),
+ "unet/diffusion_pytorch_model.safetensors",
+ hf_token,
+ )?,
+ };
+
+ // Build scheduler
+ println!("\nBuilding DDIM scheduler with {} steps...", args.n_steps);
+ let scheduler = sd_config.build_ddim_scheduler::(args.n_steps, &device);
+
+ // Build models
+ println!("Building CLIP text encoder...");
+ let clip = sd_config.build_clip_transformer::(&device);
+
+ println!("Building VAE...");
+ let vae = sd_config.build_vae::(&device);
+
+ println!("Building UNet...");
+ let unet = sd_config.build_unet::(&device, 4);
+
+ // Load weights
+ println!("\nLoading CLIP weights...");
+ let clip = load_clip_safetensors::(clip, &clip_weights, &device)?;
+
+ println!("Loading VAE weights...");
+ let vae = load_vae_safetensors::(vae, &vae_weights, &device)?;
+
+ println!("Loading UNet weights...");
+ let unet = load_unet_safetensors::(unet, &unet_weights, &device)?;
+
+ // Assemble pipeline
+ let pipeline = StableDiffusion {
+ clip,
+ vae,
+ unet,
+ width: sd_config.width,
+ height: sd_config.height,
+ };
+
+ // Generate image
+ println!("\nGenerating image...");
+ println!(" Size: {}x{}", sd_config.width, sd_config.height);
+ println!(" Steps: {}", args.n_steps);
+ println!(" Guidance scale: {}", GUIDANCE_SCALE);
+ println!(" Seed: {}", args.seed);
+
+ let image_tensor = generate_image_ddim(
+ &pipeline,
+ &scheduler,
+ &tokens,
+ &uncond_tokens,
+ GUIDANCE_SCALE,
+ args.seed,
+ &device,
+ );
+
+ // Save image
+ println!("\nSaving image to {}...", args.output);
+ let img = tensor_to_image(image_tensor);
+ img.save(&args.output)?;
+
+ println!("Done!");
+ Ok(())
+}
+
+fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+ run(args)
+}
diff --git a/src/lib.rs b/src/lib.rs
index 1e2bea9..8e12c68 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -6,6 +6,7 @@
pub mod models;
pub mod pipelines;
+pub mod schedulers;
pub mod transformers;
pub mod utils;
diff --git a/src/models/attention.rs b/src/models/attention.rs
index 2326fb5..254a753 100644
--- a/src/models/attention.rs
+++ b/src/models/attention.rs
@@ -154,11 +154,9 @@ impl CrossAttention {
fn reshape_batch_dim_to_heads(&self, xs: Tensor) -> Tensor {
let [batch_size, seq_len, dim] = xs.dims();
- let output = xs
- .reshape([batch_size / self.n_heads, self.n_heads, seq_len, dim])
+ xs.reshape([batch_size / self.n_heads, self.n_heads, seq_len, dim])
.swap_dims(1, 2)
- .reshape([batch_size / self.n_heads, seq_len, dim * self.n_heads]);
- output
+ .reshape([batch_size / self.n_heads, seq_len, dim * self.n_heads])
}
fn sliced_attention(
@@ -314,7 +312,13 @@ impl Proj {
fn forward(&self, xs: Tensor) -> Tensor {
match self {
Proj::Conv2d(conv) => conv.forward(xs),
- Proj::Linear(linear) => linear.forward(xs),
+ Proj::Linear(linear) => {
+ // For linear projection, we need to permute from [batch, channels, h, w]
+ // to [batch, h, w, channels], apply linear, then permute back
+ let xs = xs.swap_dims(1, 2).swap_dims(2, 3); // [batch, h, w, channels]
+ let xs = linear.forward(xs);
+ xs.swap_dims(2, 3).swap_dims(1, 2) // [batch, channels, h, w]
+ }
}
}
}
@@ -325,7 +329,7 @@ pub struct SpatialTransformer {
norm: GroupNorm,
proj_in: Proj,
transformer_blocks: Vec>,
- proj_out: nn::conv::Conv2d,
+ proj_out: Proj,
}
impl SpatialTransformerConfig {
@@ -352,8 +356,13 @@ impl SpatialTransformerConfig {
transformer_blocks.push(tb)
}
- let proj_out =
- nn::conv::Conv2dConfig::new([d_inner, self.in_channels], [1, 1]).init(device);
+ let proj_out = if self.use_linear_projection {
+ Proj::Linear(nn::LinearConfig::new(d_inner, self.in_channels).init(device))
+ } else {
+ Proj::Conv2d(
+ nn::conv::Conv2dConfig::new([d_inner, self.in_channels], [1, 1]).init(device),
+ )
+ };
SpatialTransformer {
norm,
@@ -655,4 +664,32 @@ mod tests {
Tolerance::rel_abs(1e-3, 1e-3),
)
}
+
+ /// Test GELU activation matches diffusers-rs (tch gelu("none"))
+ /// Reference values from diffusers-rs v0.3.1
+ #[test]
+ fn test_gelu_matches_diffusers_rs() {
+ let device = Default::default();
+ let xs: Tensor = Tensor::from_data(
+ TensorData::from([-2.0f32, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]),
+ &device,
+ );
+
+ let result = gelu(xs);
+
+ // Reference values from diffusers-rs: gelu("none")
+ // [-0.04550028, -0.15865526, -0.15426877, 0.0, 0.34573123, 0.8413447, 1.9544997]
+ result.into_data().assert_approx_eq::(
+ &TensorData::from([
+ -0.04550028,
+ -0.15865526,
+ -0.15426877,
+ 0.0,
+ 0.34573123,
+ 0.8413447,
+ 1.9544997,
+ ]),
+ Tolerance::rel_abs(1e-4, 1e-4),
+ );
+ }
}
diff --git a/src/models/controlnet.rs b/src/models/controlnet.rs
new file mode 100644
index 0000000..c0023a0
--- /dev/null
+++ b/src/models/controlnet.rs
@@ -0,0 +1,582 @@
+//! ControlNet Model
+//!
+//! ControlNet is a neural network structure to control diffusion models by adding extra conditions.
+//! https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet.py
+
+use alloc::vec;
+use alloc::vec::Vec;
+use burn::config::Config;
+use burn::module::Module;
+use burn::nn::conv::{Conv2d, Conv2dConfig};
+use burn::nn::PaddingConfig2d;
+use burn::tensor::activation::silu;
+use burn::tensor::backend::Backend;
+use burn::tensor::Tensor;
+
+use super::embeddings::{get_timestep_embedding, TimestepEmbedding, TimestepEmbeddingConfig};
+use super::unet_2d::{BlockConfig, UNetDownBlock};
+use super::unet_2d_blocks::{
+ CrossAttnDownBlock2DConfig, DownBlock2DConfig, UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DCrossAttnConfig,
+};
+
+/// ControlNet conditioning embedding module.
+///
+/// Processes the conditioning image (e.g., canny edges, depth map) into
+/// an embedding that can be added to the UNet features.
+#[derive(Module, Debug)]
+pub struct ControlNetConditioningEmbedding {
+ conv_in: Conv2d,
+ conv_out: Conv2d,
+ blocks: Vec<(Conv2d, Conv2d)>,
+}
+
+/// Configuration for ControlNetConditioningEmbedding.
+#[derive(Config, Debug)]
+pub struct ControlNetConditioningEmbeddingConfig {
+ /// Output channels for the conditioning embedding.
+ conditioning_embedding_channels: usize,
+ /// Input channels from the conditioning image.
+ #[config(default = 3)]
+ conditioning_channels: usize,
+ /// Channel progression for the embedding blocks.
+ block_out_channels: Vec,
+}
+
+impl ControlNetConditioningEmbeddingConfig {
+ /// Initialize the ControlNetConditioningEmbedding module.
+ pub fn init(&self, device: &B::Device) -> ControlNetConditioningEmbedding {
+ let b_channels = self.block_out_channels[0];
+ let bl_channels = *self.block_out_channels.last().unwrap();
+
+ let conv_in = Conv2dConfig::new([self.conditioning_channels, b_channels], [3, 3])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ let conv_out =
+ Conv2dConfig::new([bl_channels, self.conditioning_embedding_channels], [3, 3])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ let blocks = (0..self.block_out_channels.len() - 1)
+ .map(|i| {
+ let channel_in = self.block_out_channels[i];
+ let channel_out = self.block_out_channels[i + 1];
+
+ let c1 = Conv2dConfig::new([channel_in, channel_in], [3, 3])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ let c2 = Conv2dConfig::new([channel_in, channel_out], [3, 3])
+ .with_stride([2, 2])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ (c1, c2)
+ })
+ .collect();
+
+ ControlNetConditioningEmbedding {
+ conv_in,
+ conv_out,
+ blocks,
+ }
+ }
+}
+
+impl ControlNetConditioningEmbedding {
+ /// Forward pass through the conditioning embedding.
+ pub fn forward(&self, xs: Tensor) -> Tensor {
+ let mut xs = silu(self.conv_in.forward(xs));
+
+ for (c1, c2) in &self.blocks {
+ xs = silu(c1.forward(xs));
+ xs = silu(c2.forward(xs));
+ }
+
+ self.conv_out.forward(xs)
+ }
+}
+
+/// Configuration for the ControlNet model.
+#[derive(Config, Debug)]
+pub struct ControlNetConfig {
+ /// Whether to flip sin to cos in timestep embedding.
+ #[config(default = true)]
+ pub flip_sin_to_cos: bool,
+ /// Frequency shift for timestep embedding.
+ #[config(default = 0.0)]
+ pub freq_shift: f64,
+ /// Configuration for each block.
+ pub blocks: Vec,
+ /// Output channels for conditioning embedding blocks.
+ pub conditioning_embedding_out_channels: Vec,
+ /// Number of ResNet layers per block.
+ #[config(default = 2)]
+ pub layers_per_block: usize,
+ /// Padding for downsampling convolutions.
+ #[config(default = 1)]
+ pub downsample_padding: usize,
+ /// Scale factor for mid block.
+ #[config(default = 1.0)]
+ pub mid_block_scale_factor: f64,
+ /// Number of groups for group normalization.
+ #[config(default = 32)]
+ pub norm_num_groups: usize,
+ /// Epsilon for normalization layers.
+ #[config(default = 1e-5)]
+ pub norm_eps: f64,
+ /// Dimension of cross-attention context.
+ #[config(default = 768)]
+ pub cross_attention_dim: usize,
+ /// Whether to use linear projection in attention.
+ #[config(default = false)]
+ pub use_linear_projection: bool,
+}
+
+impl Default for ControlNetConfig {
+ fn default() -> Self {
+ Self {
+ flip_sin_to_cos: true,
+ freq_shift: 0.0,
+ blocks: vec![
+ BlockConfig::new(320)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(640)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(1280)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(1280)
+ .with_use_cross_attn(false)
+ .with_attention_head_dim(8),
+ ],
+ conditioning_embedding_out_channels: vec![16, 32, 96, 256],
+ layers_per_block: 2,
+ downsample_padding: 1,
+ mid_block_scale_factor: 1.0,
+ norm_num_groups: 32,
+ norm_eps: 1e-5,
+ cross_attention_dim: 768,
+ use_linear_projection: false,
+ }
+ }
+}
+
+/// ControlNet model for adding spatial conditioning to diffusion models.
+///
+/// ControlNet copies the weights of a pretrained UNet and adds zero convolution
+/// layers to inject conditioning information. The output consists of residuals
+/// that are added to the corresponding layers of the UNet.
+#[derive(Module, Debug)]
+pub struct ControlNet {
+ conv_in: Conv2d,
+ controlnet_mid_block: Conv2d,
+ controlnet_cond_embedding: ControlNetConditioningEmbedding,
+ time_embedding: TimestepEmbedding,
+ down_blocks: Vec>,
+ controlnet_down_blocks: Vec>,
+ mid_block: UNetMidBlock2DCrossAttn,
+ #[module(skip)]
+ time_proj_channels: usize,
+ #[module(skip)]
+ flip_sin_to_cos: bool,
+ #[module(skip)]
+ freq_shift: f64,
+}
+
+impl ControlNetConfig {
+ /// Initialize the ControlNet model.
+ pub fn init(&self, in_channels: usize, device: &B::Device) -> ControlNet {
+ let n_blocks = self.blocks.len();
+ let b_channels = self.blocks[0].out_channels;
+ let bl_channels = self.blocks.last().unwrap().out_channels;
+ let bl_attention_head_dim = self.blocks.last().unwrap().attention_head_dim;
+ let time_embed_dim = b_channels * 4;
+
+ // Time embeddings
+ let time_embedding = TimestepEmbeddingConfig::new(b_channels, time_embed_dim).init(device);
+
+ // Input convolution
+ let conv_in = Conv2dConfig::new([in_channels, b_channels], [3, 3])
+ .with_stride([1, 1])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ // ControlNet mid block (1x1 conv, zero initialized in practice)
+ let controlnet_mid_block =
+ Conv2dConfig::new([bl_channels, bl_channels], [1, 1]).init(device);
+
+ // Conditioning embedding
+ let controlnet_cond_embedding = ControlNetConditioningEmbeddingConfig::new(
+ b_channels,
+ self.conditioning_embedding_out_channels.clone(),
+ )
+ .init(device);
+
+ // Down blocks
+ let down_blocks: Vec> = (0..n_blocks)
+ .map(|i| {
+ let block_config = &self.blocks[i];
+ let out_channels = block_config.out_channels;
+ let attention_head_dim = block_config.attention_head_dim;
+
+ let in_ch = if i > 0 {
+ self.blocks[i - 1].out_channels
+ } else {
+ b_channels
+ };
+
+ let db_config = DownBlock2DConfig::new(in_ch, out_channels)
+ .with_temb_channels(Some(time_embed_dim))
+ .with_n_layers(self.layers_per_block)
+ .with_resnet_eps(self.norm_eps)
+ .with_resnet_groups(self.norm_num_groups)
+ .with_add_downsample(i < n_blocks - 1)
+ .with_downsample_padding(self.downsample_padding);
+
+ if block_config.use_cross_attn {
+ let config = CrossAttnDownBlock2DConfig::new(in_ch, out_channels, db_config)
+ .with_temb_channels(Some(time_embed_dim))
+ .with_attn_num_head_channels(attention_head_dim)
+ .with_cross_attention_dim(self.cross_attention_dim)
+ .with_use_linear_projection(self.use_linear_projection);
+ UNetDownBlock::CrossAttn(config.init(device))
+ } else {
+ UNetDownBlock::Basic(db_config.init(device))
+ }
+ })
+ .collect();
+
+ // Mid block
+ let mid_config = UNetMidBlock2DCrossAttnConfig::new(bl_channels)
+ .with_temb_channels(Some(time_embed_dim))
+ .with_resnet_eps(self.norm_eps)
+ .with_output_scale_factor(self.mid_block_scale_factor)
+ .with_cross_attn_dim(self.cross_attention_dim)
+ .with_attn_num_head_channels(bl_attention_head_dim)
+ .with_resnet_groups(Some(self.norm_num_groups))
+ .with_use_linear_projection(self.use_linear_projection);
+ let mid_block = mid_config.init(device);
+
+ // ControlNet down blocks (1x1 convs, zero initialized in practice)
+ let mut controlnet_down_blocks =
+ vec![Conv2dConfig::new([b_channels, b_channels], [1, 1]).init(device)];
+
+ for (i, block) in self.blocks.iter().enumerate() {
+ let out_channels = block.out_channels;
+ for _ in 0..self.layers_per_block {
+ controlnet_down_blocks
+ .push(Conv2dConfig::new([out_channels, out_channels], [1, 1]).init(device));
+ }
+ if i + 1 != self.blocks.len() {
+ controlnet_down_blocks
+ .push(Conv2dConfig::new([out_channels, out_channels], [1, 1]).init(device));
+ }
+ }
+
+ ControlNet {
+ conv_in,
+ controlnet_mid_block,
+ controlnet_cond_embedding,
+ time_embedding,
+ down_blocks,
+ controlnet_down_blocks,
+ mid_block,
+ time_proj_channels: b_channels,
+ flip_sin_to_cos: self.flip_sin_to_cos,
+ freq_shift: self.freq_shift,
+ }
+ }
+}
+
+impl ControlNet {
+ /// Forward pass through ControlNet.
+ ///
+ /// # Arguments
+ /// * `xs` - Noisy input tensor [batch, channels, height, width]
+ /// * `timestep` - Current diffusion timestep
+ /// * `encoder_hidden_states` - Encoder hidden states for cross-attention [batch, seq_len, dim]
+ /// * `controlnet_cond` - Conditioning image [batch, 3, height, width]
+ /// * `conditioning_scale` - Scale factor for the conditioning (typically 1.0)
+ ///
+ /// # Returns
+ /// A tuple of (down_block_residuals, mid_block_residual) to be added to the UNet.
+ pub fn forward(
+ &self,
+ xs: Tensor,
+ timestep: f64,
+ encoder_hidden_states: Tensor,
+ controlnet_cond: Tensor,
+ conditioning_scale: f64,
+ ) -> (Vec>, Tensor) {
+ let [bsize, _channels, _height, _width] = xs.dims();
+ let device = xs.device();
+
+ // 1. Time embedding
+ let timesteps: Tensor = Tensor::full([bsize], timestep as f32, &device);
+ let emb = get_timestep_embedding(
+ timesteps,
+ self.time_proj_channels,
+ self.flip_sin_to_cos,
+ self.freq_shift,
+ );
+ let emb = self.time_embedding.forward(emb);
+
+ // 2. Pre-process
+ let xs = self.conv_in.forward(xs);
+ let controlnet_cond = self.controlnet_cond_embedding.forward(controlnet_cond);
+ let xs = xs + controlnet_cond;
+
+ // 3. Down blocks
+ let mut down_block_res_xs = vec![xs.clone()];
+ let mut xs = xs;
+ for down_block in self.down_blocks.iter() {
+ let (new_xs, res_xs) = match down_block {
+ UNetDownBlock::Basic(b) => b.forward(xs, Some(emb.clone())),
+ UNetDownBlock::CrossAttn(b) => {
+ b.forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone()))
+ }
+ };
+ down_block_res_xs.extend(res_xs);
+ xs = new_xs;
+ }
+
+ // 4. Mid block
+ let xs = self
+ .mid_block
+ .forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone()));
+
+ // 5. ControlNet blocks - apply 1x1 convs and scale
+ let controlnet_down_block_res_xs: Vec> = self
+ .controlnet_down_blocks
+ .iter()
+ .enumerate()
+ .map(|(i, block)| block.forward(down_block_res_xs[i].clone()) * conditioning_scale)
+ .collect();
+
+ let mid_block_res = self.controlnet_mid_block.forward(xs) * conditioning_scale;
+
+ (controlnet_down_block_res_xs, mid_block_res)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::TestBackend;
+ use alloc::string::{String, ToString};
+ use burn::module::{Module, ModuleMapper, Param};
+ use burn::tensor::{Shape, TensorData};
+
+ #[test]
+ fn test_controlnet_conditioning_embedding_shape() {
+ let device = Default::default();
+
+ let config = ControlNetConditioningEmbeddingConfig::new(320, vec![16, 32, 96, 256]);
+ let embedding = config.init::(&device);
+
+ // Input: [batch=1, channels=3, height=64, width=64]
+ let xs: Tensor = Tensor::zeros([1, 3, 64, 64], &device);
+ let output = embedding.forward(xs);
+
+ // Output should have 320 channels and be downsampled by 2^3 = 8
+ // 64 / 8 = 8
+ assert_eq!(output.shape(), Shape::from([1, 320, 8, 8]));
+ }
+
+ #[test]
+ fn test_controlnet_output_shape() {
+ let device = Default::default();
+
+ // Create a small ControlNet for testing
+ // Note: conditioning_embedding_out_channels length determines downsampling factor
+ // With [16] (1 element), there's no downsampling in the conditioning embedding
+ let config = ControlNetConfig {
+ blocks: vec![
+ BlockConfig::new(32)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(64)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ ],
+ conditioning_embedding_out_channels: vec![16], // Single element = no downsampling
+ layers_per_block: 1,
+ norm_num_groups: 32,
+ cross_attention_dim: 64,
+ ..Default::default()
+ };
+
+ let controlnet = config.init::(4, &device);
+
+ // Input: [batch=1, channels=4, height=32, width=32]
+ let xs: Tensor = Tensor::zeros([1, 4, 32, 32], &device);
+ // Encoder hidden states: [batch=1, seq_len=8, dim=64]
+ let encoder_hidden_states: Tensor = Tensor::zeros([1, 8, 64], &device);
+ // Conditioning image: [batch=1, channels=3, height=32, width=32]
+ let controlnet_cond: Tensor = Tensor::zeros([1, 3, 32, 32], &device);
+
+ let (down_residuals, mid_residual) =
+ controlnet.forward(xs, 1.0, encoder_hidden_states, controlnet_cond, 1.0);
+
+ // Should have residuals for each down block output
+ // With 2 blocks and 1 layer each, plus initial conv_in:
+ // - 1 from conv_in (32 channels)
+ // - 1 from block 0 resnet (32 channels)
+ // - 1 from block 0 downsample (32 channels)
+ // - 1 from block 1 resnet (64 channels)
+ // Total: 4 residuals
+ assert_eq!(down_residuals.len(), 4);
+
+ // Mid block residual should match mid block channels
+ assert_eq!(mid_residual.dims()[1], 64);
+ }
+
+ /// A ModuleMapper that sets weights to one value and biases to another.
+ struct WeightBiasMapper<'a, B: Backend> {
+ weight_val: f32,
+ bias_val: f32,
+ device: &'a B::Device,
+ current_field: String,
+ }
+
+ impl<'a, B: Backend> ModuleMapper for WeightBiasMapper<'a, B> {
+ fn enter_module(&mut self, name: &str, _container_type: &str) {
+ self.current_field = name.to_string();
+ }
+
+ fn map_float(
+ &mut self,
+ tensor: Param>,
+ ) -> Param> {
+ let shape = tensor.shape();
+ let dims: [usize; D] = shape.dims();
+
+ let is_bias = self.current_field.contains("bias") || self.current_field == "beta";
+ let val = if is_bias {
+ self.bias_val
+ } else {
+ self.weight_val
+ };
+
+ Param::initialized(tensor.id, Tensor::full(dims, val, self.device))
+ }
+ }
+
+ fn set_weights_and_biases(
+ controlnet: ControlNet,
+ weight_val: f32,
+ bias_val: f32,
+ device: &B::Device,
+ ) -> ControlNet {
+ let mut mapper = WeightBiasMapper:: {
+ weight_val,
+ bias_val,
+ device,
+ current_field: String::new(),
+ };
+ controlnet.map(&mut mapper)
+ }
+
+ /// Test ControlNet with fixed weights matches diffusers-rs
+ /// Note: Skipped for wgpu due to GPU floating-point precision differences
+ #[test]
+ #[cfg(not(feature = "wgpu"))]
+ fn test_controlnet_fixed_weights_matches_diffusers_rs() {
+ let device = Default::default();
+
+ // Create a small ControlNet matching the diffusers-rs test config
+ let config = ControlNetConfig {
+ blocks: vec![
+ BlockConfig::new(32)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(64)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ ],
+ conditioning_embedding_out_channels: vec![16],
+ layers_per_block: 1,
+ norm_num_groups: 32,
+ cross_attention_dim: 64,
+ ..Default::default()
+ };
+
+ let controlnet = config.init::(4, &device);
+
+ // Set weights to 0.1, biases to 0.0
+ let controlnet = set_weights_and_biases(controlnet, 0.1, 0.0, &device);
+
+ // Input tensor [batch=1, channels=4, height=32, width=32]
+ let input_data: Vec = (0..(4 * 32 * 32))
+ .map(|i| i as f32 / (4.0 * 32.0 * 32.0))
+ .collect();
+ let xs: Tensor =
+ Tensor::from_data(TensorData::from(input_data.as_slice()), &device);
+ let xs: Tensor = xs.reshape([1, 4, 32, 32]);
+
+ // Encoder hidden states [batch=1, seq_len=8, dim=64]
+ let enc_data: Vec = (0..(8 * 64)).map(|i| i as f32 / (8.0 * 64.0)).collect();
+ let encoder_hidden_states: Tensor =
+ Tensor::from_data(TensorData::from(enc_data.as_slice()), &device);
+ let encoder_hidden_states: Tensor =
+ encoder_hidden_states.reshape([1, 8, 64]);
+
+ // Conditioning image [batch=1, channels=3, height=32, width=32]
+ let cond_data: Vec = (0..(3 * 32 * 32))
+ .map(|i| i as f32 / (3.0 * 32.0 * 32.0))
+ .collect();
+ let controlnet_cond: Tensor =
+ Tensor::from_data(TensorData::from(cond_data.as_slice()), &device);
+ let controlnet_cond: Tensor = controlnet_cond.reshape([1, 3, 32, 32]);
+
+ let (down_residuals, mid_residual) =
+ controlnet.forward(xs, 1.0, encoder_hidden_states, controlnet_cond, 1.0);
+
+ // Reference values from diffusers-rs
+ assert_eq!(down_residuals.len(), 4);
+
+ // Down residual 0: shape=[1, 32, 32, 32], mean=51.52621078491211
+ let mean0: f32 = down_residuals[0].clone().mean().into_scalar();
+ assert!(
+ (mean0 - 51.526210).abs() < 0.1,
+ "Down residual 0 mean mismatch: actual={}, expected=51.526210",
+ mean0
+ );
+
+ // Down residual 1: shape=[1, 32, 32, 32], mean=156.5426025390625
+ let mean1: f32 = down_residuals[1].clone().mean().into_scalar();
+ assert!(
+ (mean1 - 156.54260).abs() < 0.1,
+ "Down residual 1 mean mismatch: actual={}, expected=156.54260",
+ mean1
+ );
+
+ // Down residual 2: shape=[1, 32, 16, 16], mean=4349.87353515625
+ let mean2: f32 = down_residuals[2].clone().mean().into_scalar();
+ assert!(
+ (mean2 - 4349.8735).abs() < 1.0,
+ "Down residual 2 mean mismatch: actual={}, expected=4349.8735",
+ mean2
+ );
+
+ // Down residual 3: shape=[1, 64, 16, 16], mean=28677.99609375
+ let mean3: f32 = down_residuals[3].clone().mean().into_scalar();
+ assert!(
+ (mean3 - 28677.996).abs() < 10.0,
+ "Down residual 3 mean mismatch: actual={}, expected=28677.996",
+ mean3
+ );
+
+ // Mid residual: shape=[1, 64, 16, 16], mean=29518.357421875
+ let mid_mean: f32 = mid_residual.clone().mean().into_scalar();
+ assert!(
+ (mid_mean - 29518.357).abs() < 10.0,
+ "Mid residual mean mismatch: actual={}, expected=29518.357",
+ mid_mean
+ );
+ }
+}
diff --git a/src/models/embeddings.rs b/src/models/embeddings.rs
index 83b663e..071770e 100644
--- a/src/models/embeddings.rs
+++ b/src/models/embeddings.rs
@@ -6,7 +6,6 @@ use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::silu;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
-use core::marker::PhantomData;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
@@ -35,49 +34,49 @@ impl TimestepEmbeddingConfig {
}
impl TimestepEmbedding {
- fn forward(&self, xs: Tensor) -> Tensor {
+ pub fn forward(&self, xs: Tensor) -> Tensor {
let xs = silu(self.linear_1.forward(xs));
self.linear_2.forward(xs)
}
}
-#[derive(Module, Debug)]
-pub struct Timesteps {
+/// Computes sinusoidal timestep embeddings.
+///
+/// This is a pure function with no learnable parameters.
+/// It generates positional embeddings for diffusion timesteps using sinusoidal encoding.
+///
+/// # Arguments
+/// * `timesteps` - 1D tensor of timestep values
+/// * `num_channels` - Number of embedding channels to produce
+/// * `flip_sin_to_cos` - If true, output is [cos, sin]; if false, output is [sin, cos]
+/// * `downscale_freq_shift` - Frequency shift applied to the exponent denominator
+///
+/// # Returns
+/// A 2D tensor of shape [batch_size, num_channels] containing the timestep embeddings.
+pub fn get_timestep_embedding(
+ timesteps: Tensor,
num_channels: usize,
flip_sin_to_cos: bool,
downscale_freq_shift: f64,
- _backend: PhantomData,
-}
-
-impl Timesteps {
- pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
- Self {
- num_channels,
- flip_sin_to_cos,
-
- downscale_freq_shift,
- _backend: PhantomData,
- }
- }
-
- pub fn forward(&self, xs: Tensor) -> Tensor {
- let half_dim = self.num_channels / 2;
- let exponent = Tensor::arange(0..half_dim as i64, &xs.device()).float() * -f64::ln(10000.);
- let exponent = exponent / (half_dim as f64 - self.downscale_freq_shift);
- let emb = exponent.exp();
- // emb = timesteps[:, None].float() * emb[None, :]
- let emb: Tensor = xs.unsqueeze_dim(D1) * emb.unsqueeze();
- let emb: Tensor = if self.flip_sin_to_cos {
- Tensor::cat(vec![emb.clone().cos(), emb.clone().sin()], D1)
- } else {
- Tensor::cat(vec![emb.clone().sin(), emb.clone().cos()], D1)
- };
-
- if self.num_channels % 2 == 1 {
- pad_with_zeros(emb, D1 - 1, 0, 1)
- } else {
- emb
- }
+) -> Tensor {
+ let half_dim = num_channels / 2;
+ let exponent =
+ Tensor::arange(0..half_dim as i64, ×teps.device()).float() * -f64::ln(10000.);
+ let exponent = exponent / (half_dim as f64 - downscale_freq_shift);
+ let emb = exponent.exp();
+ // emb = timesteps[:, None].float() * emb[None, :]
+ let emb: Tensor = timesteps.unsqueeze_dim(1) * emb.unsqueeze();
+ // Concatenate along the last dimension (-1 in PyTorch, which is dim 1 for 2D tensor)
+ let emb: Tensor = if flip_sin_to_cos {
+ Tensor::cat(vec![emb.clone().cos(), emb.clone().sin()], 1)
+ } else {
+ Tensor::cat(vec![emb.clone().sin(), emb.clone().cos()], 1)
+ };
+
+ if num_channels % 2 == 1 {
+ pad_with_zeros(emb, 1, 0, 1)
+ } else {
+ emb
}
}
@@ -87,49 +86,93 @@ mod tests {
use crate::TestBackend;
use burn::tensor::{Shape, TensorData, Tolerance};
+ /// Test get_timestep_embedding with even channels - validated against diffusers-rs v0.3.1
#[test]
#[cfg(not(feature = "torch"))]
fn test_timesteps_even_channels() {
let device = Default::default();
- let timesteps = Timesteps::::new(4, true, 0.);
let xs: Tensor =
Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device);
- let emb: Tensor = timesteps.forward(xs);
+ let emb: Tensor = get_timestep_embedding(xs, 4, true, 0.);
assert_eq!(emb.shape(), Shape::from([4, 4]));
+ // Reference values from diffusers-rs v0.3.1:
emb.into_data().assert_approx_eq::(
&TensorData::from([
- [0.5403, 1.0000, 0.8415, 0.0100],
- [-0.4161, 0.9998, 0.9093, 0.0200],
- [-0.9900, 0.9996, 0.1411, 0.0300],
- [-0.6536, 0.9992, -0.7568, 0.0400],
+ [0.5403023, 0.99995, 0.84147096, 0.009999833],
+ [-0.41614684, 0.9998, 0.9092974, 0.019998666],
+ [-0.9899925, 0.99955004, 0.14112, 0.0299955],
+ [-0.6536436, 0.9992001, -0.7568025, 0.039989334],
]),
- Tolerance::rel_abs(1e-3, 1e-3),
+ Tolerance::rel_abs(1e-4, 1e-4),
);
}
+ /// Test get_timestep_embedding with odd channels (padding) - validated against diffusers-rs v0.3.1
#[test]
#[cfg(not(feature = "torch"))]
fn test_timesteps_odd_channels() {
let device = Default::default();
- let timesteps = Timesteps::::new(5, true, 0.);
+ let xs: Tensor = Tensor::from_data(TensorData::from([1., 2., 3.]), &device);
+
+ let emb: Tensor = get_timestep_embedding(xs, 5, true, 0.);
+
+ assert_eq!(emb.shape(), Shape::from([3, 5]));
+ // Reference values from diffusers-rs v0.3.1:
+ emb.into_data().assert_approx_eq::(
+ &TensorData::from([
+ [0.5403023, 0.99995, 0.84147096, 0.009999833, 0.0],
+ [-0.41614684, 0.9998, 0.9092974, 0.019998666, 0.0],
+ [-0.9899925, 0.99955004, 0.14112, 0.0299955, 0.0],
+ ]),
+ Tolerance::rel_abs(1e-4, 1e-4),
+ );
+ }
+
+ /// Test get_timestep_embedding with flip_sin_to_cos=false - validated against diffusers-rs v0.3.1
+ #[test]
+ #[cfg(not(feature = "torch"))]
+ fn test_timesteps_no_flip() {
+ let device = Default::default();
let xs: Tensor =
- Tensor::from_data(TensorData::from([1., 2., 3., 4., 5.]), &device);
+ Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device);
+
+ let emb: Tensor = get_timestep_embedding(xs, 4, false, 0.);
+
+ assert_eq!(emb.shape(), Shape::from([4, 4]));
+ // Reference values from diffusers-rs v0.3.1 with flip_sin_to_cos=false:
+ emb.into_data().assert_approx_eq::(
+ &TensorData::from([
+ [0.84147096, 0.009999833, 0.5403023, 0.99995],
+ [0.9092974, 0.019998666, -0.41614684, 0.9998],
+ [0.14112, 0.0299955, -0.9899925, 0.99955004],
+ [-0.7568025, 0.039989334, -0.6536436, 0.9992001],
+ ]),
+ Tolerance::rel_abs(1e-4, 1e-4),
+ );
+ }
+
+ /// Test get_timestep_embedding with downscale_freq_shift - validated against diffusers-rs v0.3.1
+ #[test]
+ #[cfg(not(feature = "torch"))]
+ fn test_timesteps_with_downscale() {
+ let device = Default::default();
+ let xs: Tensor =
+ Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device);
- let emb: Tensor = timesteps.forward(xs);
+ let emb: Tensor = get_timestep_embedding(xs, 4, true, 1.0);
- assert_eq!(emb.shape(), Shape::from([6, 4]));
+ assert_eq!(emb.shape(), Shape::from([4, 4]));
+ // Reference values from diffusers-rs v0.3.1 with downscale_freq_shift=1.0:
emb.into_data().assert_approx_eq::(
&TensorData::from([
- [0.5403, 1.0000, 0.8415, 0.0100],
- [-0.4161, 0.9998, 0.9093, 0.0200],
- [-0.9900, 0.9996, 0.1411, 0.0300],
- [-0.6536, 0.9992, -0.7568, 0.0400],
- [0.2837, 0.9988, -0.9589, 0.0500],
- [0.0000, 0.0000, 0.0000, 0.0000],
+ [0.5403023, 1.0, 0.84147096, 9.999999e-5],
+ [-0.41614684, 1.0, 0.9092974, 0.00019999998],
+ [-0.9899925, 0.99999994, 0.14112, 0.00029999996],
+ [-0.6536436, 0.99999994, -0.7568025, 0.00039999996],
]),
- Tolerance::rel_abs(1e-3, 1e-3),
+ Tolerance::rel_abs(1e-4, 1e-4),
);
}
}
diff --git a/src/models/mod.rs b/src/models/mod.rs
index 5d7234d..743d60f 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -3,6 +3,9 @@
//! A collection of models to be used in a diffusion loop.
pub mod attention;
+pub mod controlnet;
pub mod embeddings;
pub mod resnet;
+pub mod unet_2d;
pub mod unet_2d_blocks;
+pub mod vae;
diff --git a/src/models/resnet.rs b/src/models/resnet.rs
index 57c1e31..2c7e47c 100644
--- a/src/models/resnet.rs
+++ b/src/models/resnet.rs
@@ -120,7 +120,8 @@ impl ResnetBlock2D {
mod tests {
use super::*;
use crate::TestBackend;
- use burn::tensor::{Distribution, Shape};
+ use alloc::vec::Vec;
+ use burn::tensor::{Distribution, Shape, TensorData, Tolerance};
#[test]
fn test_resnet_block_2d_no_temb() {
@@ -142,4 +143,327 @@ mod tests {
assert_eq!(output.shape(), Shape::from([2, 128, 64, 64]));
}
+
+ /// Test SiLU activation matches diffusers-rs (tch silu)
+ /// Reference values from diffusers-rs v0.3.1
+ #[test]
+ fn test_silu_matches_diffusers_rs() {
+ let device = Default::default();
+ let xs: Tensor = Tensor::from_data(
+ TensorData::from([-2.0f32, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]),
+ &device,
+ );
+
+ let result = silu(xs);
+
+ // Reference values from diffusers-rs: tensor.silu()
+ result.into_data().assert_approx_eq::(
+ &TensorData::from([
+ -0.23840584,
+ -0.26894143,
+ -0.18877034,
+ 0.0,
+ 0.31122968,
+ 0.7310586,
+ 1.761594,
+ ]),
+ Tolerance::rel_abs(1e-4, 1e-4),
+ );
+ }
+
+ /// Test GroupNorm matches diffusers-rs
+ /// Reference values from diffusers-rs v0.3.1 with weight=1, bias=0
+ #[test]
+ fn test_group_norm_matches_diffusers_rs() {
+ let device = Default::default();
+
+ // Create GroupNorm: 2 groups, 4 channels
+ let norm = GroupNormConfig::new(2, 4)
+ .with_epsilon(1e-6)
+ .init::(&device);
+
+ // Set weight to 1 and bias to 0 (default initialization)
+ // GroupNorm in Burn initializes gamma=1, beta=0 by default
+
+ // Input: [batch=1, channels=4, height=2, width=2] with sequential values 1-16
+ let xs: Tensor = Tensor::from_data(
+ TensorData::from([
+ 1.0f32, 2.0, 3.0, 4.0, // channel 0
+ 5.0, 6.0, 7.0, 8.0, // channel 1
+ 9.0, 10.0, 11.0, 12.0, // channel 2
+ 13.0, 14.0, 15.0, 16.0, // channel 3
+ ]),
+ &device,
+ );
+ let xs: Tensor = xs.reshape([1, 4, 2, 2]);
+
+ let result = norm.forward(xs);
+
+ // Reference values from diffusers-rs GroupNorm
+ let result_flat = result.flatten::<1>(0, 3);
+ result_flat.into_data().assert_approx_eq::(
+ &TensorData::from([
+ -1.5275251,
+ -1.0910892,
+ -0.65465355,
+ -0.21821785,
+ 0.21821797,
+ 0.65465367,
+ 1.0910894,
+ 1.5275251,
+ -1.5275252,
+ -1.0910892,
+ -0.65465355,
+ -0.21821785,
+ 0.21821785,
+ 0.65465355,
+ 1.0910892,
+ 1.527525,
+ ]),
+ Tolerance::rel_abs(1e-4, 1e-4),
+ );
+ }
+
+ /// Helper function to set all weights in a Conv2d to a constant value
+ fn set_conv2d_weights(
+ conv: Conv2d,
+ weight_val: f32,
+ bias_val: f32,
+ device: &B::Device,
+ ) -> Conv2d {
+ let weight_shape = conv.weight.shape();
+ let [out_ch, in_ch, kh, kw] = weight_shape.dims();
+
+ // Use Param::map to transform the weight tensor
+ let new_weight = conv
+ .weight
+ .map(|_| Tensor::full([out_ch, in_ch, kh, kw], weight_val, device));
+
+ let new_bias = conv
+ .bias
+ .map(|b| b.map(|_| Tensor::full([out_ch], bias_val, device)));
+
+ Conv2d {
+ weight: new_weight,
+ bias: new_bias,
+ stride: conv.stride,
+ kernel_size: conv.kernel_size,
+ dilation: conv.dilation,
+ groups: conv.groups,
+ padding: conv.padding,
+ }
+ }
+
+ /// Helper function to set GroupNorm weights
+ fn set_group_norm_weights(
+ norm: GroupNorm,
+ gamma_val: f32,
+ beta_val: f32,
+ device: &B::Device,
+ ) -> GroupNorm {
+ let num_channels = norm.num_channels;
+
+ let new_gamma = norm
+ .gamma
+ .map(|g| g.map(|_| Tensor::full([num_channels], gamma_val, device)));
+
+ let new_beta = norm
+ .beta
+ .map(|b| b.map(|_| Tensor::full([num_channels], beta_val, device)));
+
+ GroupNorm {
+ gamma: new_gamma,
+ beta: new_beta,
+ num_groups: norm.num_groups,
+ num_channels: norm.num_channels,
+ epsilon: norm.epsilon,
+ affine: norm.affine,
+ }
+ }
+
+ /// Helper function to set Linear weights
+ fn set_linear_weights(
+ linear: Linear,
+ weight_val: f32,
+ bias_val: f32,
+ device: &B::Device,
+ ) -> Linear {
+ let weight_shape = linear.weight.shape();
+ let [d_input, d_output] = weight_shape.dims();
+
+ let new_weight = linear
+ .weight
+ .map(|_| Tensor::full([d_input, d_output], weight_val, device));
+
+ let new_bias = linear
+ .bias
+ .map(|b| b.map(|_| Tensor::full([d_output], bias_val, device)));
+
+ Linear {
+ weight: new_weight,
+ bias: new_bias,
+ }
+ }
+
+ /// Test ResnetBlock2D with fixed weights matches diffusers-rs
+ /// Reference values from diffusers-rs v0.3.1
+ #[test]
+ fn test_resnet_block_2d_fixed_weights_matches_diffusers_rs() {
+ let device = Default::default();
+
+ // Create ResnetBlock2D: in_channels=4, out_channels=4, groups=2
+ let config = ResnetBlock2DConfig::new(4)
+ .with_out_channels(Some(4))
+ .with_groups(2)
+ .with_groups_out(Some(2))
+ .with_eps(1e-6)
+ .with_use_in_shortcut(Some(false));
+
+ let mut block = config.init::(&device);
+
+ // Set all weights to 0.1 and biases to 0.0 to match diffusers-rs test
+ // NOTE: diffusers-rs sets ALL "weight" params to 0.1, including GroupNorm gamma
+ block.norm1 = set_group_norm_weights(block.norm1, 0.1, 0.0, &device);
+ block.norm2 = set_group_norm_weights(block.norm2, 0.1, 0.0, &device);
+ block.conv1 = set_conv2d_weights(block.conv1, 0.1, 0.0, &device);
+ block.conv2 = set_conv2d_weights(block.conv2, 0.1, 0.0, &device);
+
+ // Input: arange(64).reshape([1, 4, 4, 4]) / 64.0
+ let input_data: Vec = (0..64).map(|i| i as f32 / 64.0).collect();
+ let xs: Tensor =
+ Tensor::from_data(TensorData::from(input_data.as_slice()), &device);
+ let xs: Tensor = xs.reshape([1, 4, 4, 4]);
+
+ let result = block.forward(xs, None);
+
+ // Reference values from diffusers-rs
+ let result_flat = result.clone().flatten::<1>(0, 3);
+ let result_data = result_flat.to_data();
+ let result_vec: Vec = result_data.to_vec().unwrap();
+
+ // Check first 16 values match diffusers-rs reference
+ let expected_first_16 = [
+ -0.087926514,
+ -0.106717095,
+ -0.07181215,
+ -0.0077848956,
+ -0.0029995441,
+ 0.005315691,
+ 0.054237127,
+ 0.10205648,
+ 0.1431311,
+ 0.20416035,
+ 0.25529763,
+ 0.25261912,
+ 0.2456759,
+ 0.31968236,
+ 0.35902193,
+ 0.33475745,
+ ];
+
+ for (i, (actual, expected)) in result_vec
+ .iter()
+ .take(16)
+ .zip(expected_first_16.iter())
+ .enumerate()
+ {
+ assert!(
+ (actual - expected).abs() < 1e-4,
+ "Mismatch at index {}: actual={}, expected={}",
+ i,
+ actual,
+ expected
+ );
+ }
+
+ // Check overall mean matches
+ let mean = result.mean().into_scalar();
+ assert!(
+ (mean - 0.4999196529388428).abs() < 1e-4,
+ "Mean mismatch: actual={}, expected=0.4999196529388428",
+ mean
+ );
+ }
+
+ /// Test ResnetBlock2D with time embedding and fixed weights matches diffusers-rs
+ /// Reference values from diffusers-rs v0.3.1
+ #[test]
+ fn test_resnet_block_2d_with_temb_fixed_weights_matches_diffusers_rs() {
+ let device = Default::default();
+
+ // Create ResnetBlock2D with time embedding: in_channels=4, out_channels=4, temb_channels=8
+ let config = ResnetBlock2DConfig::new(4)
+ .with_out_channels(Some(4))
+ .with_temb_channels(Some(8))
+ .with_groups(2)
+ .with_groups_out(Some(2))
+ .with_eps(1e-6)
+ .with_use_in_shortcut(Some(false));
+
+ let mut block = config.init::(&device);
+
+ // Set all weights to 0.1 and biases to 0.0 to match diffusers-rs test
+ // NOTE: diffusers-rs sets ALL "weight" params to 0.1, including GroupNorm gamma
+ block.norm1 = set_group_norm_weights(block.norm1, 0.1, 0.0, &device);
+ block.norm2 = set_group_norm_weights(block.norm2, 0.1, 0.0, &device);
+ block.conv1 = set_conv2d_weights(block.conv1, 0.1, 0.0, &device);
+ block.conv2 = set_conv2d_weights(block.conv2, 0.1, 0.0, &device);
+ block.time_emb_proj = block
+ .time_emb_proj
+ .map(|proj| set_linear_weights(proj, 0.1, 0.0, &device));
+
+ // Input: arange(64).reshape([1, 4, 4, 4]) / 64.0
+ let input_data: Vec = (0..64).map(|i| i as f32 / 64.0).collect();
+ let xs: Tensor =
+ Tensor::from_data(TensorData::from(input_data.as_slice()), &device);
+ let xs: Tensor = xs.reshape([1, 4, 4, 4]);
+
+ // Time embedding: arange(8).reshape([1, 8]) / 8.0
+ let temb_data: Vec = (0..8).map(|i| i as f32 / 8.0).collect();
+ let temb: Tensor =
+ Tensor::from_data(TensorData::from(temb_data.as_slice()), &device);
+ let temb: Tensor = temb.reshape([1, 8]);
+
+ let result = block.forward(xs, Some(temb));
+
+ // Reference values from diffusers-rs (same as without temb due to the specific input values)
+ let result_flat = result.flatten::<1>(0, 3);
+ let result_data = result_flat.to_data();
+ let result_vec: Vec = result_data.to_vec().unwrap();
+
+ // Check first 16 values match diffusers-rs reference
+ let expected_first_16 = [
+ -0.08792652,
+ -0.10671712,
+ -0.07181221,
+ -0.007784918,
+ -0.0029995516,
+ 0.005315639,
+ 0.0542371,
+ 0.10205645,
+ 0.14313108,
+ 0.2041603,
+ 0.2552976,
+ 0.2526191,
+ 0.24567588,
+ 0.3196823,
+ 0.3590219,
+ 0.33475742,
+ ];
+
+ for (i, (actual, expected)) in result_vec
+ .iter()
+ .take(16)
+ .zip(expected_first_16.iter())
+ .enumerate()
+ {
+ assert!(
+ (actual - expected).abs() < 1e-4,
+ "Mismatch at index {}: actual={}, expected={}",
+ i,
+ actual,
+ expected
+ );
+ }
+ }
}
diff --git a/src/models/unet_2d.rs b/src/models/unet_2d.rs
new file mode 100644
index 0000000..7689c75
--- /dev/null
+++ b/src/models/unet_2d.rs
@@ -0,0 +1,634 @@
+//! 2D UNet Denoising Models
+//!
+//! The 2D UNet models take as input a noisy sample and the current diffusion
+//! timestep and return a denoised version of the input.
+
+use burn::config::Config;
+use burn::module::Module;
+use burn::nn::conv::{Conv2d, Conv2dConfig};
+use burn::nn::{GroupNorm, GroupNormConfig, PaddingConfig2d};
+use burn::tensor::activation::silu;
+use burn::tensor::backend::Backend;
+use burn::tensor::Tensor;
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+use super::embeddings::{get_timestep_embedding, TimestepEmbedding, TimestepEmbeddingConfig};
+use super::unet_2d_blocks::{
+ CrossAttnDownBlock2D, CrossAttnDownBlock2DConfig, CrossAttnUpBlock2D, CrossAttnUpBlock2DConfig,
+ DownBlock2D, DownBlock2DConfig, UNetMidBlock2DCrossAttn, UNetMidBlock2DCrossAttnConfig,
+ UpBlock2D, UpBlock2DConfig,
+};
+
+/// Configuration for a single UNet block.
+#[derive(Debug, Clone, burn::serde::Serialize, burn::serde::Deserialize)]
+pub struct BlockConfig {
+ /// Output channels for this block.
+ pub out_channels: usize,
+ /// Whether to use cross-attention in this block.
+ pub use_cross_attn: bool,
+ /// Number of attention heads.
+ pub attention_head_dim: usize,
+}
+
+impl BlockConfig {
+ /// Create a new block configuration.
+ pub fn new(out_channels: usize) -> Self {
+ Self {
+ out_channels,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ }
+ }
+
+ /// Set whether to use cross-attention.
+ pub fn with_use_cross_attn(mut self, use_cross_attn: bool) -> Self {
+ self.use_cross_attn = use_cross_attn;
+ self
+ }
+
+ /// Set the attention head dimension.
+ pub fn with_attention_head_dim(mut self, attention_head_dim: usize) -> Self {
+ self.attention_head_dim = attention_head_dim;
+ self
+ }
+}
+
+/// Configuration for the UNet2DConditionModel.
+#[derive(Config, Debug)]
+pub struct UNet2DConditionModelConfig {
+ /// Whether to center the input sample.
+ #[config(default = false)]
+ pub center_input_sample: bool,
+ /// Whether to flip sin to cos in timestep embedding.
+ #[config(default = true)]
+ pub flip_sin_to_cos: bool,
+ /// Frequency shift for timestep embedding.
+ #[config(default = 0.0)]
+ pub freq_shift: f64,
+ /// Configuration for each block.
+ pub blocks: Vec,
+ /// Number of ResNet layers per block.
+ #[config(default = 2)]
+ pub layers_per_block: usize,
+ /// Padding for downsampling convolutions.
+ #[config(default = 1)]
+ pub downsample_padding: usize,
+ /// Scale factor for mid block.
+ #[config(default = 1.0)]
+ pub mid_block_scale_factor: f64,
+ /// Number of groups for group normalization.
+ #[config(default = 32)]
+ pub norm_num_groups: usize,
+ /// Epsilon for normalization layers.
+ #[config(default = 1e-5)]
+ pub norm_eps: f64,
+ /// Dimension of cross-attention context.
+ #[config(default = 1280)]
+ pub cross_attention_dim: usize,
+ /// Size for sliced attention (None for full attention).
+ pub sliced_attention_size: Option,
+ /// Whether to use linear projection in attention.
+ #[config(default = false)]
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNet2DConditionModelConfig {
+ fn default() -> Self {
+ Self {
+ center_input_sample: false,
+ flip_sin_to_cos: true,
+ freq_shift: 0.0,
+ blocks: vec![
+ BlockConfig::new(320)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(640)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(1280)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(1280)
+ .with_use_cross_attn(false)
+ .with_attention_head_dim(8),
+ ],
+ layers_per_block: 2,
+ downsample_padding: 1,
+ mid_block_scale_factor: 1.0,
+ norm_num_groups: 32,
+ norm_eps: 1e-5,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+/// Down block types for UNet.
+#[derive(Module, Debug)]
+pub enum UNetDownBlock {
+ Basic(DownBlock2D),
+ CrossAttn(CrossAttnDownBlock2D),
+}
+
+/// Up block types for UNet.
+#[derive(Module, Debug)]
+pub enum UNetUpBlock {
+ Basic(UpBlock2D),
+ CrossAttn(CrossAttnUpBlock2D),
+}
+
+/// UNet2D Conditional Model for denoising diffusion.
+///
+/// This model takes a noisy sample, timestep, and encoder hidden states
+/// (from text conditioning) and predicts the noise to be removed.
+#[derive(Module, Debug)]
+pub struct UNet2DConditionModel {
+ conv_in: Conv2d,
+ time_embedding: TimestepEmbedding,
+ down_blocks: Vec>,
+ mid_block: UNetMidBlock2DCrossAttn,
+ up_blocks: Vec>,
+ conv_norm_out: GroupNorm,
+ conv_out: Conv2d,
+ #[module(skip)]
+ time_proj_channels: usize,
+ #[module(skip)]
+ flip_sin_to_cos: bool,
+ #[module(skip)]
+ freq_shift: f64,
+ #[module(skip)]
+ center_input_sample: bool,
+}
+
+impl UNet2DConditionModelConfig {
+ /// Initialize the UNet2DConditionModel.
+ pub fn init(
+ &self,
+ in_channels: usize,
+ out_channels: usize,
+ device: &B::Device,
+ ) -> UNet2DConditionModel {
+ let n_blocks = self.blocks.len();
+ let b_channels = self.blocks[0].out_channels;
+ let bl_channels = self.blocks.last().unwrap().out_channels;
+ let bl_attention_head_dim = self.blocks.last().unwrap().attention_head_dim;
+ let time_embed_dim = b_channels * 4;
+
+ // Input convolution
+ let conv_in = Conv2dConfig::new([in_channels, b_channels], [3, 3])
+ .with_stride([1, 1])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ // Time embeddings
+ let time_embedding = TimestepEmbeddingConfig::new(b_channels, time_embed_dim).init(device);
+
+ // Down blocks
+ let down_blocks = (0..n_blocks)
+ .map(|i| {
+ let block_config = &self.blocks[i];
+ let out_channels = block_config.out_channels;
+ let attention_head_dim = block_config.attention_head_dim;
+
+ // Enable automatic attention slicing if config sliced_attention_size is 0
+ let sliced_attention_size = match self.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ other => other,
+ };
+
+ let in_ch = if i > 0 {
+ self.blocks[i - 1].out_channels
+ } else {
+ b_channels
+ };
+
+ let db_config = DownBlock2DConfig::new(in_ch, out_channels)
+ .with_temb_channels(Some(time_embed_dim))
+ .with_n_layers(self.layers_per_block)
+ .with_resnet_eps(self.norm_eps)
+ .with_resnet_groups(self.norm_num_groups)
+ .with_add_downsample(i < n_blocks - 1)
+ .with_downsample_padding(self.downsample_padding);
+
+ if block_config.use_cross_attn {
+ let config = CrossAttnDownBlock2DConfig::new(in_ch, out_channels, db_config)
+ .with_temb_channels(Some(time_embed_dim))
+ .with_attn_num_head_channels(attention_head_dim)
+ .with_cross_attention_dim(self.cross_attention_dim)
+ .with_sliced_attention_size(sliced_attention_size)
+ .with_use_linear_projection(self.use_linear_projection);
+ UNetDownBlock::CrossAttn(config.init(device))
+ } else {
+ UNetDownBlock::Basic(db_config.init(device))
+ }
+ })
+ .collect();
+
+ // Mid block
+ let mid_config = UNetMidBlock2DCrossAttnConfig::new(bl_channels)
+ .with_temb_channels(Some(time_embed_dim))
+ .with_resnet_eps(self.norm_eps)
+ .with_output_scale_factor(self.mid_block_scale_factor)
+ .with_cross_attn_dim(self.cross_attention_dim)
+ .with_attn_num_head_channels(bl_attention_head_dim)
+ .with_resnet_groups(Some(self.norm_num_groups))
+ .with_use_linear_projection(self.use_linear_projection);
+ let mid_block = mid_config.init(device);
+
+ // Up blocks
+ let up_blocks = (0..n_blocks)
+ .map(|i| {
+ let block_config = &self.blocks[n_blocks - 1 - i];
+ let out_channels = block_config.out_channels;
+ let attention_head_dim = block_config.attention_head_dim;
+
+ // Enable automatic attention slicing if config sliced_attention_size is 0
+ let sliced_attention_size = match self.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ other => other,
+ };
+
+ let prev_out_channels = if i > 0 {
+ self.blocks[n_blocks - i].out_channels
+ } else {
+ bl_channels
+ };
+
+ let in_ch = {
+ let index = if i == n_blocks - 1 {
+ 0
+ } else {
+ n_blocks - i - 2
+ };
+ self.blocks[index].out_channels
+ };
+
+ let ub_config = UpBlock2DConfig::new(in_ch, prev_out_channels, out_channels)
+ .with_temb_channels(Some(time_embed_dim))
+ .with_n_layers(self.layers_per_block + 1)
+ .with_resnet_eps(self.norm_eps)
+ .with_resnet_groups(self.norm_num_groups)
+ .with_add_upsample(i < n_blocks - 1);
+
+ if block_config.use_cross_attn {
+ let config = CrossAttnUpBlock2DConfig::new(
+ in_ch,
+ prev_out_channels,
+ out_channels,
+ ub_config,
+ )
+ .with_temb_channels(Some(time_embed_dim))
+ .with_attn_num_head_channels(attention_head_dim)
+ .with_cross_attention_dim(self.cross_attention_dim)
+ .with_sliced_attention_size(sliced_attention_size)
+ .with_use_linear_projection(self.use_linear_projection);
+ UNetUpBlock::CrossAttn(config.init(device))
+ } else {
+ UNetUpBlock::Basic(ub_config.init(device))
+ }
+ })
+ .collect();
+
+ // Output layers
+ let conv_norm_out = GroupNormConfig::new(self.norm_num_groups, b_channels)
+ .with_epsilon(self.norm_eps)
+ .init(device);
+
+ let conv_out = Conv2dConfig::new([b_channels, out_channels], [3, 3])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ UNet2DConditionModel {
+ conv_in,
+ time_embedding,
+ down_blocks,
+ mid_block,
+ time_proj_channels: b_channels,
+ flip_sin_to_cos: self.flip_sin_to_cos,
+ freq_shift: self.freq_shift,
+ up_blocks,
+ conv_norm_out,
+ conv_out,
+ center_input_sample: self.center_input_sample,
+ }
+ }
+}
+
+impl UNet2DConditionModel {
+ /// Forward pass through the UNet.
+ ///
+ /// # Arguments
+ /// * `xs` - Noisy input tensor [batch, channels, height, width]
+ /// * `timestep` - Current diffusion timestep
+ /// * `encoder_hidden_states` - Encoder hidden states for cross-attention [batch, seq_len, dim]
+ pub fn forward(
+ &self,
+ xs: Tensor,
+ timestep: f64,
+ encoder_hidden_states: Tensor,
+ ) -> Tensor {
+ self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
+ }
+
+ /// Forward pass with additional residuals (for ControlNet support).
+ pub fn forward_with_additional_residuals(
+ &self,
+ xs: Tensor,
+ timestep: f64,
+ encoder_hidden_states: Tensor,
+ down_block_additional_residuals: Option<&[Tensor]>,
+ mid_block_additional_residual: Option<&Tensor>,
+ ) -> Tensor {
+ let [bsize, _channels, height, width] = xs.dims();
+ let device = xs.device();
+ let n_blocks = self.down_blocks.len();
+ let num_upsamplers = n_blocks - 1;
+ let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
+ let forward_upsample_size =
+ height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
+
+ // 0. Center input if necessary
+ let xs = if self.center_input_sample {
+ xs * 2.0 - 1.0
+ } else {
+ xs
+ };
+
+ // 1. Time embedding
+ let timesteps: Tensor = Tensor::full([bsize], timestep as f32, &device);
+ let emb = get_timestep_embedding(
+ timesteps,
+ self.time_proj_channels,
+ self.flip_sin_to_cos,
+ self.freq_shift,
+ );
+ let emb = self.time_embedding.forward(emb);
+
+ // 2. Pre-process
+ let xs = self.conv_in.forward(xs);
+
+ // 3. Down blocks
+ let mut down_block_res_xs = vec![xs.clone()];
+ let mut xs = xs;
+ for down_block in self.down_blocks.iter() {
+ let (new_xs, res_xs) = match down_block {
+ UNetDownBlock::Basic(b) => b.forward(xs, Some(emb.clone())),
+ UNetDownBlock::CrossAttn(b) => {
+ b.forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone()))
+ }
+ };
+ down_block_res_xs.extend(res_xs);
+ xs = new_xs;
+ }
+
+ // Add additional residuals if provided (for ControlNet)
+ let mut down_block_res_xs = if let Some(additional) = down_block_additional_residuals {
+ down_block_res_xs
+ .iter()
+ .zip(additional.iter())
+ .map(|(r, a)| r.clone() + a.clone())
+ .collect()
+ } else {
+ down_block_res_xs
+ };
+
+ // 4. Mid block
+ let xs = self
+ .mid_block
+ .forward(xs, Some(emb.clone()), Some(encoder_hidden_states.clone()));
+ let xs = match mid_block_additional_residual {
+ Some(m) => xs + m.clone(),
+ None => xs,
+ };
+
+ // 5. Up blocks
+ let mut xs = xs;
+ let mut upsample_size = None;
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let n_resnets = match up_block {
+ UNetUpBlock::Basic(b) => b.resnets.len(),
+ UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
+ };
+ let res_xs: Vec<_> = down_block_res_xs
+ .drain(down_block_res_xs.len() - n_resnets..)
+ .collect();
+
+ if i < n_blocks - 1 && forward_upsample_size {
+ let last = down_block_res_xs.last().unwrap();
+ let [_, _, h, w] = last.dims();
+ upsample_size = Some((h, w));
+ }
+
+ xs = match up_block {
+ UNetUpBlock::Basic(b) => b.forward(xs, &res_xs, Some(emb.clone()), upsample_size),
+ UNetUpBlock::CrossAttn(b) => b.forward(
+ xs,
+ &res_xs,
+ Some(emb.clone()),
+ upsample_size,
+ Some(encoder_hidden_states.clone()),
+ ),
+ };
+ }
+
+ // 6. Post-process
+ let xs = self.conv_norm_out.forward(xs);
+ let xs = silu(xs);
+ self.conv_out.forward(xs)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::TestBackend;
+ use alloc::string::{String, ToString};
+ use burn::module::{Module, ModuleMapper, Param};
+ use burn::tensor::{Shape, TensorData};
+
+ /// A ModuleMapper that sets weights to one value and biases to another.
+ /// Uses the field name from enter_module to distinguish weight vs bias.
+ struct WeightBiasMapper<'a, B: Backend> {
+ weight_val: f32,
+ bias_val: f32,
+ device: &'a B::Device,
+ current_field: String,
+ }
+
+ impl<'a, B: Backend> ModuleMapper for WeightBiasMapper<'a, B> {
+ fn enter_module(&mut self, name: &str, _container_type: &str) {
+ self.current_field = name.to_string();
+ }
+
+ fn map_float(
+ &mut self,
+ tensor: Param>,
+ ) -> Param> {
+ let shape = tensor.shape();
+ let dims: [usize; D] = shape.dims();
+
+ // Use field name to distinguish: "bias" and "beta" get bias_val,
+ // "weight" and "gamma" get weight_val
+ let is_bias = self.current_field.contains("bias") || self.current_field == "beta";
+ let val = if is_bias {
+ self.bias_val
+ } else {
+ self.weight_val
+ };
+
+ Param::initialized(tensor.id, Tensor::full(dims, val, self.device))
+ }
+ }
+
+ /// Set weights and biases to different constant values.
+ fn set_weights_and_biases(
+ unet: UNet2DConditionModel,
+ weight_val: f32,
+ bias_val: f32,
+ device: &B::Device,
+ ) -> UNet2DConditionModel {
+ let mut mapper = WeightBiasMapper:: {
+ weight_val,
+ bias_val,
+ device,
+ current_field: String::new(),
+ };
+ unet.map(&mut mapper)
+ }
+
+ #[test]
+ fn test_unet2d_output_shape() {
+ let device = Default::default();
+
+ // Create a small UNet for testing
+ let config = UNet2DConditionModelConfig {
+ blocks: vec![
+ BlockConfig::new(32)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(64)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ ],
+ layers_per_block: 1,
+ norm_num_groups: 32,
+ cross_attention_dim: 64,
+ ..Default::default()
+ };
+
+ let unet = config.init::(4, 4, &device);
+
+ // Input: [batch=1, channels=4, height=32, width=32]
+ let xs: Tensor = Tensor::zeros([1, 4, 32, 32], &device);
+ // Encoder hidden states: [batch=1, seq_len=8, dim=64]
+ let encoder_hidden_states: Tensor = Tensor::zeros([1, 8, 64], &device);
+
+ let output = unet.forward(xs, 1.0, encoder_hidden_states);
+
+ // Output should have same spatial dimensions as input
+ assert_eq!(output.shape(), Shape::from([1, 4, 32, 32]));
+ }
+
+ /// Test UNet2D forward with fixed weights matches diffusers-rs
+ /// Reference values from diffusers-rs v0.3.1
+ /// Note: Skipped for wgpu due to GPU floating-point precision differences
+ #[test]
+ #[cfg(not(feature = "wgpu"))]
+ fn test_unet2d_fixed_weights_matches_diffusers_rs() {
+ let device = Default::default();
+
+ // Create a small UNet with same config as diffusers-rs test
+ let config = UNet2DConditionModelConfig {
+ blocks: vec![
+ BlockConfig::new(32)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ BlockConfig::new(64)
+ .with_use_cross_attn(true)
+ .with_attention_head_dim(8),
+ ],
+ layers_per_block: 1,
+ norm_num_groups: 32,
+ cross_attention_dim: 64,
+ ..Default::default()
+ };
+
+ let unet = config.init::(4, 4, &device);
+
+ // Set weights to 0.1, biases to 0.0 (matching diffusers-rs test)
+ let unet = set_weights_and_biases(unet, 0.1, 0.0, &device);
+
+ // Input: normalized arange values for reproducibility
+ let input_data: Vec = (0..(4 * 32 * 32))
+ .map(|i| i as f32 / (4.0 * 32.0 * 32.0))
+ .collect();
+ let xs: Tensor =
+ Tensor::from_data(TensorData::from(input_data.as_slice()), &device);
+ let xs: Tensor = xs.reshape([1, 4, 32, 32]);
+
+ // Encoder hidden states: normalized values
+ let enc_data: Vec = (0..(8 * 64)).map(|i| i as f32 / (8.0 * 64.0)).collect();
+ let encoder_hidden_states: Tensor =
+ Tensor::from_data(TensorData::from(enc_data.as_slice()), &device);
+ let encoder_hidden_states: Tensor =
+ encoder_hidden_states.reshape([1, 8, 64]);
+
+ let output = unet.forward(xs, 1.0, encoder_hidden_states);
+
+ // Verify output shape
+ assert_eq!(output.shape(), Shape::from([1, 4, 32, 32]));
+
+ // Get result values
+ let result_flat = output.clone().flatten::<1>(0, 3);
+ let result_data = result_flat.to_data();
+ let result_vec: Vec = result_data.to_vec().unwrap();
+
+ // Reference values from diffusers-rs v0.3.1
+ let expected_first_16 = [
+ -1.7083509_f32,
+ -2.3742673,
+ -1.9932699,
+ -1.836023,
+ -1.7530675,
+ -1.7501539,
+ -1.7482271,
+ -1.7476634,
+ -1.7470942,
+ -1.7467214,
+ -1.7462531,
+ -1.7458805,
+ -1.7454054,
+ -1.7450254,
+ -1.7445545,
+ -1.7441819,
+ ];
+
+ for (i, (actual, expected)) in result_vec
+ .iter()
+ .take(16)
+ .zip(expected_first_16.iter())
+ .enumerate()
+ {
+ assert!(
+ (actual - expected).abs() < 1e-3,
+ "Mismatch at index {}: actual={}, expected={}",
+ i,
+ actual,
+ expected
+ );
+ }
+
+ // Check overall mean (reference: 0.1854093372821808)
+ let mean = output.clone().mean().into_scalar();
+ let expected_mean = 0.18540934_f32;
+ assert!(
+ (mean - expected_mean).abs() < 1e-4,
+ "Mean mismatch: actual={}, expected={}",
+ mean,
+ expected_mean
+ );
+ }
+}
diff --git a/src/models/unet_2d_blocks.rs b/src/models/unet_2d_blocks.rs
index 9526e7b..22841c6 100644
--- a/src/models/unet_2d_blocks.rs
+++ b/src/models/unet_2d_blocks.rs
@@ -66,7 +66,7 @@ impl Downsample2D {
return pad_with_zeros(xs, 4 - 2, 0, 1);
}
- return xs;
+ xs
}
fn forward(&self, xs: Tensor) -> Tensor {
@@ -187,7 +187,7 @@ impl DownEncoderBlock2DConfig {
}
impl DownEncoderBlock2D {
- fn forward(&self, xs: Tensor) -> Tensor {
+ pub fn forward(&self, xs: Tensor) -> Tensor {
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(xs, None)
@@ -259,7 +259,7 @@ impl UpDecoderBlock2DConfig {
}
impl UpDecoderBlock2D {
- fn forward(&self, xs: Tensor) -> Tensor {
+ pub fn forward(&self, xs: Tensor) -> Tensor {
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(xs, None)
@@ -477,15 +477,21 @@ pub struct DownBlock2DConfig {
#[derive(Module, Debug)]
pub struct DownBlock2D {
- resnets: Vec>,
+ pub resnets: Vec>,
downsampler: Option>,
}
impl DownBlock2DConfig {
pub fn init(&self, device: &B::Device) -> DownBlock2D {
let resnets = (0..self.n_layers)
- .map(|_| {
- ResnetBlock2DConfig::new(self.out_channels)
+ .map(|i| {
+ let in_channels = if i == 0 {
+ self.in_channels
+ } else {
+ self.out_channels
+ };
+ ResnetBlock2DConfig::new(in_channels)
+ .with_out_channels(Some(self.out_channels))
.with_eps(self.resnet_eps)
.with_groups(self.resnet_groups)
.with_output_scale_factor(self.output_scale_factor)
@@ -636,7 +642,7 @@ pub struct UpBlock2DConfig {
#[derive(Module, Debug)]
pub struct UpBlock2D {
- resnets: Vec>,
+ pub resnets: Vec>,
upsampler: Option>,
}
@@ -922,4 +928,40 @@ mod tests {
assert_eq!(output.shape(), Shape::new([4, 32, 64, 64]));
}
+
+ /// Test Downsample2D avg_pool matches diffusers-rs
+ /// Reference values from diffusers-rs v0.3.1: avg_pool2d([2,2], [2,2], [0,0], false, true, None)
+ #[test]
+ fn test_downsample_2d_avg_pool_matches_diffusers_rs() {
+ let device = Default::default();
+ let tensor: Tensor = Tensor::from_data(
+ TensorData::from([
+ [
+ [[0.0351f32, 0.4179], [0.0137, 0.6947]],
+ [[0.9526, 0.5386], [0.2856, 0.1839]],
+ [[0.3215, 0.4595], [0.6777, 0.3946]],
+ [[0.5221, 0.4230], [0.2774, 0.1069]],
+ ],
+ [
+ [[0.8941, 0.8696], [0.5735, 0.8750]],
+ [[0.6718, 0.4144], [0.1038, 0.2629]],
+ [[0.7467, 0.9415], [0.5005, 0.6309]],
+ [[0.6534, 0.2019], [0.3670, 0.8074]],
+ ],
+ ]),
+ &device,
+ );
+
+ let downsample_2d = Downsample2DConfig::new(4, false, 4, 0).init(&device);
+ let output = downsample_2d.forward(tensor);
+
+ // Reference values from diffusers-rs: [0.29035002, 0.49017498, 0.463325, 0.33235, 0.80305, 0.363225, 0.7049, 0.507425]
+ output.into_data().assert_approx_eq::(
+ &TensorData::from([
+ [[[0.29035002f32]], [[0.49017498]], [[0.463325]], [[0.33235]]],
+ [[[0.80305]], [[0.363225]], [[0.7049]], [[0.507425]]],
+ ]),
+ Tolerance::rel_abs(1e-4, 1e-4),
+ );
+ }
}
diff --git a/src/models/vae.rs b/src/models/vae.rs
new file mode 100644
index 0000000..60e24c7
--- /dev/null
+++ b/src/models/vae.rs
@@ -0,0 +1,605 @@
+//! # Variational Auto-Encoder (VAE) Models.
+//!
+//! Auto-encoder models compress their input to a usually smaller latent space
+//! before expanding it back to its original shape. This results in the latent values
+//! compressing the original information.
+
+use burn::config::Config;
+use burn::module::Module;
+use burn::nn::conv::{Conv2d, Conv2dConfig};
+use burn::nn::{GroupNorm, GroupNormConfig, PaddingConfig2d};
+use burn::tensor::activation::silu;
+use burn::tensor::backend::Backend;
+use burn::tensor::{Distribution, Tensor};
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+use super::unet_2d_blocks::{
+ DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
+ UpDecoderBlock2D, UpDecoderBlock2DConfig,
+};
+
+/// Configuration for the VAE Encoder.
+#[derive(Config, Debug)]
+pub struct EncoderConfig {
+ /// Number of input channels (e.g., 3 for RGB images).
+ in_channels: usize,
+ /// Number of output channels (latent channels).
+ out_channels: usize,
+ /// Output channels for each block.
+ block_out_channels: Vec,
+ /// Number of resnet layers per block.
+ #[config(default = 2)]
+ layers_per_block: usize,
+ /// Number of groups for group normalization.
+ #[config(default = 32)]
+ norm_num_groups: usize,
+ /// Whether to output double channels for mean and logvar.
+ #[config(default = true)]
+ double_z: bool,
+}
+
+/// VAE Encoder - compresses images to latent space.
+#[derive(Module, Debug)]
+pub struct Encoder {
+ conv_in: Conv2d,
+ down_blocks: Vec>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: GroupNorm,
+ conv_out: Conv2d,
+}
+
+impl EncoderConfig {
+ /// Initialize the Encoder.
+ pub fn init(&self, device: &B::Device) -> Encoder {
+ let conv_in = Conv2dConfig::new([self.in_channels, self.block_out_channels[0]], [3, 3])
+ .with_stride([1, 1])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ let mut down_blocks = vec![];
+ for index in 0..self.block_out_channels.len() {
+ let out_channels = self.block_out_channels[index];
+ let in_channels = if index > 0 {
+ self.block_out_channels[index - 1]
+ } else {
+ self.block_out_channels[0]
+ };
+ let is_final = index + 1 == self.block_out_channels.len();
+
+ let down_block = DownEncoderBlock2DConfig::new(in_channels, out_channels)
+ .with_n_layers(self.layers_per_block)
+ .with_resnet_eps(1e-6)
+ .with_resnet_groups(self.norm_num_groups)
+ .with_add_downsample(!is_final)
+ .with_downsample_padding(0)
+ .init(device);
+
+ down_blocks.push(down_block);
+ }
+
+ let last_block_out_channels = *self.block_out_channels.last().unwrap();
+
+ let mid_block = UNetMidBlock2DConfig::new(last_block_out_channels)
+ .with_resnet_eps(1e-6)
+ .with_output_scale_factor(1.0)
+ .with_attn_num_head_channels(None)
+ .with_resnet_groups(Some(self.norm_num_groups))
+ .with_n_layers(1)
+ .init(device);
+
+ let conv_norm_out = GroupNormConfig::new(self.norm_num_groups, last_block_out_channels)
+ .with_epsilon(1e-6)
+ .init(device);
+
+ let conv_out_channels = if self.double_z {
+ 2 * self.out_channels
+ } else {
+ self.out_channels
+ };
+
+ let conv_out = Conv2dConfig::new([last_block_out_channels, conv_out_channels], [3, 3])
+ .with_padding(PaddingConfig2d::Explicit(1, 1))
+ .init(device);
+
+ Encoder {
+ conv_in,
+ down_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ }
+ }
+}
+
+impl Encoder {
+ /// Forward pass through the encoder.
+ pub fn forward(&self, xs: Tensor) -> Tensor {
+ let mut xs = self.conv_in.forward(xs);
+
+ for down_block in self.down_blocks.iter() {
+ xs = down_block.forward(xs);
+ }
+
+ let xs = self.mid_block.forward(xs, None);
+ let xs = self.conv_norm_out.forward(xs);
+ let xs = silu(xs);
+ self.conv_out.forward(xs)
+ }
+}
+
+/// Configuration for the VAE Decoder.
+#[derive(Config, Debug)]
+pub struct DecoderConfig {
+ /// Number of input channels (latent channels).
+ in_channels: usize,
+ /// Number of output channels (e.g., 3 for RGB images).
+ out_channels: usize,
+ /// Output channels for each block.
+ block_out_channels: Vec,
+ /// Number of resnet layers per block.
+ #[config(default = 2)]
+ layers_per_block: usize,
+ /// Number of groups for group normalization.
+ #[config(default = 32)]
+ norm_num_groups: usize,
+}
+
+/// VAE Decoder - expands latent space back to images.
+#[derive(Module, Debug)]
+pub struct Decoder {
+ conv_in: Conv2d,
+ up_blocks: Vec>,
+ mid_block: UNetMidBlock2D