Skip to content

Commit 3a1aa53

Browse files
committed
add instruumentation to KNN so we can spy on visited neighbours
1 parent 62c99fb commit 3a1aa53

File tree

8 files changed

+58
-28
lines changed

8 files changed

+58
-28
lines changed

algo/src/main/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighbors.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ final class GenerateRandomNeighbors implements Runnable {
3939
private final int boundedK;
4040
private final ProgressTracker progressTracker;
4141
private final Partition partition;
42+
private final NeighbourConsumers neighbourConsumers;
43+
4244
private long neighborsFound;
4345

4446
GenerateRandomNeighbors(
@@ -50,7 +52,8 @@ final class GenerateRandomNeighbors implements Runnable {
5052
int k,
5153
int boundedK,
5254
Partition partition,
53-
ProgressTracker progressTracker
55+
ProgressTracker progressTracker,
56+
NeighbourConsumers neighbourConsumers
5457
) {
5558
this.sampler = sampler;
5659
this.random = random;
@@ -62,6 +65,7 @@ final class GenerateRandomNeighbors implements Runnable {
6265
this.progressTracker = progressTracker;
6366
this.partition = partition;
6467
this.neighborsFound = 0;
68+
this.neighbourConsumers = neighbourConsumers;
6569
}
6670

6771
@Override
@@ -80,7 +84,7 @@ public void run() {
8084
l -> neighborFilter.excludeNodePair(nodeId, l)
8185
);
8286

83-
var neighbors = new NeighborList(k);
87+
var neighbors = new NeighborList(k, neighbourConsumers.get(nodeId));
8488
for (long candidate : chosen) {
8589
neighbors.add(candidate, computer.safeSimilarity(nodeId, candidate), rng, 0.0);
8690
}

algo/src/main/java/org/neo4j/gds/similarity/knn/Knn.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,30 @@ public class Knn extends Algorithm<Knn.Result> {
5353
private final ExecutorService executorService;
5454
private final SplittableRandom splittableRandom;
5555
private final SimilarityComputer similarityComputer;
56+
private final NeighbourConsumers neighborConsumers;
5657

5758
private long nodePairsConsidered;
5859

5960
public static Knn createWithDefaults(Graph graph, KnnBaseConfig config, KnnContext context) {
61+
return createWithDefaultsss(graph, config, context, NeighbourConsumers.devNull);
62+
}
63+
64+
@NotNull
65+
public static Knn createWithDefaultsss(
66+
Graph graph,
67+
KnnBaseConfig config,
68+
KnnContext context,
69+
NeighbourConsumers neighborConsumers
70+
) {
6071
return new Knn(
6172
context.progressTracker(),
6273
graph,
6374
config,
6475
SimilarityComputer.ofProperties(graph, config.nodeProperties()),
6576
new KnnNeighborFilterFactory(graph.nodeCount()),
6677
context.executor(),
67-
getSplittableRandom(config.randomSeed())
78+
getSplittableRandom(config.randomSeed()),
79+
neighborConsumers
6880
);
6981
}
7082

@@ -83,7 +95,8 @@ public static Knn create(
8395
similarityComputer,
8496
neighborFilterFactory,
8597
context.executor(),
86-
splittableRandom
98+
splittableRandom,
99+
NeighbourConsumers.devNull
87100
);
88101
}
89102

@@ -99,7 +112,8 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
99112
SimilarityComputer similarityComputer,
100113
NeighborFilterFactory neighborFilterFactory,
101114
ExecutorService executorService,
102-
SplittableRandom splittableRandom
115+
SplittableRandom splittableRandom,
116+
NeighbourConsumers neighborConsumers
103117
) {
104118
super(progressTracker);
105119
this.graph = graph;
@@ -108,6 +122,7 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
108122
this.neighborFilterFactory = neighborFilterFactory;
109123
this.executorService = executorService;
110124
this.splittableRandom = splittableRandom;
125+
this.neighborConsumers = neighborConsumers;
111126
}
112127

113128
public long nodeCount() {
@@ -202,7 +217,8 @@ public void release() {
202217
k,
203218
boundedK,
204219
partition,
205-
progressTracker
220+
progressTracker,
221+
neighborConsumers
206222
);
207223
},
208224
Optional.of(config.minBatchSize())

algo/src/main/java/org/neo4j/gds/similarity/knn/NeighborList.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ static boolean isChecked(long value) {
8181
static final int NOT_INSERTED = 0;
8282
private static final int INSERTED = 1;
8383

84+
// a listener that sees every neighbour someone tries to add
85+
private final NeighbourConsumer neighbourConsumer;
86+
8487
// maximum number of elements, aka the top K
8588
private final int elementCapacity;
8689
// currently stored number of elements
@@ -90,13 +93,14 @@ static boolean isChecked(long value) {
9093
// every item occupies two entries in the array, [ doubleToLongBits(priority), element ]
9194
private final long[] priorityElementPairs;
9295

93-
NeighborList(int elementCapacity) {
96+
NeighborList(int elementCapacity, NeighbourConsumer neighbourConsumer) {
9497
if (elementCapacity <= 0) {
9598
throw new IllegalArgumentException("Bound cannot be smaller than or equal to 0");
9699
}
97100

98101
this.elementCapacity = elementCapacity;
99102
this.priorityElementPairs = new long[elementCapacity * 2];
103+
this.neighbourConsumer = neighbourConsumer;
100104
}
101105

102106
public LongStream elements() {
@@ -125,6 +129,8 @@ long getAndFlagAsChecked(int index) {
125129
* This allows KNN to just add the return values together without having the check on each of them.
126130
*/
127131
public long add(long element, double priority, SplittableRandom random, double perturbationRate) {
132+
neighbourConsumer.offer(element, priority);
133+
128134
int insertIdx = 0;
129135
int currNumElementsWithPriority = elementCount * 2;
130136

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredNeighborFilterFactory.java renamed to algo/src/main/java/org/neo4j/gds/similarity/knn/NeighbourConsumer.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.similarity.filteredknn;
20+
package org.neo4j.gds.similarity.knn;
2121

22-
public interface FilteredNeighborFilterFactory {
23-
FilteredNeighborFilter create();
22+
/**
23+
* During KNN execution, this listener is offered every neighbour we encounter.
24+
*/
25+
public interface NeighbourConsumer {
26+
NeighbourConsumer devNull = (element, priority) -> { /* do nothing */ };
27+
28+
void offer(long element, double priority);
2429
}

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnNeighborFilterFactory.java renamed to algo/src/main/java/org/neo4j/gds/similarity/knn/NeighbourConsumers.java

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.similarity.filteredknn;
20+
package org.neo4j.gds.similarity.knn;
2121

22-
public class FilteredKnnNeighborFilterFactory implements FilteredNeighborFilterFactory {
23-
24-
private final long nodeCount;
25-
26-
public FilteredKnnNeighborFilterFactory(long nodeCount) {
27-
this.nodeCount = nodeCount;
28-
}
22+
/**
23+
* A holder for {@link org.neo4j.gds.similarity.knn.NeighbourConsumer}s. This instrument helps us extend KNN.
24+
*/
25+
public interface NeighbourConsumers {
26+
/**
27+
* A holder for sending data into the void, which is the default behaviour in regular KNN
28+
*/
29+
NeighbourConsumers devNull = nodeId -> (element, priority) -> { /* do nothing */ };
2930

30-
@Override
31-
public FilteredNeighborFilter create() {
32-
return new FilteredKnnNeighborFilter(nodeCount);
33-
}
31+
NeighbourConsumer get(long nodeId);
3432
}

algo/src/test/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighborsTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ public long size() {
8080
k,
8181
k,
8282
Partition.of(0, nodeCount),
83-
ProgressTracker.NULL_TRACKER
83+
ProgressTracker.NULL_TRACKER,
84+
NeighbourConsumers.devNull
8485
);
8586

8687
generateRandomNeighbors.run();

algo/src/test/java/org/neo4j/gds/similarity/knn/NeighborListTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class NeighborListTest {
3636
void shouldKeepMaxValuesOrderedByPriority() {
3737
long[] expected = {6L, 5L, 2L};
3838

39-
NeighborList queue = new NeighborList(3);
39+
NeighborList queue = new NeighborList(3, NeighbourConsumer.devNull);
4040
SplittableRandom splittableRandom = new SplittableRandom();
4141

4242
assertEquals(1, queue.add(0, Double.MIN_VALUE, splittableRandom, 0.0));
@@ -55,7 +55,7 @@ void shouldKeepMaxValuesOrderedByPriority() {
5555
void shouldLimitReturnWhenNotFull() {
5656
long[] expected = {6L, 5L, 4L};
5757

58-
NeighborList queue = new NeighborList(10);
58+
NeighborList queue = new NeighborList(10, NeighbourConsumer.devNull);
5959
SplittableRandom splittableRandom = new SplittableRandom();
6060

6161
assertEquals(1, queue.add(6, 6.0, splittableRandom, 0.0));
@@ -72,7 +72,7 @@ void insertEverything() {
7272
var elements = LongStream.range(0, nodeCount).boxed().collect(Collectors.toList());
7373
var rng = new SplittableRandom(1337L);
7474

75-
var queue = new NeighborList(nodeCount);
75+
var queue = new NeighborList(nodeCount, NeighbourConsumer.devNull);
7676

7777
elements.forEach(candidate -> queue.add(candidate, 1.0 / (1.0 + Math.abs(candidate - 2)), rng, 0.0));
7878

@@ -82,7 +82,7 @@ void insertEverything() {
8282
@Test
8383
void insertEveryThingTake2() {
8484
List<Long> elements = List.of(0L, 2L);
85-
var queue = new NeighborList(2);
85+
var queue = new NeighborList(2, NeighbourConsumer.devNull);
8686
var rng = new SplittableRandom(1337L);
8787

8888
elements.forEach(candidate -> queue.add(candidate, 1.0 / (1.0 + Math.abs(candidate - 1)), rng, 0.0));

algo/src/test/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighborsTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void name(
5050

5151
SplittableRandom rng = new SplittableRandom();
5252
allNeighbors.setAll(nodeId -> {
53-
var neighbors = new NeighborList(k);
53+
var neighbors = new NeighborList(k, NeighbourConsumer.devNull);
5454
LongStream.concat(
5555
LongStream.range(nodeId + 1, nodeCount),
5656
LongStream.range(0, nodeId)

0 commit comments

Comments
 (0)