Skip to content

Commit 2494d0d

Browse files
committed
Create what we can of node filters eariler
1 parent 0c84d82 commit 2494d0d

File tree

6 files changed

+180
-88
lines changed

6 files changed

+180
-88
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ public static FilteredKnn create(
7575
FilteredNeighborFilterFactory neighborFilterFactory
7676
) {
7777
var splittableRandom = getSplittableRandom(config.randomSeed());
78-
var sourceNodeFilter = NodeFilter.create(config.sourceNodeFilter(), graph);
79-
var targetNodeFilter = NodeFilter.create(config.targetNodeFilter(), graph);
78+
var sourceNodeFilter = config.sourceNodeFilter().toNodeFilter(graph);
79+
var targetNodeFilter = config.targetNodeFilter().toNodeFilter(graph);
8080
var samplerSupplier = samplerSupplier(graph, config);
8181
return new FilteredKnn(
8282
context.progressTracker(),

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnBaseConfig.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,23 @@
2828
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
2929

3030
import java.util.Collection;
31+
import java.util.List;
3132

3233
@ValueClass
3334
@Configuration
3435
@SuppressWarnings("immutables:subtype")
3536
public interface FilteredKnnBaseConfig extends KnnBaseConfig {
3637

3738
@Value.Default
38-
default Object sourceNodeFilter() {
39-
return NodeFilter.noOp();
39+
@Configuration.ConvertWith("org.neo4j.gds.similarity.filteredknn.NodeFilterSpec#create")
40+
default NodeFilterSpec sourceNodeFilter() {
41+
return NodeFilterSpec.create(List.of());
4042
}
4143

4244
@Value.Default
43-
default Object targetNodeFilter() {
44-
return NodeFilter.noOp();
45+
@Configuration.ConvertWith("org.neo4j.gds.similarity.filteredknn.NodeFilterSpec#create")
46+
default NodeFilterSpec targetNodeFilter() {
47+
return NodeFilterSpec.create(List.of());
4548
}
4649

4750
@Configuration.GraphStoreValidationCheck

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/NodeFilter.java

Lines changed: 21 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -21,114 +21,60 @@
2121

2222
import org.neo4j.gds.NodeLabel;
2323
import org.neo4j.gds.api.IdMap;
24-
import org.neo4j.graphdb.Node;
2524

26-
import java.util.ArrayList;
27-
import java.util.HashSet;
28-
import java.util.List;
2925
import java.util.Set;
3026
import java.util.function.LongPredicate;
27+
import java.util.stream.Collectors;
3128

3229
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
3330

3431
public class NodeFilter implements LongPredicate {
3532

36-
public static NodeFilter create(Object input, IdMap idMap) {
37-
if (input instanceof NodeFilter) {
38-
return (NodeFilter) input;
39-
}
40-
41-
if (input instanceof String) {
42-
// parse as label
43-
return parseFromString((String) input, idMap);
44-
}
45-
46-
Set<Long> nodeIds = null;
47-
48-
if (input instanceof List) {
49-
nodeIds = parseFromList((List) input, idMap);
50-
}
51-
52-
if (input instanceof Long) {
53-
nodeIds = parseFromLong((Long) input, idMap);
54-
}
55-
56-
if (input instanceof Node) {
57-
nodeIds = parseFromNode((Node) input, idMap);
58-
}
59-
60-
if (nodeIds == null) {
61-
throw new IllegalArgumentException(
62-
String.format("Invalid scalar type. Expected Long or Node but found: %s", input.getClass().getSimpleName())
63-
);
64-
}
65-
66-
return new NodeFilter(nodeIds);
67-
}
68-
69-
private static NodeFilter parseFromString(String input, IdMap idMap) {
70-
return new NodeFilter(input, idMap);
71-
}
72-
73-
private static Set<Long> parseFromLong(Long input, IdMap idMap) {
74-
Set<Long> nodeIds = new HashSet<>();
75-
nodeIds.add(idMap.toMappedNodeId(input));
76-
return nodeIds;
33+
public static NodeFilter create(Set<Long> externalNodeIds, IdMap idMap) {
34+
var mappedNodeIds = externalNodeIds.stream().map(idMap::toMappedNodeId).collect(Collectors.toSet());
35+
return new NodeFilter(mappedNodeIds);
7736
}
7837

79-
private static Set<Long> parseFromNode(Node input, IdMap idMap) {
80-
Set<Long> nodeIds = new HashSet<>();
81-
nodeIds.add(idMap.toMappedNodeId(input.getId()));
82-
return nodeIds;
83-
}
84-
85-
private static Set<Long> parseFromList(List input, IdMap idMap) {
86-
Set<Long> nodeIds = new HashSet<>();
87-
List<String> badTypes = new ArrayList<>();
88-
input.forEach(o -> {
89-
if (o instanceof Long) {
90-
nodeIds.add(idMap.toMappedNodeId((Long) o));
91-
} else if (o instanceof Node) {
92-
nodeIds.add(idMap.toMappedNodeId(((Node) o).getId()));
93-
} else {
94-
badTypes.add(o.getClass().getSimpleName());
38+
public static NodeFilter create(String labelString, IdMap idMap) {
39+
NodeLabel label = null;
40+
for (var existingLabel : idMap.availableNodeLabels()) {
41+
if (existingLabel.name.equalsIgnoreCase(labelString)) {
42+
label = existingLabel;
9543
}
96-
});
97-
98-
if (badTypes.isEmpty()) {
99-
return nodeIds;
10044
}
101-
102-
throw new IllegalArgumentException(formatWithLocale(
103-
"Invalid types in list. Expected Longs or Nodes but found %s",
104-
badTypes
105-
));
45+
if (null == label) {
46+
throw new IllegalArgumentException(formatWithLocale(
47+
"The label `%s` does not exist in the graph",
48+
labelString
49+
));
50+
}
51+
return new NodeFilter(label, idMap);
10652
}
10753

10854
public static NodeFilter noOp() {
10955
return new NoOpNodeFilter(Set.of());
11056
}
11157

11258
private final Set<Long> nodeIds;
113-
private final NodeLabel nodeLabel;
59+
private final NodeLabel label;
11460
private final IdMap idMap;
11561

11662
private NodeFilter(Set<Long> nodeIds) {
11763
this.nodeIds = nodeIds;
118-
this.nodeLabel = null;
64+
this.label = null;
11965
this.idMap = null;
12066
}
12167

122-
private NodeFilter(String labelString, IdMap idMap) {
68+
private NodeFilter(NodeLabel label, IdMap idMap) {
12369
this.nodeIds = null;
124-
this.nodeLabel = NodeLabel.of(labelString);
70+
this.label = label;
12571
this.idMap = idMap;
12672
}
12773

12874
@Override
12975
public boolean test(long nodeId) {
13076
return null == nodeIds
131-
? idMap.hasLabel(nodeId, nodeLabel)
77+
? idMap.hasLabel(nodeId, label)
13278
: nodeIds.contains(nodeId);
13379
}
13480

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import org.neo4j.gds.api.IdMap;
23+
import org.neo4j.graphdb.Node;
24+
25+
import java.util.ArrayList;
26+
import java.util.HashSet;
27+
import java.util.List;
28+
import java.util.Set;
29+
30+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
31+
32+
/**
33+
* This class serves to do as much parsing and validation as possible in the UI for creating a {@link NodeFilter}.
34+
*
35+
* We can
36+
* <ul>
37+
* <li>normalize {@code Long}, {@code List<Long>}, {@code Node} and {@code List<Node>} to {@code Set<Long>}</li>
38+
* <li>store the normalized {@code Set<Long>}, or the label {@code String}, as the case may be.</li>
39+
* </ul>
40+
* But we cannot
41+
* <ul>
42+
* <li>validated that the nodes or label exist in the graph.</li>
43+
* <li>translate node ids from Neo4j id space to the internal id space.</li>
44+
* </ul>
45+
*
46+
* The latter two have to happen later, when the {@link NodeFilterSpec} is turned into a {@link NodeFilter}.
47+
*/
48+
public class NodeFilterSpec {
49+
50+
public static NodeFilterSpec create(Object input) {
51+
if (input instanceof NodeFilterSpec) {
52+
return (NodeFilterSpec) input;
53+
}
54+
55+
if (input instanceof String) {
56+
// parse as label
57+
return new NodeFilterSpec((String) input);
58+
}
59+
60+
Set<Long> nodeIds = null;
61+
62+
if (input instanceof List) {
63+
nodeIds = parseFromList((List) input);
64+
}
65+
66+
if (input instanceof Long) {
67+
nodeIds = parseFromLong((Long) input);
68+
}
69+
70+
if (input instanceof Node) {
71+
nodeIds = parseFromNode((Node) input);
72+
}
73+
74+
if (nodeIds == null) {
75+
throw new IllegalArgumentException(
76+
formatWithLocale("Invalid scalar type. Expected Long or Node but found: %s", input.getClass().getSimpleName())
77+
);
78+
}
79+
80+
return new NodeFilterSpec(nodeIds);
81+
}
82+
83+
private static Set<Long> parseFromLong(Long input) {
84+
Set<Long> nodeIds = new HashSet<>();
85+
nodeIds.add(input);
86+
return nodeIds;
87+
}
88+
89+
private static Set<Long> parseFromNode(Node input) {
90+
Set<Long> nodeIds = new HashSet<>();
91+
nodeIds.add(input.getId());
92+
return nodeIds;
93+
}
94+
95+
private static Set<Long> parseFromList(List input) {
96+
Set<Long> nodeIds = new HashSet<>();
97+
List<String> badTypes = new ArrayList<>();
98+
input.forEach(o -> {
99+
if (o instanceof Long) {
100+
nodeIds.add((Long) o);
101+
} else if (o instanceof Node) {
102+
nodeIds.add(((Node) o).getId());
103+
} else {
104+
badTypes.add(o.getClass().getSimpleName());
105+
}
106+
});
107+
108+
if (badTypes.isEmpty()) {
109+
return nodeIds;
110+
}
111+
112+
throw new IllegalArgumentException(formatWithLocale(
113+
"Invalid types in list. Expected Longs or Nodes but found %s",
114+
badTypes
115+
));
116+
}
117+
118+
private final Set<Long> nodeIds;
119+
private final String labelString;
120+
121+
NodeFilterSpec(Set<Long> nodeIds) {
122+
this.nodeIds = nodeIds;
123+
this.labelString = null;
124+
}
125+
126+
NodeFilterSpec(String labelString) {
127+
this.nodeIds = null;
128+
this.labelString = labelString;
129+
}
130+
131+
NodeFilter toNodeFilter(IdMap idMap) {
132+
if (nodeIds != null) {
133+
if (nodeIds.isEmpty()) {
134+
return NodeFilter.noOp();
135+
}
136+
return NodeFilter.create(nodeIds, idMap);
137+
}
138+
if (labelString != null) {
139+
return NodeFilter.create(labelString, idMap);
140+
}
141+
throw new IllegalStateException("This object is broken. This should not happen, says Jonatan.");
142+
}
143+
}

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnIdMappingTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void shouldIdMapTheSourceNodeFilter() {
6262
.maxIterations(1)
6363
.randomSeed(20L)
6464
.concurrency(1)
65-
.sourceNodeFilter(NodeFilter.create(lowestNeoId, graph))
65+
.sourceNodeFilter(NodeFilterSpec.create(lowestNeoId))
6666
.build();
6767
var knn = FilteredKnn.createWithDefaults(graph, config, FilteredKnnContext.empty());
6868

algo/src/test/java/org/neo4j/gds/similarity/filteredknn/NodeFilterTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import java.util.Collections;
3131
import java.util.List;
32+
import java.util.Set;
3233

3334
import static org.assertj.core.api.Assertions.assertThat;
3435
import static org.assertj.core.api.Assertions.assertThatNoException;
@@ -39,19 +40,18 @@ class NodeFilterTest {
3940
@Test
4041
void shouldFailToParseInvalidInput() {
4142
var validInput = 1L;
42-
var idMap = new DirectIdMap(10);
4343

4444
// double is invalid
45-
assertThatThrownBy(() -> NodeFilter.create(1.0, idMap))
45+
assertThatThrownBy(() -> NodeFilterSpec.create(1.0))
4646
.isInstanceOf(IllegalArgumentException.class)
4747
.hasMessage("Invalid scalar type. Expected Long or Node but found: Double");
48-
assertThatThrownBy(() -> NodeFilter.create(List.of(validInput, 1.0), idMap))
48+
assertThatThrownBy(() -> NodeFilterSpec.create(List.of(validInput, 1.0)))
4949
.isInstanceOf(IllegalArgumentException.class)
5050
.hasMessage("Invalid types in list. Expected Longs or Nodes but found [Double]");
5151

5252
// String is valid as scalar but not in a list
53-
assertThatNoException().isThrownBy(() -> NodeFilter.create("foo", idMap));
54-
assertThatThrownBy(() -> NodeFilter.create(List.of(validInput, "foo"), idMap))
53+
assertThatNoException().isThrownBy(() -> NodeFilterSpec.create("foo"));
54+
assertThatThrownBy(() -> NodeFilterSpec.create(List.of(validInput, "foo")))
5555
.isInstanceOf(IllegalArgumentException.class)
5656
.hasMessage("Invalid types in list. Expected Longs or Nodes but found [String]");
5757
}
@@ -105,7 +105,7 @@ void shouldFilterBasedOnLabel() {
105105

106106
@Test
107107
void shouldFilter() {
108-
var nodeFilter = NodeFilter.create(10L, new DirectIdMap(10));
108+
var nodeFilter = NodeFilter.create(Set.of(10L), new DirectIdMap(10));
109109
assertThat(nodeFilter.test(10)).isTrue();
110110
assertThat(nodeFilter.test(1)).isFalse();
111111
}

0 commit comments

Comments
 (0)