Skip to content

Commit bd48bd4

Browse files
committed
Move out internal classes
1 parent d5bdbb6 commit bd48bd4

File tree

2 files changed

+265
-234
lines changed

2 files changed

+265
-234
lines changed
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filteredknn;
21+
22+
import com.carrotsearch.hppc.LongArrayList;
23+
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
24+
import org.neo4j.gds.core.utils.partition.Partition;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
26+
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
27+
28+
import java.util.SplittableRandom;
29+
30+
final class FilteredJoinNeighbors implements Runnable {
31+
private final SplittableRandom random;
32+
private final SimilarityComputer computer;
33+
private final FilteredNeighborFilter neighborFilter;
34+
private final HugeObjectArray<FilteredNeighborList> neighbors;
35+
private final HugeObjectArray<LongArrayList> allOldNeighbors;
36+
private final HugeObjectArray<LongArrayList> allNewNeighbors;
37+
private final HugeObjectArray<LongArrayList> allReverseOldNeighbors;
38+
private final HugeObjectArray<LongArrayList> allReverseNewNeighbors;
39+
private final long n;
40+
private final int k;
41+
private final int sampledK;
42+
private final int randomJoins;
43+
private final ProgressTracker progressTracker;
44+
private long updateCount;
45+
private final Partition partition;
46+
private long nodePairsConsidered;
47+
private final double perturbationRate;
48+
49+
FilteredJoinNeighbors(
50+
SplittableRandom random,
51+
SimilarityComputer computer,
52+
FilteredNeighborFilter neighborFilter,
53+
HugeObjectArray<FilteredNeighborList> neighbors,
54+
HugeObjectArray<LongArrayList> allOldNeighbors,
55+
HugeObjectArray<LongArrayList> allNewNeighbors,
56+
HugeObjectArray<LongArrayList> allReverseOldNeighbors,
57+
HugeObjectArray<LongArrayList> allReverseNewNeighbors,
58+
long n,
59+
int k,
60+
int sampledK,
61+
double perturbationRate,
62+
int randomJoins,
63+
Partition partition,
64+
ProgressTracker progressTracker
65+
) {
66+
this.random = random;
67+
this.computer = computer;
68+
this.neighborFilter = neighborFilter;
69+
this.neighbors = neighbors;
70+
this.allOldNeighbors = allOldNeighbors;
71+
this.allNewNeighbors = allNewNeighbors;
72+
this.allReverseOldNeighbors = allReverseOldNeighbors;
73+
this.allReverseNewNeighbors = allReverseNewNeighbors;
74+
this.n = n;
75+
this.k = k;
76+
this.sampledK = sampledK;
77+
this.randomJoins = randomJoins;
78+
this.partition = partition;
79+
this.progressTracker = progressTracker;
80+
this.perturbationRate = perturbationRate;
81+
this.updateCount = 0;
82+
this.nodePairsConsidered = 0;
83+
}
84+
85+
@Override
86+
public void run() {
87+
var rng = random;
88+
var computer = this.computer;
89+
var n = this.n;
90+
var k = this.k;
91+
var sampledK = this.sampledK;
92+
var allNeighbors = this.neighbors;
93+
var allNewNeighbors = this.allNewNeighbors;
94+
var allOldNeighbors = this.allOldNeighbors;
95+
var allReverseNewNeighbors = this.allReverseNewNeighbors;
96+
var allReverseOldNeighbors = this.allReverseOldNeighbors;
97+
98+
var startNode = partition.startNode();
99+
long endNode = startNode + partition.nodeCount();
100+
101+
for (long nodeId = startNode; nodeId < endNode; nodeId++) {
102+
// old[v] ∪ Sample(old′[v], ρK)
103+
var oldNeighbors = allOldNeighbors.get(nodeId);
104+
if (oldNeighbors != null) {
105+
joinOldNeighbors(rng, sampledK, allReverseOldNeighbors, nodeId, oldNeighbors);
106+
}
107+
108+
109+
// new[v] ∪ Sample(new′[v], ρK)
110+
var newNeighbors = allNewNeighbors.get(nodeId);
111+
if (newNeighbors != null) {
112+
this.updateCount += joinNewNeighbors(
113+
rng,
114+
computer,
115+
n,
116+
k,
117+
sampledK,
118+
allNeighbors,
119+
allReverseNewNeighbors,
120+
nodeId,
121+
oldNeighbors,
122+
newNeighbors
123+
);
124+
}
125+
126+
// this isn't in the paper
127+
randomJoins(rng, computer, n, k, allNeighbors, nodeId, this.randomJoins);
128+
}
129+
progressTracker.logProgress(partition.nodeCount());
130+
}
131+
132+
long updateCount() {
133+
return updateCount;
134+
}
135+
136+
private void joinOldNeighbors(
137+
SplittableRandom rng,
138+
int sampledK,
139+
HugeObjectArray<LongArrayList> allReverseOldNeighbors,
140+
long nodeId,
141+
LongArrayList oldNeighbors
142+
) {
143+
var reverseOldNeighbors = allReverseOldNeighbors.get(nodeId);
144+
if (reverseOldNeighbors != null) {
145+
var numberOfReverseOldNeighbors = reverseOldNeighbors.size();
146+
for (var elem : reverseOldNeighbors) {
147+
if (rng.nextInt(numberOfReverseOldNeighbors) < sampledK) {
148+
// TODO: this could add nodes twice, maybe? should this be a set?
149+
oldNeighbors.add(elem.value);
150+
}
151+
}
152+
}
153+
}
154+
155+
private long joinNewNeighbors(
156+
SplittableRandom rng,
157+
SimilarityComputer computer,
158+
long n,
159+
int k,
160+
int sampledK,
161+
HugeObjectArray<FilteredNeighborList> allNeighbors,
162+
HugeObjectArray<LongArrayList> allReverseNewNeighbors,
163+
long nodeId,
164+
LongArrayList oldNeighbors,
165+
LongArrayList newNeighbors
166+
) {
167+
long updateCount = 0;
168+
169+
joinOldNeighbors(rng, sampledK, allReverseNewNeighbors, nodeId, newNeighbors);
170+
171+
var newNeighborElements = newNeighbors.buffer;
172+
var newNeighborsCount = newNeighbors.elementsCount;
173+
174+
for (int i = 0; i < newNeighborsCount; i++) {
175+
var elem1 = newNeighborElements[i];
176+
assert elem1 != nodeId;
177+
178+
// join(u1, v), this isn't in the paper
179+
updateCount += join(rng, computer, allNeighbors, n, k, elem1, nodeId);
180+
181+
// join(new_nbd, new_ndb)
182+
for (int j = i + 1; j < newNeighborsCount; j++) {
183+
var elem2 = newNeighborElements[i];
184+
if (elem1 == elem2) {
185+
continue;
186+
}
187+
188+
updateCount += join(rng, computer, allNeighbors, n, k, elem1, elem2);
189+
updateCount += join(rng, computer, allNeighbors, n, k, elem2, elem1);
190+
}
191+
192+
// join(new_nbd, old_ndb)
193+
if (oldNeighbors != null) {
194+
for (var oldElemCursor : oldNeighbors) {
195+
var elem2 = oldElemCursor.value;
196+
197+
if (elem1 == elem2) {
198+
continue;
199+
}
200+
201+
updateCount += join(rng, computer, allNeighbors, n, k, elem1, elem2);
202+
updateCount += join(rng, computer, allNeighbors, n, k, elem2, elem1);
203+
}
204+
}
205+
}
206+
return updateCount;
207+
}
208+
209+
private void randomJoins(
210+
SplittableRandom rng,
211+
SimilarityComputer computer,
212+
long n,
213+
int k,
214+
HugeObjectArray<FilteredNeighborList> allNeighbors,
215+
long nodeId,
216+
int randomJoins
217+
) {
218+
for (int i = 0; i < randomJoins; i++) {
219+
var randomNodeId = rng.nextLong(n - 1);
220+
if (randomNodeId >= nodeId) {
221+
++randomNodeId;
222+
}
223+
// random joins are not counted towards the actual update counter
224+
join(rng, computer, allNeighbors, n, k, nodeId, randomNodeId);
225+
}
226+
}
227+
228+
private long join(
229+
SplittableRandom splittableRandom,
230+
SimilarityComputer computer,
231+
HugeObjectArray<FilteredNeighborList> allNeighbors,
232+
long n,
233+
int k,
234+
long base,
235+
long joiner
236+
) {
237+
assert base != joiner;
238+
assert n > 1 && k > 0;
239+
240+
if (neighborFilter.excludeNodePair(base, joiner)) {
241+
return 0;
242+
}
243+
244+
var similarity = computer.safeSimilarity(base, joiner);
245+
nodePairsConsidered++;
246+
var neighbors = allNeighbors.get(base);
247+
248+
synchronized (neighbors) {
249+
var k2 = neighbors.size();
250+
251+
assert k2 > 0;
252+
assert k2 <= k;
253+
assert k2 <= n - 1;
254+
255+
return neighbors.add(joiner, similarity, splittableRandom, perturbationRate);
256+
}
257+
}
258+
259+
long nodePairsConsidered() {
260+
return nodePairsConsidered;
261+
}
262+
}

0 commit comments

Comments
 (0)