File tree Expand file tree Collapse file tree 4 files changed +9
-4
lines changed
ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions Expand file tree Collapse file tree 4 files changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -52,10 +52,11 @@ public Matrix apply(ComputationContext ctx) {
5252
5353 @ Override
5454 public Tensor <?> gradient (Variable <?> parent , ComputationContext ctx ) {
55+ Matrix selfGradient = ctx .gradient (this );
5556 if (parent == matrixVariable ) {
56- return ctx . gradient ( this ) ;
57+ return selfGradient ;
5758 } else {
58- return new Scalar (ctx . gradient ( this ) .aggregateSum ());
59+ return new Scalar (selfGradient .aggregateSum ());
5960 }
6061 }
6162}
Original file line number Diff line number Diff line change @@ -45,6 +45,7 @@ public Scalar apply(ComputationContext ctx) {
4545
4646 @ Override
4747 public Tensor <?> gradient (Variable <?> parent , ComputationContext ctx ) {
48- return ctx .data (parent ).map (ignore -> ctx .gradient (this ).value ());
48+ double selfGradient = ctx .gradient (this ).value ();
49+ return ctx .data (parent ).map (ignore -> selfGradient );
4950 }
5051}
Original file line number Diff line number Diff line change @@ -54,6 +54,7 @@ public Scalar apply(ComputationContext ctx) {
5454
5555 @ Override
5656 public Matrix gradientForParent (ComputationContext ctx ) {
57- return ctx .data (parent ).copy ().scalarMultiply (2 * ctx .gradient (this ).value ());
57+ double selfGradient = ctx .gradient (this ).value ();
58+ return ctx .data (parent ).scalarMultiply (2 * selfGradient );
5859 }
5960}
Original file line number Diff line number Diff line change @@ -69,6 +69,7 @@ public static long sizeInBytes() {
6969
7070 @ Override
7171 public Scalar apply (ComputationContext ctx ) {
72+ // manually call forward as `predictions` is not registered as a parent
7273 var predictionsMatrix = ctx .forward (predictions );
7374 var labelsVector = ctx .data (labels );
7475
@@ -85,6 +86,7 @@ public Scalar apply(ComputationContext ctx) {
8586
8687 @ Override
8788 public Tensor <?> gradient (Variable <?> parent , ComputationContext ctx ) {
89+ // manually call forward as `predictions` is not registered as a parent
8890 var predMatrix = ctx .forward (predictions );
8991 var labelsVector = ctx .data (labels );
9092 int numberOfExamples = labelsVector .length ();
You can’t perform that action at this time.
0 commit comments