4141import static org .neo4j .gds .utils .StringFormatting .formatWithLocale ;
4242
4343public class FilteredKnn extends Algorithm <FilteredKnnResult > {
44+ private final ExecutorService executorService ;
45+ private final int concurrency ;
4446 private final Graph graph ;
47+ private final SimilarityComputer similarityComputer ;
4548 private final FilteredNeighborFilterFactory neighborFilterFactory ;
46- private final ExecutorService executorService ;
4749 private final SplittableRandom splittableRandom ;
48- private final SimilarityComputer similarityComputer ;
49- private final List <Long > sourceNodes ;
50- private final int maxIterations ;
51- private final double sampleRate ;
52- private final int topK ;
50+ private final Function <SplittableRandom , FilteredKnnSampler > samplerSupplier ;
51+
5352 private final double deltaThreshold ;
54- private final double similarityCutoff ;
55- private final int concurrency ;
53+ private final int maxIterations ;
5654 private final int minBatchSize ;
5755 private final double perturbationRate ;
58- private final int sampledK ;
5956 private final int randomJoins ;
60- private final Function <SplittableRandom , FilteredKnnSampler > samplerSupplier ;
57+ private final int sampledK ;
58+ private final double sampleRate ;
59+ private final double similarityCutoff ;
60+ private final List <Long > sourceNodes ;
61+ private final List <Long > targetNodes ;
62+ private final int topK ;
6163
6264 private long nodePairsConsidered ;
6365
@@ -76,6 +78,7 @@ public static FilteredKnn create(
7678 ) {
7779 var splittableRandom = getSplittableRandom (config .randomSeed ());
7880 var sourceNodes = config .sourceNodeFilter ().stream ().map (graph ::toMappedNodeId ).collect (Collectors .toList ());
81+ var targetNodes = config .targetNodeFilter ().stream ().map (graph ::toMappedNodeId ).collect (Collectors .toList ());
7982 var samplerSupplier = samplerSupplier (graph , config );
8083 return new FilteredKnn (
8184 context .progressTracker (),
@@ -95,6 +98,7 @@ public static FilteredKnn create(
9598 config .sampleRate (),
9699 config .similarityCutoff (),
97100 sourceNodes ,
101+ targetNodes ,
98102 config .topK ()
99103 );
100104 }
@@ -138,6 +142,7 @@ private FilteredKnn(
138142 double sampleRate ,
139143 double similarityCutoff ,
140144 List <Long > sourceNodes ,
145+ List <Long > targetNodes ,
141146 int topK
142147
143148 ) {
@@ -158,6 +163,7 @@ private FilteredKnn(
158163 this .sampleRate = sampleRate ;
159164 this .similarityCutoff = similarityCutoff ;
160165 this .sourceNodes = sourceNodes ;
166+ this .targetNodes = targetNodes ;
161167 this .topK = topK ;
162168 }
163169
0 commit comments