Skip to content

Commit 72b137e

Browse files
authored
Merge pull request #5362 from FlorentinD/gs-extract-batch-sampling-logic
GraphSage extract sampling logic
2 parents 5247932 + 45cb005 commit 72b137e

File tree

6 files changed

+249
-152
lines changed

6 files changed

+249
-152
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.graphsage;
21+
22+
import com.carrotsearch.hppc.LongHashSet;
23+
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.api.ImmutableRelationshipCursor;
25+
import org.neo4j.gds.core.utils.partition.Partition;
26+
import org.neo4j.gds.core.utils.partition.PartitionUtils;
27+
import org.neo4j.gds.ml.core.samplers.WeightedUniformSampler;
28+
import org.neo4j.gds.ml.core.subgraph.NeighborhoodSampler;
29+
30+
import java.util.Arrays;
31+
import java.util.List;
32+
import java.util.OptionalLong;
33+
import java.util.Random;
34+
import java.util.concurrent.atomic.AtomicLong;
35+
import java.util.stream.LongStream;
36+
37+
public final class BatchSampler {
38+
39+
public static final double DEGREE_SMOOTHING_FACTOR = 0.75;
40+
private final Graph graph;
41+
42+
BatchSampler(Graph graph) {
43+
this.graph = graph;
44+
}
45+
46+
List<long[]> extendedBatches(int batchSize, int searchDepth, long randomSeed) {
47+
return PartitionUtils.rangePartitionWithBatchSize(
48+
graph.nodeCount(),
49+
batchSize,
50+
batch -> {
51+
var localSeed = Math.toIntExact(Math.floorDiv(batch.startNode(), graph.nodeCount())) + randomSeed;
52+
return sampleNeighborAndNegativeNodePerBatchNode(batch, searchDepth, localSeed);
53+
}
54+
);
55+
}
56+
57+
/**
58+
* For each node in the batch we sample one neighbor node and one negative node from the graph.
59+
*/
60+
long[] sampleNeighborAndNegativeNodePerBatchNode(Partition batch, int searchDepth, long randomSeed) {
61+
var neighbours = neighborBatch(batch, randomSeed, searchDepth).toArray();
62+
63+
LongStream negativeSamples = negativeBatch(Math.toIntExact(batch.nodeCount()), neighbours, randomSeed);
64+
65+
return LongStream.concat(
66+
batch.stream(),
67+
LongStream.concat(
68+
Arrays.stream(neighbours),
69+
// batch.nodeCount is <= config.batchsize (which is an int)
70+
negativeSamples
71+
)
72+
).toArray();
73+
}
74+
75+
LongStream neighborBatch(Partition batch, long batchLocalSeed, int searchDepth) {
76+
var neighborBatchBuilder = LongStream.builder();
77+
var localRandom = new Random(batchLocalSeed);
78+
79+
// sample a neighbor for each batchNode
80+
batch.consume(nodeId -> {
81+
// randomWalk with at most maxSearchDepth steps and only save last node
82+
int actualSearchDepth = localRandom.nextInt(searchDepth) + 1;
83+
AtomicLong currentNode = new AtomicLong(nodeId);
84+
while (actualSearchDepth > 0) {
85+
NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler(currentNode.get() + actualSearchDepth);
86+
OptionalLong maybeSample = neighborhoodSampler.sampleOne(graph, nodeId);
87+
if (maybeSample.isPresent()) {
88+
currentNode.set(maybeSample.getAsLong());
89+
} else {
90+
// terminate
91+
actualSearchDepth = 0;
92+
}
93+
actualSearchDepth--;
94+
}
95+
neighborBatchBuilder.add(currentNode.get());
96+
});
97+
98+
return neighborBatchBuilder.build();
99+
}
100+
101+
// get a negative sample per node in batch
102+
LongStream negativeBatch(int batchSize, long[] batchNeighbors, long batchLocalRandomSeed) {
103+
long nodeCount = graph.nodeCount();
104+
var sampler = new WeightedUniformSampler(batchLocalRandomSeed);
105+
106+
// avoid sampling the sampled neighbor as a negative example
107+
var neighborsSet = new LongHashSet(batchNeighbors.length);
108+
neighborsSet.addAll(batchNeighbors);
109+
110+
// each node should be possible to sample
111+
// therefore we need fictive rels to all nodes
112+
// Math.log to avoid always sampling the high degree nodes
113+
var degreeWeightedNodes = LongStream.range(0, nodeCount)
114+
.mapToObj(nodeId -> ImmutableRelationshipCursor.of(0, nodeId, Math.pow(graph.degree(nodeId),
115+
DEGREE_SMOOTHING_FACTOR
116+
)));
117+
118+
return sampler.sample(degreeWeightedNodes, nodeCount, batchSize, sample -> !neighborsSet.contains(sample));
119+
}
120+
}

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 12 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,12 @@
1919
*/
2020
package org.neo4j.gds.embeddings.graphsage;
2121

