From f22dbe6bb2cc84837ea807ab635c931ea5277e02 Mon Sep 17 00:00:00 2001 From: trivernis Date: Thu, 16 Nov 2023 20:03:12 +0100 Subject: [PATCH 1/8] Fix compile errors in clip --- Cargo.toml | 1 + src/transformers/clip.rs | 142 +++++++++++++++++++++++++-------------- 2 files changed, 92 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7d96477..c9aa870 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ anyhow = "1.0.75" burn = "0.10.0" burn-tch = "0.10.0" clap = { version = "4.4.7", features = ["derive"] } +serde = { version = "1.0.192", features = ["derive"] } diff --git a/src/transformers/clip.rs b/src/transformers/clip.rs index 5f87542..20153db 100644 --- a/src/transformers/clip.rs +++ b/src/transformers/clip.rs @@ -5,17 +5,20 @@ //! //! https://github.com/openai/CLIP +use std::alloc::GlobalAlloc; +use std::cmp::max; + +use burn::tensor::activation::softmax; +use burn::tensor::ops::TensorOps; use burn::{ module::Module, nn, tensor::{ - Int, - Tensor, activation::{gelu, sigmoid}, backend::Backend, - } + Int, Tensor, + }, }; -use burn::tensor::activation::softmax; #[derive(Module, Debug, Clone, Copy)] pub enum Activation { @@ -35,7 +38,7 @@ impl Activation { #[derive(Debug, Clone)] pub struct Config { vocab_size: usize, - embed_dim: usize, // aka config.hidden_size + embed_dim: usize, // aka config.hidden_size activation: Activation, // aka config.hidden_act intermediate_size: usize, max_position_embeddings: usize, @@ -124,25 +127,35 @@ impl Config { struct ClipTextEmbeddings { token_embedding: nn::Embedding, position_embedding: nn::Embedding, - position_ids: Tensor, + position_ids: Tensor, } impl ClipTextEmbeddings { - fn new(device: &Backend::Device, c: &Config) -> Self { + fn new(device: &B::Device, c: &Config) -> Self { let token_embedding = nn::EmbeddingConfig::new(c.vocab_size, c.embed_dim).init(); - let position_embedding = nn::EmbeddingConfig::new(c.max_position_embeddings, c.embed_dim).init(); - let position_ids = Tensor::arange_device(0..c.max_position_embeddings, device).unsqueeze(); - ClipTextEmbeddings { token_embedding, position_embedding, position_ids } + let position_embedding = + nn::EmbeddingConfig::new(c.max_position_embeddings, c.embed_dim).init(); + let position_ids = Tensor::arange_device(0..c.max_position_embeddings, device) + .unsqueeze() + .float(); + + ClipTextEmbeddings { + token_embedding, + position_embedding, + position_ids, + } } pub fn forward(&self, xs: Tensor) -> Tensor { let token_embedding = self.token_embedding.forward(xs); - let position_embedding = self.position_embedding.forward(self.position_ids.clone()); + let position_embedding = self + .position_embedding + .forward(self.position_ids.to_owned().int()); token_embedding + position_embedding } } -#[derive(Debug)] +#[derive(Module, Debug)] struct ClipAttention { k_proj: nn::Linear, v_proj: nn::Linear, @@ -155,7 +168,13 @@ struct ClipAttention { impl ClipAttention { fn new(c: &Config) -> Self { - assert_eq!(c.embed_dim % c.num_attention_heads, 0, "embed_dim {} must be a multiple of num_attention_heads {}", c.embed_dim, c.num_attention_heads); + assert_eq!( + c.embed_dim % c.num_attention_heads, + 0, + "embed_dim {} must be a multiple of num_attention_heads {}", + c.embed_dim, + c.num_attention_heads + ); let embed_dim = c.embed_dim; let num_attention_heads = c.num_attention_heads; @@ -171,50 +190,62 @@ impl ClipAttention { let out_proj = nn::LinearConfig::new(embed_dim, embed_dim).init(); let head_dim = embed_dim / num_attention_heads; let scale = (head_dim as f64).powf(-0.5); - ClipAttention { k_proj, v_proj, q_proj, out_proj, head_dim, scale, num_attention_heads } + ClipAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + } } - fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Tensor { + fn shape(&self, xs: Tensor, seq_len: usize, bsz: usize) -> Tensor { xs.reshape([bsz, seq_len, self.num_attention_heads, self.head_dim]) .swap_dims(1, 2) - .contiguous() + // .contiguous() // TODO: Figure out if this is needed or if we can abstract over memory } pub fn forward(&self, xs: Tensor, causal_attention_mask: &Tensor) -> Tensor { let [bsz, seq_len, embed_dim] = xs.dims(); let query_states = self.q_proj.forward(xs.clone()) * self.scale; - let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let proj_shape = [bsz * self.num_attention_heads, seq_len, self.head_dim]; + let query_states = self - .shape(&query_states, seq_len, bsz) + .shape(query_states, seq_len, bsz) .reshape(proj_shape) .to_full_precision(); let key_states = self - .shape(&self.k_proj.forward(xs.clone()), seq_len, bsz) + .shape(self.k_proj.forward(xs.clone()), seq_len, bsz) .reshape(proj_shape) .to_full_precision(); let value_states = self - .shape(&self.v_proj.forward(xs), seq_len, bsz) + .shape(self.v_proj.forward(xs), seq_len, bsz) .reshape(proj_shape) .to_full_precision(); + let src_len = key_states.dims()[1]; let attn_weights = query_states.matmul(key_states.swap_dims(1, 2)); - let src_len = key_states.dims()[1]; let attn_weights = attn_weights .reshape([bsz, self.num_attention_heads, seq_len, src_len]) - + causal_attention_mask.unsqueeze::<4>(); - let attn_weights = attn_weights - .reshape([bsz * self.num_attention_heads, seq_len, src_len]); + .add( + causal_attention_mask + .to_owned() + .unsqueeze::<4>() + .to_full_precision(), + ); + let attn_weights = attn_weights.reshape([bsz * self.num_attention_heads, seq_len, src_len]); let attn_weights = softmax(attn_weights, 3); - let attn_output = attn_weights - .matmul(value_states) - .to_full_precision(); + let attn_output = attn_weights.matmul(value_states); let attn_output = attn_output .reshape([bsz, self.num_attention_heads, seq_len, self.head_dim]) .swap_dims(1, 2) .reshape([bsz, seq_len, embed_dim]); - self.out_proj.forward(attn_output) + self.out_proj + .forward(Tensor::from_full_precision(attn_output)) } } @@ -229,7 +260,11 @@ impl ClipMlp { fn new(c: &Config) -> Self { let fc1 = nn::LinearConfig::new(c.embed_dim, c.intermediate_size).init(); let fc2 = nn::LinearConfig::new(c.intermediate_size, c.embed_dim).init(); - ClipMlp { fc1, fc2, activation: c.activation } + ClipMlp { + fc1, + fc2, + activation: c.activation, + } } fn forward(&self, xs: Tensor) -> Tensor { @@ -238,7 +273,7 @@ impl ClipMlp { } } -#[derive(Debug)] +#[derive(Module, Debug)] struct ClipEncoderLayer { self_attn: ClipAttention, layer_norm1: nn::LayerNorm, @@ -252,23 +287,28 @@ impl ClipEncoderLayer { let layer_norm1 = nn::LayerNormConfig::new(c.embed_dim).init(); let mlp = ClipMlp::new(c); let layer_norm2 = nn::LayerNormConfig::new(c.embed_dim).init(); - ClipEncoderLayer { self_attn, layer_norm1, mlp, layer_norm2 } + ClipEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + } } pub fn forward(&self, xs: Tensor, causal_attention_mask: &Tensor) -> Tensor { let residual = xs; - let xs = self.layer_norm1.forward(xs.clone()); + let xs = self.layer_norm1.forward(residual.clone()); let xs = self.self_attn.forward(xs, causal_attention_mask); - let xs = xs + residual; + let xs2 = xs.clone() + residual; - let residual = xs; - let xs = self.layer_norm2.forward(xs.clone())?; - let xs = self.mlp.forward(xs)?; + let residual = xs2; + let xs = self.layer_norm2.forward(xs.clone()); + let xs = self.mlp.forward(xs); xs + residual } } -#[derive(Debug)] +#[derive(Module, Debug)] struct ClipEncoder { layers: Vec>, } @@ -301,7 +341,7 @@ pub struct ClipTextTransformer { } impl ClipTextTransformer { - pub fn new(device: &Backend::Device, c: &Config) -> Self { + pub fn new(device: &B::Device, c: &Config) -> Self { let embeddings = ClipTextEmbeddings::new(device, c); let encoder = ClipEncoder::new(c); let final_layer_norm = nn::LayerNormConfig::new(c.embed_dim).init(); @@ -313,7 +353,7 @@ impl ClipTextTransformer { } // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 - fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Backend::Device) -> Tensor { + fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &B::Device) -> Tensor { let mask = Tensor::full_device([bsz, seq_len, seq_len], f32::MIN, device); let mask = zero_lower_diagonal(mask); // zero out the lower diagonal let mask = mask.unsqueeze(); // expand mask @@ -321,21 +361,21 @@ impl ClipTextTransformer { } fn forward(&self, xs: Tensor) -> Tensor { - let [bsz, seq_len, _] = xs.dims(); - let xs = self.embeddings.forward(xs)?; - let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device()); - let xs = self.encoder.forward(xs, &causal_attention_mask)?; + let [bsz, seq_len] = xs.dims(); + let xs = self.embeddings.forward(xs); + let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, &xs.device()); + let xs = self.encoder.forward(xs, &causal_attention_mask); self.final_layer_norm.forward(xs) } } -fn zero_lower_diagonal(mut xs: Tensor) -> Tensor { - let [bsz, seq_len, _] = xs.dims(); - for i in 0..seq_len { - for j in 0..i { - xs[[i, j, 0]] = 0.0; // Assuming the third dimension is the channel/diagonal dimension - } - } +fn zero_lower_diagonal(xs: Tensor) -> Tensor { + let [m, n, _] = xs.dims(); + + // build an upper-triangle matrix + let upper_diag = (0..max(m, n)) + .map(Tensor::::diagonal) + .fold(Tensor::zeros([max(m, n); 2]), Tensor::add); - xs + upper_diag.reshape([m, n]).unsqueeze().float().mul(xs) } From 9d53644bdd889055c3b5fd5e3ba129314dd1591a Mon Sep 17 00:00:00 2001 From: trivernis Date: Sun, 19 Nov 2023 17:53:08 +0100 Subject: [PATCH 2/8] Add ddim initialization code --- Cargo.toml | 1 + src/lib.rs | 1 + src/schedulers/ddim.rs | 155 +++++++++++++++++++++++++++++++++++++++ src/schedulers/mod.rs | 23 ++++++ src/transformers/clip.rs | 1 - 5 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 src/schedulers/ddim.rs create mode 100644 src/schedulers/mod.rs diff --git a/Cargo.toml b/Cargo.toml index c9aa870..e3beac0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,4 +16,5 @@ anyhow = "1.0.75" burn = "0.10.0" burn-tch = "0.10.0" clap = { version = "4.4.7", features = ["derive"] } +num-traits = "0.2.17" serde = { version = "1.0.192", features = ["derive"] } diff --git a/src/lib.rs b/src/lib.rs index 1e51cfe..4f8263d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod pipelines; +pub mod schedulers; pub mod transformers; // pub fn add(left: usize, right: usize) -> usize { diff --git a/src/schedulers/ddim.rs b/src/schedulers/ddim.rs new file mode 100644 index 0000000..0d921c9 --- /dev/null +++ b/src/schedulers/ddim.rs @@ -0,0 +1,155 @@ +use num_traits::ToPrimitive; +use std::marker::PhantomData; + +use burn::{ + module::Module, + tensor::{backend::Backend, Data, ElementConversion, Shape, Tensor}, +}; + +use super::{BetaSchedule, PredictionType}; + +/// The configuration for the DDIM scheduler. +#[derive(Module, Debug, Clone)] +pub struct DDIMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// The amount of noise to be added at each step. + pub eta: f64, + /// Adjust the indexes of the inference schedule by this value. + pub steps_offset: usize, + /// prediction type of the scheduler function + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model + pub train_timesteps: usize, +} + +impl Default for DDIMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085f64, + beta_end: 0.012f64, + beta_schedule: BetaSchedule::ScaledLinear, + eta: 0., + steps_offset: 1, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +#[derive(Module, Debug)] +pub struct DDIMScheduler { + timesteps: Vec, + alphas_cumprod: Vec, + step_ratio: usize, + init_noise_sigma: f64, + config: DDIMSchedulerConfig, + __phantom: PhantomData, +} + +impl DDIMScheduler { + pub fn new(device: &B::Device, inference_steps: usize, config: DDIMSchedulerConfig) -> Self { + let step_ratio = config.train_timesteps / inference_steps; + let timesteps = (0..inference_steps) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(); + let betas = match config.beta_schedule { + BetaSchedule::Linear => linear_tensor::( + device, + config.beta_start, + config.beta_end, + config.train_timesteps, + ), + BetaSchedule::ScaledLinear => scaled_linear_tensor::( + device, + config.beta_start, + config.beta_end, + config.train_timesteps, + ), + BetaSchedule::SquaredcosCapV2 => { + squared_cos_tensor::(device, config.train_timesteps, 0.999) + } + }; + + let betas_vec: Vec = betas.to_data().value; + let mut alphas_cumprod = Vec::with_capacity(betas_vec.len()); + + for beta in &betas_vec { + let alpha = 1.0 - beta.to_f64().expect("beta to be a float"); + alphas_cumprod.push(alpha * alphas_cumprod.last().copied().unwrap_or(1.0)) + } + + Self { + timesteps, + alphas_cumprod, + step_ratio, + init_noise_sigma: 1.0, + config, + __phantom: PhantomData, + } + } + + pub fn timesteps(&self) -> &[usize] { + &self.timesteps.as_slice() + } +} + +fn scaled_linear_tensor( + device: &B::Device, + start: f64, + end: f64, + num_steps: usize, +) -> Tensor { + linear_tensor(device, start.sqrt(), end.sqrt(), num_steps) +} + +/// Creates a linear tensor (vector) with the values `start..end` evenly distributed +/// over `num_steps` +fn linear_tensor( + device: &B::Device, + start: f64, + end: f64, + num_steps: usize, +) -> Tensor { + let mut cur = start; + let mut betas = Vec::with_capacity(num_steps); + + assert!(start < end); + + let step_size = (end - start) / num_steps as f64; + + assert!(step_size > 0.0); + + while cur < end { + betas.push(cur.elem()); + cur += step_size; + } + Tensor::from_data_device(Data::new(betas, Shape::new([betas.len()])), device) +} + +/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of +/// `(1-beta)` over time from `t = [0,1]`. +/// +/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)` +/// up to that part of the diffusion process. +fn squared_cos_tensor( + device: &B::Device, + num_diffusion_timesteps: usize, + max_beta: f64, +) -> Tensor { + let alpha_bar = |time_step: usize| { + f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2) + }; + let mut betas = Vec::with_capacity(num_diffusion_timesteps); + for i in 0..num_diffusion_timesteps { + let t1 = i / num_diffusion_timesteps; + let t2 = (i + 1) / num_diffusion_timesteps; + betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta).elem()); + } + Tensor::from_data_device(Data::new(betas, Shape::new([betas.len()])), device) +} diff --git a/src/schedulers/mod.rs b/src/schedulers/mod.rs new file mode 100644 index 0000000..9402f5b --- /dev/null +++ b/src/schedulers/mod.rs @@ -0,0 +1,23 @@ +pub mod ddim; + +/// This represents how beta ranges from its minimum value to the maximum +/// during training. +#[derive(Debug, Clone, Copy)] +pub enum BetaSchedule { + /// Linear interpolation. + Linear, + /// Linear interpolation of the square root of beta. + ScaledLinear, + /// Glide cosine schedule + SquaredcosCapV2, +} + +/// prediction type of the scheduler function, one of `epsilon` (predicting +/// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) +/// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) +#[derive(Debug, Clone, Copy)] +pub enum PredictionType { + Epsilon, + VPrediction, + Sample, +} diff --git a/src/transformers/clip.rs b/src/transformers/clip.rs index 20153db..453bb4b 100644 --- a/src/transformers/clip.rs +++ b/src/transformers/clip.rs @@ -5,7 +5,6 @@ //! //! https://github.com/openai/CLIP -use std::alloc::GlobalAlloc; use std::cmp::max; use burn::tensor::activation::softmax; From fb7ce918799ef41a84d30e3a3de4b6d279be9e5e Mon Sep 17 00:00:00 2001 From: trivernis Date: Sun, 19 Nov 2023 19:26:37 +0100 Subject: [PATCH 3/8] Implement `add_noise` for ddim --- src/schedulers/ddim.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/schedulers/ddim.rs b/src/schedulers/ddim.rs index 0d921c9..b374641 100644 --- a/src/schedulers/ddim.rs +++ b/src/schedulers/ddim.rs @@ -97,6 +97,23 @@ impl DDIMScheduler { pub fn timesteps(&self) -> &[usize] { &self.timesteps.as_slice() } + + pub fn add_noise( + &self, + original: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Tensor { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + + original.mul_scalar(sqrt_alpha_prod) + noise.mul_scalar(sqrt_one_minus_alpha_prod) + } } fn scaled_linear_tensor( From ece5bd7b4d49ba09244fe4ed956bb1ea099f8a1e Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 25 Nov 2023 13:58:14 +0100 Subject: [PATCH 4/8] Add ddim step implementation --- src/schedulers/ddim.rs | 78 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/src/schedulers/ddim.rs b/src/schedulers/ddim.rs index b374641..3124972 100644 --- a/src/schedulers/ddim.rs +++ b/src/schedulers/ddim.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use burn::{ module::Module, - tensor::{backend::Backend, Data, ElementConversion, Shape, Tensor}, + tensor::{backend::Backend, Data, Distribution, ElementConversion, Shape, Tensor}, }; use super::{BetaSchedule, PredictionType}; @@ -94,8 +94,67 @@ impl DDIMScheduler { } } - pub fn timesteps(&self) -> &[usize] { - &self.timesteps.as_slice() + // Perform a backward step + pub fn step( + &self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + let prev_timestep = if timestep > self.step_ratio { + timestep - self.step_ratio + } else { + 0 + }; + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]; + let beta_prod_t = 1. - alpha_prod_t; + let beta_prod_t_prev = 1. - alpha_prod_t_prev; + + let (pred_original_sample, pred_epsilon) = match self.config.prediction_type { + PredictionType::Epsilon => { + let pred_original_sample = sample.sub(model_output.mul_scalar(beta_prod_t.sqrt())) + * (1. / alpha_prod_t.sqrt()); + (pred_original_sample, model_output.clone()) + } + PredictionType::VPrediction => { + let pred_original_sample = sample.mul_scalar(alpha_prod_t.sqrt()) + - model_output.mul_scalar(beta_prod_t.sqrt()); + let pred_epsilon = model_output.mul_scalar(alpha_prod_t.sqrt()) + + sample.mul_scalar(beta_prod_t.sqrt()); + (pred_original_sample, pred_epsilon) + } + PredictionType::Sample => { + let pred_original_sample = model_output.clone(); + let pred_epsilon = sample.sub(pred_original_sample.mul_scalar(alpha_prod_t.sqrt())) + * (1. / beta_prod_t.sqrt()); + (pred_original_sample, pred_epsilon) + } + }; + + let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev); + let std_dev_t = self.config.eta * variance.sqrt(); + + let pred_sample_direction = + pred_epsilon.mul_scalar((1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt()); + let prev_sample = + pred_original_sample.mul_scalar(alpha_prod_t_prev.sqrt()) + pred_sample_direction; + + if self.config.eta > 0. { + prev_sample + + Tensor::random_device( + prev_sample.shape(), + Distribution::Normal(0f64, std_dev_t as f64), + &prev_sample.device(), + ) + } else { + prev_sample + } } pub fn add_noise( @@ -114,6 +173,14 @@ impl DDIMScheduler { original.mul_scalar(sqrt_alpha_prod) + noise.mul_scalar(sqrt_one_minus_alpha_prod) } + + pub fn timesteps(&self) -> &[usize] { + &self.timesteps.as_slice() + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } } fn scaled_linear_tensor( @@ -152,8 +219,8 @@ fn linear_tensor( /// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of /// `(1-beta)` over time from `t = [0,1]`. /// -/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)` -/// up to that part of the diffusion process. +/// Contains a function `alpha_bar` that takes an argument `t` and transforms +/// it to the cumulative product of `(1-beta)` up to that part of the diffusion process. fn squared_cos_tensor( device: &B::Device, num_diffusion_timesteps: usize, @@ -168,5 +235,6 @@ fn squared_cos_tensor( let t2 = (i + 1) / num_diffusion_timesteps; betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta).elem()); } + Tensor::from_data_device(Data::new(betas, Shape::new([betas.len()])), device) } From 56bd7a5fdc6d3d6092eaa0d13835d368d594dae2 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 25 Nov 2023 14:08:22 +0100 Subject: [PATCH 5/8] Fix build errors in ddim --- src/schedulers/ddim.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/schedulers/ddim.rs b/src/schedulers/ddim.rs index 3124972..b698283 100644 --- a/src/schedulers/ddim.rs +++ b/src/schedulers/ddim.rs @@ -97,9 +97,9 @@ impl DDIMScheduler { // Perform a backward step pub fn step( &self, - model_output: &Tensor, + model_output: Tensor, timestep: usize, - sample: &Tensor, + sample: Tensor, ) -> Tensor { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 @@ -118,22 +118,22 @@ impl DDIMScheduler { let (pred_original_sample, pred_epsilon) = match self.config.prediction_type { PredictionType::Epsilon => { - let pred_original_sample = sample.sub(model_output.mul_scalar(beta_prod_t.sqrt())) + let pred_original_sample = sample + .sub(model_output.clone().mul_scalar(beta_prod_t.sqrt())) * (1. / alpha_prod_t.sqrt()); - (pred_original_sample, model_output.clone()) + (pred_original_sample, model_output) } PredictionType::VPrediction => { - let pred_original_sample = sample.mul_scalar(alpha_prod_t.sqrt()) - - model_output.mul_scalar(beta_prod_t.sqrt()); + let pred_original_sample = sample.clone().mul_scalar(alpha_prod_t.sqrt()) + - model_output.clone().mul_scalar(beta_prod_t.sqrt()); let pred_epsilon = model_output.mul_scalar(alpha_prod_t.sqrt()) + sample.mul_scalar(beta_prod_t.sqrt()); (pred_original_sample, pred_epsilon) } PredictionType::Sample => { - let pred_original_sample = model_output.clone(); - let pred_epsilon = sample.sub(pred_original_sample.mul_scalar(alpha_prod_t.sqrt())) + let pred_epsilon = sample.sub(model_output.clone().mul_scalar(alpha_prod_t.sqrt())) * (1. / beta_prod_t.sqrt()); - (pred_original_sample, pred_epsilon) + (model_output, pred_epsilon) } }; @@ -146,7 +146,7 @@ impl DDIMScheduler { pred_original_sample.mul_scalar(alpha_prod_t_prev.sqrt()) + pred_sample_direction; if self.config.eta > 0. { - prev_sample + prev_sample.clone() + Tensor::random_device( prev_sample.shape(), Distribution::Normal(0f64, std_dev_t as f64), @@ -159,7 +159,7 @@ impl DDIMScheduler { pub fn add_noise( &self, - original: &Tensor, + original: Tensor, noise: Tensor, timestep: usize, ) -> Tensor { @@ -213,7 +213,9 @@ fn linear_tensor( betas.push(cur.elem()); cur += step_size; } - Tensor::from_data_device(Data::new(betas, Shape::new([betas.len()])), device) + let dims = [betas.len()]; + + Tensor::from_data_device(Data::new(betas, Shape::new(dims)), device) } /// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -230,11 +232,13 @@ fn squared_cos_tensor( f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2) }; let mut betas = Vec::with_capacity(num_diffusion_timesteps); + for i in 0..num_diffusion_timesteps { let t1 = i / num_diffusion_timesteps; let t2 = (i + 1) / num_diffusion_timesteps; betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta).elem()); } + let dims = [betas.len()]; - Tensor::from_data_device(Data::new(betas, Shape::new([betas.len()])), device) + Tensor::from_data_device(Data::new(betas, Shape::new(dims)), device) } From f6706842e0174b7f1a6c731a1168d0fc20116c99 Mon Sep 17 00:00:00 2001 From: trivernis Date: Wed, 29 Nov 2023 19:12:03 +0100 Subject: [PATCH 6/8] Merge branch 'main' into port/ddim --- .github/workflows/validate.yml | 117 ++++++- .gitignore | 4 + Cargo.toml | 34 +- README.md | 12 + build.rs | 21 ++ src/cli/convert.rs | 16 - src/cli/mod.rs | 1 - src/lib.rs | 16 + src/main.rs | 33 -- src/models/attention.rs | 616 +++++++++++++++++++++++++++++++++ src/models/embeddings.rs | 53 ++- src/models/mod.rs | 1 + src/models/resnet.rs | 5 +- src/transformers/clip.rs | 19 +- src/utils.rs | 20 +- 15 files changed, 842 insertions(+), 126 deletions(-) create mode 100644 build.rs delete mode 100644 src/cli/convert.rs delete mode 100644 src/cli/mod.rs delete mode 100644 src/main.rs create mode 100644 src/models/attention.rs diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index cb4961e..b935807 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -20,30 +20,109 @@ jobs: run: cargo fmt -- --check # - name: Run cargo clippy # run: cargo clippy -- -D warnings - macos-check: - runs-on: macos-latest + check-std: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [macos-13, ubuntu-latest, windows-latest] + rust: [stable] + feature: ['ndarray', 'wgpu', 'torch', 'ndarray-blas-accelerate'] + include: + - cache: stable + rust: stable + exclude: + # macOS feature only + - os: ubuntu-latest + feature: 'ndarray-blas-accelerate' + # macOS feature only + - os: windows-latest + feature: 'ndarray-blas-accelerate' + # macos on CI does not have Metal GPU + - os: macos-13 + feature: 'wgpu' + # windows can have CPU Vulkan but Burn doesn't select CPU well yet + - os: windows-latest + feature: 'wgpu' + # ubuntu is throwing SIGSEGV + - os: ubuntu-latest + feature: 'wgpu' steps: - - uses: actions/checkout@v2 - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: checkout + uses: actions/checkout@v2 + - name: install rust + uses: actions-rust-lang/setup-rust-toolchain@v1 with: + toolchain: ${{ matrix.rust }} rustflags: "" # Disable when we're ready - - name: Test - run: cargo test - ubuntu-check: + - name: caching + uses: Swatinem/rust-cache@v2 + with: + key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.feature }}-${{ hashFiles('**/Cargo.toml') }} + - name: (linux) install llvmpipe, lavapipe + if: runner.os == 'Linux' + run: |- + sudo apt-get update -y -qq + sudo add-apt-repository ppa:kisak/kisak-mesa -y + sudo apt-get update + sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers + - name: (windows) install warp + if: runner.os == 'Windows' + shell: bash + run: |- + set -e + + curl.exe -L https://www.nuget.org/api/v2/package/Microsoft.Direct3D.WARP/1.0.7.1 -o warp.zip + 7z.exe e warp.zip -owarp build/native/amd64/d3d10warp.dll + + mkdir -p target/debug/deps + + cp -v warp/d3d10warp.dll target/debug/ + cp -v warp/d3d10warp.dll target/debug/deps + - name: (windows) install mesa + if: runner.os == 'Windows' + shell: bash + run: |- + set -e + + curl.exe -L https://github.com/pal1000/mesa-dist-win/releases/download/23.2.1/mesa3d-23.2.1-release-msvc.7z -o mesa.7z + 7z.exe e mesa.7z -omesa x64/{opengl32.dll,libgallium_wgl.dll,libglapi.dll,vulkan_lvp.dll,lvp_icd.x86_64.json} + + mkdir -p target/debug/deps + + cp -v mesa/* target/debug/ + cp -v mesa/* target/debug/deps + + echo "VK_DRIVER_FILES=$PWD/mesa/lvp_icd.x86_64.json" >> "$GITHUB_ENV" + echo "GALLIUM_DRIVER=llvmpipe" >> "$GITHUB_ENV" + - name: (windows) install dxc + if: runner.os == 'Windows' + uses: napokue/setup-dxc@v1.1.0 + - name: test + run: cargo test --features ${{ matrix.feature }} + check-no-std: runs-on: ubuntu-latest + strategy: + matrix: + rust: [stable] + target: ['wasm32-unknown-unknown', 'thumbv7m-none-eabi'] + feature: ['ndarray-no-std'] + include: + - cache: stable + rust: stable steps: - - uses: actions/checkout@v2 - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: checkout + uses: actions/checkout@v2 + - name: install rust + uses: actions-rust-lang/setup-rust-toolchain@v1 with: + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} rustflags: "" # Disable when we're ready - - name: Test - run: cargo test - windows-check: - runs-on: windows-latest - steps: - - uses: actions/checkout@v2 - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: caching + uses: Swatinem/rust-cache@v2 with: - rustflags: "" # Disable when we're ready - - name: Test - run: cargo test + key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.feature }}-${{ hashFiles('**/Cargo.toml') }} + - name: test + run: cargo test --no-default-features --features ${{ matrix.feature }} + - name: build ${{ matrix.target }} + run: cargo build --no-default-features --features ${{ matrix.feature }} --target ${{ matrix.target }} diff --git a/.gitignore b/.gitignore index fb02044..3efb051 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ Cargo.lock # Ignore macOS directory attributes .DS_Store + +# IDEs +.idea +.fleet diff --git a/Cargo.toml b/Cargo.toml index 57e3338..41da61d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,24 +4,28 @@ version = "0.1.0" edition = "2021" [features] -wgpu-backend = ["burn-wgpu"] +default = ["burn/default", "std"] -[dependencies.burn-wgpu] -package = "burn-wgpu" -version = "0.11.0" -optional = true +std = ["burn/std"] + +ndarray = ["burn-ndarray", "burn/ndarray"] +ndarray-no-std = ["burn-ndarray", "burn/ndarray-no-std"] +ndarray-blas-accelerate = ["burn-ndarray", "burn/ndarray-blas-accelerate"] + +torch = ["burn-tch", "burn/tch"] +wgpu = ["burn-wgpu", "burn/wgpu"] [dependencies] -anyhow = "1.0.75" -burn = "0.11.0" -burn-ndarray = "0.11.0" -burn-tch = "0.11.0" -clap = { version = "4.4.7", features = ["derive"] } +burn = { version = "0.11.0", default-features = false } +burn-ndarray = { version = "0.11.0", default-features = false, optional = true } +burn-tch = { version = "0.11.0", default-features = false, optional = true } +burn-wgpu = { version = "0.11.0", default-features = false, optional = true } +num-traits = { version = "0.2.17", default-features = false } num-traits = "0.2.17" -serde = { version = "1.0.192", features = ["derive"] } +serde = { version = "1.0.192", default-features = false, features = ["derive", "alloc"] } [patch.crates-io] -burn-wgpu = { git = "https://github.com/Tracel-AI/burn", rev = "be5bb33" } -burn = { git = "https://github.com/Tracel-AI/burn", rev = "be5bb33" } -burn-ndarray = { git = "https://github.com/Tracel-AI/burn", rev = "be5bb33" } -burn-tch = { git = "https://github.com/Tracel-AI/burn", rev = "be5bb33" } +burn-wgpu = { git = "https://github.com/Tracel-AI/burn", rev = "60c24430c6d4685032d9e351a537eded7da1a35c" } +burn = { git = "https://github.com/Tracel-AI/burn", rev = "60c24430c6d4685032d9e351a537eded7da1a35c" } +burn-ndarray = { git = "https://github.com/Tracel-AI/burn", rev = "60c24430c6d4685032d9e351a537eded7da1a35c" } +burn-tch = { git = "https://github.com/Tracel-AI/burn", rev = "60c24430c6d4685032d9e351a537eded7da1a35c" } diff --git a/README.md b/README.md index 2f5f1fc..6c4258b 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,18 @@ The `diffusers-burn` crate is a conversion of [diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using [burn](https://github.com/burn-rs/burn) rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1, as well as Stable Diffusion XL 1.0. +## Feature Flags + +This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling +the default `std` feature. + +* `std` - enables the standard library. Enabled by default. +* `wgpu` - uses ndarray as the backend. Enabled by default when none specified and `std`. +* `ndarray` - uses ndarray as the backend. +* `ndarray-no-std` - uses ndarray-no-std as the backend. Enabled by default when none and `#![no_std]`. +* `ndarray-blas-accelerate` - uses ndarray with Accelerate framework (macOS only). +* `torch` - uses torch as the backend. + ## Community If you are excited about the project or want to contribute, don't hesitate to join our [Discord](https://discord.gg/UHtSgF6j5J)! diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..bf6b577 --- /dev/null +++ b/build.rs @@ -0,0 +1,21 @@ +fn main() { + // Check if any of the features are specified + let any_feature_selected = cfg!(any( + feature = "ndarray", + feature = "ndarray-no-std", + feature = "torch", + feature = "wgpu" + )); + let has_std = cfg!(feature = "std"); + + // If none of the features are specified, default to wgpu + if !any_feature_selected { + if !has_std { + println!("cargo:rerun-if-env-changed=FORCE_NDARRAY_NO_STD"); // Optional: Trigger a recompile if needed + println!("cargo:rustc-cfg=ndarray-no-std"); + } else { + println!("cargo:rerun-if-env-changed=FORCE_WGPU"); // Optional: Trigger a recompile if needed + println!("cargo:rustc-cfg=wgpu"); + } + } +} diff --git a/src/cli/convert.rs b/src/cli/convert.rs deleted file mode 100644 index 715c0ed..0000000 --- a/src/cli/convert.rs +++ /dev/null @@ -1,16 +0,0 @@ -use anyhow::Result; -use clap::Args; - -#[derive(Args, Debug)] -pub struct ConvertArgs { - /// Path to the fined tuned model - input: String, - - /// Output directory to save the converted model - #[arg(short)] - output_dir: String, -} - -pub fn handle_convert(_args: &ConvertArgs) -> Result<()> { - Ok(()) -} diff --git a/src/cli/mod.rs b/src/cli/mod.rs deleted file mode 100644 index b5b6721..0000000 --- a/src/cli/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod convert; diff --git a/src/lib.rs b/src/lib.rs index a460e24..c836d65 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,24 @@ //! //! This is a Rust port of Hugging Face's [diffusers](https://github.com/huggingface/diffusers) Python api using [Burn](https://github.com/burn-rs/burn) +#![cfg_attr(not(feature = "std"), no_std)] + pub mod models; pub mod pipelines; pub mod schedulers; pub mod transformers; pub mod utils; + +extern crate alloc; + +#[cfg(all(test, not(feature = "wgpu"), not(feature = "torch")))] +pub type TestBackend = burn_ndarray::NdArray; + +#[cfg(all(test, feature = "torch"))] +pub type TestBackend = burn_tch::LibTorch; + +#[cfg(all(test, feature = "wgpu", not(target_os = "macos")))] +pub type TestBackend = burn_wgpu::Wgpu; + +#[cfg(all(test, feature = "wgpu", target_os = "macos"))] +pub type TestBackend = burn_wgpu::Wgpu; diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index c493477..0000000 --- a/src/main.rs +++ /dev/null @@ -1,33 +0,0 @@ -use anyhow::Result; -use clap::{Parser, Subcommand}; - -mod cli; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -#[command(propagate_version = true)] -struct Cli { - #[command(subcommand)] - command: Commands, -} - -#[derive(Subcommand, Debug)] -enum Commands { - /// Convert fined tuned Stable Diffusion version into burn's native format - Convert { - #[clap(flatten)] - args: cli::convert::ConvertArgs, - }, -} - -fn main() -> Result<()> { - let cli = Cli::parse(); - - match &cli.command { - Commands::Convert { args } => { - cli::convert::handle_convert(args)?; - } - } - - Ok(()) -} diff --git a/src/models/attention.rs b/src/models/attention.rs new file mode 100644 index 0000000..21fc8cd --- /dev/null +++ b/src/models/attention.rs @@ -0,0 +1,616 @@ +//! Attention Based Building Blocks + +use alloc::vec; +use alloc::vec::Vec; +use burn::config::Config; +use burn::module::Module; +use burn::nn::{ + self, Dropout, DropoutConfig, GroupNorm, GroupNormConfig, LayerNorm, LayerNormConfig, + LinearConfig, +}; +use burn::tensor::activation::{gelu, softmax}; +use burn::tensor::backend::Backend; +use burn::tensor::Tensor; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +#[derive(Config)] +pub struct GeGluConfig { + /// The size of the input features. + d_input: usize, + /// The size of the output features. + d_output: usize, +} + +/// A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. +#[derive(Module, Debug)] +struct GeGlu { + proj: nn::Linear, +} + +impl GeGluConfig { + fn init(&self) -> GeGlu { + let proj = LinearConfig::new(self.d_input, 2 * self.d_output).init(); + GeGlu { proj } + } +} + +impl GeGlu { + fn forward(&self, xs: Tensor) -> Tensor { + let hidden_states_and_gate = self.proj.forward(xs).chunk(2, D - 1); + hidden_states_and_gate[0].clone() * gelu(hidden_states_and_gate[1].clone()) + } +} + +#[derive(Config)] +pub struct FeedForwardConfig { + /// The size of the input features. + pub d_input: usize, + /// The size of the output features. If not given defaults to `d_input`. + d_output: Option, + /// The multiplier to use for the hidden dimension. + #[config(default = 4)] + multiplier: usize, + /// The dropout probability. Default: 0.0 + #[config(default = 0.)] + dropout: f64, +} + +#[derive(Module, Debug)] +pub struct FeedForward { + geglu: GeGlu, + dropout: Dropout, + linear_outer: nn::Linear, +} + +impl FeedForwardConfig { + pub fn init(&self) -> FeedForward { + let inner_dim = self.d_input * self.multiplier; + let dim_out = self.d_output.unwrap_or(self.d_input); + + FeedForward { + geglu: GeGluConfig { + d_input: self.d_input, + d_output: inner_dim, + } + .init(), + linear_outer: LinearConfig::new(inner_dim, dim_out).init(), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +impl FeedForward { + pub fn forward(&self, xs: Tensor) -> Tensor { + let x = self.geglu.forward(xs); + let x = self.dropout.forward(x); + self.linear_outer.forward(x) + } +} + +#[derive(Config)] +pub struct CrossAttentionConfig { + /// The number of channels in the query. + d_query: usize, + /// The number of channels in the context. If not given defaults to `query_dim`. + d_context: Option, + /// The number of heads to use for the multi-head attention. + #[config(default = 8)] + n_heads: usize, + /// The number of channels in each head. + #[config(default = 64)] + d_head: usize, + /// The size of the slices to use for the multi-head attention. + slice_size: Option, + #[config(default = 0.)] + // The dropout probability. + dropout: f64, +} + +#[derive(Module, Debug)] +pub struct CrossAttention { + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + output: nn::Linear, + n_heads: usize, + scale: f64, + slice_size: Option, +} + +impl CrossAttentionConfig { + pub fn init(&self) -> CrossAttention { + let inner_dim = self.d_head * self.n_heads; + let context_dim = self.d_context.unwrap_or(self.d_query); + let scale = 1. / (self.d_head as f64).sqrt(); + + CrossAttention { + query: LinearConfig::new(self.d_query, inner_dim) + .with_bias(false) + .init(), + key: LinearConfig::new(context_dim, inner_dim) + .with_bias(false) + .init(), + value: LinearConfig::new(context_dim, inner_dim) + .with_bias(false) + .init(), + output: LinearConfig::new(inner_dim, self.d_query).init(), + n_heads: self.n_heads, + scale, + slice_size: self.slice_size, + } + } +} + +impl CrossAttention { + fn reshape_heads_to_batch_dim(&self, xs: Tensor) -> Tensor { + let [batch_size, seq_len, dim] = xs.dims(); + xs.reshape([batch_size, seq_len, self.n_heads, dim / self.n_heads]) + .swap_dims(1, 2) + .reshape([batch_size * self.n_heads, seq_len, dim / self.n_heads]) + } + + fn reshape_batch_dim_to_heads(&self, xs: Tensor) -> Tensor { + let [batch_size, seq_len, dim] = xs.dims(); + let output = xs + .reshape([batch_size / self.n_heads, self.n_heads, seq_len, dim]) + .swap_dims(1, 2) + .reshape([batch_size / self.n_heads, seq_len, dim * self.n_heads]); + output + } + + fn sliced_attention( + &self, + query: Tensor, + key: Tensor, + value: Tensor, + slice_size: usize, + ) -> Tensor { + let batch_size_attention = query.clone().shape().dims[0]; + let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size); + + for i in 0..batch_size_attention / slice_size { + let start_idx = i * slice_size; + let end_idx = (i + 1) * slice_size; + + let xs = query + .clone() + .slice([start_idx..end_idx, 0..query.shape().dims[1]]) + .matmul( + key.clone() + .slice([start_idx..end_idx, 0..key.shape().dims[1]]) + .swap_dims(3 - 1, 3 - 2) + * self.scale, + ); + + let xs = softmax(xs, 3 - 1).matmul( + value + .clone() + .slice([start_idx..end_idx, 0..value.shape().dims[1]]), + ); + + hidden_states.push(xs); + } + + let output = Tensor::cat(hidden_states, 0); + self.reshape_batch_dim_to_heads(output) + } + + fn attention( + &self, + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor { + let xs = query.matmul(key.swap_dims(3 - 1, 3 - 2) * self.scale); + let xs = softmax(xs, 3 - 1).matmul(value); + + self.reshape_batch_dim_to_heads(xs) + } + + pub fn forward(&self, xs: Tensor, context: Option>) -> Tensor { + let query = self.query.forward(xs.clone()); + let context = context.unwrap_or(xs); + let key = self.key.forward(context.clone()); + let value = self.value.forward(context); + + let query = self.reshape_heads_to_batch_dim(query); + let key = self.reshape_heads_to_batch_dim(key); + let value = self.reshape_heads_to_batch_dim(value); + + let output_tensor = match self.slice_size { + None => self.attention(query, key, value), + Some(slice_size) if query.shape().dims[0] / slice_size <= 1 => { + self.attention(query, key, value) + } + Some(slice_size) => self.sliced_attention(query, key, value, slice_size), + }; + + self.output.forward(output_tensor) + } +} + +#[derive(Config)] +pub struct BasicTransformerBlockConfig { + d_model: usize, + d_context: Option, + n_heads: usize, + d_head: usize, + sliced_attn_size: Option, +} + +/// A basic Transformer block. +#[derive(Module, Debug)] +pub struct BasicTransformerBlock { + attn1: CrossAttention, + ff: FeedForward, + attn2: CrossAttention, + norm1: LayerNorm, + norm2: LayerNorm, + norm3: LayerNorm, +} + +impl BasicTransformerBlockConfig { + fn init(&self) -> BasicTransformerBlock { + let attn1 = CrossAttentionConfig::new(self.d_model) + .with_n_heads(self.n_heads) + .with_d_head(self.d_head) + .with_slice_size(self.sliced_attn_size) + .init(); + let ff = FeedForwardConfig::new(self.d_model).init(); + let attn2 = CrossAttentionConfig::new(self.d_model) + .with_d_context(self.d_context) + .with_n_heads(self.n_heads) + .with_d_head(self.d_head) + .with_slice_size(self.sliced_attn_size) + .init(); + let norm1 = LayerNormConfig::new(self.d_model).init(); + let norm2 = LayerNormConfig::new(self.d_model).init(); + let norm3 = LayerNormConfig::new(self.d_model).init(); + + BasicTransformerBlock { + attn1, + ff, + attn2, + norm1, + norm2, + norm3, + } + } +} + +impl BasicTransformerBlock { + pub fn forward(&self, xs: Tensor, context: Option>) -> Tensor { + let xs = self.attn1.forward(self.norm1.forward(xs.clone()), None) + xs; + let xs = self.attn2.forward(self.norm2.forward(xs.clone()), context) + xs; + self.ff.forward(self.norm3.forward(xs.clone())) + xs + } +} + +#[derive(Config, Debug)] +pub struct SpatialTransformerConfig { + #[config(default = 1)] + pub depth: usize, + #[config(default = 32)] + pub n_groups: usize, + pub d_context: Option, + pub sliced_attn_size: Option, + // #[config(default = false)] + // pub use_linear_projection: bool, + pub in_channels: usize, + pub n_heads: usize, + pub d_head: usize, +} + +//#[derive(Config, Debug)] +//enum Proj { +// Conv2d(nn::conv::Conv2d), +// Linear(nn::Linear) +//} + +/// Aka Transformer2DModel +#[derive(Module, Debug)] +pub struct SpatialTransformer { + norm: GroupNorm, + proj_in: nn::conv::Conv2d, + transformer_blocks: Vec>, + proj_out: nn::conv::Conv2d, +} + +impl SpatialTransformerConfig { + fn init(&self) -> SpatialTransformer { + let d_inner = self.n_heads * self.d_head; + let norm = GroupNormConfig::new(self.n_groups, self.in_channels) + .with_epsilon(1e-6) + .init(); + // let proj_in = if config.use_linear_projection { + let proj_in = nn::conv::Conv2dConfig::new([self.in_channels, d_inner], [1, 1]).init(); + + let mut transformer_blocks = vec![]; + for _index in 0..self.depth { + let tb = BasicTransformerBlockConfig::new(d_inner, self.n_heads, self.d_head) + .with_d_context(self.d_context) + .with_sliced_attn_size(self.sliced_attn_size) + .init(); + + transformer_blocks.push(tb) + } + + let proj_out = nn::conv::Conv2dConfig::new([d_inner, self.in_channels], [1, 1]).init(); + + SpatialTransformer { + norm, + proj_in, + transformer_blocks, + proj_out, + } + } +} + +impl SpatialTransformer { + fn forward(&self, xs: Tensor, context: Option>) -> Tensor { + let [n_batch, _n_channel, height, weight] = xs.dims(); + + let residual = xs.clone(); + let xs = self.norm.forward(xs); + let xs = self.proj_in.forward(xs); + let d_inner = xs.shape().dims[1]; + let xs = xs + .swap_dims(1, 2) + .transpose() + .reshape([n_batch, height * weight, d_inner]); + + let mut xs = xs; + for block in self.transformer_blocks.iter() { + xs = block.forward(xs, context.clone()) + } + + let xs = xs + .reshape([n_batch, height, weight, d_inner]) + .transpose() + .swap_dims(1, 2); + + self.proj_out.forward(xs) + residual + } +} + +#[derive(Config, Debug)] +pub struct AttentionBlockConfig { + pub channels: usize, + pub n_head_channels: Option, + #[config(default = 32)] + pub n_groups: usize, + #[config(default = 1.)] + pub rescale_output_factor: f64, + #[config(default = 1e-5)] + pub eps: f64, +} + +#[derive(Module, Debug)] +pub struct AttentionBlock { + group_norm: nn::GroupNorm, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + proj_attn: nn::Linear, + channels: usize, + n_heads: usize, + rescale_output_factor: f64, +} + +impl AttentionBlockConfig { + fn init(&self) -> AttentionBlock { + let n_head_channels = self.n_head_channels.unwrap_or(self.channels); + let n_heads = self.channels / n_head_channels; + let group_norm = GroupNormConfig::new(self.n_groups, self.channels) + .with_epsilon(self.eps) + .init(); + let query = LinearConfig::new(self.channels, self.channels).init(); + let key = LinearConfig::new(self.channels, self.channels).init(); + let value = LinearConfig::new(self.channels, self.channels).init(); + let proj_attn = LinearConfig::new(self.channels, self.channels).init(); + + AttentionBlock { + group_norm, + query, + key, + value, + proj_attn, + channels: self.channels, + n_heads, + rescale_output_factor: self.rescale_output_factor, + } + } +} + +impl AttentionBlock { + fn transpose_for_scores(&self, xs: Tensor) -> Tensor { + let [n_batch, t, h_times_d] = xs.dims(); + xs.reshape([n_batch, t, self.n_heads, h_times_d / self.n_heads]) + .swap_dims(1, 2) + } + + fn forward(&self, xs: Tensor) -> Tensor { + let residual = xs.clone(); + let [n_batch, channel, height, width] = xs.dims(); + let xs = self + .group_norm + .forward(xs) + .reshape([n_batch, channel, height * width]) + .swap_dims(1, 2); + + let query_proj = self.query.forward(xs.clone()); + let key_proj = self.key.forward(xs.clone()); + let value_proj = self.value.forward(xs.clone()); + + let query_states = self.transpose_for_scores(query_proj); + let key_states = self.transpose_for_scores(key_proj); + let value_states = self.transpose_for_scores(value_proj); + + // scale is applied twice, hence the -0.25 here rather than -0.5. + // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L87 + let scale = f64::powf(self.channels as f64 / self.n_heads as f64, -0.25); + let attention_scores = (query_states * scale).matmul(key_states.transpose() * scale); + let attention_probs = softmax(attention_scores, 4 - 1); + + let xs = attention_probs.matmul(value_states); + let xs = xs.swap_dims(1, 2); + let xs: Tensor = xs.flatten(4 - 2, 4 - 1); + let xs = self + .proj_attn + .forward(xs) + .transpose() + .reshape([n_batch, channel, height, width]); + + (xs + residual) / self.rescale_output_factor + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::module::{Param, ParamId}; + use burn::tensor::{Data, Shape}; + + #[test] + fn test_geglu_tensor_shape_3() { + let weight = Tensor::from_data(Data::from([ + [ + 0.1221, 2.0378, -0.1171, 1.3004, -0.9630, -0.3108, -1.3376, -1.0593, + ], + [ + 0.4669, -0.8146, 0.9965, -0.4659, 2.0444, -0.0709, -0.0147, 0.2135, + ], + ])); + let bias = Tensor::from_data(Data::from([ + 0.2867778149426027, + 0.6646517317105776, + 0.023946332404821136, + -0.1395737454364393, + 0.05131041098737321, + -0.4225726694675192, + 0.036411720220954735, + 0.01829268669677364, + ])); + + let geglu = GeGlu { + proj: nn::Linear { + weight: Param::new(ParamId::new(), weight), + bias: Some(Param::new(ParamId::new(), bias)), + }, + }; + + let tensor: Tensor = Tensor::from_data(Data::from([ + [[1., 2.], [3., 4.], [5., 6.]], + [[7., 8.], [9., 10.], [11., 12.]], + ])); + + let output = geglu.forward(tensor); + assert_eq!(output.shape(), Shape::from([2, 3, 4])); + output.to_data().assert_approx_eq( + &Data::from([ + [ + [4.2632e0, -1.7927e-1, -2.3216e-1, -3.7916e-2], + [1.3460e1, -2.9266e-1, -2.1707e-4, -4.5595e-2], + [2.7750e1, -1.1442e-1, -2.5335e-13, -2.5403e-4], + ], + [ + [4.7135e1, -1.7708e-2, -0.0000e0, -6.7097e-9], + [7.1616e1, -1.0652e-3, -0.0000e0, -0.0000e0], + [1.0119e2, -2.1943e-5, -0.0000e0, -0.0000e0], + ], + ]), + 2, + ); + } + + #[test] + fn test_geglu_tensor_shape_2() { + let weight = Tensor::from_data(Data::from([ + [0.6054, 1.9322, 0.1445, 1.3004, -0.6853, -0.8947], + [-0.3678, 0.4081, -1.9001, -1.5843, -0.9399, 0.1018], + ])); + let bias = Tensor::from_data(Data::from([ + 0.3237631905393836, + 0.22052049807936902, + -0.3196353346822061, + -0.02244043444199162, + -0.33600250665852865, + 0.5259391939301621, + ])); + + let geglu = GeGlu { + proj: nn::Linear { + weight: Param::new(ParamId::new(), weight), + bias: Some(Param::new(ParamId::new(), bias)), + }, + }; + + let tensor: Tensor = + Tensor::from_data(Data::from([[1., 2.], [3., 4.], [5., 6.]])); + + let output = geglu.forward(tensor); + assert_eq!(output.shape(), Shape::from([3, 3])); + output.to_data().assert_approx_eq( + &Data::from([ + [-2.4192e-5, -3.3057e-2, 2.8535e-1], + [-0.0000e0, -2.0983e-7, 5.2465e-1], + [-0.0000e0, -0.0000e0, 1.2599e-2], + ]), + 1, + ); + } + + #[test] + fn test_sliced_attention() { + // create tensor of size [2, 4, 2] + let query: Tensor = Tensor::from_data(Data::from([ + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]], + [[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]], + [[25.0, 26.0], [27.0, 28.0], [29.0, 30.0], [31.0, 32.0]], + ])); + let key: Tensor = Tensor::from_data(Data::from([ + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]], + [[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]], + [[25.0, 26.0], [27.0, 28.0], [29.0, 30.0], [31.0, 32.0]], + ])); + let value: Tensor = Tensor::from_data(Data::from([ + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]], + [[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]], + [[25.0, 26.0], [27.0, 28.0], [29.0, 30.0], [31.0, 32.0]], + ])); + + let cross_attention = CrossAttentionConfig::new(320) + .with_n_heads(2) + .with_d_head(40) + .with_slice_size(Some(2)) + .init(); + + let output = cross_attention.sliced_attention(query, key, value, 2); + + assert_eq!(output.shape(), Shape::from([2, 4, 4])); + output.into_data().assert_approx_eq( + &Data::from([ + [ + [5.9201, 6.9201, 14.9951, 15.9951], + [6.7557, 7.7557, 14.9986, 15.9986], + [6.9363, 7.9363, 14.9996, 15.9996], + [6.9824, 7.9824, 14.9999, 15.9999], + ], + [ + [23.0000, 24.0000, 31.0000, 32.0000], + [23.0000, 24.0000, 31.0000, 32.0000], + [23.0000, 24.0000, 31.0000, 32.0000], + [23.0000, 24.0000, 31.0000, 32.0000], + ], + ]), + 3, + ) + } +} diff --git a/src/models/embeddings.rs b/src/models/embeddings.rs index c17ee83..5a5cb5b 100644 --- a/src/models/embeddings.rs +++ b/src/models/embeddings.rs @@ -1,21 +1,43 @@ use crate::utils::pad_with_zeros; +use alloc::vec; +use burn::config::Config; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; +use burn::tensor::activation::silu; use burn::tensor::backend::Backend; use burn::tensor::Tensor; +use core::marker::PhantomData; -#[derive(Debug)] +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +#[derive(Config, Debug)] +pub struct TimestepEmbeddingConfig { + channel: usize, + time_embed_dim: usize, +} + +#[derive(Module, Debug)] pub struct TimestepEmbedding { linear_1: Linear, linear_2: Linear, } +impl TimestepEmbeddingConfig { + /// Initialize a new [embedding](TimestepEmbedding) module. + /// Uses activating function: "silu". + pub fn init(&self) -> TimestepEmbedding { + let linear_1 = LinearConfig::new(self.channel, self.time_embed_dim).init(); + let linear_2 = LinearConfig::new(self.time_embed_dim, self.time_embed_dim).init(); + TimestepEmbedding { linear_1, linear_2 } + } +} + impl TimestepEmbedding { - // act_fn: "silu" - pub fn new(channel: usize, time_embed_dim: usize) -> Self { - let linear_1 = LinearConfig::new(channel, time_embed_dim).init(); - let linear_2 = LinearConfig::new(time_embed_dim, time_embed_dim).init(); - Self { linear_1, linear_2 } + fn forward(&self, xs: Tensor) -> Tensor { + let xs = silu(self.linear_1.forward(xs)); + self.linear_2.forward(xs) } } @@ -24,7 +46,7 @@ pub struct Timesteps { num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64, - _backend: std::marker::PhantomData, + _backend: PhantomData, } impl Timesteps { @@ -33,7 +55,7 @@ impl Timesteps { num_channels, flip_sin_to_cos, downscale_freq_shift, - _backend: std::marker::PhantomData, + _backend: PhantomData, } } @@ -61,16 +83,14 @@ impl Timesteps { #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; use burn::tensor::{Data, Shape}; #[test] + #[cfg(not(feature = "torch"))] fn test_timesteps_even_channels() { - type TestBackend = burn_ndarray::NdArray; - let device = ::Device::default(); - let timesteps = Timesteps::::new(4, true, 0.); - let xs: Tensor = - Tensor::from_data_device(Data::from([1., 2., 3., 4.]), &device); + let xs: Tensor = Tensor::from_data(Data::from([1., 2., 3., 4.])); let emb = timesteps.forward(xs); @@ -87,13 +107,10 @@ mod tests { } #[test] + #[cfg(not(feature = "torch"))] fn test_timesteps_odd_channels() { - type TestBackend = burn_ndarray::NdArray; - let device = ::Device::default(); - let timesteps = Timesteps::::new(5, true, 0.); - let xs: Tensor = - Tensor::from_data_device(Data::from([1., 2., 3., 4., 5.]), &device); + let xs: Tensor = Tensor::from_data(Data::from([1., 2., 3., 4., 5.])); let emb = timesteps.forward(xs); diff --git a/src/models/mod.rs b/src/models/mod.rs index 9ba410e..845a097 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -2,5 +2,6 @@ //! //! A collection of models to be used in a diffusion loop. +pub mod attention; pub mod embeddings; pub mod resnet; diff --git a/src/models/resnet.rs b/src/models/resnet.rs index 0eac04d..4fd2477 100644 --- a/src/models/resnet.rs +++ b/src/models/resnet.rs @@ -117,12 +117,11 @@ impl ResnetBlock2D { #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; use burn::tensor::{Distribution, Shape}; #[test] fn test_resnet_block_2d_no_temb() { - type TestBackend = burn_ndarray::NdArray; - let block = ResnetBlock2DConfig::new(128).init::(); let xs = Tensor::::random([2, 128, 64, 64], Distribution::Default); let output = block.forward(xs, None); @@ -132,8 +131,6 @@ mod tests { #[test] fn test_resnet_block_2d_with_temb() { - type TestBackend = burn_ndarray::NdArray; - let block = ResnetBlock2DConfig::new(128).init::(); let xs = Tensor::::random([2, 128, 64, 64], Distribution::Default); let temb = Tensor::::random([2, 128], Distribution::Default); diff --git a/src/transformers/clip.rs b/src/transformers/clip.rs index 6771c97..9f87adb 100644 --- a/src/transformers/clip.rs +++ b/src/transformers/clip.rs @@ -5,9 +5,10 @@ //! //! https://github.com/openai/CLIP -use std::f32::consts::SQRT_2; - -use crate::utils::build_causal_attention_mask; +use crate::utils::generate_causal_attention_mask; +use alloc::string::String; +use alloc::string::ToString; +use alloc::vec::Vec; use burn::config::Config; use burn::tensor::activation::softmax; use burn::{ @@ -19,6 +20,11 @@ use burn::{ Int, Tensor, }, }; +use core::f32::consts::SQRT_2; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; #[derive(Module, Debug, Clone, serde::Deserialize, serde::Serialize)] pub enum Activation { @@ -329,7 +335,7 @@ impl ClipTextTransformer { fn forward(&self, xs: Tensor) -> Tensor { let [bsz, seq_len] = xs.dims(); let xs = self.embeddings.forward(xs); - let causal_attention_mask = build_causal_attention_mask(bsz, seq_len, &xs.device()); + let causal_attention_mask = generate_causal_attention_mask(bsz, seq_len, &xs.device()); let xs = self.encoder.forward(xs, causal_attention_mask); self.final_layer_norm.forward(xs) } @@ -338,12 +344,11 @@ impl ClipTextTransformer { #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; use burn::tensor::{Data, Shape}; #[test] fn test_init_text_embeddings() { - type TestBackend = burn_ndarray::NdArray; - let clip_config = ClipConfig::v1_5(); let text_embeddings: ClipTextEmbeddings = clip_config.init_text_embeddings(); @@ -362,8 +367,6 @@ mod tests { #[test] fn test_clip_attention_shape() { - type TestBackend = burn_ndarray::NdArray; - let clip_config = ClipConfig::v1_5(); let clip_attention: ClipAttention = clip_config.init_attention(); diff --git a/src/utils.rs b/src/utils.rs index 3e761e7..aa0683b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,8 +1,10 @@ +use alloc::vec; +use alloc::vec::Vec; use burn::tensor::backend::Backend; use burn::tensor::{Data, Element, ElementConversion, Numeric, Shape, Tensor}; // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 -pub(crate) fn build_causal_attention_mask( +pub(crate) fn generate_causal_attention_mask( bsz: usize, seq_len: usize, device: &B::Device, @@ -73,15 +75,14 @@ where #[cfg(test)] mod tests { use super::*; - use burn::tensor::backend::Backend; - use burn::tensor::{Data, Shape}; + use crate::TestBackend; + use burn::tensor::{backend::Backend, Data, Shape}; #[test] fn test_build_causal_attention_mask() { - type TestBackend = burn_ndarray::NdArray; let device = ::Device::default(); - let mask = build_causal_attention_mask::(2, 4, &device); + let mask = generate_causal_attention_mask::(2, 4, &device); assert_eq!(mask.shape(), Shape::from([2, 1, 4, 4])); mask.to_data().assert_approx_eq( @@ -105,13 +106,8 @@ mod tests { #[test] fn test_pad_with_zeros() { - type TestBackend = burn_ndarray::NdArray; - let device = ::Device::default(); - - let tensor: Tensor = Tensor::from_data_device( - Data::from([[[1.6585, 0.4320], [-0.8701, -0.4649]]]), - &device, - ); + let tensor: Tensor = + Tensor::from_data(Data::from([[[1.6585, 0.4320], [-0.8701, -0.4649]]])); let padded = pad_with_zeros(tensor, 0, 1, 2); From 154de88a4b77a14a1e62ef097d6cfea641ff825e Mon Sep 17 00:00:00 2001 From: trivernis Date: Wed, 29 Nov 2023 19:33:09 +0100 Subject: [PATCH 7/8] Add ddim tests for linear and scaled tensors --- Cargo.toml | 1 - src/schedulers/ddim.rs | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 41da61d..470a08d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ burn-ndarray = { version = "0.11.0", default-features = false, optional = true } burn-tch = { version = "0.11.0", default-features = false, optional = true } burn-wgpu = { version = "0.11.0", default-features = false, optional = true } num-traits = { version = "0.2.17", default-features = false } -num-traits = "0.2.17" serde = { version = "1.0.192", default-features = false, features = ["derive", "alloc"] } [patch.crates-io] diff --git a/src/schedulers/ddim.rs b/src/schedulers/ddim.rs index b698283..137f7ec 100644 --- a/src/schedulers/ddim.rs +++ b/src/schedulers/ddim.rs @@ -242,3 +242,38 @@ fn squared_cos_tensor( Tensor::from_data_device(Data::new(betas, Shape::new(dims)), device) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn it_creates_linear_tensors() { + let expected = (0..10) + .map(|i| i as f32) + .map(::FloatElem::from) + .collect::>(); + let actual = + linear_tensor::(&::Device::default(), 0., 10., 10) + .into_data() + .value; + + assert_eq!(expected, actual); + } + + #[test] + fn it_creates_squared_linear_tensors() { + let expected = vec![2., 2.25, 2.5, 2.75]; + let actual = scaled_linear_tensor::( + &::Device::default(), + 4., + 9., + 4, + ) + .into_data() + .value; + + assert_eq!(expected, actual); + } +} From 319f5c2cb73e78f3e202b150fc0e79eb4231bdf5 Mon Sep 17 00:00:00 2001 From: David Chavez Date: Fri, 1 Dec 2023 02:21:44 +0100 Subject: [PATCH 8/8] Fix merge issue --- src/transformers/clip.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/clip.rs b/src/transformers/clip.rs index 3239faf..35559bc 100644 --- a/src/transformers/clip.rs +++ b/src/transformers/clip.rs @@ -5,7 +5,6 @@ //! //! https://github.com/openai/CLIP -use crate::utils::generate_causal_attention_mask; use alloc::string::String; use alloc::string::ToString; use alloc::vec::Vec; @@ -344,6 +343,7 @@ impl ClipTextTransformer { fn forward(&self, xs: Tensor) -> Tensor { let [bsz, seq_len] = xs.dims(); + let xs = self.embeddings.forward(xs); let causal_attention_mask = Self::generate_causal_attention_mask(bsz, seq_len, &xs.device()); let xs = self.encoder.forward(xs, causal_attention_mask);