Skip to content

Commit 57fee87

Browse files
committed
Throw on requesting gradient for targets
1 parent bf4aa87 commit 57fee87

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/CrossEntropyLoss.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,12 @@ public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
7878

7979
// Compare to a threshold value rather than `0`, very small probability can result in setting infinite gradient values.
8080
if (predictedProbabilityForTrueClass > PREDICTED_PROBABILITY_THRESHOLD) {
81-
gradient.setDataAt(row, trueClass, multiplier / predictedProbabilityForTrueClass
82-
);
81+
gradient.setDataAt(row, trueClass, multiplier / predictedProbabilityForTrueClass);
8382
}
8483
}
8584
return gradient;
8685
} else {
87-
// targets should never require a gradient
88-
return ctx.data(parent).createWithSameDimensions();
86+
throw new IllegalStateException("The gradient should not be necessary for the targets. But got: " + targets.render());
8987
}
9088
}
9189
}

0 commit comments

Comments
 (0)