22-
import com.carrotsearch.hppc.LongHashSet;
2322
import org.immutables.value.Value;
2423
import org.neo4j.gds.annotation.ValueClass;
2524
import org.neo4j.gds.api.Graph;
26-
import org.neo4j.gds.api.ImmutableRelationshipCursor;
2725
import org.neo4j.gds.config.ToMapConvertible;
2826
import org.neo4j.gds.core.concurrency.ParallelUtil;
2927
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
30-
import org.neo4j.gds.core.utils.partition.Partition;
31-
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3228
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3329
import org.neo4j.gds.core.utils.progress.tasks.Task;
3430
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
@@ -38,29 +34,23 @@
3834
import org.neo4j.gds.ml.core.features.FeatureExtraction;
3935
import org.neo4j.gds.ml.core.functions.Weights;
4036
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
41-
import org.neo4j.gds.ml.core.samplers.WeightedUniformSampler;
42-
import org.neo4j.gds.ml.core.subgraph.NeighborhoodSampler;
4337
import org.neo4j.gds.ml.core.subgraph.SubGraph;
4438
import org.neo4j.gds.ml.core.tensor.Matrix;
4539
import org.neo4j.gds.ml.core.tensor.Scalar;
4640
import org.neo4j.gds.ml.core.tensor.Tensor;
4741

4842
import java.util.ArrayList;
49-
import java.util.Arrays;
5043
import java.util.Collection;
5144
import java.util.Collections;
5245
import java.util.List;
5346
import java.util.Map;
54-
import java.util.OptionalLong;
5547
import java.util.Random;
5648
import java.util.concurrent.ExecutorService;
5749
import java.util.concurrent.ThreadLocalRandom;
58-
import java.util.concurrent.atomic.AtomicLong;
5950
import java.util.function.Function;
6051
import java.util.function.Supplier;
6152
import java.util.stream.Collectors;
6253
import java.util.stream.IntStream;
63-
import java.util.stream.LongStream;
6454

6555
import static org.neo4j.gds.embeddings.graphsage.GraphSageHelper.embeddingsComputationGraph;
6656
import static org.neo4j.gds.ml.core.RelationshipWeights.UNWEIGHTED;
@@ -125,11 +115,14 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
125115

126116
progressTracker.beginSubTask("Prepare batches");
127117

128-
var batchTasks = PartitionUtils.rangePartitionWithBatchSize(
129-
graph.nodeCount(),
130-
config.batchSize(),
131-
batch -> createBatchTask(graph, features, layers, weights, batch)
132-
);
118+
var batchSampler = new BatchSampler(graph);
119+
120+
var batchTasks = batchSampler
121+
.extendedBatches(config.batchSize(), config.searchDepth(), randomSeed)
122+
.stream()
123+
.map(extendedBatch -> createBatchTask(extendedBatch, graph, features, layers, weights))
124+
.collect(Collectors.toList());
125+
133126
var random = new Random(randomSeed);
134127
Supplier<List<BatchTask>> batchTaskSampler = () -> IntStream.range(0, config.batchesPerIteration(graph.nodeCount()))
135128
.mapToObj(__ -> batchTasks.get(random.nextInt(batchTasks.size())))
@@ -160,17 +153,15 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
160153
}
161154

