Skip to content

Commit 199ebaf

Browse files
committed
Filtered knn accepts target node filter
1 parent 666caf9 commit 199ebaf

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,25 @@
4141
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
4242

4343
public class FilteredKnn extends Algorithm<FilteredKnnResult> {
44+
private final ExecutorService executorService;
45+
private final int concurrency;
4446
private final Graph graph;
47+
private final SimilarityComputer similarityComputer;
4548
private final FilteredNeighborFilterFactory neighborFilterFactory;
46-
private final ExecutorService executorService;
4749
private final SplittableRandom splittableRandom;
48-
private final SimilarityComputer similarityComputer;
49-
private final List<Long> sourceNodes;
50-
private final int maxIterations;
51-
private final double sampleRate;
52-
private final int topK;
50+
private final Function<SplittableRandom, FilteredKnnSampler> samplerSupplier;
51+
5352
private final double deltaThreshold;
54-
private final double similarityCutoff;
55-
private final int concurrency;
53+
private final int maxIterations;
5654
private final int minBatchSize;
5755
private final double perturbationRate;
58-
private final int sampledK;
5956
private final int randomJoins;
60-
private final Function<SplittableRandom, FilteredKnnSampler> samplerSupplier;
57+
private final int sampledK;
58+
private final double sampleRate;
59+
private final double similarityCutoff;
60+
private final List<Long> sourceNodes;
61+
private final List<Long> targetNodes;
62+
private final int topK;
6163

6264
private long nodePairsConsidered;
6365

@@ -76,6 +78,7 @@ public static FilteredKnn create(
7678
) {
7779
var splittableRandom = getSplittableRandom(config.randomSeed());
7880
var sourceNodes = config.sourceNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
81+
var targetNodes = config.targetNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
7982
var samplerSupplier = samplerSupplier(graph, config);
8083
return new FilteredKnn(
8184
context.progressTracker(),
@@ -95,6 +98,7 @@ public static FilteredKnn create(
9598
config.sampleRate(),
9699
config.similarityCutoff(),
97100
sourceNodes,
101+
targetNodes,
98102
config.topK()
99103
);
100104
}
@@ -138,6 +142,7 @@ private FilteredKnn(
138142
double sampleRate,
139143
double similarityCutoff,
140144
List<Long> sourceNodes,
145+
List<Long> targetNodes,
141146
int topK
142147

143148
) {
@@ -158,6 +163,7 @@ private FilteredKnn(
158163
this.sampleRate = sampleRate;
159164
this.similarityCutoff = similarityCutoff;
160165
this.sourceNodes = sourceNodes;
166+
this.targetNodes = targetNodes;
161167
this.topK = topK;
162168
}
163169

0 commit comments

Comments
 (0)