diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index d177d18..87d80db 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -52,13 +52,11 @@ jobs: uses: Swatinem/rust-cache@v2 with: key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.backend }}-${{ hashFiles('**/Cargo.toml') }} - - name: (Linux) Install llvmpipe, lavapipe - if: runner.os == 'Linux' + - name: (Linux) Install Vulkan deps + if: runner.os == 'Linux' && matrix.backend == 'wgpu' run: |- sudo apt-get update -y -qq - sudo add-apt-repository ppa:kisak/kisak-mesa -y - sudo apt-get update - sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers + sudo apt-get install -y libvulkan1 mesa-vulkan-drivers vulkan-tools libxcb-xfixes0-dev - name: (Windows) Install warp if: runner.os == 'Windows' shell: bash @@ -88,9 +86,6 @@ jobs: echo "VK_DRIVER_FILES=$PWD/mesa/lvp_icd.x86_64.json" >> "$GITHUB_ENV" echo "GALLIUM_DRIVER=llvmpipe" >> "$GITHUB_ENV" - - name: (Windows) Install dxc - if: runner.os == 'Windows' - uses: napokue/setup-dxc@v1.1.0 - name: Run tests shell: bash run: |- diff --git a/.gitignore b/.gitignore index 3efb051..dda62bd 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ Cargo.lock # IDEs .idea .fleet + +# Others +inspiration/ diff --git a/Cargo.toml b/Cargo.toml index ba21c39..197a6fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ torch = ["burn/tch"] wgpu = ["burn/wgpu"] [dependencies] -burn = { version = "0.13.0", default-features = false } +burn = { version = "0.20.1", default-features = false } num-traits = { version = "0.2.18", default-features = false } serde = { version = "1.0.197", default-features = false, features = [ "derive", diff --git a/src/lib.rs b/src/lib.rs index 679284c..1e2bea9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,25 +9,13 @@ pub mod pipelines; pub mod transformers; pub mod utils; -#[cfg(all(test, feature = "ndarray"))] -use burn::backend::ndarray; - -#[cfg(all(test, feature = "torch"))] -use burn::backend::libtorch; - -#[cfg(all(test, feature = "wgpu"))] -use burn::backend::wgpu; - extern crate alloc; #[cfg(all(test, feature = "ndarray"))] -pub type TestBackend = ndarray::NdArray; +pub type TestBackend = burn::backend::NdArray; #[cfg(all(test, feature = "torch"))] -pub type TestBackend = libtorch::LibTorch; +pub type TestBackend = burn::backend::LibTorch; -#[cfg(all(test, feature = "wgpu", not(target_os = "macos")))] -pub type TestBackend = wgpu::Wgpu; - -#[cfg(all(test, feature = "wgpu", target_os = "macos"))] -pub type TestBackend = wgpu::Wgpu; +#[cfg(all(test, feature = "wgpu"))] +pub type TestBackend = burn::backend::Wgpu; diff --git a/src/models/attention.rs b/src/models/attention.rs index 70372f1..2326fb5 100644 --- a/src/models/attention.rs +++ b/src/models/attention.rs @@ -16,7 +16,7 @@ use burn::tensor::Tensor; #[allow(unused_imports)] use num_traits::Float; -#[derive(Config)] +#[derive(Config, Debug)] pub struct GeGluConfig { /// The size of the input features. d_input: usize, @@ -44,7 +44,7 @@ impl GeGlu { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct FeedForwardConfig { /// The size of the input features. pub d_input: usize, @@ -90,7 +90,7 @@ impl FeedForward { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct CrossAttentionConfig { /// The number of channels in the query. d_query: usize, @@ -232,7 +232,7 @@ impl CrossAttention { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct BasicTransformerBlockConfig { d_model: usize, d_context: Option, @@ -488,13 +488,13 @@ mod tests { use super::*; use crate::TestBackend; use burn::module::{Param, ParamId}; - use burn::tensor::{Data, Shape}; + use burn::tensor::{Shape, TensorData, Tolerance}; #[test] fn test_geglu_tensor_shape_3() { let device = Default::default(); let weight = Tensor::from_data( - Data::from([ + TensorData::from([ [ 0.1221, 2.0378, -0.1171, 1.3004, -0.9630, -0.3108, -1.3376, -1.0593, ], @@ -505,7 +505,7 @@ mod tests { &device, ); let bias = Tensor::from_data( - Data::from([ + TensorData::from([ 0.2867778149426027, 0.6646517317105776, 0.023946332404821136, @@ -526,7 +526,7 @@ mod tests { }; let tensor: Tensor = Tensor::from_data( - Data::from([ + TensorData::from([ [[1., 2.], [3., 4.], [5., 6.]], [[7., 8.], [9., 10.], [11., 12.]], ]), @@ -535,8 +535,8 @@ mod tests { let output = geglu.forward(tensor); assert_eq!(output.shape(), Shape::from([2, 3, 4])); - output.to_data().assert_approx_eq( - &Data::from([ + output.into_data().assert_approx_eq::( + &TensorData::from([ [ [4.2632e0, -1.7927e-1, -2.3216e-1, -3.7916e-2], [1.3460e1, -2.9266e-1, -2.1707e-4, -4.5595e-2], @@ -548,7 +548,7 @@ mod tests { [1.0119e2, -2.1943e-5, -0.0000e0, -0.0000e0], ], ]), - 2, + Tolerance::rel_abs(1e-2, 1e-2), ); } @@ -556,14 +556,14 @@ mod tests { fn test_geglu_tensor_shape_2() { let device = Default::default(); let weight = Tensor::from_data( - Data::from([ + TensorData::from([ [0.6054, 1.9322, 0.1445, 1.3004, -0.6853, -0.8947], [-0.3678, 0.4081, -1.9001, -1.5843, -0.9399, 0.1018], ]), &device, ); let bias = Tensor::from_data( - Data::from([ + TensorData::from([ 0.3237631905393836, 0.22052049807936902, -0.3196353346822061, @@ -582,17 +582,17 @@ mod tests { }; let tensor: Tensor = - Tensor::from_data(Data::from([[1., 2.], [3., 4.], [5., 6.]]), &device); + Tensor::from_data(TensorData::from([[1., 2.], [3., 4.], [5., 6.]]), &device); let output = geglu.forward(tensor); assert_eq!(output.shape(), Shape::from([3, 3])); - output.to_data().assert_approx_eq( - &Data::from([ + output.into_data().assert_approx_eq::( + &TensorData::from([ [-2.4192e-5, -3.3057e-2, 2.8535e-1], [-0.0000e0, -2.0983e-7, 5.2465e-1], [-0.0000e0, -0.0000e0, 1.2599e-2], ]), - 1, + Tolerance::rel_abs(1e-1, 1e-1), ); } @@ -601,7 +601,7 @@ mod tests { let device = Default::default(); // create tensor of size [2, 4, 2] let query: Tensor = Tensor::from_data( - Data::from([ + TensorData::from([ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]], [[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]], @@ -610,7 +610,7 @@ mod tests { &device, ); let key: Tensor = Tensor::from_data( - Data::from([ + TensorData::from([ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]], [[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]], @@ -619,7 +619,7 @@ mod tests { &device, ); let value: Tensor = Tensor::from_data( - Data::from([ + TensorData::from([ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]], [[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]], @@ -637,8 +637,8 @@ mod tests { let output = cross_attention.sliced_attention(query, key, value, 2); assert_eq!(output.shape(), Shape::from([2, 4, 4])); - output.into_data().assert_approx_eq( - &Data::from([ + output.into_data().assert_approx_eq::( + &TensorData::from([ [ [5.9201, 6.9201, 14.9951, 15.9951], [6.7557, 7.7557, 14.9986, 15.9986], @@ -652,7 +652,7 @@ mod tests { [23.0000, 24.0000, 31.0000, 32.0000], ], ]), - 3, + Tolerance::rel_abs(1e-3, 1e-3), ) } } diff --git a/src/models/embeddings.rs b/src/models/embeddings.rs index 8a13124..83b663e 100644 --- a/src/models/embeddings.rs +++ b/src/models/embeddings.rs @@ -85,26 +85,27 @@ impl Timesteps { mod tests { use super::*; use crate::TestBackend; - use burn::tensor::{Data, Shape}; + use burn::tensor::{Shape, TensorData, Tolerance}; #[test] #[cfg(not(feature = "torch"))] fn test_timesteps_even_channels() { let device = Default::default(); let timesteps = Timesteps::::new(4, true, 0.); - let xs: Tensor = Tensor::from_data(Data::from([1., 2., 3., 4.]), &device); + let xs: Tensor = + Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device); - let emb = timesteps.forward(xs); + let emb: Tensor = timesteps.forward(xs); assert_eq!(emb.shape(), Shape::from([4, 4])); - emb.to_data().assert_approx_eq( - &Data::from([ + emb.into_data().assert_approx_eq::( + &TensorData::from([ [0.5403, 1.0000, 0.8415, 0.0100], [-0.4161, 0.9998, 0.9093, 0.0200], [-0.9900, 0.9996, 0.1411, 0.0300], [-0.6536, 0.9992, -0.7568, 0.0400], ]), - 3, + Tolerance::rel_abs(1e-3, 1e-3), ); } @@ -114,13 +115,13 @@ mod tests { let device = Default::default(); let timesteps = Timesteps::::new(5, true, 0.); let xs: Tensor = - Tensor::from_data(Data::from([1., 2., 3., 4., 5.]), &device); + Tensor::from_data(TensorData::from([1., 2., 3., 4., 5.]), &device); - let emb = timesteps.forward(xs); + let emb: Tensor = timesteps.forward(xs); assert_eq!(emb.shape(), Shape::from([6, 4])); - emb.to_data().assert_approx_eq( - &Data::from([ + emb.into_data().assert_approx_eq::( + &TensorData::from([ [0.5403, 1.0000, 0.8415, 0.0100], [-0.4161, 0.9998, 0.9093, 0.0200], [-0.9900, 0.9996, 0.1411, 0.0300], @@ -128,7 +129,7 @@ mod tests { [0.2837, 0.9988, -0.9589, 0.0500], [0.0000, 0.0000, 0.0000, 0.0000], ]), - 3, + Tolerance::rel_abs(1e-3, 1e-3), ); } } diff --git a/src/models/unet_2d_blocks.rs b/src/models/unet_2d_blocks.rs index 6ad2ff3..9526e7b 100644 --- a/src/models/unet_2d_blocks.rs +++ b/src/models/unet_2d_blocks.rs @@ -25,7 +25,7 @@ use super::{ use alloc::vec; use alloc::vec::Vec; -#[derive(Config)] +#[derive(Config, Debug)] struct Downsample2DConfig { in_channels: usize, use_conv: bool, @@ -71,13 +71,13 @@ impl Downsample2D { fn forward(&self, xs: Tensor) -> Tensor { match &self.conv { - None => avg_pool2d(xs, [2, 2], [2, 2], [0, 0], true), + None => avg_pool2d(xs, [2, 2], [2, 2], [0, 0], true, false), Some(conv) => conv.forward(Self::pad_tensor(xs, self.padding)), } } } -#[derive(Config)] +#[derive(Config, Debug)] struct Upsample2DConfig { in_channels: usize, out_channels: usize, @@ -353,7 +353,7 @@ impl UNetMidBlock2D { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct UNetMidBlock2DCrossAttnConfig { in_channels: usize, temb_channels: Option, @@ -454,7 +454,7 @@ impl UNetMidBlock2DCrossAttn { } } -#[derive(Config, Copy)] +#[derive(Config, Debug, Copy)] pub struct DownBlock2DConfig { in_channels: usize, out_channels: usize, @@ -537,7 +537,7 @@ impl DownBlock2D { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct CrossAttnDownBlock2DConfig { in_channels: usize, out_channels: usize, @@ -614,7 +614,7 @@ impl CrossAttnDownBlock2D { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct UpBlock2DConfig { in_channels: usize, prev_output_channels: usize, @@ -703,7 +703,7 @@ impl UpBlock2D { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct CrossAttnUpBlock2DConfig { in_channels: usize, prev_output_channels: usize, @@ -788,13 +788,13 @@ impl CrossAttnUpBlock2D { mod tests { use super::*; use crate::TestBackend; - use burn::tensor::{Data, Distribution, Shape}; + use burn::tensor::{Distribution, Shape, TensorData, Tolerance}; #[test] fn test_downsample_2d_no_conv() { let device = Default::default(); let tensor: Tensor = Tensor::from_data( - Data::from([ + TensorData::from([ [ [[0.0351, 0.4179], [0.0137, 0.6947]], [[0.9526, 0.5386], [0.2856, 0.1839]], @@ -814,12 +814,12 @@ mod tests { let downsample_2d = Downsample2DConfig::new(4, false, 4, 0).init(&device); let output = downsample_2d.forward(tensor); - output.into_data().assert_approx_eq( - &Data::from([ + output.into_data().assert_approx_eq::( + &TensorData::from([ [[[0.2904]], [[0.4902]], [[0.4633]], [[0.3323]]], [[[0.8031]], [[0.3632]], [[0.7049]], [[0.5074]]], ]), - 3, + Tolerance::rel_abs(1e-3, 1e-3), ); } @@ -827,7 +827,7 @@ mod tests { fn test_pad_tensor_0() { let device = Default::default(); let tensor: Tensor = Tensor::from_data( - Data::from([ + TensorData::from([ [ [[0.8600, 0.9473], [0.2543, 0.6181]], [[0.3889, 0.7722], [0.6736, 0.0454]], @@ -846,8 +846,8 @@ mod tests { let output = Downsample2D::pad_tensor(tensor, 0); - output.into_data().assert_approx_eq( - &Data::from([ + output.into_data().assert_approx_eq::( + &TensorData::from([ [ [ [0.8600, 0.9473, 0.0000], @@ -893,15 +893,14 @@ mod tests { ], ], ]), - 3, + Tolerance::rel_abs(1e-3, 1e-3), ); } #[test] fn test_down_encoder_block2d() { - TestBackend::seed(0); - let device = Default::default(); + TestBackend::seed(&device, 0); let block = DownEncoderBlock2DConfig::new(32, 32).init::(&device); let tensor: Tensor = @@ -913,9 +912,8 @@ mod tests { #[test] fn test_up_decoder_block2d() { - TestBackend::seed(0); - let device = Default::default(); + TestBackend::seed(&device, 0); let block = UpDecoderBlock2DConfig::new(32, 32).init::(&device); let tensor: Tensor = diff --git a/src/transformers/clip.rs b/src/transformers/clip.rs index ca1ea0c..ef02b89 100644 --- a/src/transformers/clip.rs +++ b/src/transformers/clip.rs @@ -357,7 +357,7 @@ impl ClipTextTransformer { mod tests { use super::*; use crate::TestBackend; - use burn::tensor::{Data, Shape}; + use burn::tensor::{Shape, TensorData, Tolerance}; #[test] fn test_init_text_embeddings() { @@ -366,17 +366,19 @@ mod tests { let text_embeddings: ClipTextEmbeddings = clip_config.init_text_embeddings(&device); + assert_eq!(text_embeddings.position_ids.shape(), Shape::from([1, 77])); + + // Convert to i32 for comparison since wgpu uses I32 while other backends use I64 + let expected: Vec = (0..77).collect(); assert_eq!( - text_embeddings.position_ids.to_data(), - Data::from([[ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, - 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, - 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76 - ]]) + text_embeddings + .position_ids + .into_data() + .convert::() + .to_vec::() + .unwrap(), + expected ); - - assert_eq!(text_embeddings.position_ids.shape(), Shape::from([1, 77])); } #[test] @@ -399,8 +401,8 @@ mod tests { ClipTextTransformer::generate_causal_attention_mask(2, 4, &device); assert_eq!(mask.shape(), Shape::from([2, 1, 4, 4])); - mask.to_data().assert_approx_eq( - &Data::from([ + mask.into_data().assert_approx_eq::( + &TensorData::from([ [[ [0.0000e0, f32::MIN, f32::MIN, f32::MIN], [0.0000e0, 0.0000e0, f32::MIN, f32::MIN], @@ -414,7 +416,7 @@ mod tests { [0.0000e0, 0.0000e0, 0.0000e0, 0.0000e0], ]], ]), - 3, + Tolerance::rel_abs(1e-3, 1e-3), ); } } diff --git a/src/utils.rs b/src/utils.rs index dfc27c1..2fd2968 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -51,27 +51,27 @@ where mod tests { use super::*; use crate::TestBackend; - use burn::tensor::{Data, Shape}; + use burn::tensor::{Shape, TensorData, Tolerance}; #[test] fn test_pad_with_zeros() { let device = Default::default(); let tensor: Tensor = Tensor::from_data( - Data::from([[[1.6585, 0.4320], [-0.8701, -0.4649]]]), + TensorData::from([[[1.6585, 0.4320], [-0.8701, -0.4649]]]), &device, ); let padded = pad_with_zeros(tensor, 0, 1, 2); assert_eq!(padded.shape(), Shape::from([4, 2, 2])); - padded.to_data().assert_approx_eq( - &Data::from([ + padded.into_data().assert_approx_eq::( + &TensorData::from([ [[0.0000, 0.0000], [0.0000, 0.0000]], [[1.6585, 0.4320], [-0.8701, -0.4649]], [[0.0000, 0.0000], [0.0000, 0.0000]], [[0.0000, 0.0000], [0.0000, 0.0000]], ]), - 3, + Tolerance::rel_abs(1e-3, 1e-3), ) } }