162155
private BatchTask createBatchTask(
156+
long[] extendedBatch,
163157
Graph graph,
164158
HugeObjectArray<double[]> features,
165159
Layer[] layers,
166-
ArrayList<Weights<? extends Tensor<?>>> weights,
167-
Partition batch
160+
ArrayList<Weights<? extends Tensor<?>>> weights
168161
) {
169162
var localGraph = graph.concurrentCopy();
170163

171-
long[] totalBatch = addSamplesPerBatchNode(batch, localGraph);
172-
173-
List<SubGraph> subGraphs = GraphSageHelper.subGraphsPerLayer(localGraph, useWeights, totalBatch, layers);
164+
List<SubGraph> subGraphs = GraphSageHelper.subGraphsPerLayer(localGraph, useWeights, extendedBatch, layers);
174165

175166
Variable<Matrix> batchedFeaturesExtractor = featureFunction.apply(
176167
localGraph,
@@ -183,7 +174,7 @@ private BatchTask createBatchTask(
183174
GraphSageLoss lossFunction = new GraphSageLoss(
184175
useWeights ? localGraph::relationshipProperty : UNWEIGHTED,
185176
embeddingVariable,
186-
totalBatch,
177+
extendedBatch,
187178
config.negativeSampleWeight()
188179
);
189180

@@ -281,68 +272,6 @@ List<? extends Tensor<?>> weightGradients() {
281272
}
282273
}
283274

284-
private long[] addSamplesPerBatchNode(Partition batch, Graph localGraph) {
285-
var batchLocalRandomSeed = getBatchIndex(batch, localGraph.nodeCount()) + randomSeed;
286-
287-
var neighbours = neighborBatch(localGraph, batch, batchLocalRandomSeed).toArray();
288-
289-
var neighborsSet = new LongHashSet(neighbours.length);
290-
neighborsSet.addAll(neighbours);
291-
292-
return LongStream.concat(
293-
batch.stream(),
294-
LongStream.concat(
295-
Arrays.stream(neighbours),
296-
// batch.nodeCount is <= config.batchsize (which is an int)
297-
negativeBatch(localGraph, Math.toIntExact(batch.nodeCount()), neighborsSet, batchLocalRandomSeed)
298-
)
299-
).toArray();
300-
}
301-
302-
LongStream neighborBatch(Graph graph, Partition batch, long batchLocalSeed) {
303-
var neighborBatchBuilder = LongStream.builder();
304-
var localRandom = new Random(batchLocalSeed);
305-
306-
// sample a neighbor for each batchNode
307-
batch.consume(nodeId -> {
308-
// randomWalk with at most maxSearchDepth steps and only save last node
309-
int searchDepth = localRandom.nextInt(config.searchDepth()) + 1;
310-
AtomicLong currentNode = new AtomicLong(nodeId);
311-
while (searchDepth > 0) {
312-
NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler(currentNode.get() + searchDepth);
313-
OptionalLong maybeSample = neighborhoodSampler.sampleOne(graph, nodeId);
314-
if (maybeSample.isPresent()) {
315-
currentNode.set(maybeSample.getAsLong());
316-
} else {
317-
// terminate
318-
searchDepth = 0;
319-
}
320-
searchDepth--;
321-
}
322-
neighborBatchBuilder.add(currentNode.get());
323-
});
324-
325-
return neighborBatchBuilder.build();
326-
}
327-
328-
// get a negative sample per node in batch
329-
LongStream negativeBatch(Graph graph, int batchSize, LongHashSet neighbours, long batchLocalRandomSeed) {
330-
long nodeCount = graph.nodeCount();
331-
var sampler = new WeightedUniformSampler(batchLocalRandomSeed);
332-
333-
// each node should be possible to sample
334-
// therefore we need fictive rels to all nodes
335-
// Math.log to avoid always sampling the high degree nodes
336-
var degreeWeightedNodes = LongStream.range(0, nodeCount)
337-
.mapToObj(nodeId -> ImmutableRelationshipCursor.of(0, nodeId, Math.pow(graph.degree(nodeId), 0.75)));
338-
339-
return sampler.sample(degreeWeightedNodes, nodeCount, batchSize, sample -> !neighbours.contains(sample));
340-
}
341-
342-
private static int getBatchIndex(Partition partition, long nodeCount) {
343-
return Math.toIntExact(Math.floorDiv(partition.startNode(), nodeCount));
344-
}
345-
346275
private static int firstLayerColumns(GraphSageTrainConfig config, Graph graph) {
347276
return config.projectedFeatureDimension().orElseGet(() -> {
348277
var featureExtractors = GraphSageHelper.featureExtractors(graph, config);
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.graphsage;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.core.utils.partition.Partition;
25+
import org.neo4j.gds.core.utils.partition.PartitionUtils;
26+
import org.neo4j.gds.extension.GdlExtension;
27+
import org.neo4j.gds.extension.GdlGraph;
28+
import org.neo4j.gds.extension.Inject;
29+
import org.neo4j.gds.gdl.GdlFactory;
30+
31+
import java.util.function.Function;
32+
import java.util.stream.Collectors;
33+
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
36+
@GdlExtension
37+
class BatchSamplerTest {
38+
39+
@GdlGraph
40+
public static final String GDL = GraphSageTestGraph.GDL;
41+
42+
@Inject
43+
Graph graph;
44+
45+
@Test
46+
void sampleDenseGraph() {
47+
Graph clique = GdlFactory.of("(a)-->(b), (b)-->(a), (b)-->(c), (c)-->(b), (c)-->(a), (a)-->(c)").build().getUnion();
48+
Partition allNodes = Partition.of(0, 2);
49+
int searchDepth = 3;
50+
51+
assertThat(new BatchSampler(clique).sampleNeighborAndNegativeNodePerBatchNode(allNodes, searchDepth, 42))
52+
.containsExactly(
53+
0L, 1L,
54+
2L, 2L,
55+
0L, 1L
56+
);
57+
}
58+
59+
60+
@Test
61+
void seededNegativeBatch() {
62+
var batchSize = 5;
63+
var seed = 20L;
64+
65+
var partitions = PartitionUtils.rangePartitionWithBatchSize(
66+
100,
67+
batchSize,
68+
Function.identity()
69+
);
70+
71+
long[] neighborsSet = {0, 3, 5, 6, 10};
72+
73+
for (int i = 0; i < partitions.size(); i++) {
74+
var localSeed = i + seed;
75+
var negativeBatch = new BatchSampler(graph).negativeBatch(Math.toIntExact(partitions.get(i).nodeCount()), neighborsSet, localSeed);
76+
var otherNegativeBatch = new BatchSampler(graph).negativeBatch(Math.toIntExact(partitions.get(i).nodeCount()), neighborsSet, localSeed);
77+
78+
assertThat(negativeBatch).containsExactlyElementsOf(otherNegativeBatch.boxed().collect(Collectors.toList()));
79+
}
80+
}
81+
82+
@Test
83+
void seededNeighborBatch() {
84+
var batchSize = 5;
85+
var seed = 20L;
86+
int searchDepth = 12;
87+
88+
var partitions = PartitionUtils.rangePartitionWithBatchSize(
89+
graph.nodeCount(),
90+
batchSize,
91+
Function.identity()
92+
);
93+
94+
for (int i = 0; i < partitions.size(); i++) {
95+
var localSeed = i + seed;
96+
var neighborBatch = new BatchSampler(graph).neighborBatch(partitions.get(i), localSeed, searchDepth);
97+
var otherNeighborBatch = new BatchSampler(graph).neighborBatch(partitions.get(i), localSeed, searchDepth);
98+
assertThat(neighborBatch).containsExactlyElementsOf(otherNeighborBatch.boxed().collect(Collectors.toList()));
99+
}
100+
}
101+
}

0 commit comments

Comments
 (0)