Skip to content

Commit b4d8424

Browse files
committed
show bug in source node filtering
1 parent 199ebaf commit b4d8424

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnResult.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ public LongStream neighborsOf(long nodeId) {
5050

5151
// http://www.flatmapthatshit.com/
5252
public Stream<SimilarityResult> streamSimilarityResult() {
53+
// [[],[],[]]
5354
var neighborList = neighborList();
5455
return Stream
5556
.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
5657
.flatMap(cursor -> IntStream.range(cursor.offset, cursor.limit)
57-
.filter(index -> sourceNodes().contains(index + cursor.base))
58+
// .filter(index -> sourceNodes().contains(index + cursor.base))
5859
.mapToObj(index -> cursor.array[index].similarityStream(index + cursor.base))
5960
.flatMap(Function.identity())
6061
);
@@ -65,7 +66,7 @@ public long totalSimilarityPairs() {
6566
return Stream
6667
.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity())
6768
.flatMapToLong(cursor -> IntStream.range(cursor.offset, cursor.limit)
68-
.filter(index -> sourceNodes().contains(index + cursor.base))
69+
// .filter(index -> sourceNodes().contains(index + cursor.base))
6970
.mapToLong(index -> cursor.array[index].size()))
7071
.sum();
7172
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
24+
import org.neo4j.gds.similarity.SimilarityResult;
25+
26+
import java.util.Arrays;
27+
import java.util.List;
28+
import java.util.SplittableRandom;
29+
import java.util.function.Function;
30+
import java.util.stream.Collectors;
31+
import java.util.stream.LongStream;
32+
import java.util.stream.Stream;
33+
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
36+
class FilteredKnnResultTest {
37+
38+
@Test
39+
void should() {
40+
var rng = new SplittableRandom();
41+
42+
var neighbors0 = new FilteredNeighborList(1);
43+
neighbors0.add(1, 0.3, rng, 0.0);
44+
45+
var neighbors1 = new FilteredNeighborList(1);
46+
neighbors1.add(0, 0.7, rng, 0.0);
47+
48+
var result = ImmutableFilteredKnnResult.of(
49+
HugeObjectArray.of(neighbors0, neighbors1),
50+
1,
51+
true,
52+
2,
53+
List.of()
54+
);
55+
56+
var neighborLists = result.neighborList();
57+
58+
// test1
59+
var other = Stream.concat(
60+
neighborLists.get(0).similarityStream(0),
61+
neighborLists.get(1).similarityStream(1)
62+
).collect(Collectors.toList());
63+
var actual = result.streamSimilarityResult().collect(Collectors.toList());
64+
assertThat(other).isEqualTo(actual);
65+
assertThat(actual).isEqualTo(other);
66+
67+
68+
// test2
69+
var resultString1 = result.streamSimilarityResult()
70+
.map(this::formatRecord)
71+
.collect(Collectors.toList())
72+
.toString();
73+
74+
var resultString2 = List.of(
75+
neighborLists.get(0).similarityStream(0).findFirst().map(this::formatRecord).get(),
76+
neighborLists.get(1).similarityStream(1).findFirst().map(this::formatRecord).get()
77+
).toString();
78+
79+
// those two should produce same result
80+
assertThat(resultString1).isEqualTo(resultString2);
81+
82+
83+
// test3
84+
var resultString3 = List.of(
85+
formatNeighborStream(result.neighborsOf(0)),
86+
formatNeighborStream(result.neighborsOf(1))
87+
).toString();
88+
89+
var resultString4 = List.of(
90+
formatNeighborStream(result.streamSimilarityResult().filter(sr -> sr.node1 == 0).mapToLong(sr -> sr.node2)),
91+
formatNeighborStream(result.streamSimilarityResult().filter(sr -> sr.node1 == 1).mapToLong(sr -> sr.node2))
92+
).toString();
93+
94+
var resultString5 = List.of(
95+
formatNeighborStream(neighborLists.get(0).similarityStream(0).mapToLong(sr -> sr.node2)),
96+
formatNeighborStream(neighborLists.get(1).similarityStream(1).mapToLong(sr -> sr.node2))
97+
).toString();
98+
99+
assertThat(resultString3).isEqualTo(resultString4).isEqualTo(resultString5);
100+
}
101+
102+
private String formatNeighborStream(LongStream stream) {
103+
return Arrays.toString(stream.toArray());
104+
}
105+
106+
private String formatRecord(SimilarityResult sr) {
107+
return String.format("%d,%d %f", sr.node1, sr.node2, sr.similarity);
108+
}
109+
110+
}

0 commit comments

Comments
 (0)