Skip to content

Commit 82681b2

Browse files
committed
Wrap target nodes in a predicate type
1 parent b4d8424 commit 82681b2

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public class FilteredKnn extends Algorithm<FilteredKnnResult> {
5858
private final double sampleRate;
5959
private final double similarityCutoff;
6060
private final List<Long> sourceNodes;
61-
private final List<Long> targetNodes;
61+
private final TargetNodePredicate targetNodePredicate;
6262
private final int topK;
6363

6464
private long nodePairsConsidered;
@@ -79,6 +79,7 @@ public static FilteredKnn create(
7979
var splittableRandom = getSplittableRandom(config.randomSeed());
8080
var sourceNodes = config.sourceNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
8181
var targetNodes = config.targetNodeFilter().stream().map(graph::toMappedNodeId).collect(Collectors.toList());
82+
var targetNodePredicate = new TargetNodePredicate(targetNodes);
8283
var samplerSupplier = samplerSupplier(graph, config);
8384
return new FilteredKnn(
8485
context.progressTracker(),
@@ -98,7 +99,7 @@ public static FilteredKnn create(
9899
config.sampleRate(),
99100
config.similarityCutoff(),
100101
sourceNodes,
101-
targetNodes,
102+
targetNodePredicate,
102103
config.topK()
103104
);
104105
}
@@ -142,7 +143,7 @@ private FilteredKnn(
142143
double sampleRate,
143144
double similarityCutoff,
144145
List<Long> sourceNodes,
145-
List<Long> targetNodes,
146+
TargetNodePredicate targetNodePredicate,
146147
int topK
147148

148149
) {
@@ -163,7 +164,7 @@ private FilteredKnn(
163164
this.sampleRate = sampleRate;
164165
this.similarityCutoff = similarityCutoff;
165166
this.sourceNodes = sourceNodes;
166-
this.targetNodes = targetNodes;
167+
this.targetNodePredicate = targetNodePredicate;
167168
this.topK = topK;
168169
}
169170

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import java.util.List;
23+
24+
public class TargetNodePredicate {
25+
private final List<Long> nodeIds;
26+
27+
public TargetNodePredicate(List<Long> nodeIds) {this.nodeIds = nodeIds;}
28+
29+
public boolean test(long nodeId) {
30+
return nodeIds.contains(nodeId);
31+
}
32+
}

0 commit comments

Comments
 (0)