1919 */
2020package org .neo4j .gds .embeddings .graphsage ;
2121
22- import com .carrotsearch .hppc .LongHashSet ;
2322import org .immutables .value .Value ;
2423import org .neo4j .gds .annotation .ValueClass ;
2524import org .neo4j .gds .api .Graph ;
26- import org .neo4j .gds .api .ImmutableRelationshipCursor ;
2725import org .neo4j .gds .config .ToMapConvertible ;
2826import org .neo4j .gds .core .concurrency .ParallelUtil ;
2927import org .neo4j .gds .core .utils .paged .HugeObjectArray ;
30- import org .neo4j .gds .core .utils .partition .Partition ;
31- import org .neo4j .gds .core .utils .partition .PartitionUtils ;
3228import org .neo4j .gds .core .utils .progress .tasks .ProgressTracker ;
3329import org .neo4j .gds .core .utils .progress .tasks .Task ;
3430import org .neo4j .gds .core .utils .progress .tasks .Tasks ;
3834import org .neo4j .gds .ml .core .features .FeatureExtraction ;
3935import org .neo4j .gds .ml .core .functions .Weights ;
4036import org .neo4j .gds .ml .core .optimizer .AdamOptimizer ;
41- import org .neo4j .gds .ml .core .samplers .WeightedUniformSampler ;
42- import org .neo4j .gds .ml .core .subgraph .NeighborhoodSampler ;
4337import org .neo4j .gds .ml .core .subgraph .SubGraph ;
4438import org .neo4j .gds .ml .core .tensor .Matrix ;
4539import org .neo4j .gds .ml .core .tensor .Scalar ;
4640import org .neo4j .gds .ml .core .tensor .Tensor ;
4741
4842import java .util .ArrayList ;
49- import java .util .Arrays ;
5043import java .util .Collection ;
5144import java .util .Collections ;
5245import java .util .List ;
5346import java .util .Map ;
54- import java .util .OptionalLong ;
5547import java .util .Random ;
5648import java .util .concurrent .ExecutorService ;
5749import java .util .concurrent .ThreadLocalRandom ;
58- import java .util .concurrent .atomic .AtomicLong ;
5950import java .util .function .Function ;
6051import java .util .function .Supplier ;
6152import java .util .stream .Collectors ;
6253import java .util .stream .IntStream ;
63- import java .util .stream .LongStream ;
6454
6555import static org .neo4j .gds .embeddings .graphsage .GraphSageHelper .embeddingsComputationGraph ;
6656import static org .neo4j .gds .ml .core .RelationshipWeights .UNWEIGHTED ;
@@ -125,11 +115,14 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
125115
126116 progressTracker .beginSubTask ("Prepare batches" );
127117
128- var batchTasks = PartitionUtils .rangePartitionWithBatchSize (
129- graph .nodeCount (),
130- config .batchSize (),
131- batch -> createBatchTask (graph , features , layers , weights , batch )
132- );
118+ var batchSampler = new BatchSampler (graph );
119+
120+ var batchTasks = batchSampler
121+ .extendedBatches (config .batchSize (), config .searchDepth (), randomSeed )
122+ .stream ()
123+ .map (extendedBatch -> createBatchTask (extendedBatch , graph , features , layers , weights ))
124+ .collect (Collectors .toList ());
125+
133126 var random = new Random (randomSeed );
134127 Supplier <List <BatchTask >> batchTaskSampler = () -> IntStream .range (0 , config .batchesPerIteration (graph .nodeCount ()))
135128 .mapToObj (__ -> batchTasks .get (random .nextInt (batchTasks .size ())))
@@ -160,17 +153,15 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
160153 }
161154
162155 private BatchTask createBatchTask (
156+ long [] extendedBatch ,
163157 Graph graph ,
164158 HugeObjectArray <double []> features ,
165159 Layer [] layers ,
166- ArrayList <Weights <? extends Tensor <?>>> weights ,
167- Partition batch
160+ ArrayList <Weights <? extends Tensor <?>>> weights
168161 ) {
169162 var localGraph = graph .concurrentCopy ();
170163
171- long [] totalBatch = addSamplesPerBatchNode (batch , localGraph );
172-
173- List <SubGraph > subGraphs = GraphSageHelper .subGraphsPerLayer (localGraph , useWeights , totalBatch , layers );
164+ List <SubGraph > subGraphs = GraphSageHelper .subGraphsPerLayer (localGraph , useWeights , extendedBatch , layers );
174165
175166 Variable <Matrix > batchedFeaturesExtractor = featureFunction .apply (
176167 localGraph ,
@@ -183,7 +174,7 @@ private BatchTask createBatchTask(
183174 GraphSageLoss lossFunction = new GraphSageLoss (
184175 useWeights ? localGraph ::relationshipProperty : UNWEIGHTED ,
185176 embeddingVariable ,
186- totalBatch ,
177+ extendedBatch ,
187178 config .negativeSampleWeight ()
188179 );
189180
@@ -281,68 +272,6 @@ List<? extends Tensor<?>> weightGradients() {
281272 }
282273 }
283274
284- private long [] addSamplesPerBatchNode (Partition batch , Graph localGraph ) {
285- var batchLocalRandomSeed = getBatchIndex (batch , localGraph .nodeCount ()) + randomSeed ;
286-
287- var neighbours = neighborBatch (localGraph , batch , batchLocalRandomSeed ).toArray ();
288-
289- var neighborsSet = new LongHashSet (neighbours .length );
290- neighborsSet .addAll (neighbours );
291-
292- return LongStream .concat (
293- batch .stream (),
294- LongStream .concat (
295- Arrays .stream (neighbours ),
296- // batch.nodeCount is <= config.batchsize (which is an int)
297- negativeBatch (localGraph , Math .toIntExact (batch .nodeCount ()), neighborsSet , batchLocalRandomSeed )
298- )
299- ).toArray ();
300- }
301-
302- LongStream neighborBatch (Graph graph , Partition batch , long batchLocalSeed ) {
303- var neighborBatchBuilder = LongStream .builder ();
304- var localRandom = new Random (batchLocalSeed );
305-
306- // sample a neighbor for each batchNode
307- batch .consume (nodeId -> {
308- // randomWalk with at most maxSearchDepth steps and only save last node
309- int searchDepth = localRandom .nextInt (config .searchDepth ()) + 1 ;
310- AtomicLong currentNode = new AtomicLong (nodeId );
311- while (searchDepth > 0 ) {
312- NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler (currentNode .get () + searchDepth );
313- OptionalLong maybeSample = neighborhoodSampler .sampleOne (graph , nodeId );
314- if (maybeSample .isPresent ()) {
315- currentNode .set (maybeSample .getAsLong ());
316- } else {
317- // terminate
318- searchDepth = 0 ;
319- }
320- searchDepth --;
321- }
322- neighborBatchBuilder .add (currentNode .get ());
323- });
324-
325- return neighborBatchBuilder .build ();
326- }
327-
328- // get a negative sample per node in batch
329- LongStream negativeBatch (Graph graph , int batchSize , LongHashSet neighbours , long batchLocalRandomSeed ) {
330- long nodeCount = graph .nodeCount ();
331- var sampler = new WeightedUniformSampler (batchLocalRandomSeed );
332-
333- // each node should be possible to sample
334- // therefore we need fictive rels to all nodes
335- // Math.log to avoid always sampling the high degree nodes
336- var degreeWeightedNodes = LongStream .range (0 , nodeCount )
337- .mapToObj (nodeId -> ImmutableRelationshipCursor .of (0 , nodeId , Math .pow (graph .degree (nodeId ), 0.75 )));
338-
339- return sampler .sample (degreeWeightedNodes , nodeCount , batchSize , sample -> !neighbours .contains (sample ));
340- }
341-
342- private static int getBatchIndex (Partition partition , long nodeCount ) {
343- return Math .toIntExact (Math .floorDiv (partition .startNode (), nodeCount ));
344- }
345-
346275 private static int firstLayerColumns (GraphSageTrainConfig config , Graph graph ) {
347276 return config .projectedFeatureDimension ().orElseGet (() -> {
348277 var featureExtractors = GraphSageHelper .featureExtractors (graph , config );
0 commit comments