Skip to content

Commit ed19d0a

Browse files
committed
Consider self gradient in ReducedCrossEntropyLoss
1 parent f9247ed commit ed19d0a

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/ReducedCrossEntropyLoss.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
9090
var predMatrix = ctx.forward(predictions);
9191
var labelsVector = ctx.data(labels);
9292
int numberOfExamples = labelsVector.length();
93+
94+
var selfGradient = ctx.gradient(this).value();
95+
9396
if (parent == weights) {
9497
var weightsMatrix = ctx.data(weights);
9598
var featureMatrix = ctx.data(features);
@@ -104,7 +107,7 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
104107
var indicatorIsTrueClass = trueClass == classIdx ? 1.0 : 0.0;
105108
var errorPerExample = (predictedClassProbability - indicatorIsTrueClass) / numberOfExamples;
106109
for (int feature = 0; feature < featureCount; feature++) {
107-
gradient.addDataAt(classIdx, feature, errorPerExample * featureMatrix.dataAt(row, feature));
110+
gradient.addDataAt(classIdx, feature, selfGradient * errorPerExample * featureMatrix.dataAt(row, feature));
108111
}
109112
}
110113
}
@@ -120,7 +123,7 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
120123
double predictedClassProbability = predMatrix.dataAt(row, classIdx);
121124
var indicatorIsTrueClass = trueClass == classIdx ? 1.0 : 0.0;
122125
var errorPerExample = (predictedClassProbability - indicatorIsTrueClass) / numberOfExamples;
123-
gradient.addDataAt(classIdx, errorPerExample);
126+
gradient.addDataAt(classIdx, selfGradient * errorPerExample);
124127
}
125128
}
126129
return gradient;

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/ReducedCrossEntropyLossTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,35 @@ void shouldComputeGradientCorrectlyStandard() {
172172
finiteDifferenceShouldApproximateGradient(List.of(bias, weights), loss);
173173
}
174174

175+
@Test
176+
void considerSelfGradient() {
177+
var features = Constant.matrix(
178+
new double[]{0.23, 0.52, 0.62, 0.32, 0.64, 0.71, 0.29, -0.52, 0.12, -0.92, 0.6, -0.11},
179+
3,
180+
4
181+
);
182+
var labels = Constant.vector(new double[]{1.0, 0.0, 2.0});
183+
184+
var weights = new Weights<>(new Matrix(new double[]{0.35, 0.41, 1.0, 0.1, 0.54, 0.12, 0.81, 0.7}, 2, 4));
185+
var bias = Weights.ofVector(0.37, 0.37);
186+
187+
var weightedFeatures = new MatrixMultiplyWithTransposedSecondOperand(features, weights);
188+
var affineVariable = new MatrixVectorSum(weightedFeatures, bias);
189+
190+
var predictions = new ReducedSoftmax(affineVariable);
191+
192+
var loss = new ReducedCrossEntropyLoss(
193+
predictions,
194+
weights,
195+
bias,
196+
features,
197+
labels
198+
);
199+
var chainedLoss = new Sigmoid<>(loss);
200+
201+
finiteDifferenceShouldApproximateGradient(List.of(bias, weights), chainedLoss);
202+
}
203+
175204
@Override
176205
public double epsilon() {
177206
return 1e-7;

0 commit comments

Comments
 (0)