Skip to content

Commit f00c6b4

Browse files
committed
integrating target node filter
adding smoke test for target node filtering
1 parent 402934a commit f00c6b4

File tree

5 files changed

+81
-7
lines changed

5 files changed

+81
-7
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ public final class FilteredKnn extends Algorithm<FilteredKnnResult> {
4545
private final NodeFilter sourceNodeFilter;
4646

4747
public static FilteredKnn create(Graph graph, FilteredKnnBaseConfig config, KnnContext context) {
48-
var targetNodeFilter = config.sourceNodeFilter().toNodeFilter(graph);
49-
var targetNodeFiltering = TargetNodeFiltering.create(graph.nodeCount(), config.boundedK(graph.nodeCount())/*, targetNodeFilter*/);
48+
var targetNodeFilter = config.targetNodeFilter().toNodeFilter(graph);
49+
var targetNodeFiltering = TargetNodeFiltering.create(graph.nodeCount(), config.boundedK(graph.nodeCount()), targetNodeFilter);
5050
var knn = Knn.createWithDefaultsAndInstrumentation(graph, config, context, targetNodeFiltering);
5151
var sourceNodeFilter = config.sourceNodeFilter().toNodeFilter(graph);
5252

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/TargetNodeFilter.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import java.util.Comparator;
2727
import java.util.TreeSet;
28+
import java.util.function.LongPredicate;
2829
import java.util.stream.Stream;
2930

3031
/**
@@ -34,14 +35,19 @@
3435
*/
3536
public class TargetNodeFilter implements NeighbourConsumer {
3637
private final TreeSet<Pair<Double, Long>> priorityQueue = new TreeSet<>(Comparator.reverseOrder());
38+
39+
private final LongPredicate predicate;
3740
private final int bound;
3841

39-
public TargetNodeFilter(int bound) {
42+
public TargetNodeFilter(LongPredicate predicate, int bound) {
43+
this.predicate = predicate;
4044
this.bound = bound;
4145
}
4246

4347
@Override
4448
public void offer(long element, double priority) {
49+
if (! predicate.test(element)) return;
50+
4551
priorityQueue.add(Pair.of(priority, element));
4652

4753
if (priorityQueue.size() > bound) priorityQueue.pollLast();

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/TargetNodeFiltering.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
public final class TargetNodeFiltering implements NeighbourConsumers {
3535
private final HugeObjectArray<TargetNodeFilter> neighbourConsumers;
3636

37-
static TargetNodeFiltering create(long nodeCount, int k) {
37+
static TargetNodeFiltering create(long nodeCount, int k, LongPredicate targetNodePredicate) {
3838
HugeObjectArray<TargetNodeFilter> neighbourConsumers = HugeObjectArray.newArray(
3939
TargetNodeFilter.class,
4040
nodeCount
4141
);
4242

4343
for (int i = 0; i < nodeCount; i++) {
44-
neighbourConsumers.set(i, new TargetNodeFilter(k));
44+
neighbourConsumers.set(i, new TargetNodeFilter(targetNodePredicate, k));
4545
}
4646

4747
return new TargetNodeFiltering(neighbourConsumers);

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnTest.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,61 @@ void shouldOnlyProduceResultsForMultipleFilteredSourceNode() {
172172
).isEqualTo(Set.of(filteredNode1, filteredNode2));
173173
}
174174
}
175+
176+
@Nested
177+
class TargetNodeFiltering {
178+
@GdlGraph
179+
private static final String DB_CYPHER =
180+
"CREATE" +
181+
" (a { knn: 1.2 } )" +
182+
", (b { knn: 1.1 } )" +
183+
", (c { knn: 2.1 } )" +
184+
", (d { knn: 3.1 } )" +
185+
", (e { knn: 4.1 } )";
186+
187+
@Test
188+
void shouldOnlyProduceResultsForFilteredTargetNode() {
189+
var targetNode = idFunction.of("a");
190+
var config = FilteredKnnBaseConfigImpl.builder()
191+
.nodeProperties(List.of("knn"))
192+
.topK(3)
193+
.randomJoins(0)
194+
.maxIterations(1)
195+
.randomSeed(20L)
196+
.concurrency(1)
197+
.targetNodeFilter(targetNode)
198+
.build();
199+
var knnContext = KnnContext.empty();
200+
var knn = FilteredKnn.create(graph, config, knnContext);
201+
var result = knn.compute();
202+
203+
assertThat(result.similarityResultStream()
204+
.map(SimilarityResult::targetNodeId)
205+
.collect(Collectors.toSet())
206+
).isEqualTo(Set.of(targetNode));
207+
}
208+
209+
@Test
210+
void shouldOnlyProduceResultsForFilteredTargetNodes() {
211+
var targetNode1 = idFunction.of("a");
212+
var targetNode2 = idFunction.of("b");
213+
var config = FilteredKnnBaseConfigImpl.builder()
214+
.nodeProperties("knn")
215+
.topK(3)
216+
.randomJoins(0)
217+
.maxIterations(1)
218+
.randomSeed(20L)
219+
.concurrency(1)
220+
.targetNodeFilter(List.of(targetNode1, targetNode2))
221+
.build();
222+
var knnContext = KnnContext.empty();
223+
var knn = FilteredKnn.create(graph, config, knnContext);
224+
var result = knn.compute();
225+
226+
assertThat(result.similarityResultStream()
227+
.map(SimilarityResult::targetNodeId)
228+
.collect(Collectors.toSet())
229+
).isEqualTo(Set.of(targetNode1, targetNode2));
230+
}
231+
}
175232
}

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/TargetNodeFilterTest.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class TargetNodeFilterTest {
2828
@Test
2929
void shouldPrioritiseTargetNodes() {
30-
TargetNodeFilter consumer = new TargetNodeFilter(3);
30+
TargetNodeFilter consumer = new TargetNodeFilter(l -> true, 3);
3131

3232
consumer.offer(23, 3.14);
3333
consumer.offer(42, 1.61);
@@ -42,7 +42,7 @@ void shouldPrioritiseTargetNodes() {
4242

4343
@Test
4444
void shouldOnlyKeepTopK() {
45-
TargetNodeFilter consumer = new TargetNodeFilter(2);
45+
TargetNodeFilter consumer = new TargetNodeFilter(l -> true, 2);
4646

4747
consumer.offer(23, 3.14);
4848
consumer.offer(42, 1.61);
@@ -53,4 +53,15 @@ void shouldOnlyKeepTopK() {
5353
new SimilarityResult(117, 87, 2.71)
5454
);
5555
}
56+
57+
@Test
58+
void shouldOnlyIncludeTargetNodes() {
59+
TargetNodeFilter consumer = new TargetNodeFilter(l -> false, 3);
60+
61+
consumer.offer(23, 3.14);
62+
consumer.offer(42, 1.61);
63+
consumer.offer(87, 2.71);
64+
65+
assertThat(consumer.asSimilarityStream(117)).isEmpty();
66+
}
5667
}

0 commit comments

Comments
 (0)