Skip to content

Commit 55805c5

Browse files
authored
Merge pull request #5219 from FlorentinD/relu-consider-selfgradient
Fix gradient compuatation for Relu and NormalizeRows
2 parents d452ee5 + 9ee8b7c commit 55805c5

File tree

6 files changed

+85
-37
lines changed

6 files changed

+85
-37
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/NormalizeRows.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public Matrix apply(ComputationContext ctx) {
5656
@Override
5757
public Matrix gradientForParent(ComputationContext ctx) {
5858
Matrix parentData = ctx.data(parent);
59-
Matrix thisGradient = ctx.gradient(this);
59+
Matrix normalizeRowsGradient = ctx.gradient(this);
6060

6161
Matrix parentGradient = parentData.createWithSameDimensions();
6262
int rows = parentData.rows();
@@ -71,21 +71,26 @@ public Matrix gradientForParent(ComputationContext ctx) {
7171
double l2 = Math.sqrt(l2Squared);
7272
double l2Cubed = l2 * l2Squared;
7373

74+
if (Double.compare(l2Cubed, 0) == 0) {
75+
continue;
76+
}
77+
7478
for (int col = 0; col < cols; col++) {
7579
double parentCellValue = parentData.dataAt(row, col);
7680
for (int gradCol = 0; gradCol < cols; gradCol++) {
7781
double partialGradient;
7882
if (col == gradCol) {
79-
partialGradient = thisGradient.dataAt(row, col) * (l2Squared - parentCellValue * parentCellValue);
83+
partialGradient = normalizeRowsGradient.dataAt(row, col) * (l2Squared - parentCellValue * parentCellValue);
8084
} else {
81-
partialGradient = -thisGradient.dataAt(row, gradCol) * (parentCellValue * parentData.dataAt(row, gradCol));
85+
partialGradient = -normalizeRowsGradient.dataAt(row, gradCol) * (parentCellValue * parentData.dataAt(row, gradCol));
8286
}
8387
parentGradient.addDataAt(row, col, partialGradient);
8488
}
8589

8690
parentGradient.setDataAt(row, col, parentGradient.dataAt(row, col) / l2Cubed);
8791
}
8892
}
93+
8994
return parentGradient;
9095
}
9196
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/Relu.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ public T apply(ComputationContext ctx) {
3838

3939
@Override
4040
public T gradientForParent(ComputationContext ctx) {
41-
return ctx.data(parent).map(value -> value > 0 ? 1 : ALPHA);
41+
T gradient = ctx.data(parent).map(value -> value > 0 ? 1 : ALPHA);
42+
gradient.elementwiseProductMutate(ctx.gradient(this));
43+
return gradient;
4244
}
4345
}

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/NormalizeRowsTest.java

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.ml.core.functions;
2121

2222
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.ml.core.ComputationContext;
2324
import org.neo4j.gds.ml.core.FiniteDifferenceTest;
2425
import org.neo4j.gds.ml.core.Variable;
2526
import org.neo4j.gds.ml.core.tensor.Matrix;
@@ -45,7 +46,27 @@ void testGradient() {
4546
Weights<Matrix> w = new Weights<>(new Matrix(data, 3, 2));
4647
Variable<Matrix> normalizeRows = new NormalizeRows(w);
4748

48-
finiteDifferenceShouldApproximateGradient(w, new ElementSum(List.of(normalizeRows)));
49+
finiteDifferenceShouldApproximateGradient(w, new ElementSum(List.of(new Sigmoid<>(normalizeRows))));
50+
}
51+
52+
@Test
53+
void testGradientOnZeroData() {
54+
double[] data = new double[] {
55+
0, 0,
56+
0, 0,
57+
0, 0
58+
};
59+
60+
Weights<Matrix> w = new Weights<>(new Matrix(data, 3, 2));
61+
Variable<Matrix> normalizeRows = new NormalizeRows(w);
62+
63+
ComputationContext ctx = new ComputationContext();
64+
65+
ElementSum loss = new ElementSum(List.of(normalizeRows));
66+
ctx.forward(loss);
67+
ctx.backward(loss);
68+
69+
assertThat(ctx.gradient(w)).isEqualTo(Matrix.create(0, 3, 2));
4970
}
5071

5172
@Test
@@ -67,5 +88,19 @@ void testApply() {
6788
assertThat(ctx.forward(normalizeRows)).matches(tensor -> tensor.equals(expected, 1e-8));
6889
}
6990

91+
@Test
92+
void testApplyOnZeroData() {
93+
double[] data = new double[] {
94+
0, 0,
95+
0, 0,
96+
0, 0
97+
};
98+
99+
Weights<Matrix> w = new Weights<>(new Matrix(data, 3, 2));
100+
var expected = Matrix.create(0.0D, 3, 2);
101+
102+
assertThat(ctx.forward(new NormalizeRows(w))).matches(tensor -> tensor.equals(expected, 1e-8));
103+
}
104+
70105

71106
}

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/ReluTest.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,22 @@ void shouldApproximateGradient() {
5050
finiteDifferenceShouldApproximateGradient(weights, new ElementSum(List.of(new Relu<>(weights))));
5151
}
5252

