Skip to content

Commit 0c84d82

Browse files
committed
First implementation of filter based on label
1 parent 9e2198a commit 0c84d82

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/NodeFilter.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.similarity.filteredknn;
2121

22+
import org.neo4j.gds.NodeLabel;
2223
import org.neo4j.gds.api.IdMap;
2324
import org.neo4j.graphdb.Node;
2425

@@ -39,7 +40,7 @@ public static NodeFilter create(Object input, IdMap idMap) {
3940

4041
if (input instanceof String) {
4142
// parse as label
42-
return parseFromString((String) input);
43+
return parseFromString((String) input, idMap);
4344
}
4445

4546
Set<Long> nodeIds = null;
@@ -65,8 +66,8 @@ public static NodeFilter create(Object input, IdMap idMap) {
6566
return new NodeFilter(nodeIds);
6667
}
6768

68-
private static NodeFilter parseFromString(String input) {
69-
throw new UnsupportedOperationException("Not implemented yet");
69+
private static NodeFilter parseFromString(String input, IdMap idMap) {
70+
return new NodeFilter(input, idMap);
7071
}
7172

7273
private static Set<Long> parseFromLong(Long input, IdMap idMap) {
@@ -109,14 +110,26 @@ public static NodeFilter noOp() {
109110
}
110111

111112
private final Set<Long> nodeIds;
113+
private final NodeLabel nodeLabel;
114+
private final IdMap idMap;
112115

113116
private NodeFilter(Set<Long> nodeIds) {
114117
this.nodeIds = nodeIds;
118+
this.nodeLabel = null;
119+
this.idMap = null;
120+
}
121+
122+
private NodeFilter(String labelString, IdMap idMap) {
123+
this.nodeIds = null;
124+
this.nodeLabel = NodeLabel.of(labelString);
125+
this.idMap = idMap;
115126
}
116127

117128
@Override
118129
public boolean test(long nodeId) {
119-
return nodeIds.contains(nodeId);
130+
return null == nodeIds
131+
? idMap.hasLabel(nodeId, nodeLabel)
132+
: nodeIds.contains(nodeId);
120133
}
121134

122135
private static class NoOpNodeFilter extends NodeFilter {

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/NodeFilterTest.java

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,19 @@
1919
*/
2020
package org.neo4j.gds.similarity.filteredknn;
2121

22+
import com.carrotsearch.hppc.IntObjectHashMap;
23+
import com.carrotsearch.hppc.IntObjectMap;
2224
import org.junit.jupiter.api.Test;
25+
import org.neo4j.gds.NodeLabel;
2326
import org.neo4j.gds.core.huge.DirectIdMap;
27+
import org.neo4j.gds.core.loading.ArrayIdMapBuilder;
28+
import org.neo4j.gds.core.loading.LabelInformation;
2429

30+
import java.util.Collections;
2531
import java.util.List;
2632

2733
import static org.assertj.core.api.Assertions.assertThat;
34+
import static org.assertj.core.api.Assertions.assertThatNoException;
2835
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2936

3037
class NodeFilterTest {
@@ -43,12 +50,59 @@ void shouldFailToParseInvalidInput() {
4350
.hasMessage("Invalid types in list. Expected Longs or Nodes but found [Double]");
4451

4552
// String is valid as scalar but not in a list
46-
// assertThatNoException().isThrownBy(() -> NodeFilter.create("foo", idMap)); // Not implemented yet
53+
assertThatNoException().isThrownBy(() -> NodeFilter.create("foo", idMap));
4754
assertThatThrownBy(() -> NodeFilter.create(List.of(validInput, "foo"), idMap))
4855
.isInstanceOf(IllegalArgumentException.class)
4956
.hasMessage("Invalid types in list. Expected Longs or Nodes but found [String]");
5057
}
5158

59+
@Test
60+
void shouldFilterBasedOnLabel() {
61+
// set up an idMap with four nodes and two labels
62+
var labelOne = NodeLabel.of("one");
63+
var labelTwo = NodeLabel.of("two");
64+
65+
IntObjectMap<List<NodeLabel>> labelTokenNodeLabelMappings = new IntObjectHashMap<List<NodeLabel>>();
66+
labelTokenNodeLabelMappings.put(1, Collections.singletonList(labelOne));
67+
labelTokenNodeLabelMappings.put(2, Collections.singletonList(labelTwo));
68+
69+
var labelInformationBuilder = LabelInformation.builder(4, labelTokenNodeLabelMappings);
70+
labelInformationBuilder.addNodeIdToLabel(labelOne, 0);
71+
labelInformationBuilder.addNodeIdToLabel(labelTwo, 1);
72+
labelInformationBuilder.addNodeIdToLabel(labelOne, 2);
73+
labelInformationBuilder.addNodeIdToLabel(labelTwo, 3);
74+
75+
var arrayIdMapBuilder = ArrayIdMapBuilder.of(4);
76+
arrayIdMapBuilder.allocate(4);
77+
var graphIds = arrayIdMapBuilder.array();
78+
graphIds.set(0, 0);
79+
graphIds.set(1, 1);
80+
graphIds.set(2, 2);
81+
graphIds.set(3, 3);
82+
var arrayIdMap = arrayIdMapBuilder.build(labelInformationBuilder, 3, 1);
83+
84+
// test that the idMap is as expected
85+
assertThat(arrayIdMap.hasLabel(0, labelOne)).isTrue();
86+
assertThat(arrayIdMap.hasLabel(1, labelTwo)).isTrue();
87+
assertThat(arrayIdMap.hasLabel(2, labelOne)).isTrue();
88+
assertThat(arrayIdMap.hasLabel(3, labelTwo)).isTrue();
89+
90+
// create a node filter based on the idMap
91+
var nodeFilterOne = NodeFilter.create("one", arrayIdMap);
92+
var nodeFilterTwo = NodeFilter.create("two", arrayIdMap);
93+
94+
// test that the filter correctly filters based on the label
95+
assertThat(nodeFilterOne.test(0)).isTrue();
96+
assertThat(nodeFilterOne.test(1)).isFalse();
97+
assertThat(nodeFilterOne.test(2)).isTrue();
98+
assertThat(nodeFilterOne.test(3)).isFalse();
99+
100+
assertThat(nodeFilterTwo.test(0)).isFalse();
101+
assertThat(nodeFilterTwo.test(1)).isTrue();
102+
assertThat(nodeFilterTwo.test(2)).isFalse();
103+
assertThat(nodeFilterTwo.test(3)).isTrue();
104+
}
105+
52106
@Test
53107
void shouldFilter() {
54108
var nodeFilter = NodeFilter.create(10L, new DirectIdMap(10));

0 commit comments

Comments
 (0)