Skip to content

Commit e6b6529

Browse files
committed
Move up use of KnnContext
1 parent 653c8a4 commit e6b6529

File tree

5 files changed

+29
-30
lines changed

5 files changed

+29
-30
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.List;
3939
import java.util.Optional;
4040
import java.util.SplittableRandom;
41+
import java.util.concurrent.ExecutorService;
4142
import java.util.function.Function;
4243
import java.util.function.UnaryOperator;
4344
import java.util.stream.Collectors;
@@ -50,7 +51,7 @@
5051
public class FilteredKnn extends Algorithm<FilteredKnn.Result> {
5152
private final Graph graph;
5253
private final FilteredNeighborFilterFactory neighborFilterFactory;
53-
private final FilteredKnnContext context;
54+
private final ExecutorService executorService;
5455
private final SplittableRandom splittableRandom;
5556
private final SimilarityComputer similarityComputer;
5657
private final List<Long> sourceNodes;
@@ -91,7 +92,7 @@ public static FilteredKnn create(
9192
sourceNodes,
9293
similarityComputer,
9394
neighborFilterFactory,
94-
context,
95+
context.executor(),
9596
splittableRandom,
9697
config.sampleRate(),
9798
config.deltaThreshold(),
@@ -134,7 +135,7 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
134135
List<Long> sourceNodes,
135136
SimilarityComputer similarityComputer,
136137
FilteredNeighborFilterFactory neighborFilterFactory,
137-
FilteredKnnContext context,
138+
ExecutorService executorService,
138139
SplittableRandom splittableRandom,
139140
double sampleRate,
140141
double deltaThreshold,
@@ -162,7 +163,7 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
162163
this.maxIterations = maxIterations;
163164
this.similarityComputer = similarityComputer;
164165
this.neighborFilterFactory = neighborFilterFactory;
165-
this.context = context;
166+
this.executorService = executorService;
166167
this.splittableRandom = splittableRandom;
167168
this.sourceNodes = sourceNodes;
168169
this.samplerSupplier = samplerSupplier;
@@ -172,8 +173,8 @@ public long nodeCount() {
172173
return graph.nodeCount();
173174
}
174175

175-
public FilteredKnnContext context() {
176-
return context;
176+
public ExecutorService executorService() {
177+
return this.executorService;
177178
}
178179

179180
@Override
@@ -218,7 +219,7 @@ public Result compute() {
218219
),
219220
Optional.of(this.minBatchSize)
220221
);
221-
ParallelUtil.runWithConcurrency(this.concurrency, neighborFilterTasks, context.executor());
222+
ParallelUtil.runWithConcurrency(this.concurrency, neighborFilterTasks, this.executorService);
222223
}
223224
this.progressTracker.endSubTask();
224225

@@ -262,7 +263,7 @@ public void release() {
262263
Optional.of(this.minBatchSize)
263264
);
264265

265-
ParallelUtil.runWithConcurrency(this.concurrency, randomNeighborGenerators, context.executor());
266+
ParallelUtil.runWithConcurrency(this.concurrency, randomNeighborGenerators, this.executorService);
266267

267268
this.nodePairsConsidered += randomNeighborGenerators.stream().mapToLong(FilteredGenerateRandomNeighbors::neighborsFound).sum();
268269

@@ -310,15 +311,12 @@ private long iteration(HugeObjectArray<FilteredNeighborList> neighbors) {
310311
return FilteredNeighborList.NOT_INSERTED;
311312
}
312313

313-
var executor = this.context.executor();
314-
315-
316314
// TODO: init in ctor and reuse - benchmark against new allocations
317315
var allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
318316
var allNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
319317

320318
progressTracker.beginSubTask();
321-
ParallelUtil.readParallel(this.concurrency, nodeCount, executor, new FilteredSplitOldAndNewNeighbors(
319+
ParallelUtil.readParallel(this.concurrency, nodeCount, this.executorService, new FilteredSplitOldAndNewNeighbors(
322320
this.splittableRandom,
323321
neighbors,
324322
allOldNeighbors,
@@ -369,7 +367,7 @@ private long iteration(HugeObjectArray<FilteredNeighborList> neighbors) {
369367
);
370368

371369
progressTracker.beginSubTask();
372-
ParallelUtil.runWithConcurrency(this.concurrency, neighborsJoiners, executor);
370+
ParallelUtil.runWithConcurrency(this.concurrency, neighborsJoiners, executorService);
373371
progressTracker.endSubTask();
374372

375373
this.nodePairsConsidered += neighborsJoiners.stream().mapToLong(JoinNeighbors::nodePairsConsidered).sum();

algo/src/main/java/org/neo4j/gds/similarity/knn/Knn.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import java.util.Optional;
3939
import java.util.SplittableRandom;
40+
import java.util.concurrent.ExecutorService;
4041
import java.util.function.Function;
4142
import java.util.function.UnaryOperator;
4243
import java.util.stream.IntStream;
@@ -49,7 +50,7 @@ public class Knn extends Algorithm<Knn.Result> {
4950
private final Graph graph;
5051
private final KnnBaseConfig config;
5152
private final NeighborFilterFactory neighborFilterFactory;
52-
private final KnnContext context;
53+
private final ExecutorService executorService;
5354
private final SplittableRandom splittableRandom;
5455
private final SimilarityComputer similarityComputer;
5556

@@ -62,7 +63,7 @@ public static Knn createWithDefaults(Graph graph, KnnBaseConfig config, KnnConte
6263
config,
6364
SimilarityComputer.ofProperties(graph, config.nodeProperties()),
6465
new KnnNeighborFilterFactory(graph.nodeCount()),
65-
context,
66+
context.executor(),
6667
getSplittableRandom(config.randomSeed())
6768
);
6869
}
@@ -81,7 +82,7 @@ public static Knn create(
8182
config,
8283
similarityComputer,
8384
neighborFilterFactory,
84-
context,
85+
context.executor(),
8586
splittableRandom
8687
);
8788
}
@@ -97,24 +98,24 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
9798
KnnBaseConfig config,
9899
SimilarityComputer similarityComputer,
99100
NeighborFilterFactory neighborFilterFactory,
100-
KnnContext context,
101+
ExecutorService executorService,
101102
SplittableRandom splittableRandom
102103
) {
103104
super(progressTracker);
104105
this.graph = graph;
105106
this.config = config;
106107
this.similarityComputer = similarityComputer;
107108
this.neighborFilterFactory = neighborFilterFactory;
108-
this.context = context;
109+
this.executorService = executorService;
109110
this.splittableRandom = splittableRandom;
110111
}
111112

112113
public long nodeCount() {
113114
return graph.nodeCount();
114115
}
115116

116-
public KnnContext context() {
117-
return context;
117+
public ExecutorService executorService() {
118+
return this.executorService;
118119
}
119120

120121
@Override
@@ -161,7 +162,7 @@ public Result compute() {
161162
),
162163
Optional.of(config.minBatchSize())
163164
);
164-
ParallelUtil.runWithConcurrency(config.concurrency(), neighborFilterTasks, context.executor());
165+
ParallelUtil.runWithConcurrency(config.concurrency(), neighborFilterTasks, this.executorService);
165166
}
166167
this.progressTracker.endSubTask();
167168

@@ -207,7 +208,7 @@ public void release() {
207208
Optional.of(config.minBatchSize())
208209
);
209210

210-
ParallelUtil.runWithConcurrency(config.concurrency(), randomNeighborGenerators, context.executor());
211+
ParallelUtil.runWithConcurrency(config.concurrency(), randomNeighborGenerators, this.executorService);
211212

212213
this.nodePairsConsidered += randomNeighborGenerators.stream().mapToLong(GenerateRandomNeighbors::neighborsFound).sum();
213214

@@ -242,7 +243,6 @@ private long iteration(HugeObjectArray<NeighborList> neighbors) {
242243
}
243244

244245
var concurrency = this.config.concurrency();
245-
var executor = this.context.executor();
246246

247247
var sampledK = this.config.sampledK(nodeCount);
248248

@@ -251,7 +251,7 @@ private long iteration(HugeObjectArray<NeighborList> neighbors) {
251251
var allNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
252252

253253
progressTracker.beginSubTask();
254-
ParallelUtil.readParallel(concurrency, nodeCount, executor, new SplitOldAndNewNeighbors(
254+
ParallelUtil.readParallel(concurrency, nodeCount, this.executorService, new SplitOldAndNewNeighbors(
255255
this.splittableRandom,
256256
neighbors,
257257
allOldNeighbors,
@@ -301,7 +301,7 @@ private long iteration(HugeObjectArray<NeighborList> neighbors) {
301301
);
302302

303303
progressTracker.beginSubTask();
304-
ParallelUtil.runWithConcurrency(concurrency, neighborsJoiners, executor);
304+
ParallelUtil.runWithConcurrency(concurrency, neighborsJoiners, this.executorService);
305305
progressTracker.endSubTask();
306306

307307
this.nodePairsConsidered += neighborsJoiners.stream().mapToLong(JoinNeighbors::nodePairsConsidered).sum();

proc/similarity/src/main/java/org/neo4j/gds/similarity/knn/KnnMutateProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public ComputationResultConsumer<Knn, Knn.Result, KnnMutateConfig, Stream<Result
118118
algorithm.nodeCount(),
119119
config.concurrency(),
120120
Objects.requireNonNull(result),
121-
algorithm.context()
121+
algorithm.executorService()
122122
);
123123
}
124124

proc/similarity/src/main/java/org/neo4j/gds/similarity/knn/KnnStatsProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ public Stream<Result> stats(ComputationResult<Knn, Knn.Result, KnnStatsConfig> c
126126
algorithm.nodeCount(),
127127
config.concurrency(),
128128
result,
129-
algorithm.context()
129+
algorithm.executorService()
130130
);
131131

132132
Graph similarityGraph = similarityGraphResult.similarityGraph();

proc/similarity/src/main/java/org/neo4j/gds/similarity/knn/KnnWriteProc.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import java.util.Map;
3838
import java.util.Objects;
39+
import java.util.concurrent.ExecutorService;
3940
import java.util.stream.Stream;
4041

4142
import static org.neo4j.gds.executor.ExecutionMode.WRITE_RELATIONSHIP;
@@ -100,7 +101,7 @@ protected SimilarityGraphResult similarityGraphResult(ComputationResult<Knn, Knn
100101
algorithm.nodeCount(),
101102
config.concurrency(),
102103
Objects.requireNonNull(computationResult.result()),
103-
algorithm.context()
104+
algorithm.executorService()
104105
);
105106
}
106107

@@ -109,12 +110,12 @@ static SimilarityGraphResult computeToGraph(
109110
long nodeCount,
110111
int concurrency,
111112
Knn.Result result,
112-
KnnContext context
113+
ExecutorService executor
113114
) {
114115
Graph similarityGraph = new SimilarityGraphBuilder(
115116
graph,
116117
concurrency,
117-
context.executor()
118+
executor
118119
).build(result.streamSimilarityResult());
119120
return new SimilarityGraphResult(similarityGraph, nodeCount, false);
120121
}

0 commit comments

Comments
 (0)