Skip to content

Commit bbf6acc

Browse files
committed
Highlight selfGradient
Also remove redundant copy() call. `Tensor:scalarMultiply` is copying internally.
1 parent d452ee5 commit bbf6acc

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/EWiseAddMatrixScalar.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ public Matrix apply(ComputationContext ctx) {
5252

5353
@Override
5454
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
55+
Matrix selfGradient = ctx.gradient(this);
5556
if (parent == matrixVariable) {
56-
return ctx.gradient(this);
57+
return selfGradient;
5758
} else {
58-
return new Scalar(ctx.gradient(this).aggregateSum());
59+
return new Scalar(selfGradient.aggregateSum());
5960
}
6061
}
6162
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/ElementSum.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public Scalar apply(ComputationContext ctx) {
4545

4646
@Override
4747
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
48-
return ctx.data(parent).map(ignore -> ctx.gradient(this).value());
48+
double selfGradient = ctx.gradient(this).value();
49+
return ctx.data(parent).map(ignore -> selfGradient);
4950
}
5051
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/L2NormSquared.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public Scalar apply(ComputationContext ctx) {
5454

5555
@Override
5656
public Matrix gradientForParent(ComputationContext ctx) {
57-
return ctx.data(parent).copy().scalarMultiply(2 * ctx.gradient(this).value());
57+
double selfGradient = ctx.gradient(this).value();
58+
return ctx.data(parent).scalarMultiply(2 * selfGradient);
5859
}
5960
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/ReducedCrossEntropyLoss.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ public static long sizeInBytes() {
6969

7070
@Override
7171
public Scalar apply(ComputationContext ctx) {
72+
// manually call forward as `predictions` is not registered as a parent
7273
var predictionsMatrix = ctx.forward(predictions);
7374
var labelsVector = ctx.data(labels);
7475

@@ -85,6 +86,7 @@ public Scalar apply(ComputationContext ctx) {
8586

8687
@Override
8788
public Tensor<?> gradient(Variable<?> parent, ComputationContext ctx) {
89+
// manually call forward as `predictions` is not registered as a parent
8890
var predMatrix = ctx.forward(predictions);
8991
var labelsVector = ctx.data(labels);
9092
int numberOfExamples = labelsVector.length();

0 commit comments

Comments
 (0)