53+
@Test
54+
void considerSelfGradient() {
55+
Weights<Vector> weights = new Weights<>(new Vector(-1, 5, 2));
56+
var chainedRelu = new Sigmoid<>(new Relu<>(weights));
57+
58+
finiteDifferenceShouldApproximateGradient(weights, new ElementSum(List.of(chainedRelu)));
59+
}
60+
5361
@Test
5462
void shouldComputeRelu() {
5563
double[] vectorData = {14, -5, 36, 0};
5664
Constant<Vector> p = Constant.vector(vectorData);
5765

5866
Variable<Vector> relu = new Relu<>(p);
5967

60-
var expected = new Vector(new double[]{14, 0.01 * -5, 36, 0});
68+
var expected = new Vector(14, 0.01 * -5, 36, 0);
6169
assertThat(ctx.forward(relu)).isEqualTo(expected);
6270
}
6371

@@ -68,7 +76,7 @@ void returnsEmptyDataForEmptyVariable() {
6876

6977
Variable<Vector> relu = new Relu<>(p);
7078

71-
var expected = new Vector(new double[]{});
79+
var expected = new Vector();
7280
assertThat(ctx.forward(relu)).isEqualTo(expected);
7381
}
7482

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkFeaturesAndLabelsExtractor.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import org.apache.commons.lang3.mutable.MutableLong;
2323
import org.neo4j.gds.RelationshipType;
2424
import org.neo4j.gds.api.Graph;
25-
import org.neo4j.gds.core.utils.TerminationFlag;
2625
import org.neo4j.gds.core.concurrency.ParallelUtil;
2726
import org.neo4j.gds.core.concurrency.Pools;
27+
import org.neo4j.gds.core.utils.TerminationFlag;
2828
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2929
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
3030
import org.neo4j.gds.core.utils.mem.MemoryRange;
@@ -113,20 +113,18 @@ private static HugeLongArray extractLabels(
113113
var startRelationshipOffset = relationshipOffset.getValue();
114114
tasks.add(() -> {
115115
var currentRelationshipOffset = new MutableLong(startRelationshipOffset);
116-
partition.consume(nodeId -> {
117-
graph.forEachRelationship(nodeId, -10, (src, trg, weight) -> {
118-
if (weight == EdgeSplitter.NEGATIVE || weight == EdgeSplitter.POSITIVE) {
119-
globalLabels.set(currentRelationshipOffset.getAndIncrement(), (long) weight);
120-
} else {
121-
throw new IllegalArgumentException(formatWithLocale("Label should be either `1` or `0`. But got %f for relationship (%d, %d)",
122-
weight,
123-
src,
124-
trg
125-
));
126-
}
127-
return true;
128-
});
129-
});
116+
partition.consume(nodeId -> graph.concurrentCopy().forEachRelationship(nodeId, -10, (src, trg, weight) -> {
117+
if (weight == EdgeSplitter.NEGATIVE || weight == EdgeSplitter.POSITIVE) {
118+
globalLabels.set(currentRelationshipOffset.getAndIncrement(), (long) weight);
119+
} else {
120+
throw new IllegalArgumentException(formatWithLocale("Label should be either `1` or `0`. But got %f for relationship (%d, %d)",
121+
weight,
122+
src,
123+
trg
124+
));
125+
}
126+
return true;
127+
}));
130128
progressTracker.logProgress(partition.totalDegree());
131129
}
132130
);

proc/embeddings/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageStreamProcTest.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,21 +87,21 @@ void weightedGraphSage() {
8787
.yields();
8888

8989
assertCypherResult(streamQuery, List.of(
90-
Map.of("nodeId", 0L, "embedding", List.of(0.999999999980722)),
91-
Map.of("nodeId", 1L, "embedding", List.of(0.9999999999975783)),
92-
Map.of("nodeId", 2L, "embedding", List.of(0.9999999999880146)),
93-
Map.of("nodeId", 3L, "embedding", List.of(0.999999999965436)),
94-
Map.of("nodeId", 4L, "embedding", List.of(0.9999999999983651)),
95-
Map.of("nodeId", 5L, "embedding", List.of(0.9999999999377848)),
96-
Map.of("nodeId", 6L, "embedding", List.of(0.9999999999580143)),
97-
Map.of("nodeId", 7L, "embedding", List.of(0.9999999999420027)),
98-
Map.of("nodeId", 8L, "embedding", List.of(0.9999999996197955)),
99-
Map.of("nodeId", 9L, "embedding", List.of(-0.9999999619795564)),
100-
Map.of("nodeId", 10L, "embedding", List.of(0.9999999999485437)),
101-
Map.of("nodeId", 11L, "embedding", List.of(-0.9999999979930556)),
102-
Map.of("nodeId", 12L, "embedding", List.of(0.999999999965436)),
103-
Map.of("nodeId", 13L, "embedding", List.of(0.999999999965436)),
104-
Map.of("nodeId", 14L, "embedding", List.of(0.999999999965436))
90+
Map.of("nodeId", 0L, "embedding", List.of(-0.9999999992039947)),
91+
Map.of("nodeId", 1L, "embedding", List.of(-0.9999999999000064)),
92+
Map.of("nodeId", 2L, "embedding", List.of(-0.9999999995051105)),
93+
Map.of("nodeId", 3L, "embedding", List.of(-0.9999999985728191)),
94+
Map.of("nodeId", 4L, "embedding", List.of(-0.9999999999324936)),
95+
Map.of("nodeId", 5L, "embedding", List.of(-0.9999999974310742)),
96+
Map.of("nodeId", 6L, "embedding", List.of(-0.9999999982663691)),
97+
Map.of("nodeId", 7L, "embedding", List.of(-0.9999999976052386)),
98+
Map.of("nodeId", 8L, "embedding", List.of(-0.9999999843010093)),
99+
Map.of("nodeId", 9L, "embedding", List.of(0.99999999984301)),
100+
Map.of("nodeId", 10L, "embedding", List.of(-0.9999999978753245)),
101+
Map.of("nodeId", 11L, "embedding", List.of(0.9999999999917132)),
102+
Map.of("nodeId", 12L, "embedding", List.of(-0.9999999985728191)),
103+
Map.of("nodeId", 13L, "embedding", List.of(-0.9999999985728191)),
104+
Map.of("nodeId", 14L, "embedding", List.of(-0.9999999985728191))
105105
));
106106
}
107107

0 commit comments

Comments
 (0)