Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions .github/workflows/validate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@ jobs:
uses: Swatinem/rust-cache@v2
with:
key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.backend }}-${{ hashFiles('**/Cargo.toml') }}
- name: (Linux) Install llvmpipe, lavapipe
if: runner.os == 'Linux'
- name: (Linux) Install Vulkan deps
if: runner.os == 'Linux' && matrix.backend == 'wgpu'
run: |-
sudo apt-get update -y -qq
sudo add-apt-repository ppa:kisak/kisak-mesa -y
sudo apt-get update
sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers
sudo apt-get install -y libvulkan1 mesa-vulkan-drivers vulkan-tools libxcb-xfixes0-dev
- name: (Windows) Install warp
if: runner.os == 'Windows'
shell: bash
Expand Down Expand Up @@ -88,9 +86,6 @@ jobs:

echo "VK_DRIVER_FILES=$PWD/mesa/lvp_icd.x86_64.json" >> "$GITHUB_ENV"
echo "GALLIUM_DRIVER=llvmpipe" >> "$GITHUB_ENV"
- name: (Windows) Install dxc
if: runner.os == 'Windows'
uses: napokue/setup-dxc@v1.1.0
- name: Run tests
shell: bash
run: |-
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ Cargo.lock
# IDEs
.idea
.fleet

