diff --git a/README.md b/README.md index 3582ea9d..38c8e64c 100644 --- a/README.md +++ b/README.md @@ -272,7 +272,7 @@ fn test_matmul_square_matrix() { |InstanceNormalization|6, 1| |IsInf|10| |IsNaN|13, 9| -|LRN|13, 1|| +|LRN|13, 1|✅|| |LSTM|14, 7, 1| |LeakyRelu|6, 1|✅|✅| |Less|13, 9, 7, 1|✅| diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index fe9a3fcf..3c66f66e 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -83,6 +83,11 @@ lazy_static! { include_str!("../templates/matrix/transpose.wgsl"), ) .unwrap(); + tera.add_raw_template( + "matrix/lrn.wgsl", + include_str!("../templates/matrix/lrn.wgsl"), + ) + .unwrap(); tera.add_raw_template( "pool/aggregate.wgsl", include_str!("../templates/pool/aggregate.wgsl"), @@ -1321,6 +1326,38 @@ pub fn compile( threads: (ceil(output_lengths[0], 256) as _, 1, 1), } } + "LRN" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#lrn + let alpha = node.get_attribute_value("alpha", Some(0.0001))?; + let beta = node.get_attribute_value("beta", Some(0.75))?; + let bias = node.get_attribute_value("bias", Some(1.0))?; + let size = node.get_attribute_value("size", Some(1))?; + + context.insert("alpha", &alpha); + context.insert("beta", &beta); + context.insert("bias", &bias); + context.insert("size", &size); + + let left_size = f64::floor((size - 1) as f64 / 2.0) as u32; + let right_size = f64::ceil((size - 1) as f64 / 2.0) as u32; + + context.insert("left_size", &left_size); + context.insert("right_size", &right_size); + + let (x_threads, workgroup_size_x) = workgroup_size( + output_lengths[0], + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_X, + )?; + context.insert("workgroup_size_x", &workgroup_size_x); + context.insert("i_chunks", &input_chunks); + + NodeTemplate { + scalar_type: agreed_type(input_shapes, output_shapes)?, + template: "matrix/lrn.wgsl", + threads: (x_threads, 1, 1), + } + } op => return Err(CompileError::UnimplementedOp(op.to_string())), }; diff --git a/wonnx/templates/matrix/lrn.wgsl b/wonnx/templates/matrix/lrn.wgsl new file mode 100644 index 00000000..c3517ab2 --- /dev/null +++ b/wonnx/templates/matrix/lrn.wgsl @@ -0,0 +1,23 @@ +{%- include "structs.wgsl" -%} + +@group(0) @binding(0) +var input_0: Array; + +@group(0) @binding(1) +var output_0: Array; + +@compute @workgroup_size({{ workgroup_size_x }}) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let c = global_id.x; + //let chunk_start = {{ i_chunks[0][1] }}u * c; + let start = (c / {{ i_shape[0][1] }}u) * {{ i_shape[0][1] }}u; + let end = start + {{ i_shape[0][1] - 1 }}u; + + var square_sum: Scalar = Scalar(); + for (var i = max(start, c - {{left_size}}u); i <= min(end, c + {{right_size}}u); i++) { + let I = input_0.data[i]; + square_sum += I * I; + } + + output_0.data[c] = input_0.data[ c ] / pow({{ scalar_type }}({{ bias }}) + ({{ scalar_type }}({{ alpha }}) / {{ scalar_type }}({{ size }})) * square_sum, {{ scalar_type }}({{ beta }})); +} diff --git a/wonnx/tests/localresponsenormalization.rs b/wonnx/tests/localresponsenormalization.rs new file mode 100644 index 00000000..8f4202fa --- /dev/null +++ b/wonnx/tests/localresponsenormalization.rs @@ -0,0 +1,61 @@ +use std::{collections::HashMap, convert::TryInto}; +use wonnx::utils::{attribute, graph, model, node, tensor}; +mod common; + +#[test] +fn local_response_normalization() { + let mut input_data = HashMap::new(); + + let batches = 1; + let width_height: usize = 3; + let channels: usize = 4; + let data: Vec = [ + 1., 1., 2., 4., 2., 2., 1., 2., 3., 1., 2., 1., 4., 2., 3., 5., 3., 3., 2., 2., 6., 2., 3., + 1., 7., 3., 4., 2., 8., 4., 3., 2., 9., 3., 4., 4., + ] + .to_vec(); + + let shape = vec![ + batches as i64, + channels as i64, + width_height as i64, + width_height as i64, + ]; + input_data.insert("X".to_string(), data.as_slice().into()); + + let bn_model = model(graph( + vec![tensor("X", &shape)], // input + vec![tensor("Y", &shape)], // output + vec![], // infos + vec![], // intializers + // nodes + vec![node( + vec!["X"], + vec!["Y"], + "lrn", + "LRN", + vec![ + attribute("alpha", 1.0), + attribute("beta", 1.0), + attribute("bias", 0.0), + attribute("size", 2), + ], + )], + )); + + // LOGIC + let session = + pollster::block_on(wonnx::Session::from_model(bn_model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + let out_y = &result["Y"]; + + common::assert_eq_vector( + out_y.try_into().unwrap(), + &[ + 1.0, 0.4, 0.2, 0.5, 0.5, 0.8, 0.4, 1.0, 0.6, 0.4, 0.8, 2.0, 0.4, 0.30769232, 0.1764706, + 0.39999998, 0.33333334, 0.4615385, 0.5, 1.0, 0.3, 0.30769232, 0.6, 2.0, 0.2413793, + 0.24, 0.4, 1.0, 0.2, 0.32, 0.4615385, 1.0, 0.2, 0.24, 0.25, 0.5, + ], + ); +}