2323import org .jetbrains .annotations .NotNull ;
2424import org .jetbrains .annotations .Nullable ;
2525import org .neo4j .gds .Algorithm ;
26- import org .neo4j .gds .annotation .ValueClass ;
2726import org .neo4j .gds .api .Graph ;
2827import org .neo4j .gds .core .concurrency .ParallelUtil ;
2928import org .neo4j .gds .core .utils .ProgressTimer ;
30- import org .neo4j .gds .core .utils .paged .HugeCursor ;
3129import org .neo4j .gds .core .utils .paged .HugeObjectArray ;
3230import org .neo4j .gds .core .utils .partition .PartitionUtils ;
3331import org .neo4j .gds .core .utils .progress .tasks .ProgressTracker ;
34- import org .neo4j .gds .similarity .SimilarityResult ;
3532import org .neo4j .gds .similarity .knn .metrics .SimilarityComputer ;
3633
3734import java .util .List ;
3835import java .util .Optional ;
3936import java .util .SplittableRandom ;
4037import java .util .concurrent .ExecutorService ;
4138import java .util .function .Function ;
42- import java .util .function .UnaryOperator ;
4339import java .util .stream .Collectors ;
44- import java .util .stream .IntStream ;
45- import java .util .stream .LongStream ;
46- import java .util .stream .Stream ;
4740
4841import static org .neo4j .gds .utils .StringFormatting .formatWithLocale ;
4942
50- public class FilteredKnn extends Algorithm <FilteredKnn . Result > {
43+ public class FilteredKnn extends Algorithm <FilteredKnnResult > {
5144 private final Graph graph ;
5245 private final FilteredNeighborFilterFactory neighborFilterFactory ;
5346 private final ExecutorService executorService ;
@@ -177,7 +170,7 @@ public ExecutorService executorService() {
177170 }
178171
179172 @ Override
180- public Result compute () {
173+ public FilteredKnnResult compute () {
181174 this .progressTracker .beginSubTask ();
182175 HugeObjectArray <FilteredNeighborList > neighbors ;
183176 try (var ignored1 = ProgressTimer .start (this ::logOverallTime )) {
@@ -187,7 +180,7 @@ public Result compute() {
187180 this .progressTracker .endSubTask ();
188181 }
189182 if (neighbors == null ) {
190- return new EmptyResult ();
183+ return FilteredKnnResult . empty ();
191184 }
192185
193186 var maxUpdates = (long ) Math .ceil (this .sampleRate * this .topK * graph .nodeCount ());
@@ -223,7 +216,7 @@ public Result compute() {
223216 this .progressTracker .endSubTask ();
224217
225218 this .progressTracker .endSubTask ();
226- return ImmutableResult .of (neighbors , iteration , didConverge , this .nodePairsConsidered , this .sourceNodes );
219+ return ImmutableFilteredKnnResult .of (neighbors , iteration , didConverge , this .nodePairsConsidered , this .sourceNodes );
227220 }
228221 }
229222
@@ -395,83 +388,4 @@ private void logIterationTime(int iteration, long ms) {
395388 private void logOverallTime (long ms ) {
396389 progressTracker .logMessage (formatWithLocale ("Graph execution took %d ms" , ms ));
397390 }
398-
399- @ ValueClass
400- public abstract static class Result {
401- abstract HugeObjectArray <FilteredNeighborList > neighborList ();
402-
403- public abstract int ranIterations ();
404-
405- public abstract boolean didConverge ();
406-
407- public abstract long nodePairsConsidered ();
408-
409- public abstract List <Long > sourceNodes ();
410-
411- public LongStream neighborsOf (long nodeId ) {
412- return neighborList ().get (nodeId ).elements ().map (FilteredNeighborList ::clearCheckedFlag );
413- }
414-
415- // http://www.flatmapthatshit.com/
416- public Stream <SimilarityResult > streamSimilarityResult () {
417- var neighborList = neighborList ();
418- return Stream .iterate (neighborList .initCursor (neighborList .newCursor ()), HugeCursor ::next , UnaryOperator .identity ())
419- .flatMap (cursor -> IntStream .range (cursor .offset , cursor .limit )
420- .filter (index -> sourceNodes ().contains (index + cursor .base ))
421- .mapToObj (index -> cursor .array [index ].similarityStream (index + cursor .base ))
422- .flatMap (Function .identity ())
423- );
424- }
425-
426- public long totalSimilarityPairs () {
427- var neighborList = neighborList ();
428- return Stream .iterate (neighborList .initCursor (neighborList .newCursor ()), HugeCursor ::next , UnaryOperator .identity ())
429- .flatMapToLong (cursor -> IntStream .range (cursor .offset , cursor .limit )
430- .filter (index -> sourceNodes ().contains (index + cursor .base ))
431- .mapToLong (index -> cursor .array [index ].size ()))
432- .sum ();
433- }
434-
435- public long size () {
436- return neighborList ().size ();
437- }
438- }
439-
440- private static final class EmptyResult extends Result {
441-
442- @ Override
443- HugeObjectArray <FilteredNeighborList > neighborList () {
444- return HugeObjectArray .of ();
445- }
446-
447- @ Override
448- public int ranIterations () {
449- return 0 ;
450- }
451-
452- @ Override
453- public boolean didConverge () {
454- return false ;
455- }
456-
457- @ Override
458- public long nodePairsConsidered () {
459- return 0 ;
460- }
461-
462- @ Override
463- public List <Long > sourceNodes () {
464- return List .of ();
465- }
466-
467- @ Override
468- public LongStream neighborsOf (long nodeId ) {
469- return LongStream .empty ();
470- }
471-
472- @ Override
473- public long size () {
474- return 0 ;
475- }
476- }
477391}
0 commit comments