Skip to content

Commit f9247ed

Browse files
committed
Consider self gradient in LogisticLoss
1 parent 57fee87 commit f9247ed

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/LogisticLoss.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ else if (predicted == 1.0) {
128128

129129
@Override
130130
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
131+
var selfGradient = ctx.gradient(this).value();
132+
131133
if (parent == weights) {
132134
ctx.forward(predictions);
133135
var predVector = ctx.data(predictions);
@@ -141,7 +143,7 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
141143
for (int idx = 0; idx < numberOfExamples; idx++) {
142144
double errorPerExample = (predVector.dataAt(idx) - targetVector.dataAt(idx)) / numberOfExamples;
143145
for (int feature = 0; feature < featureCount; feature++) {
144-
gradient.addDataAt(feature, errorPerExample * featuresTensor.dataAt(idx, feature));
146+
gradient.addDataAt(feature, selfGradient * errorPerExample * featuresTensor.dataAt(idx, feature));
145147
}
146148
}
147149
return gradient;
@@ -154,13 +156,14 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
154156

155157
for (int idx = 0; idx < numberOfExamples; idx++) {
156158
double errorPerExample = (predVector.dataAt(idx) - targetVector.dataAt(idx));
157-
gradient.addDataAt(0, errorPerExample);
159+
gradient.addDataAt(0, selfGradient * errorPerExample);
158160
}
159161

160162
return gradient.scalarMultiplyMutate(1.0D / numberOfExamples);
161163
} else {
162-
// assume feature and target variables do not require gradient
163-
return ctx.data(parent).createWithSameDimensions();
164+
throw new IllegalStateException(
165+
"The gradient should only be computed for the bias and the weights parents, but got " + parent.render()
166+
);
164167
}
165168
}
166169

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/LogisticLossTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.ml.core.functions;
2121

2222
import org.assertj.core.data.Offset;
23+
import org.junit.jupiter.api.Test;
2324
import org.junit.jupiter.params.ParameterizedTest;
2425
import org.junit.jupiter.params.provider.ValueSource;
2526
import org.neo4j.gds.ml.core.ComputationContext;
@@ -71,6 +72,21 @@ void logisticLossApproximatesGradient(boolean withBias) {
7172
finiteDifferenceShouldApproximateGradient(weights, loss);
7273
}
7374

75+
@Test
76+
void considerSelfGradient() {
77+
var features = Constant.matrix(new double[]{0.23, 0.52, 0.62, 0.32, 0.64, 0.71}, 2, 3);
78+
var targets = Constant.vector(new double[]{1.0, 0.0});
79+
var weights = new Weights<>(new Matrix(new double[]{0.35, 0.41, 1.0}, 1, 3));
80+
var bias = Weights.ofScalar(2.5);
81+
82+
var predictions = new Sigmoid<>(new MatrixMultiplyWithTransposedSecondOperand(features, weights));
83+
84+
var loss = new LogisticLoss(weights, bias, predictions, features, targets);
85+
var chainedLoss = new Sigmoid<>(loss);
86+
87+
finiteDifferenceShouldApproximateGradient(weights, chainedLoss);
88+
}
89+
7490
@Override
7591
public double epsilon() {
7692
return 1e-7;

0 commit comments

Comments
 (0)