Skip to content

Commit 653c8a4

Browse files
Move up use of KnnConfig
Co-Authored-By: Lasse Westh-Nielsen <lassewesth@gmail.com>
1 parent 90dd0c5 commit 653c8a4

File tree

2 files changed

+60
-46
lines changed

2 files changed

+60
-46
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -69,42 +69,59 @@ public class FilteredKnn extends Algorithm<FilteredKnn.Result> {
6969
private long nodePairsConsidered;
7070

7171
public static FilteredKnn createWithDefaults(Graph graph, FilteredKnnBaseConfig config, FilteredKnnContext context) {
72-
var sourceNodes = config.sourceNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
73-
return new FilteredKnn(
74-
context.progressTracker(),
75-
graph,
76-
config,
77-
config.maxIterations(),
78-
sourceNodes,
79-
SimilarityComputer.ofProperties(graph, config.nodeProperties()),
80-
new FilteredKnnNeighborFilterFactory(graph.nodeCount()),
81-
context,
82-
getSplittableRandom(config.randomSeed())
83-
);
72+
var similarityComputer = SimilarityComputer.ofProperties(graph, config.nodeProperties());
73+
var neighborFilterFactory = new FilteredKnnNeighborFilterFactory(graph.nodeCount());
74+
return create(graph, config, context, similarityComputer, neighborFilterFactory);
8475
}
8576

8677
public static FilteredKnn create(
8778
Graph graph,
8879
FilteredKnnBaseConfig config,
80+
FilteredKnnContext context,
8981
SimilarityComputer similarityComputer,
90-
FilteredNeighborFilterFactory neighborFilterFactory,
91-
FilteredKnnContext context
82+
FilteredNeighborFilterFactory neighborFilterFactory
9283
) {
93-
SplittableRandom splittableRandom = getSplittableRandom(config.randomSeed());
84+
var splittableRandom = getSplittableRandom(config.randomSeed());
9485
var sourceNodes = config.sourceNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
86+
var samplerSupplier = samplerSupplier(graph, config);
9587
return new FilteredKnn(
9688
context.progressTracker(),
9789
graph,
98-
config,
9990
config.maxIterations(),
10091
sourceNodes,
10192
similarityComputer,
10293
neighborFilterFactory,
10394
context,
104-
splittableRandom
95+
splittableRandom,
96+
config.sampleRate(),
97+
config.deltaThreshold(),
98+
config.similarityCutoff(),
99+
config.topK(),
100+
config.concurrency(),
101+
config.minBatchSize(),
102+
config.perturbationRate(),
103+
config.sampledK(graph.nodeCount()),
104+
config.randomJoins(),
105+
samplerSupplier
105106
);
106107
}
107108

109+
@NotNull
110+
private static Function<SplittableRandom, FilteredKnnSampler> samplerSupplier(Graph graph, FilteredKnnBaseConfig config) {
111+
switch(config.initialSampler()) {
112+
case UNIFORM:
113+
return new UniformFilteredKnnSamplerSupplier(graph);
114+
case RANDOMWALK:
115+
return new RandomWalkFilteredKnnSamplerSupplier(
116+
graph.concurrentCopy(),
117+
config.randomSeed(),
118+
config.boundedK(graph.nodeCount())
119+
);
120+
default:
121+
throw new IllegalStateException("Invalid FilteredKnnSampler");
122+
}
123+
}
124+
108125
@NotNull
109126
private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
110127
return randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
@@ -113,45 +130,42 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
113130
FilteredKnn(
114131
ProgressTracker progressTracker,
115132
Graph graph,
116-
FilteredKnnBaseConfig config,
117133
int maxIterations,
118134
List<Long> sourceNodes,
119135
SimilarityComputer similarityComputer,
120136
FilteredNeighborFilterFactory neighborFilterFactory,
121137
FilteredKnnContext context,
122-
SplittableRandom splittableRandom
138+
SplittableRandom splittableRandom,
139+
double sampleRate,
140+
double deltaThreshold,
141+
double similarityCutoff,
142+
int topK,
143+
int concurrency,
144+
int minBatchSize,
145+
double perturbationRate,
146+
int sampledK,
147+
int randomJoins,
148+
Function<SplittableRandom, FilteredKnnSampler> samplerSupplier
149+
123150
) {
124151
super(progressTracker);
125152
this.graph = graph;
126-
this.sampleRate = config.sampleRate();
127-
this.deltaThreshold = config.deltaThreshold();
128-
this.similarityCutoff = config.similarityCutoff();
129-
this.topK = config.topK();
130-
this.concurrency = config.concurrency();
131-
this.minBatchSize = config.minBatchSize();
132-
this.perturbationRate = config.perturbationRate();
133-
this.sampledK = config.sampledK(graph.nodeCount());
134-
this.randomJoins = config.randomJoins();
153+
this.sampleRate = sampleRate;
154+
this.deltaThreshold = deltaThreshold;
155+
this.similarityCutoff = similarityCutoff;
156+
this.topK = topK;
157+
this.concurrency = concurrency;
158+
this.minBatchSize = minBatchSize;
159+
this.perturbationRate = perturbationRate;
160+
this.sampledK = sampledK;
161+
this.randomJoins = randomJoins;
135162
this.maxIterations = maxIterations;
136163
this.similarityComputer = similarityComputer;
137164
this.neighborFilterFactory = neighborFilterFactory;
138165
this.context = context;
139166
this.splittableRandom = splittableRandom;
140167
this.sourceNodes = sourceNodes;
141-
switch(config.initialSampler()) {
142-
case UNIFORM:
143-
this.samplerSupplier = new UniformFilteredKnnSamplerSupplier(graph);
144-
break;
145-
case RANDOMWALK:
146-
this.samplerSupplier = new RandomWalkFilteredKnnSamplerSupplier(
147-
graph.concurrentCopy(),
148-
config.randomSeed(),
149-
config.boundedK(graph.nodeCount())
150-
);
151-
break;
152-
default:
153-
throw new IllegalStateException("Invalid FilteredKnnSampler");
154-
}
168+
this.samplerSupplier = samplerSupplier;
155169
}
156170

157171
public long nodeCount() {

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ void testNonExistingProperties(NodePropertyValues nodeProperties) {
281281
var knn = FilteredKnn.create(
282282
graph,
283283
knnConfig,
284+
knnContext,
284285
SimilarityComputer.ofProperty(graph, "knn", nodeProperties),
285-
new FilteredKnnNeighborFilterFactory(graph.nodeCount()),
286-
knnContext
286+
new FilteredKnnNeighborFilterFactory(graph.nodeCount())
287287
);
288288
var result = knn.compute();
289289
assertThat(result)
@@ -313,9 +313,9 @@ void testMixedExistingAndNonExistingProperties(SoftAssertions softly) {
313313
.concurrency(1)
314314
.randomSeed(42L)
315315
.build(),
316+
ImmutableFilteredKnnContext.builder().build(),
316317
SimilarityComputer.ofProperty(graph, "{knn}", nodeProperties),
317-
new FilteredKnnNeighborFilterFactory(graph.nodeCount()),
318-
ImmutableFilteredKnnContext.builder().build()
318+
new FilteredKnnNeighborFilterFactory(graph.nodeCount())
319319
);
320320

321321
var result = knn.compute();

0 commit comments

Comments
 (0)