@@ -69,42 +69,59 @@ public class FilteredKnn extends Algorithm<FilteredKnn.Result> {
6969 private long nodePairsConsidered ;
7070
7171 public static FilteredKnn createWithDefaults (Graph graph , FilteredKnnBaseConfig config , FilteredKnnContext context ) {
72- var sourceNodes = config .sourceNodeFilter ().stream ().map (graph ::toMappedNodeId ).collect (Collectors .toList ());
73- return new FilteredKnn (
74- context .progressTracker (),
75- graph ,
76- config ,
77- config .maxIterations (),
78- sourceNodes ,
79- SimilarityComputer .ofProperties (graph , config .nodeProperties ()),
80- new FilteredKnnNeighborFilterFactory (graph .nodeCount ()),
81- context ,
82- getSplittableRandom (config .randomSeed ())
83- );
72+ var similarityComputer = SimilarityComputer .ofProperties (graph , config .nodeProperties ());
73+ var neighborFilterFactory = new FilteredKnnNeighborFilterFactory (graph .nodeCount ());
74+ return create (graph , config , context , similarityComputer , neighborFilterFactory );
8475 }
8576
8677 public static FilteredKnn create (
8778 Graph graph ,
8879 FilteredKnnBaseConfig config ,
80+ FilteredKnnContext context ,
8981 SimilarityComputer similarityComputer ,
90- FilteredNeighborFilterFactory neighborFilterFactory ,
91- FilteredKnnContext context
82+ FilteredNeighborFilterFactory neighborFilterFactory
9283 ) {
93- SplittableRandom splittableRandom = getSplittableRandom (config .randomSeed ());
84+ var splittableRandom = getSplittableRandom (config .randomSeed ());
9485 var sourceNodes = config .sourceNodeFilter ().stream ().map (graph ::toMappedNodeId ).collect (Collectors .toList ());
86+ var samplerSupplier = samplerSupplier (graph , config );
9587 return new FilteredKnn (
9688 context .progressTracker (),
9789 graph ,
98- config ,
9990 config .maxIterations (),
10091 sourceNodes ,
10192 similarityComputer ,
10293 neighborFilterFactory ,
10394 context ,
104- splittableRandom
95+ splittableRandom ,
96+ config .sampleRate (),
97+ config .deltaThreshold (),
98+ config .similarityCutoff (),
99+ config .topK (),
100+ config .concurrency (),
101+ config .minBatchSize (),
102+ config .perturbationRate (),
103+ config .sampledK (graph .nodeCount ()),
104+ config .randomJoins (),
105+ samplerSupplier
105106 );
106107 }
107108
109+ @ NotNull
110+ private static Function <SplittableRandom , FilteredKnnSampler > samplerSupplier (Graph graph , FilteredKnnBaseConfig config ) {
111+ switch (config .initialSampler ()) {
112+ case UNIFORM :
113+ return new UniformFilteredKnnSamplerSupplier (graph );
114+ case RANDOMWALK :
115+ return new RandomWalkFilteredKnnSamplerSupplier (
116+ graph .concurrentCopy (),
117+ config .randomSeed (),
118+ config .boundedK (graph .nodeCount ())
119+ );
120+ default :
121+ throw new IllegalStateException ("Invalid FilteredKnnSampler" );
122+ }
123+ }
124+
108125 @ NotNull
109126 private static SplittableRandom getSplittableRandom (Optional <Long > randomSeed ) {
110127 return randomSeed .map (SplittableRandom ::new ).orElseGet (SplittableRandom ::new );
@@ -113,45 +130,42 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
113130 FilteredKnn (
114131 ProgressTracker progressTracker ,
115132 Graph graph ,
116- FilteredKnnBaseConfig config ,
117133 int maxIterations ,
118134 List <Long > sourceNodes ,
119135 SimilarityComputer similarityComputer ,
120136 FilteredNeighborFilterFactory neighborFilterFactory ,
121137 FilteredKnnContext context ,
122- SplittableRandom splittableRandom
138+ SplittableRandom splittableRandom ,
139+ double sampleRate ,
140+ double deltaThreshold ,
141+ double similarityCutoff ,
142+ int topK ,
143+ int concurrency ,
144+ int minBatchSize ,
145+ double perturbationRate ,
146+ int sampledK ,
147+ int randomJoins ,
148+ Function <SplittableRandom , FilteredKnnSampler > samplerSupplier
149+
123150 ) {
124151 super (progressTracker );
125152 this .graph = graph ;
126- this .sampleRate = config . sampleRate () ;
127- this .deltaThreshold = config . deltaThreshold () ;
128- this .similarityCutoff = config . similarityCutoff () ;
129- this .topK = config . topK () ;
130- this .concurrency = config . concurrency () ;
131- this .minBatchSize = config . minBatchSize () ;
132- this .perturbationRate = config . perturbationRate () ;
133- this .sampledK = config . sampledK ( graph . nodeCount ()) ;
134- this .randomJoins = config . randomJoins () ;
153+ this .sampleRate = sampleRate ;
154+ this .deltaThreshold = deltaThreshold ;
155+ this .similarityCutoff = similarityCutoff ;
156+ this .topK = topK ;
157+ this .concurrency = concurrency ;
158+ this .minBatchSize = minBatchSize ;
159+ this .perturbationRate = perturbationRate ;
160+ this .sampledK = sampledK ;
161+ this .randomJoins = randomJoins ;
135162 this .maxIterations = maxIterations ;
136163 this .similarityComputer = similarityComputer ;
137164 this .neighborFilterFactory = neighborFilterFactory ;
138165 this .context = context ;
139166 this .splittableRandom = splittableRandom ;
140167 this .sourceNodes = sourceNodes ;
141- switch (config .initialSampler ()) {
142- case UNIFORM :
143- this .samplerSupplier = new UniformFilteredKnnSamplerSupplier (graph );
144- break ;
145- case RANDOMWALK :
146- this .samplerSupplier = new RandomWalkFilteredKnnSamplerSupplier (
147- graph .concurrentCopy (),
148- config .randomSeed (),
149- config .boundedK (graph .nodeCount ())
150- );
151- break ;
152- default :
153- throw new IllegalStateException ("Invalid FilteredKnnSampler" );
154- }
168+ this .samplerSupplier = samplerSupplier ;
155169 }
156170
157171 public long nodeCount () {
0 commit comments