Skip to content

Commit f59afea

Browse files
committed
Consider only #concurrency batches per iteration
Drastically lowers the runtime of the algorithm and even gives better quality in most cases. Reasoning: Before we averaged the gradient over all batches which was just noisy. By using an approximate gradient we still are close enough and save a lot of time.
1 parent d8ee182 commit f59afea

File tree

4 files changed

+56
-45
lines changed

4 files changed

+56
-45
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@
5757
import java.util.concurrent.ThreadLocalRandom;
5858
import java.util.concurrent.atomic.AtomicLong;
5959
import java.util.function.Function;
60+
import java.util.function.Supplier;
6061
import java.util.stream.Collectors;
62+
import java.util.stream.IntStream;
6163
import java.util.stream.LongStream;
6264

6365
import static org.neo4j.gds.embeddings.graphsage.GraphSageHelper.embeddingsComputationGraph;
@@ -148,12 +150,17 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
148150
boolean converged = false;
149151
var iterationLossesPerEpoch = new ArrayList<List<Double>>();
150152

153+
var prevEpochLoss = Double.NaN;
154+
var random = new Random(randomSeed);
155+
151156
progressTracker.beginSubTask("Train model");
152157

153158
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
154159
progressTracker.beginSubTask("Epoch");
155-
var epochResult = trainEpoch(batchTasks, weights);
156-
iterationLossesPerEpoch.add(epochResult.losses());
160+
var epochResult = trainEpoch(() -> batchTasks.get(random.nextInt(batchTasks.size())), weights, prevEpochLoss);
161+
List<Double> epochLosses = epochResult.losses();
162+
iterationLossesPerEpoch.add(epochLosses);
163+
prevEpochLoss = epochLosses.get(epochLosses.size() - 1);
157164
converged = epochResult.converged();
158165
progressTracker.endSubTask("Epoch");
159166
}
@@ -194,27 +201,37 @@ private BatchTask createBatchTask(
194201
return new BatchTask(lossFunction, weights, tolerance, progressTracker);
195202
}
196203

