Skip to content

Commit 6e4a436

Browse files
committed
Move out internal classes
1 parent bd48bd4 commit 6e4a436

File tree

3 files changed

+120
-94
lines changed

3 files changed

+120
-94
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 4 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,24 @@
2323
import org.jetbrains.annotations.NotNull;
2424
import org.jetbrains.annotations.Nullable;
2525
import org.neo4j.gds.Algorithm;
26-
import org.neo4j.gds.annotation.ValueClass;
2726
import org.neo4j.gds.api.Graph;
2827
import org.neo4j.gds.core.concurrency.ParallelUtil;
2928
import org.neo4j.gds.core.utils.ProgressTimer;
30-
import org.neo4j.gds.core.utils.paged.HugeCursor;
3129
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
3230
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3331
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
34-
import org.neo4j.gds.similarity.SimilarityResult;
3532
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
3633

3734
import java.util.List;
3835
import java.util.Optional;
3936
import java.util.SplittableRandom;
4037
import java.util.concurrent.ExecutorService;
4138
import java.util.function.Function;
42-
import java.util.function.UnaryOperator;
4339
import java.util.stream.Collectors;
44-
import java.util.stream.IntStream;
45-
import java.util.stream.LongStream;
46-
import java.util.stream.Stream;
4740

4841
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
4942

