Skip to content

Commit a233030

Browse files
committed
Use predicate type for source nodes also
1 parent 82681b2 commit a233030

File tree

4 files changed

+19
-20
lines changed

4 files changed

+19
-20
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3232
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
3333

34-
import java.util.List;
3534
import java.util.Optional;
3635
import java.util.SplittableRandom;
3736
import java.util.concurrent.ExecutorService;
@@ -57,8 +56,8 @@ public class FilteredKnn extends Algorithm<FilteredKnnResult> {
5756
private final int sampledK;
5857
private final double sampleRate;
5958
private final double similarityCutoff;
60-
private final List<Long> sourceNodes;
61-
private final TargetNodePredicate targetNodePredicate;
59+
private final NodeFilter sourceNodeFilter;
60+
private final NodeFilter targetNodeFilter;
6261
private final int topK;
6362

6463
private long nodePairsConsidered;
@@ -79,7 +78,8 @@ public static FilteredKnn create(
7978
var splittableRandom = getSplittableRandom(config.randomSeed());
8079
var sourceNodes = config.sourceNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
8180
var targetNodes = config.targetNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
82-
var targetNodePredicate = new TargetNodePredicate(targetNodes);
81+
var sourceNodeFilter = new NodeFilter(sourceNodes);
82+
var targetNodeFilter = new NodeFilter(targetNodes);
8383
var samplerSupplier = samplerSupplier(graph, config);
8484
return new FilteredKnn(
8585
context.progressTracker(),
@@ -98,8 +98,8 @@ public static FilteredKnn create(
9898
config.sampledK(graph.nodeCount()),
9999
config.sampleRate(),
100100
config.similarityCutoff(),
101-
sourceNodes,
102-
targetNodePredicate,
101+
sourceNodeFilter,
102+
targetNodeFilter,
103103
config.topK()
104104
);
105105
}
@@ -142,8 +142,8 @@ private FilteredKnn(
142142
int sampledK,
143143
double sampleRate,
144144
double similarityCutoff,
145-
List<Long> sourceNodes,
146-
TargetNodePredicate targetNodePredicate,
145+
NodeFilter sourceNodeFilter,
146+
NodeFilter targetNodeFilter,
147147
int topK
148148

149149
) {
@@ -163,8 +163,8 @@ private FilteredKnn(
163163
this.sampledK = sampledK;
164164
this.sampleRate = sampleRate;
165165
this.similarityCutoff = similarityCutoff;
166-
this.sourceNodes = sourceNodes;
167-
this.targetNodePredicate = targetNodePredicate;
166+
this.sourceNodeFilter = sourceNodeFilter;
167+
this.targetNodeFilter = targetNodeFilter;
168168
this.topK = topK;
169169
}
170170

@@ -223,7 +223,7 @@ public FilteredKnnResult compute() {
223223
this.progressTracker.endSubTask();
224224

225225
this.progressTracker.endSubTask();
226-
return ImmutableFilteredKnnResult.of(neighbors, iteration, didConverge, this.nodePairsConsidered, this.sourceNodes);
226+
return ImmutableFilteredKnnResult.of(neighbors, iteration, didConverge, this.nodePairsConsidered, this.sourceNodeFilter);
227227
}
228228
}
229229

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnResult.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public abstract class FilteredKnnResult {
4242

4343
public abstract long nodePairsConsidered();
4444

45-
public abstract List<Long> sourceNodes();
45+
public abstract NodeFilter sourceNodeFilter();
4646

4747
public LongStream neighborsOf(long nodeId) {
4848
return neighborList().get(nodeId).elements().map(FilteredNeighborList::clearCheckedFlag);
@@ -55,7 +55,7 @@ public Stream<SimilarityResult> streamSimilarityResult() {
5555
return Stream
5656
.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
5757
.flatMap(cursor -> IntStream.range(cursor.offset, cursor.limit)
58-
// .filter(index -> sourceNodes().contains(index + cursor.base))
58+
.filter(index -> sourceNodeFilter().test(index + cursor.base))
5959
.mapToObj(index -> cursor.array[index].similarityStream(index + cursor.base))
6060
.flatMap(Function.identity())
6161
);
@@ -66,7 +66,7 @@ public long totalSimilarityPairs() {
6666
return Stream
6767
.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
6868
.flatMapToLong(cursor -> IntStream.range(cursor.offset, cursor.limit)
69-
// .filter(index -> sourceNodes().contains(index + cursor.base))
69+
.filter(index -> sourceNodeFilter().test(index + cursor.base))
7070
.mapToLong(index -> cursor.array[index].size()))
7171
.sum();
7272
}
@@ -100,8 +100,8 @@ public long nodePairsConsidered() {
100100
}
101101

102102
@Override
103-
public List<Long> sourceNodes() {
104-
return List.of();
103+
public NodeFilter sourceNodeFilter() {
104+
return new NodeFilter(List.of());
105105
}
106106

107107
@Override

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/TargetNodePredicate.java renamed to algo/src/main/java/org/neo4j/gds/similarity/filteredknn/NodeFilter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222
import java.util.List;
2323

24-
public class TargetNodePredicate {
24+
public class NodeFilter {
2525
private final List<Long> nodeIds;
2626

27-
public TargetNodePredicate(List<Long> nodeIds) {this.nodeIds = nodeIds;}
27+
public NodeFilter(List<Long> nodeIds) {this.nodeIds = nodeIds;}
2828

2929
public boolean test(long nodeId) {
3030
return nodeIds.contains(nodeId);

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnResultTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import java.util.Arrays;
2727
import java.util.List;
2828
import java.util.SplittableRandom;
29-
import java.util.function.Function;
3029
import java.util.stream.Collectors;
3130
import java.util.stream.LongStream;
3231
import java.util.stream.Stream;
@@ -50,7 +49,7 @@ void should() {
5049
1,
5150
true,
5251
2,
53-
List.of()
52+
new NodeFilter(List.of())
5453
);
5554

5655
var neighborLists = result.neighborList();

0 commit comments

Comments
 (0)