197-
private EpochResult trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
204+
private EpochResult trainEpoch(Supplier<BatchTask> batchTaskSupplier, List<Weights<? extends Tensor<?>>> weights, double prevEpochLoss) {
198205
var updater = new AdamOptimizer(weights, learningRate);
199206

200207
int iteration = 1;
201208
var iterationLosses = new ArrayList<Double>();
209+
double prevLoss = prevEpochLoss;
202210
var converged = false;
203211

204212
for (;iteration <= maxIterations; iteration++) {
205213
progressTracker.beginSubTask("Iteration");
206214

215+
// TODO let the user configer the number of batches per iteration
216+
var batchTasks = IntStream
217+
.range(0, concurrency)
218+
.mapToObj(__ -> batchTaskSupplier.get())
219+
.collect(Collectors.toList());
220+
207221
// run forward + maybe backward for each Batch
208222
ParallelUtil.runWithConcurrency(concurrency, batchTasks, executor);
209223
var avgLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
210224
iterationLosses.add(avgLoss);
225+
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
211226

212-
converged = batchTasks.stream().allMatch(task -> task.converged);
213-
if (converged) {
214-
progressTracker.endSubTask();
227+
if (Math.abs(prevLoss - avgLoss) < tolerance) {
228+
converged = true;
229+
progressTracker.endSubTask("Iteration");
215230
break;
216231
}
217232

233+
prevLoss = avgLoss;
234+
218235
var batchedGradients = batchTasks
219236
.stream()
220237
.map(BatchTask::weightGradients)
@@ -223,8 +240,6 @@ private EpochResult trainEpoch(List<BatchTask> batchTasks, List<Weights<? extend
223240
var meanGradients = averageTensors(batchedGradients);
224241

225242
updater.update(meanGradients);
226-
227-
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
228243
progressTracker.endSubTask("Iteration");
229244
}
230245

@@ -245,7 +260,6 @@ static class BatchTask implements Runnable {
245260
private List<? extends Tensor<?>> weightGradients;
246261
private final double tolerance;
247262
private final ProgressTracker progressTracker;
248-
private boolean converged;
249263
private double prevLoss;
250264

251265
BatchTask(
@@ -262,14 +276,9 @@ static class BatchTask implements Runnable {
262276

263277
@Override
264278
public void run() {
265-
if(converged) { // Don't try to go further
266-
return;
267-
}
268-
269279
var localCtx = new ComputationContext();
270280
var loss = localCtx.forward(lossFunction).value();
271281

272-
converged = Math.abs(prevLoss - loss) < tolerance;
273282
prevLoss = loss;
274283

275284
localCtx.backward(lossFunction);

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -228,17 +228,17 @@ void testLosses() {
228228
assertThat(epochLosses).isInstanceOf(List.class);
229229
assertThat(((List<Double>) epochLosses).stream().mapToDouble(Double::doubleValue).toArray())
230230
.contains(new double[]{
231-
91.33327272,
232-
88.17940500,
233-
87.68340477,
234-
85.60797746,
235-
85.59108701,
236-
85.59007234,
237-
81.44403525,
238-
81.44260858,
239-
81.44349342,
240-
81.45612978
241-
}, Offset.offset(1e-8)
231+
78.30,
232+
71.55,
233+
71.07,
234+
71.65,
235+
74.36,
236+
74.08,
237+
73.98,
238+
80.28,
239+
71.07,
240+
71.07
241+
}, Offset.offset(0.05)
242242
);
243243
}
244244

@@ -276,16 +276,16 @@ void testLossesWithPoolAggregator() {
276276
assertThat(epochLosses).isInstanceOf(List.class);
277277
assertThat(((List<Double>) epochLosses).stream().mapToDouble(Double::doubleValue).toArray())
278278
.contains(new double[]{
279-
90.53,
280-
83.29,
281-
74.75,
282-
74.61,
283-
74.68,
284-
74.54,
285-
74.46,
286-
74.47,
287-
74.41,
288-
74.41
279+
87.34,
280+
80.75,
281+
74.07,
282+
93.12,
283+
96.36,
284+
80.50,
285+
77.31,
286+
99.70,
287+
83.60,
288+
83.60
289289
}, Offset.offset(0.05)
290290
);
291291
}

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainAlgorithmFactoryTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,15 +460,17 @@ void memoryEstimationTreeStructure(boolean isMultiLabel) {
460460

461461
@Test
462462
void testLogging() {
463-
var config = ImmutableGraphSageTrainConfig.builder()
464-
.addFeatureProperties(DUMMY_PROPERTY)
463+
var config = GraphSageTrainConfigImpl.builder()
464+
.username("DUMMY")
465+
.featureProperties(List.of(DUMMY_PROPERTY))
465466
.embeddingDimension(12)
466467
.aggregator(Aggregator.AggregatorType.POOL)
467468
.tolerance(1e-10)
468469
.sampleSizes(List.of(5, 3))
469470
.batchSize(5)
470471
.randomSeed(42L)
471472
.modelName("model")
473+
.activationFunction("RELU")
472474
.epochs(2)
473475
.maxIterations(2)
474476
.build();

doc/asciidoc/machine-learning/node-embeddings/graph-sage/graph-sage.adoc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ RETURN
278278
.Results
279279
|===
280280
| modelName | didConverge | ranEpochs | epochLosses
281-
| "exampleTrainModel" | true | 1 | [186.04946807210226]
281+
| "exampleTrainModel" | true | 1 | [186.0494680638198]
282282
|===
283283
--
284284

@@ -504,13 +504,13 @@ YIELD nodeId, embedding
504504
.Results
505505
|===
506506
| nodeId | embedding
507-
| 0 | [0.528500243954147, 0.46821819122905217, 0.7081378518617193]
508-
| 1 | [0.5285002439545966, 0.4682181912292858, 0.7081378518612291]
509-
| 2 | [0.5285002439541305, 0.4682181912290437, 0.7081378518617372]
510-
| 3 | [0.528500243952747, 0.46821819122832464, 0.7081378518632452]
511-
| 4 | [0.5285002439970667, 0.46821819125135444, 0.7081378518149409]
512-
| 5 | [0.5285002440594959, 0.46821819128379416, 0.7081378517468996]
513-
| 6 | [0.528500243952941, 0.46821819122842556, 0.7081378518630335]
507+
| 0 | [0.5285002294775042, 0.46821819621782496, 0.7081378593674258]
508+
| 1 | [0.5285002294779538, 0.4682181962180586, 0.7081378593669356]
509+
| 2 | [0.5285002294774878, 0.46821819621781646, 0.7081378593674437]
510+
| 3 | [0.5285002294761042, 0.4682181962170975, 0.7081378593689517]
511+
| 4 | [0.5285002295204241, 0.4682181962401272, 0.7081378593206474]
512+
| 5 | [0.528500229582853, 0.468218196272567, 0.7081378592526062]
513+
| 6 | [0.5285002294762983, 0.4682181962171984, 0.7081378593687399]
514514
|===
515515
--
516516

0 commit comments

Comments
 (0)