Skip to content

Commit 9ee8b7c

Browse files
committed
Test ReLu considers self gradient
1 parent 8eb4858 commit 9ee8b7c

File tree

1 file changed

+10
-2
lines changed
  • ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions

1 file changed

+10
-2
lines changed

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/ReluTest.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,22 @@ void shouldApproximateGradient() {
5050
finiteDifferenceShouldApproximateGradient(weights, new ElementSum(List.of(new Relu<>(weights))));
5151
}
5252

53+
@Test
54+
void considerSelfGradient() {
55+
Weights<Vector> weights = new Weights<>(new Vector(-1, 5, 2));
56+
var chainedRelu = new Sigmoid<>(new Relu<>(weights));
57+
58+
finiteDifferenceShouldApproximateGradient(weights, new ElementSum(List.of(chainedRelu)));
59+
}
60+
5361
@Test
5462
void shouldComputeRelu() {
5563
double[] vectorData = {14, -5, 36, 0};
5664
Constant<Vector> p = Constant.vector(vectorData);
5765

5866
Variable<Vector> relu = new Relu<>(p);
5967

60-
var expected = new Vector(new double[]{14, 0.01 * -5, 36, 0});
68+
var expected = new Vector(14, 0.01 * -5, 36, 0);
6169
assertThat(ctx.forward(relu)).isEqualTo(expected);
6270
}
6371

@@ -68,7 +76,7 @@ void returnsEmptyDataForEmptyVariable() {
6876

6977
Variable<Vector> relu = new Relu<>(p);
7078

71-
var expected = new Vector(new double[]{});
79+
var expected = new Vector();
7280
assertThat(ctx.forward(relu)).isEqualTo(expected);
7381
}
7482

0 commit comments

Comments
 (0)