Skip to content

Commit d05769c

Browse files
committed
More types and less null checking
1 parent 2494d0d commit d05769c

File tree

13 files changed

+363
-216
lines changed

13 files changed

+363
-216
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnBaseConfig.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,22 @@
2828
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
2929

3030
import java.util.Collection;
31-
import java.util.List;
3231

3332
@ValueClass
3433
@Configuration
3534
@SuppressWarnings("immutables:subtype")
3635
public interface FilteredKnnBaseConfig extends KnnBaseConfig {
3736

3837
@Value.Default
39-
@Configuration.ConvertWith("org.neo4j.gds.similarity.filteredknn.NodeFilterSpec#create")
38+
@Configuration.ConvertWith("org.neo4j.gds.similarity.filteredknn.NodeFilterSpecFactory#create")
4039
default NodeFilterSpec sourceNodeFilter() {
41-
return NodeFilterSpec.create(List.of());
40+
return NodeFilterSpec.noOp;
4241
}
4342

4443
@Value.Default
45-
@Configuration.ConvertWith("org.neo4j.gds.similarity.filteredknn.NodeFilterSpec#create")
44+
@Configuration.ConvertWith("org.neo4j.gds.similarity.filteredknn.NodeFilterSpecFactory#create")
4645
default NodeFilterSpec targetNodeFilter() {
47-
return NodeFilterSpec.create(List.of());
46+
return NodeFilterSpec.noOp;
4847
}
4948

5049
@Configuration.GraphStoreValidationCheck

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnResult.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public long nodePairsConsidered() {
100100

101101
@Override
102102
public NodeFilter sourceNodeFilter() {
103-
return NodeFilter.noOp();
103+
return NodeFilter.noOp;
104104
}
105105

106106
@Override
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import org.neo4j.gds.NodeLabel;
23+
import org.neo4j.gds.api.IdMap;
24+
25+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
26+
27+
public final class LabelNodeFilter implements NodeFilter {
28+
29+
public static LabelNodeFilter create(String labelString, IdMap idMap) {
30+
NodeLabel label = null;
31+
for (var existingLabel : idMap.availableNodeLabels()) {
32+
if (existingLabel.name.equalsIgnoreCase(labelString)) {
33+
label = existingLabel;
34+
}
35+
}
36+
if (null == label) {
37+
throw new IllegalArgumentException(formatWithLocale(
38+
"The label `%s` does not exist in the graph",
39+
labelString
40+
));
41+
}
42+
return new LabelNodeFilter(label, idMap);
43+
}
44+
45+
private final NodeLabel label;
46+
private final IdMap idMap;
47+
48+
private LabelNodeFilter(NodeLabel label, IdMap idMap) {
49+
this.label = label;
50+
this.idMap = idMap;
51+
}
52+
53+
@Override
54+
public boolean test(long nodeId) {
55+
return idMap.hasLabel(nodeId, label);
56+
}
57+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import org.neo4j.gds.api.IdMap;
23+
24+
public class LabelNodeFilterSpec implements NodeFilterSpec {
25+
26+
private final String labelString;
27+
28+
LabelNodeFilterSpec(String labelString) {
29+
this.labelString = labelString;
30+
}
31+
32+
@Override
33+
public NodeFilter toNodeFilter(IdMap idMap) {
34+
return LabelNodeFilter.create(labelString, idMap);
35+
}
36+
}

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/NodeFilter.java

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,74 +19,8 @@
1919
*/
2020
package org.neo4j.gds.similarity.filteredknn;
2121

22-
import org.neo4j.gds.NodeLabel;
23-
import org.neo4j.gds.api.IdMap;
24-
25-
import java.util.Set;
2622
import java.util.function.LongPredicate;
27-
import java.util.stream.Collectors;
28-
29-
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
30-
31-
public class NodeFilter implements LongPredicate {
32-
33-
public static NodeFilter create(Set<Long> externalNodeIds, IdMap idMap) {
34-
var mappedNodeIds = externalNodeIds.stream().map(idMap::toMappedNodeId).collect(Collectors.toSet());
35-
return new NodeFilter(mappedNodeIds);
36-
}
37-
38-
public static NodeFilter create(String labelString, IdMap idMap) {
39-
NodeLabel label = null;
40-
for (var existingLabel : idMap.availableNodeLabels()) {
41-
if (existingLabel.name.equalsIgnoreCase(labelString)) {
42-
label = existingLabel;
43-
}
44-
}
45-
if (null == label) {
46-
throw new IllegalArgumentException(formatWithLocale(
47-
"The label `%s` does not exist in the graph",
48-
labelString
49-
));
50-
}
51-
return new NodeFilter(label, idMap);
52-
}
53-
54-
public static NodeFilter noOp() {
55-
return new NoOpNodeFilter(Set.of());
56-
}
57-
58-
private final Set<Long> nodeIds;
59-
private final NodeLabel label;
60-
private final IdMap idMap;
61-
62-
private NodeFilter(Set<Long> nodeIds) {
63-
this.nodeIds = nodeIds;
64-
this.label = null;
65-
this.idMap = null;
66-
}
67-
68-
private NodeFilter(NodeLabel label, IdMap idMap) {
69-
this.nodeIds = null;
70-
this.label = label;
71-
this.idMap = idMap;
72-
}
73-
74-
@Override
75-
public boolean test(long nodeId) {
76-
return null == nodeIds
77-
? idMap.hasLabel(nodeId, label)
78-
: nodeIds.contains(nodeId);
79-
}
80-
81-
private static class NoOpNodeFilter extends NodeFilter {
82-
83-
NoOpNodeFilter(Set<Long> nodeIds) {
84-
super(nodeIds);
85-
}
8623

87-
@Override
88-
public boolean test(long nodeId) {
89-
return true;
90-
}
91-
}
24+
interface NodeFilter extends LongPredicate {
25+
NodeFilter noOp = (nodeId) -> true;
9226
}

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/NodeFilterSpec.java

Lines changed: 10 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -20,124 +20,20 @@
2020
package org.neo4j.gds.similarity.filteredknn;
2121

2222
import org.neo4j.gds.api.IdMap;
23-
import org.neo4j.graphdb.Node;
24-
25-
import java.util.ArrayList;
26-
import java.util.HashSet;
27-
import java.util.List;
28-
import java.util.Set;
29-
30-
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
3123

3224
/**
33-
* This class serves to do as much parsing and validation as possible in the UI for creating a {@link NodeFilter}.
25+
* A {@code NodeFilterSpec} is a partially constructed {@link NodeFilter}. Because we cannot fully construct the
26+
* {@link NodeFilter} from user inputs alone, we parse and validate what we can, and prepare for completing the
27+
* construction once the rest is available.
3428
*
35-
* We can
36-
* <ul>
37-
* <li>normalize {@code Long}, {@code List<Long>}, {@code Node} and {@code List<Node>} to {@code Set<Long>}</li>
38-
* <li>store the normalized {@code Set<Long>}, or the label {@code String}, as the case may be.</li>
39-
* </ul>
40-
* But we cannot
41-
* <ul>
42-
* <li>validated that the nodes or label exist in the graph.</li>
43-
* <li>translate node ids from Neo4j id space to the internal id space.</li>
44-
* </ul>
29+
* There are two types of node filters: ones based on a list of node IDs, and ones based on labels.
30+
* There are therefore two types of node filter specs, accordingly.
4531
*
46-
* The latter two have to happen later, when the {@link NodeFilterSpec} is turned into a {@link NodeFilter}.
32+
* The spec is created using {@link NodeFilterSpecFactory#create(Object)} and the {@link NodeFilter} is then created
33+
* using {@link NodeFilterSpec#toNodeFilter(org.neo4j.gds.api.IdMap)}.
4734
*/
48-
public class NodeFilterSpec {
49-
50-
public static NodeFilterSpec create(Object input) {
51-
if (input instanceof NodeFilterSpec) {
52-
return (NodeFilterSpec) input;
53-
}
54-
55-
if (input instanceof String) {
56-
// parse as label
57-
return new NodeFilterSpec((String) input);
58-
}
59-
60-
Set<Long> nodeIds = null;
61-
62-
if (input instanceof List) {
63-
nodeIds = parseFromList((List) input);
64-
}
65-
66-
if (input instanceof Long) {
67-
nodeIds = parseFromLong((Long) input);
68-
}
69-
70-
if (input instanceof Node) {
71-
nodeIds = parseFromNode((Node) input);
72-
}
73-
74-
if (nodeIds == null) {
75-
throw new IllegalArgumentException(
76-
formatWithLocale("Invalid scalar type. Expected Long or Node but found: %s", input.getClass().getSimpleName())
77-
);
78-
}
79-
80-
return new NodeFilterSpec(nodeIds);
81-
}
82-
83-
private static Set<Long> parseFromLong(Long input) {
84-
Set<Long> nodeIds = new HashSet<>();
85-
nodeIds.add(input);
86-
return nodeIds;
87-
}
88-
89-
private static Set<Long> parseFromNode(Node input) {
90-
Set<Long> nodeIds = new HashSet<>();
91-
nodeIds.add(input.getId());
92-
return nodeIds;
93-
}
94-
95-
private static Set<Long> parseFromList(List input) {
96-
Set<Long> nodeIds = new HashSet<>();
97-
List<String> badTypes = new ArrayList<>();
98-
input.forEach(o -> {
99-
if (o instanceof Long) {
100-
nodeIds.add((Long) o);
101-
} else if (o instanceof Node) {
102-
nodeIds.add(((Node) o).getId());
103-
} else {
104-
badTypes.add(o.getClass().getSimpleName());
105-
}
106-
});
107-
108-
if (badTypes.isEmpty()) {
109-
return nodeIds;
110-
}
111-
112-
throw new IllegalArgumentException(formatWithLocale(
113-
"Invalid types in list. Expected Longs or Nodes but found %s",
114-
badTypes
115-
));
116-
}
117-
118-
private final Set<Long> nodeIds;
119-
private final String labelString;
120-
121-
NodeFilterSpec(Set<Long> nodeIds) {
122-
this.nodeIds = nodeIds;
123-
this.labelString = null;
124-
}
125-
126-
NodeFilterSpec(String labelString) {
127-
this.nodeIds = null;
128-
this.labelString = labelString;
129-
}
35+
interface NodeFilterSpec {
36+
NodeFilter toNodeFilter(IdMap idMap);
13037

131-
NodeFilter toNodeFilter(IdMap idMap) {
132-
if (nodeIds != null) {
133-
if (nodeIds.isEmpty()) {
134-
return NodeFilter.noOp();
135-
}
136-
return NodeFilter.create(nodeIds, idMap);
137-
}
138-
if (labelString != null) {
139-
return NodeFilter.create(labelString, idMap);
140-
}
141-
throw new IllegalStateException("This object is broken. This should not happen, says Jonatan.");
142-
}
38+
NodeFilterSpec noOp = (idMap) -> NodeFilter.noOp;
14339
}

0 commit comments

Comments
 (0)