1919 */
2020package org .neo4j .gds .similarity .filteredknn ;
2121
22+ import com .carrotsearch .hppc .IntObjectHashMap ;
23+ import com .carrotsearch .hppc .IntObjectMap ;
2224import org .junit .jupiter .api .Test ;
25+ import org .neo4j .gds .NodeLabel ;
2326import org .neo4j .gds .core .huge .DirectIdMap ;
27+ import org .neo4j .gds .core .loading .ArrayIdMapBuilder ;
28+ import org .neo4j .gds .core .loading .LabelInformation ;
2429
30+ import java .util .Collections ;
2531import java .util .List ;
2632
2733import static org .assertj .core .api .Assertions .assertThat ;
34+ import static org .assertj .core .api .Assertions .assertThatNoException ;
2835import static org .assertj .core .api .Assertions .assertThatThrownBy ;
2936
3037class NodeFilterTest {
@@ -43,12 +50,59 @@ void shouldFailToParseInvalidInput() {
4350 .hasMessage ("Invalid types in list. Expected Longs or Nodes but found [Double]" );
4451
4552 // String is valid as scalar but not in a list
46- // assertThatNoException().isThrownBy(() -> NodeFilter.create("foo", idMap)); // Not implemented yet
53+ assertThatNoException ().isThrownBy (() -> NodeFilter .create ("foo" , idMap ));
4754 assertThatThrownBy (() -> NodeFilter .create (List .of (validInput , "foo" ), idMap ))
4855 .isInstanceOf (IllegalArgumentException .class )
4956 .hasMessage ("Invalid types in list. Expected Longs or Nodes but found [String]" );
5057 }
5158
59+ @ Test
60+ void shouldFilterBasedOnLabel () {
61+ // set up an idMap with four nodes and two labels
62+ var labelOne = NodeLabel .of ("one" );
63+ var labelTwo = NodeLabel .of ("two" );
64+
65+ IntObjectMap <List <NodeLabel >> labelTokenNodeLabelMappings = new IntObjectHashMap <List <NodeLabel >>();
66+ labelTokenNodeLabelMappings .put (1 , Collections .singletonList (labelOne ));
67+ labelTokenNodeLabelMappings .put (2 , Collections .singletonList (labelTwo ));
68+
69+ var labelInformationBuilder = LabelInformation .builder (4 , labelTokenNodeLabelMappings );
70+ labelInformationBuilder .addNodeIdToLabel (labelOne , 0 );
71+ labelInformationBuilder .addNodeIdToLabel (labelTwo , 1 );
72+ labelInformationBuilder .addNodeIdToLabel (labelOne , 2 );
73+ labelInformationBuilder .addNodeIdToLabel (labelTwo , 3 );
74+
75+ var arrayIdMapBuilder = ArrayIdMapBuilder .of (4 );
76+ arrayIdMapBuilder .allocate (4 );
77+ var graphIds = arrayIdMapBuilder .array ();
78+ graphIds .set (0 , 0 );
79+ graphIds .set (1 , 1 );
80+ graphIds .set (2 , 2 );
81+ graphIds .set (3 , 3 );
82+ var arrayIdMap = arrayIdMapBuilder .build (labelInformationBuilder , 3 , 1 );
83+
84+ // test that the idMap is as expected
85+ assertThat (arrayIdMap .hasLabel (0 , labelOne )).isTrue ();
86+ assertThat (arrayIdMap .hasLabel (1 , labelTwo )).isTrue ();
87+ assertThat (arrayIdMap .hasLabel (2 , labelOne )).isTrue ();
88+ assertThat (arrayIdMap .hasLabel (3 , labelTwo )).isTrue ();
89+
90+ // create a node filter based on the idMap
91+ var nodeFilterOne = NodeFilter .create ("one" , arrayIdMap );
92+ var nodeFilterTwo = NodeFilter .create ("two" , arrayIdMap );
93+
94+ // test that the filter correctly filters based on the label
95+ assertThat (nodeFilterOne .test (0 )).isTrue ();
96+ assertThat (nodeFilterOne .test (1 )).isFalse ();
97+ assertThat (nodeFilterOne .test (2 )).isTrue ();
98+ assertThat (nodeFilterOne .test (3 )).isFalse ();
99+
100+ assertThat (nodeFilterTwo .test (0 )).isFalse ();
101+ assertThat (nodeFilterTwo .test (1 )).isTrue ();
102+ assertThat (nodeFilterTwo .test (2 )).isFalse ();
103+ assertThat (nodeFilterTwo .test (3 )).isTrue ();
104+ }
105+
52106 @ Test
53107 void shouldFilter () {
54108 var nodeFilter = NodeFilter .create (10L , new DirectIdMap (10 ));
0 commit comments