Skip to content

Commit 237a8d0

Browse files
committed
Inline PassthroughVariable
This Variable should not be used outside the ComputationContext
1 parent ed19d0a commit 237a8d0

File tree

2 files changed

+23
-47
lines changed

2 files changed

+23
-47
lines changed

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/ComputationContext.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.neo4j.gds.ml.core;
2121

2222
import org.jetbrains.annotations.TestOnly;
23-
import org.neo4j.gds.ml.core.functions.PassthroughVariable;
23+
import org.neo4j.gds.ml.core.functions.SingleParentVariable;
2424
import org.neo4j.gds.ml.core.tensor.Tensor;
2525
import org.neo4j.gds.ml.core.tensor.TensorFactory;
2626

@@ -71,7 +71,7 @@ public void backward(Variable<?> function) {
7171

7272
gradients.clear();
7373
Queue<BackPropTask> executionQueue = new LinkedBlockingQueue<>();
74-
PassthroughVariable<?> dummy = new PassthroughVariable<>(function);
74+
var dummy = new PassthroughVariable<>(function);
7575
executionQueue.add(new BackPropTask(function, dummy));
7676
Map<Variable<?>, AtomicInteger> upstreamCounters = new HashMap<>();
7777
initUpstream(dummy, upstreamCounters);
@@ -167,4 +167,25 @@ static class BackPropTask {
167167
}
168168
}
169169

170+
private static class PassthroughVariable<T extends Tensor<T>> extends SingleParentVariable<T, T> {
171+
172+
public PassthroughVariable(Variable<T> parent) {
173+
super(parent, parent.dimensions());
174+
175+
if (parent instanceof PassthroughVariable) {
176+
throw new IllegalArgumentException("Redundant use of PassthroughVariables. Chaining does not make sense.");
177+
}
178+
}
179+
180+
@Override
181+
public T apply(ComputationContext ctx) {
182+
return ctx.data(parent);
183+
}
184+
185+
@Override
186+
public T gradientForParent(ComputationContext ctx) {
187+
// initialize gradient computation with `1`
188+
return ctx.data(parent).map(v -> 1);
189+
}
190+
}
170191
}

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/functions/PassthroughVariable.java

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)