Skip to content

Commit b63e7a0

Browse files
committed
implementing a simple bounded priority queue for target node filtering
shortening names
1 parent 5fa1b40 commit b63e7a0

File tree

5 files changed

+112
-23
lines changed

5 files changed

+112
-23
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,46 @@
2525
import org.neo4j.gds.similarity.knn.Knn;
2626
import org.neo4j.gds.similarity.knn.KnnContext;
2727

28+
<<<<<<< HEAD
29+
=======
30+
import java.util.List;
31+
32+
/**
33+
* Filtered KNN is the same as ordinary KNN, _but_ we allow users to regulate final output in two ways.
34+
*
35+
* Firstly, we enable source node filtering, meaning reported results are limited nodes from a certain set.
36+
*
37+
* Secondly, we enable target node filtering in the sense that every result will be from a certain set of nodes.
38+
*
39+
* In both cases the source or target node set can be actual specified nodes, or it could be all nodes with a label.
40+
*/
2841
public final class FilteredKnn extends Algorithm<FilteredKnnResult> {
2942
/**
3043
* This is KNN instrumented with neighbour consumers
3144
*/
3245
private final Knn delegate;
3346

34-
private final TargetNodeFilteringNeighbourConsumers neighbourConsumers;
47+
private final TargetNodeFiltering targetNodeFiltering;
3548
private final NodeFilter sourceNodeFilter;
3649

3750
public static FilteredKnn create(Graph graph, FilteredKnnBaseConfig config, KnnContext context) {
3851
var targetNodeFilter = config.sourceNodeFilter().toNodeFilter(graph);
39-
var neighbourConsumers = TargetNodeFilteringNeighbourConsumers.create(graph.nodeCount()/*, targetNodeFilter*/);
40-
var knn = Knn.createWithDefaultsAndInstrumentation(graph, config, context, neighbourConsumers);
52+
var targetNodeFiltering = TargetNodeFiltering.create(graph.nodeCount(), config.boundedK(graph.nodeCount())/*, targetNodeFilter*/);
53+
var knn = Knn.createWithDefaultsAndInstrumentation(graph, config, context, targetNodeFiltering);
4154
var sourceNodeFilter = config.sourceNodeFilter().toNodeFilter(graph);
4255

43-
return new FilteredKnn(context.progressTracker(), knn, neighbourConsumers, sourceNodeFilter);
56+
return new FilteredKnn(context.progressTracker(), knn, targetNodeFiltering, sourceNodeFilter);
4457
}
4558

4659
private FilteredKnn(
4760
ProgressTracker progressTracker,
4861
Knn delegate,
49-
TargetNodeFilteringNeighbourConsumers neighbourConsumers,
62+
TargetNodeFiltering targetNodeFiltering,
5063
NodeFilter sourceNodeFilter
5164
) {
5265
super(progressTracker);
5366
this.delegate = delegate;
54-
this.neighbourConsumers = neighbourConsumers;
67+
this.targetNodeFiltering = targetNodeFiltering;
5568
this.sourceNodeFilter = sourceNodeFilter;
5669
}
5770

@@ -60,7 +73,7 @@ public FilteredKnnResult compute() {
6073
Knn.Result result = delegate.compute();
6174

6275
return ImmutableFilteredKnnResult.of(
63-
neighbourConsumers,
76+
targetNodeFiltering,
6477
result.ranIterations(),
6578
result.didConverge(),
6679
result.nodePairsConsidered(),

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnResult.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
@ValueClass
2929
public abstract class FilteredKnnResult {
30-
abstract TargetNodeFilteringNeighbourConsumers neighbourConsumers();
30+
abstract TargetNodeFiltering neighbourConsumers();
3131

3232
public abstract int ranIterations();
3333

@@ -38,7 +38,7 @@ public abstract class FilteredKnnResult {
3838
public abstract List<Long> sourceNodes();
3939

4040
public Stream<SimilarityResult> similarityResultStream() {
41-
TargetNodeFilteringNeighbourConsumers neighbourConsumers = neighbourConsumers();
41+
TargetNodeFiltering neighbourConsumers = neighbourConsumers();
4242
List<Long> sourceNodes = sourceNodes();
4343

4444
return neighbourConsumers.asSimilarityResultStream(sourceNodes::contains);

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/TargetNodeFilteringNeighbourConsumer.java renamed to algo/src/main/java/org/neo4j/gds/similarity/filteredknn/TargetNodeFilter.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,38 @@
1919
*/
2020
package org.neo4j.gds.similarity.filteredknn;
2121

22+
import org.apache.commons.lang3.tuple.Pair;
2223
import org.neo4j.gds.similarity.SimilarityResult;
2324
import org.neo4j.gds.similarity.knn.NeighbourConsumer;
2425

26+
import java.util.Comparator;
27+
import java.util.TreeSet;
2528
import java.util.stream.Stream;
2629

27-
public class TargetNodeFilteringNeighbourConsumer implements NeighbourConsumer {
30+
/**
31+
* We sort results by score, descending.
32+
*
33+
* For now a simple bounded priority queue that does _not_ handle duplicates.
34+
*/
35+
public class TargetNodeFilter implements NeighbourConsumer {
36+
private final TreeSet<Pair<Double, Long>> priorityQueue = new TreeSet<>(Comparator.reverseOrder());
37+
private final int bound;
38+
39+
public TargetNodeFilter(int bound) {
40+
this.bound = bound;
41+
}
42+
2843
@Override
2944
public void offer(long element, double priority) {
30-
throw new UnsupportedOperationException("TODO");
45+
priorityQueue.add(Pair.of(priority, element));
46+
47+
if (priorityQueue.size() > bound) priorityQueue.pollLast();
3148
}
3249

50+
/**
51+
* As part of an instrumentation of KNN this is a handy utility.
52+
*/
3353
Stream<SimilarityResult> asSimilarityStream(long nodeId) {
34-
throw new UnsupportedOperationException("TODO");
54+
return priorityQueue.stream().map(p -> new SimilarityResult(nodeId, p.getRight(), p.getLeft()));
3555
}
3656
}

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/TargetNodeFilteringNeighbourConsumers.java renamed to algo/src/main/java/org/neo4j/gds/similarity/filteredknn/TargetNodeFiltering.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,24 @@
3131
import java.util.stream.IntStream;
3232
import java.util.stream.Stream;
3333

34-
public class TargetNodeFilteringNeighbourConsumers implements NeighbourConsumers {
35-
private final HugeObjectArray<TargetNodeFilteringNeighbourConsumer> neighbourConsumers;
34+
public final class TargetNodeFiltering implements NeighbourConsumers {
35+
private final HugeObjectArray<TargetNodeFilter> neighbourConsumers;
3636

37-
public TargetNodeFilteringNeighbourConsumers(HugeObjectArray<TargetNodeFilteringNeighbourConsumer> neighbourConsumers) {
38-
this.neighbourConsumers = neighbourConsumers;
39-
}
40-
41-
static TargetNodeFilteringNeighbourConsumers create(long nodeCount) {
42-
HugeObjectArray<TargetNodeFilteringNeighbourConsumer> neighbourConsumers = HugeObjectArray.newArray(
43-
TargetNodeFilteringNeighbourConsumer.class,
37+
static TargetNodeFiltering create(long nodeCount, int k) {
38+
HugeObjectArray<TargetNodeFilter> neighbourConsumers = HugeObjectArray.newArray(
39+
TargetNodeFilter.class,
4440
nodeCount
4541
);
4642

4743
for (int i = 0; i < nodeCount; i++) {
48-
neighbourConsumers.set(i, new TargetNodeFilteringNeighbourConsumer());
44+
neighbourConsumers.set(i, new TargetNodeFilter(k));
4945
}
5046

51-
return new TargetNodeFilteringNeighbourConsumers(neighbourConsumers);
47+
return new TargetNodeFiltering(neighbourConsumers);
48+
}
49+
50+
private TargetNodeFiltering(HugeObjectArray<TargetNodeFilter> neighbourConsumers) {
51+
this.neighbourConsumers = neighbourConsumers;
5252
}
5353

5454
@Override
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.similarity.SimilarityResult;
24+
25+
import static org.assertj.core.api.Assertions.assertThat;
26+
27+
class TargetNodeFilterTest {
28+
@Test
29+
void shouldPrioritiseTargetNodes() {
30+
TargetNodeFilter consumer = new TargetNodeFilter(3);
31+
32+
consumer.offer(23, 3.14);
33+
consumer.offer(42, 1.61);
34+
consumer.offer(87, 2.71);
35+
36+
assertThat(consumer.asSimilarityStream(117)).containsExactly(
37+
new SimilarityResult(117, 23, 3.14),
38+
new SimilarityResult(117, 87, 2.71),
39+
new SimilarityResult(117, 42, 1.61)
40+
);
41+
}
42+
43+
@Test
44+
void shouldOnlyKeepTopK() {
45+
TargetNodeFilter consumer = new TargetNodeFilter(2);
46+
47+
consumer.offer(23, 3.14);
48+
consumer.offer(42, 1.61);
49+
consumer.offer(87, 2.71);
50+
51+
assertThat(consumer.asSimilarityStream(117)).containsExactly(
52+
new SimilarityResult(117, 23, 3.14),
53+
new SimilarityResult(117, 87, 2.71)
54+
);
55+
}
56+
}

0 commit comments

Comments
 (0)