Skip to content

Commit bf4aa87

Browse files
committed
Consider self gradient in CrossEntropyLoss::gradient
1 parent bbf6acc commit bf4aa87

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/CrossEntropyLoss.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,14 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
7171
Matrix gradient = predictionsMatrix.createWithSameDimensions();
7272
var targetsVector = ctx.data(targets);
7373

74-
var multiplier = -1.0 / gradient.rows();
74+
var multiplier = - ctx.gradient(this).value() / gradient.rows();
7575
for (int row = 0; row < gradient.rows(); row++) {
7676
var trueClass = (int) targetsVector.dataAt(row);
7777
var predictedProbabilityForTrueClass = predictionsMatrix.dataAt(row * predictionsMatrix.cols() + trueClass);
7878

7979
// Compare to a threshold value rather than `0`, very small probability can result in setting infinite gradient values.
8080
if (predictedProbabilityForTrueClass > PREDICTED_PROBABILITY_THRESHOLD) {
81-
gradient.setDataAt(
82-
row * predictionsMatrix.cols() + trueClass,
83-
multiplier / predictedProbabilityForTrueClass
81+
gradient.setDataAt(row, trueClass, multiplier / predictedProbabilityForTrueClass
8482
);
8583
}
8684
}

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/functions/CrossEntropyLossTest.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,28 @@ void shouldComputeGradientCorrectly() {
7676
}
7777

7878
@Test
79-
void infiniteSmallProbabilities() {
79+
void considerSelfGradient() {
80+
var targets = Constant.vector(new double[]{1.0, 2.0, 0.0});
81+
var predictions = new Weights<>(
82+
new Matrix(
83+
new double[]{
84+
0.35, 0.65, 0.0,
85+
0.45, 0.45, 0.1,
86+
0.14, 0.66, 0.2
87+
},
88+
3, 3
89+
)
90+
);
8091

92+
var loss = new CrossEntropyLoss(predictions, targets);
93+
var chainedLoss = new Sigmoid<>(loss);
94+
95+
finiteDifferenceShouldApproximateGradient(predictions, chainedLoss);
96+
}
97+
98+
99+
@Test
100+
void infiniteSmallProbabilities() {
81101
var predictions = new Weights<>(new Matrix(new double[]{5.277E-321, 5.277E-321}, 1, 2));
82102
var targets = Constant.vector(new double[]{1});
83103

0 commit comments

Comments
 (0)