Skip to content

Commit 93473ff

Browse files
authored
Merge pull request #5236 from FlorentinD/fix-gradients-apply-self-gradient
Correct gradients by apply the self gradient
2 parents b494f3a + 237a8d0 commit 93473ff

File tree

11 files changed

+113
-65
lines changed

11 files changed

+113
-65
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/ComputationContext.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.neo4j.gds.ml.core;
2121

2222
import org.jetbrains.annotations.TestOnly;
23-
import org.neo4j.gds.ml.core.functions.PassthroughVariable;
23+
import org.neo4j.gds.ml.core.functions.SingleParentVariable;
2424
import org.neo4j.gds.ml.core.tensor.Tensor;
2525

2626
import java.util.HashMap;
@@ -70,7 +70,7 @@ public void backward(Variable<?> function) {
7070

7171
gradients.clear();
7272
Queue<BackPropTask> executionQueue = new LinkedBlockingQueue<>();
73-
PassthroughVariable<?> dummy = new PassthroughVariable<>(function);
73+
var dummy = new PassthroughVariable<>(function);
7474
executionQueue.add(new BackPropTask(function, dummy));
7575
Map<Variable<?>, AtomicInteger> upstreamCounters = new HashMap<>();
7676
initUpstream(dummy, upstreamCounters);
@@ -169,4 +169,25 @@ static class BackPropTask {
169169
}
170170
}
171171

172+
private static class PassthroughVariable<T extends Tensor<T>> extends SingleParentVariable<T, T> {
173+
174+
public PassthroughVariable(Variable<T> parent) {
175+
super(parent, parent.dimensions());
176+
177+
if (parent instanceof PassthroughVariable) {
178+
throw new IllegalArgumentException("Redundant use of PassthroughVariables. Chaining does not make sense.");
179+
}
180+
}
181+
182+
@Override
183+
public T apply(ComputationContext ctx) {
184+
return ctx.data(parent);
185+
}
186+
187+
@Override
188+
public T gradientForParent(ComputationContext ctx) {
189+
// initialize gradient computation with `1`
190+
return ctx.data(parent).map(v -> 1);
191+
}
192+
}
172193
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/CrossEntropyLoss.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,19 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
7171
Matrix gradient = predictionsMatrix.createWithSameDimensions();
7272
var targetsVector = ctx.data(targets);
7373

