diff --git a/README.md b/README.md
index 3582ea9d..38c8e64c 100644
--- a/README.md
+++ b/README.md
@@ -272,7 +272,7 @@ fn test_matmul_square_matrix() {
|InstanceNormalization|6, 1|
|IsInf|10|
|IsNaN|13, 9|
-|LRN|13, 1||
+|LRN|13, 1|✅||
|LSTM|14, 7, 1|
|LeakyRelu|6, 1|✅|✅|
|Less|13, 9, 7, 1|✅|
diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs
index fe9a3fcf..3c66f66e 100644
--- a/wonnx/src/compiler.rs
+++ b/wonnx/src/compiler.rs
@@ -83,6 +83,11 @@ lazy_static! {
include_str!("../templates/matrix/transpose.wgsl"),
)
.unwrap();
+ tera.add_raw_template(
+ "matrix/lrn.wgsl",
+ include_str!("../templates/matrix/lrn.wgsl"),
+ )
+ .unwrap();
tera.add_raw_template(
"pool/aggregate.wgsl",
include_str!("../templates/pool/aggregate.wgsl"),
@@ -1321,6 +1326,38 @@ pub fn compile(
threads: (ceil(output_lengths[0], 256) as _, 1, 1),
}
}
+ "LRN" => {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#lrn
+ let alpha = node.get_attribute_value("alpha", Some(0.0001))?;
+ let beta = node.get_attribute_value("beta", Some(0.75))?;
+ let bias = node.get_attribute_value("bias", Some(1.0))?;
+ let size = node.get_attribute_value("size", Some(1))?;
+
+ context.insert("alpha", &alpha);
+ context.insert("beta", &beta);
+ context.insert("bias", &bias);
+ context.insert("size", &size);
+
+ let left_size = f64::floor((size - 1) as f64 / 2.0) as u32;
+ let right_size = f64::ceil((size - 1) as f64 / 2.0) as u32;
+
+ context.insert("left_size", &left_size);
+ context.insert("right_size", &right_size);
+
+ let (x_threads, workgroup_size_x) = workgroup_size(
+ output_lengths[0],
+ MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
+ MAX_WORKGROUP_SIZE_X,
+ )?;
+ context.insert("workgroup_size_x", &workgroup_size_x);
+ context.insert("i_chunks", &input_chunks);
+
+ NodeTemplate {
+ scalar_type: agreed_type(input_shapes, output_shapes)?,
+ template: "matrix/lrn.wgsl",
+ threads: (x_threads, 1, 1),
+ }
+ }
op => return Err(CompileError::UnimplementedOp(op.to_string())),
};
diff --git a/wonnx/templates/matrix/lrn.wgsl b/wonnx/templates/matrix/lrn.wgsl
new file mode 100644
index 00000000..c3517ab2
--- /dev/null
+++ b/wonnx/templates/matrix/lrn.wgsl
@@ -0,0 +1,23 @@
+{%- include "structs.wgsl" -%}
+
+@group(0) @binding(0)
+var input_0: Array;
+
+@group(0) @binding(1)
+var output_0: Array;
+
+@compute @workgroup_size({{ workgroup_size_x }})
+fn main(@builtin(global_invocation_id) global_id: vec3) {
+ let c = global_id.x;
+ //let chunk_start = {{ i_chunks[0][1] }}u * c;
+ let start = (c / {{ i_shape[0][1] }}u) * {{ i_shape[0][1] }}u;
+ let end = start + {{ i_shape[0][1] - 1 }}u;
+
+ var square_sum: Scalar = Scalar();
+ for (var i = max(start, c - {{left_size}}u); i <= min(end, c + {{right_size}}u); i++) {
+ let I = input_0.data[i];
+ square_sum += I * I;
+ }
+
+ output_0.data[c] = input_0.data[ c ] / pow({{ scalar_type }}({{ bias }}) + ({{ scalar_type }}({{ alpha }}) / {{ scalar_type }}({{ size }})) * square_sum, {{ scalar_type }}({{ beta }}));
+}
diff --git a/wonnx/tests/localresponsenormalization.rs b/wonnx/tests/localresponsenormalization.rs
new file mode 100644
index 00000000..8f4202fa
--- /dev/null
+++ b/wonnx/tests/localresponsenormalization.rs
@@ -0,0 +1,61 @@
+use std::{collections::HashMap, convert::TryInto};
+use wonnx::utils::{attribute, graph, model, node, tensor};
+mod common;
+
+#[test]
+fn local_response_normalization() {
+ let mut input_data = HashMap::new();
+
+ let batches = 1;
+ let width_height: usize = 3;
+ let channels: usize = 4;
+ let data: Vec = [
+ 1., 1., 2., 4., 2., 2., 1., 2., 3., 1., 2., 1., 4., 2., 3., 5., 3., 3., 2., 2., 6., 2., 3.,
+ 1., 7., 3., 4., 2., 8., 4., 3., 2., 9., 3., 4., 4.,
+ ]
+ .to_vec();
+
+ let shape = vec![
+ batches as i64,
+ channels as i64,
+ width_height as i64,
+ width_height as i64,
+ ];
+ input_data.insert("X".to_string(), data.as_slice().into());
+
+ let bn_model = model(graph(
+ vec![tensor("X", &shape)], // input
+ vec![tensor("Y", &shape)], // output
+ vec![], // infos
+ vec![], // intializers
+ // nodes
+ vec![node(
+ vec!["X"],
+ vec!["Y"],
+ "lrn",
+ "LRN",
+ vec![
+ attribute("alpha", 1.0),
+ attribute("beta", 1.0),
+ attribute("bias", 0.0),
+ attribute("size", 2),
+ ],
+ )],
+ ));
+
+ // LOGIC
+ let session =
+ pollster::block_on(wonnx::Session::from_model(bn_model)).expect("Session did not create");
+
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+ let out_y = &result["Y"];
+
+ common::assert_eq_vector(
+ out_y.try_into().unwrap(),
+ &[
+ 1.0, 0.4, 0.2, 0.5, 0.5, 0.8, 0.4, 1.0, 0.6, 0.4, 0.8, 2.0, 0.4, 0.30769232, 0.1764706,
+ 0.39999998, 0.33333334, 0.4615385, 0.5, 1.0, 0.3, 0.30769232, 0.6, 2.0, 0.2413793,
+ 0.24, 0.4, 1.0, 0.2, 0.32, 0.4615385, 1.0, 0.2, 0.24, 0.25, 0.5,
+ ],
+ );
+}