50-
public class FilteredKnn extends Algorithm<FilteredKnn.Result> {
43+
public class FilteredKnn extends Algorithm<FilteredKnnResult> {
5144
private final Graph graph;
5245
private final FilteredNeighborFilterFactory neighborFilterFactory;
5346
private final ExecutorService executorService;
@@ -177,7 +170,7 @@ public ExecutorService executorService() {
177170
}
178171

179172
@Override
180-
public Result compute() {
173+
public FilteredKnnResult compute() {
181174
this.progressTracker.beginSubTask();
182175
HugeObjectArray<FilteredNeighborList> neighbors;
183176
try (var ignored1 = ProgressTimer.start(this::logOverallTime)) {
@@ -187,7 +180,7 @@ public Result compute() {
187180
this.progressTracker.endSubTask();
188181
}
189182
if (neighbors == null) {
190-
return new EmptyResult();
183+
return FilteredKnnResult.empty();
191184
}
192185

193186
var maxUpdates = (long) Math.ceil(this.sampleRate * this.topK * graph.nodeCount());
@@ -223,7 +216,7 @@ public Result compute() {
223216
this.progressTracker.endSubTask();
224217

225218
this.progressTracker.endSubTask();
226-
return ImmutableResult.of(neighbors, iteration, didConverge, this.nodePairsConsidered, this.sourceNodes);
219+
return ImmutableFilteredKnnResult.of(neighbors, iteration, didConverge, this.nodePairsConsidered, this.sourceNodes);
227220
}
228221
}
229222

@@ -395,83 +388,4 @@ private void logIterationTime(int iteration, long ms) {
395388
private void logOverallTime(long ms) {
396389
progressTracker.logMessage(formatWithLocale("Graph execution took %d ms", ms));
397390
}
398-
399-
@ValueClass
400-
public abstract static class Result {
401-
abstract HugeObjectArray<FilteredNeighborList> neighborList();
402-
403-
public abstract int ranIterations();
404-
405-
public abstract boolean didConverge();
406-
407-
public abstract long nodePairsConsidered();
408-
409-
public abstract List<Long> sourceNodes();
410-
411-
public LongStream neighborsOf(long nodeId) {
412-
return neighborList().get(nodeId).elements().map(FilteredNeighborList::clearCheckedFlag);
413-
}
414-
415-
// http://www.flatmapthatshit.com/
416-
public Stream<SimilarityResult> streamSimilarityResult() {
417-
var neighborList = neighborList();
418-
return Stream.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
419-
.flatMap(cursor -> IntStream.range(cursor.offset, cursor.limit)
420-
.filter(index -> sourceNodes().contains(index + cursor.base))
421-
.mapToObj(index -> cursor.array[index].similarityStream(index + cursor.base))
422-
.flatMap(Function.identity())
423-
);
424-
}
425-
426-
public long totalSimilarityPairs() {
427-
var neighborList = neighborList();
428-
return Stream.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
429-
.flatMapToLong(cursor -> IntStream.range(cursor.offset, cursor.limit)
430-
.filter(index -> sourceNodes().contains(index + cursor.base))
431-
.mapToLong(index -> cursor.array[index].size()))
432-
.sum();
433-
}
434-
435-
public long size() {
436-
return neighborList().size();
437-
}
438-
}
439-
440-
private static final class EmptyResult extends Result {
441-
442-
@Override
443-
HugeObjectArray<FilteredNeighborList> neighborList() {
444-
return HugeObjectArray.of();
445-
}
446-
447-
@Override
448-
public int ranIterations() {
449-
return 0;
450-
}
451-
452-
@Override
453-
public boolean didConverge() {
454-
return false;
455-
}
456-
457-
@Override
458-
public long nodePairsConsidered() {
459-
return 0;
460-
}
461-
462-
@Override
463-
public List<Long> sourceNodes() {
464-
return List.of();
465-
}
466-
467-
@Override
468-
public LongStream neighborsOf(long nodeId) {
469-
return LongStream.empty();
470-
}
471-
472-
@Override
473-
public long size() {
474-
return 0;
475-
}
476-
}
477391
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import org.jetbrains.annotations.NotNull;
23+
import org.neo4j.gds.annotation.ValueClass;
24+
import org.neo4j.gds.core.utils.paged.HugeCursor;
25+
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
26+
import org.neo4j.gds.similarity.SimilarityResult;
27+
28+
import java.util.List;
29+
import java.util.function.Function;
30+
import java.util.function.UnaryOperator;
31+
import java.util.stream.IntStream;
32+
import java.util.stream.LongStream;
33+
import java.util.stream.Stream;
34+
35+
@ValueClass
36+
public abstract class FilteredKnnResult {
37+
abstract HugeObjectArray<FilteredNeighborList> neighborList();
38+
39+
public abstract int ranIterations();
40+
41+
public abstract boolean didConverge();
42+
43+
public abstract long nodePairsConsidered();
44+
45+
public abstract List<Long> sourceNodes();
46+
47+
public LongStream neighborsOf(long nodeId) {
48+
return neighborList().get(nodeId).elements().map(FilteredNeighborList::clearCheckedFlag);
49+
}
50+
51+
// http://www.flatmapthatshit.com/
52+
public Stream<SimilarityResult> streamSimilarityResult() {
53+
var neighborList = neighborList();
54+
return Stream
55+
.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
56+
.flatMap(cursor -> IntStream.range(cursor.offset, cursor.limit)
57+
.filter(index -> sourceNodes().contains(index + cursor.base))
58+
.mapToObj(index -> cursor.array[index].similarityStream(index + cursor.base))
59+
.flatMap(Function.identity())
60+
);
61+
}
62+
63+
public long totalSimilarityPairs() {
64+
var neighborList = neighborList();
65+
return Stream
66+
.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
67+
.flatMapToLong(cursor -> IntStream.range(cursor.offset, cursor.limit)
68+
.filter(index -> sourceNodes().contains(index + cursor.base))
69+
.mapToLong(index -> cursor.array[index].size()))
70+
.sum();
71+
}
72+
73+
public long size() {
74+
return neighborList().size();
75+
}
76+
77+
@NotNull
78+
static FilteredKnnResult empty() {
79+
return new FilteredKnnResult() {
80+
81+
@Override
82+
HugeObjectArray<FilteredNeighborList> neighborList() {
83+
return HugeObjectArray.of();
84+
}
85+
86+
@Override
87+
public int ranIterations() {
88+
return 0;
89+
}
90+
91+
@Override
92+
public boolean didConverge() {
93+
return false;
94+
}
95+
96+
@Override
97+
public long nodePairsConsidered() {
98+
return 0;
99+
}
100+
101+
@Override
102+
public List<Long> sourceNodes() {
103+
return List.of();
104+
}
105+
106+
@Override
107+
public LongStream neighborsOf(long nodeId) {
108+
return LongStream.empty();
109+
}
110+
};
111+
}
112+
}

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ void shouldHaveEachNodeConnected() {
150150
assertCorrectNeighborList(result, nodeCId, nodeAId, nodeBId);
151151
}
152152
private void assertCorrectNeighborList(
153-
FilteredKnn.Result result,
153+
FilteredKnnResult result,
154154
long nodeId,
155155
long... expectedNeighbors
156156
) {
@@ -265,7 +265,7 @@ void shouldFilterResultsOfLowSimilarity() {
265265
assertCorrectNeighborList(result, nodeEveId, nodeBobId);
266266
}
267267

268-
private void assertEmptyNeighborList(FilteredKnn.Result result, long nodeId) {
268+
private void assertEmptyNeighborList(FilteredKnnResult result, long nodeId) {
269269
var actualNeighbors = result.neighborsOf(nodeId).toArray();
270270
assertThat(actualNeighbors).isEmpty();
271271
}
@@ -288,7 +288,7 @@ void testNonExistingProperties(NodePropertyValues nodeProperties) {
288288
var result = knn.compute();
289289
assertThat(result)
290290
.isNotNull()
291-
.extracting(FilteredKnn.Result::size)
291+
.extracting(FilteredKnnResult::size)
292292
.isEqualTo(3L);
293293
}
294294

@@ -322,7 +322,7 @@ void testMixedExistingAndNonExistingProperties(SoftAssertions softly) {
322322

323323
softly.assertThat(result)
324324
.isNotNull()
325-
.extracting(FilteredKnn.Result::size)
325+
.extracting(FilteredKnnResult::size)
326326
.isEqualTo(3L);
327327

328328
long nodeAId = idFunction.of("a");

0 commit comments

Comments
 (0)