74-
var multiplier = -1.0 / gradient.rows();
74+
var multiplier = - ctx.gradient(this).value() / gradient.rows();
7575
for (int row = 0; row < gradient.rows(); row++) {
7676
var trueClass = (int) targetsVector.dataAt(row);
7777
var predictedProbabilityForTrueClass = predictionsMatrix.dataAt(row * predictionsMatrix.cols() + trueClass);
7878

7979
// Compare to a threshold value rather than `0`, very small probability can result in setting infinite gradient values.
8080
if (predictedProbabilityForTrueClass > PREDICTED_PROBABILITY_THRESHOLD) {
81-
gradient.setDataAt(
82-
row * predictionsMatrix.cols() + trueClass,
83-
multiplier / predictedProbabilityForTrueClass
84-
);
81+
gradient.setDataAt(row, trueClass, multiplier / predictedProbabilityForTrueClass);
8582
}
8683
}
8784
return gradient;
8885
} else {
89-
// targets should never require a gradient
90-
return ctx.data(parent).createWithSameDimensions();
86+
throw new IllegalStateException("The gradient should not be necessary for the targets. But got: " + targets.render());
9187
}
9288
}
9389
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/EWiseAddMatrixScalar.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ public Matrix apply(ComputationContext ctx) {
5252

5353
@Override
5454
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
55+
Matrix selfGradient = ctx.gradient(this);
5556
if (parent == matrixVariable) {
56-
return ctx.gradient(this);
57+
return selfGradient;
5758
} else {
58-
return new Scalar(ctx.gradient(this).aggregateSum());
59+
return new Scalar(selfGradient.aggregateSum());
5960
}
6061
}
6162
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/ElementSum.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public Scalar apply(ComputationContext ctx) {
4545

4646
@Override
4747
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
48-
return ctx.data(parent).map(ignore -> ctx.gradient(this).value());
48+
double selfGradient = ctx.gradient(this).value();
49+
return ctx.data(parent).map(ignore -> selfGradient);
4950
}
5051
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/L2NormSquared.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public Scalar apply(ComputationContext ctx) {
5454

5555
@Override
5656
public Matrix gradientForParent(ComputationContext ctx) {
57-
return ctx.data(parent).copy().scalarMultiply(2 * ctx.gradient(this).value());
57+
double selfGradient = ctx.gradient(this).value();
58+
return ctx.data(parent).scalarMultiply(2 * selfGradient);
5859
}
5960
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/LogisticLoss.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ else if (predicted == 1.0) {
128128

129129
@Override
130130
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
131+
var selfGradient = ctx.gradient(this).value();
132+
131133
if (parent == weights) {
132134
ctx.forward(predictions);
133135
var predVector = ctx.data(predictions);
@@ -141,7 +143,7 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
141143
for (int idx = 0; idx < numberOfExamples; idx++) {
142144
double errorPerExample = (predVector.dataAt(idx) - targetVector.dataAt(idx)) / numberOfExamples;
143145
for (int feature = 0; feature < featureCount; feature++) {
144-
gradient.addDataAt(feature, errorPerExample * featuresTensor.dataAt(idx, feature));
146+
gradient.addDataAt(feature, selfGradient * errorPerExample * featuresTensor.dataAt(idx, feature));
145147
}
146148
}
147149
return gradient;
@@ -154,13 +156,14 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
154156

155157
for (int idx = 0; idx < numberOfExamples; idx++) {
156158
double errorPerExample = (predVector.dataAt(idx) - targetVector.dataAt(idx));
157-
gradient.addDataAt(0, errorPerExample);
159+
gradient.addDataAt(0, selfGradient * errorPerExample);
158160
}
159161

160162
return gradient.scalarMultiplyMutate(1.0D / numberOfExamples);
161163
} else {
162-
// assume feature and target variables do not require gradient
163-
return ctx.data(parent).createWithSameDimensions();
164+
throw new IllegalStateException(
165+
"The gradient should only be computed for the bias and the weights parents, but got " + parent.render()
166+
);
164167
}
165168
}
166169

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/PassthroughVariable.java

Lines changed: 0 additions & 45 deletions
This file was deleted.

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/ReducedCrossEntropyLoss.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ public static long sizeInBytes() {
6969

7070
@Override
7171
public Scalar apply(ComputationContext ctx) {
72+
// manually call forward as `predictions` is not registered as a parent
7273
var predictionsMatrix = ctx.forward(predictions);
7374
var labelsVector = ctx.data(labels);
7475

@@ -85,9 +86,13 @@ public Scalar apply(ComputationContext ctx) {
8586

8687
@Override
8788
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
89+
// manually call forward as `predictions` is not registered as a parent
8890
var predMatrix = ctx.forward(predictions);
8991
var labelsVector = ctx.data(labels);
9092
int numberOfExamples = labelsVector.length();
93+
94+
var selfGradient = ctx.gradient(this).value();
95+
9196
if (parent == weights) {
9297
var weightsMatrix = ctx.data(weights);
9398
var featureMatrix = ctx.data(features);
@@ -102,7 +107,7 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
102107
var indicatorIsTrueClass = trueClass == classIdx ? 1.0 : 0.0;
103108
var errorPerExample = (predictedClassProbability - indicatorIsTrueClass) / numberOfExamples;
104109
for (int feature = 0; feature < featureCount; feature++) {
105-
gradient.addDataAt(classIdx, feature, errorPerExample * featureMatrix.dataAt(row, feature));
110+
gradient.addDataAt(classIdx, feature, selfGradient * errorPerExample * featureMatrix.dataAt(row, feature));
106111
}
107112
}
108113
}
@@ -118,7 +123,7 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
118123
double predictedClassProbability = predMatrix.dataAt(row, classIdx);
119124
var indicatorIsTrueClass = trueClass == classIdx ? 1.0 : 0.0;
120125
var errorPerExample = (predictedClassProbability - indicatorIsTrueClass) / numberOfExamples;
121-
gradient.addDataAt(classIdx, errorPerExample);
126+
gradient.addDataAt(classIdx, selfGradient * errorPerExample);
122127
}
123128
}
124129
return gradient;

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/CrossEntropyLossTest.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,28 @@ void shouldComputeGradientCorrectly() {
7676
}
7777

7878
@Test
79-
void infiniteSmallProbabilities() {
79+
void considerSelfGradient() {
80+
var targets = Constant.vector(new double[]{1.0, 2.0, 0.0});
81+
var predictions = new Weights<>(
82+
new Matrix(
83+
new double[]{
84+
0.35, 0.65, 0.0,
85+
0.45, 0.45, 0.1,
86+
0.14, 0.66, 0.2
87+
},
88+
3, 3
89+
)
90+
);
8091

92+
var loss = new CrossEntropyLoss(predictions, targets);
93+
var chainedLoss = new Sigmoid<>(loss);
94+
95+
finiteDifferenceShouldApproximateGradient(predictions, chainedLoss);
96+
}
97+
98+
99+
@Test
100+
void infiniteSmallProbabilities() {
81101
var predictions = new Weights<>(new Matrix(new double[]{5.277E-321, 5.277E-321}, 1, 2));
82102
var targets = Constant.vector(new double[]{1});
83103

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/LogisticLossTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.ml.core.functions;
2121

2222
import org.assertj.core.data.Offset;
23+
import org.junit.jupiter.api.Test;
2324
import org.junit.jupiter.params.ParameterizedTest;
2425
import org.junit.jupiter.params.provider.ValueSource;
2526
import org.neo4j.gds.ml.core.ComputationContext;
@@ -71,6 +72,21 @@ void logisticLossApproximatesGradient(boolean withBias) {
7172
finiteDifferenceShouldApproximateGradient(weights, loss);
7273
}
7374

75+
@Test
76+
void considerSelfGradient() {
77+
var features = Constant.matrix(new double[]{0.23, 0.52, 0.62, 0.32, 0.64, 0.71}, 2, 3);
78+
var targets = Constant.vector(new double[]{1.0, 0.0});
79+
var weights = new Weights<>(new Matrix(new double[]{0.35, 0.41, 1.0}, 1, 3));
80+
var bias = Weights.ofScalar(2.5);
81+
82+
var predictions = new Sigmoid<>(new MatrixMultiplyWithTransposedSecondOperand(features, weights));
83+
84+
var loss = new LogisticLoss(weights, bias, predictions, features, targets);
85+
var chainedLoss = new Sigmoid<>(loss);
86+
87+
finiteDifferenceShouldApproximateGradient(weights, chainedLoss);
88+
}
89+
7490
@Override
7591
public double epsilon() {
7692
return 1e-7;

0 commit comments

Comments
 (0)