# Others
inspiration/
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ torch = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
burn = { version = "0.13.0", default-features = false }
burn = { version = "0.20.1", default-features = false }
num-traits = { version = "0.2.18", default-features = false }
serde = { version = "1.0.197", default-features = false, features = [
"derive",
Expand Down
20 changes: 4 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,13 @@ pub mod pipelines;
pub mod transformers;
pub mod utils;

#[cfg(all(test, feature = "ndarray"))]
use burn::backend::ndarray;

#[cfg(all(test, feature = "torch"))]
use burn::backend::libtorch;

#[cfg(all(test, feature = "wgpu"))]
use burn::backend::wgpu;

extern crate alloc;

#[cfg(all(test, feature = "ndarray"))]
pub type TestBackend = ndarray::NdArray<f32>;
pub type TestBackend = burn::backend::NdArray<f32>;

#[cfg(all(test, feature = "torch"))]
pub type TestBackend = libtorch::LibTorch<f32>;
pub type TestBackend = burn::backend::LibTorch<f32>;

#[cfg(all(test, feature = "wgpu", not(target_os = "macos")))]
pub type TestBackend = wgpu::Wgpu<wgpu::Vulkan, f32, i32>;

#[cfg(all(test, feature = "wgpu", target_os = "macos"))]
pub type TestBackend = wgpu::Wgpu<wgpu::Metal, f32, i32>;
#[cfg(all(test, feature = "wgpu"))]
pub type TestBackend = burn::backend::Wgpu;
46 changes: 23 additions & 23 deletions src/models/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use burn::tensor::Tensor;
#[allow(unused_imports)]
use num_traits::Float;

#[derive(Config)]
#[derive(Config, Debug)]
pub struct GeGluConfig {
/// The size of the input features.
d_input: usize,
Expand Down Expand Up @@ -44,7 +44,7 @@ impl<B: Backend> GeGlu<B> {
}
}

#[derive(Config)]
#[derive(Config, Debug)]
pub struct FeedForwardConfig {
/// The size of the input features.
pub d_input: usize,
Expand Down Expand Up @@ -90,7 +90,7 @@ impl<B: Backend> FeedForward<B> {
}
}

#[derive(Config)]
#[derive(Config, Debug)]
pub struct CrossAttentionConfig {
/// The number of channels in the query.
d_query: usize,
Expand Down Expand Up @@ -232,7 +232,7 @@ impl<B: Backend> CrossAttention<B> {
}
}

#[derive(Config)]
#[derive(Config, Debug)]
pub struct BasicTransformerBlockConfig {
d_model: usize,
d_context: Option<usize>,
Expand Down Expand Up @@ -488,13 +488,13 @@ mod tests {
use super::*;
use crate::TestBackend;
use burn::module::{Param, ParamId};
use burn::tensor::{Data, Shape};
use burn::tensor::{Shape, TensorData, Tolerance};

#[test]
fn test_geglu_tensor_shape_3() {
let device = Default::default();
let weight = Tensor::from_data(
Data::from([
TensorData::from([
[
0.1221, 2.0378, -0.1171, 1.3004, -0.9630, -0.3108, -1.3376, -1.0593,
],
Expand All @@ -505,7 +505,7 @@ mod tests {
&device,
);
let bias = Tensor::from_data(
Data::from([
TensorData::from([
0.2867778149426027,
0.6646517317105776,
0.023946332404821136,
Expand All @@ -526,7 +526,7 @@ mod tests {
};

let tensor: Tensor<TestBackend, 3> = Tensor::from_data(
Data::from([
TensorData::from([
[[1., 2.], [3., 4.], [5., 6.]],
[[7., 8.], [9., 10.], [11., 12.]],
]),
Expand All @@ -535,8 +535,8 @@ mod tests {

let output = geglu.forward(tensor);
assert_eq!(output.shape(), Shape::from([2, 3, 4]));
output.to_data().assert_approx_eq(
&Data::from([
output.into_data().assert_approx_eq::<f32>(
&TensorData::from([
[
[4.2632e0, -1.7927e-1, -2.3216e-1, -3.7916e-2],
[1.3460e1, -2.9266e-1, -2.1707e-4, -4.5595e-2],
Expand All @@ -548,22 +548,22 @@ mod tests {
[1.0119e2, -2.1943e-5, -0.0000e0, -0.0000e0],
],
]),
2,
Tolerance::rel_abs(1e-2, 1e-2),
);
}

#[test]
fn test_geglu_tensor_shape_2() {
let device = Default::default();
let weight = Tensor::from_data(
Data::from([
TensorData::from([
[0.6054, 1.9322, 0.1445, 1.3004, -0.6853, -0.8947],
[-0.3678, 0.4081, -1.9001, -1.5843, -0.9399, 0.1018],
]),
&device,
);
let bias = Tensor::from_data(
Data::from([
TensorData::from([
0.3237631905393836,
0.22052049807936902,
-0.3196353346822061,
Expand All @@ -582,17 +582,17 @@ mod tests {
};

let tensor: Tensor<TestBackend, 2> =
Tensor::from_data(Data::from([[1., 2.], [3., 4.], [5., 6.]]), &device);
Tensor::from_data(TensorData::from([[1., 2.], [3., 4.], [5., 6.]]), &device);

let output = geglu.forward(tensor);
assert_eq!(output.shape(), Shape::from([3, 3]));
output.to_data().assert_approx_eq(
&Data::from([
output.into_data().assert_approx_eq::<f32>(
&TensorData::from([
[-2.4192e-5, -3.3057e-2, 2.8535e-1],
[-0.0000e0, -2.0983e-7, 5.2465e-1],
[-0.0000e0, -0.0000e0, 1.2599e-2],
]),
1,
Tolerance::rel_abs(1e-1, 1e-1),
);
}

Expand All @@ -601,7 +601,7 @@ mod tests {
let device = Default::default();
// create tensor of size [2, 4, 2]
let query: Tensor<TestBackend, 3> = Tensor::from_data(
Data::from([
TensorData::from([
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]],
[[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]],
Expand All @@ -610,7 +610,7 @@ mod tests {
&device,
);
let key: Tensor<TestBackend, 3> = Tensor::from_data(
Data::from([
TensorData::from([
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]],
[[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]],
Expand All @@ -619,7 +619,7 @@ mod tests {
&device,
);
let value: Tensor<TestBackend, 3> = Tensor::from_data(
Data::from([
TensorData::from([
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]],
[[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]],
Expand All @@ -637,8 +637,8 @@ mod tests {
let output = cross_attention.sliced_attention(query, key, value, 2);

assert_eq!(output.shape(), Shape::from([2, 4, 4]));
output.into_data().assert_approx_eq(
&Data::from([
output.into_data().assert_approx_eq::<f32>(
&TensorData::from([
[
[5.9201, 6.9201, 14.9951, 15.9951],
[6.7557, 7.7557, 14.9986, 15.9986],
Expand All @@ -652,7 +652,7 @@ mod tests {
[23.0000, 24.0000, 31.0000, 32.0000],
],
]),
3,
Tolerance::rel_abs(1e-3, 1e-3),
)
}
}
23 changes: 12 additions & 11 deletions src/models/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
}

impl<B: Backend> TimestepEmbedding<B> {
fn forward(&self, xs: Tensor<B, 2>) -> Tensor<B, 2> {

Check warning on line 38 in src/models/embeddings.rs

View workflow job for this annotation

GitHub Actions / check-std (ubuntu-latest, stable, ndarray)

method `forward` is never used
let xs = silu(self.linear_1.forward(xs));
self.linear_2.forward(xs)
}
Expand Down Expand Up @@ -85,26 +85,27 @@
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Data, Shape};
use burn::tensor::{Shape, TensorData, Tolerance};

#[test]
#[cfg(not(feature = "torch"))]
fn test_timesteps_even_channels() {
let device = Default::default();
let timesteps = Timesteps::<TestBackend>::new(4, true, 0.);
let xs: Tensor<TestBackend, 1> = Tensor::from_data(Data::from([1., 2., 3., 4.]), &device);
let xs: Tensor<TestBackend, 1> =
Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device);

let emb = timesteps.forward(xs);
let emb: Tensor<TestBackend, 2> = timesteps.forward(xs);

assert_eq!(emb.shape(), Shape::from([4, 4]));
emb.to_data().assert_approx_eq(
&Data::from([
emb.into_data().assert_approx_eq::<f32>(
&TensorData::from([
[0.5403, 1.0000, 0.8415, 0.0100],
[-0.4161, 0.9998, 0.9093, 0.0200],
[-0.9900, 0.9996, 0.1411, 0.0300],
[-0.6536, 0.9992, -0.7568, 0.0400],
]),
3,
Tolerance::rel_abs(1e-3, 1e-3),
);
}

Expand All @@ -114,21 +115,21 @@
let device = Default::default();
let timesteps = Timesteps::<TestBackend>::new(5, true, 0.);
let xs: Tensor<TestBackend, 1> =
Tensor::from_data(Data::from([1., 2., 3., 4., 5.]), &device);
Tensor::from_data(TensorData::from([1., 2., 3., 4., 5.]), &device);

let emb = timesteps.forward(xs);
let emb: Tensor<TestBackend, 2> = timesteps.forward(xs);

assert_eq!(emb.shape(), Shape::from([6, 4]));
emb.to_data().assert_approx_eq(
&Data::from([
emb.into_data().assert_approx_eq::<f32>(
&TensorData::from([
[0.5403, 1.0000, 0.8415, 0.0100],
[-0.4161, 0.9998, 0.9093, 0.0200],
[-0.9900, 0.9996, 0.1411, 0.0300],
[-0.6536, 0.9992, -0.7568, 0.0400],
[0.2837, 0.9988, -0.9589, 0.0500],
[0.0000, 0.0000, 0.0000, 0.0000],
]),
3,
Tolerance::rel_abs(1e-3, 1e-3),
);
}
}
Loading