Skip to content

Commit dbdc9d7

Browse files
authored
Merge pull request #5294 from FlorentinD/gs-loss-per-iteration
GraphSage return loss per iteration
2 parents 4644250 + 434a177 commit dbdc9d7

File tree

5 files changed

+64
-33
lines changed

5 files changed

+64
-33
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -145,28 +145,22 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
145145

146146
progressTracker.endSubTask("Prepare batches");
147147

148-
double previousLoss = Double.MAX_VALUE;
149148
boolean converged = false;
150-
var epochLosses = new ArrayList<Double>();
149+
var iterationLossesPerEpoch = new ArrayList<List<Double>>();
151150

152151
progressTracker.beginSubTask("Train model");
153152

154-
for (int epoch = 1; epoch <= epochs; epoch++) {
153+
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
155154
progressTracker.beginSubTask("Epoch");
156-
157-
double newLoss = trainEpoch(batchTasks, weights);
158-
epochLosses.add(newLoss);
155+
var epochResult = trainEpoch(batchTasks, weights);
156+
iterationLossesPerEpoch.add(epochResult.losses());
157+
converged = epochResult.converged();
159158
progressTracker.endSubTask("Epoch");
160-
if (Math.abs((newLoss - previousLoss) / previousLoss) < tolerance) {
161-
converged = true;
162-
break;
163-
}
164-
previousLoss = newLoss;
165159
}
166160

167161
progressTracker.endSubTask("Train model");
168162

169-
return ModelTrainResult.of(epochLosses, converged, layers);
163+
return ModelTrainResult.of(iterationLossesPerEpoch, converged, layers);
170164
}
171165

172166
private BatchTask createBatchTask(
@@ -200,19 +194,22 @@ private BatchTask createBatchTask(
200194
return new BatchTask(lossFunction, weights, tolerance, progressTracker);
201195
}
202196

203-
private double trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
197+
private EpochResult trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
204198
var updater = new AdamOptimizer(weights, learningRate);
205199

206-
double totalLoss = Double.NaN;
207200
int iteration = 1;
201+
var iterationLosses = new ArrayList<Double>();
202+
var converged = false;
203+
208204
for (;iteration <= maxIterations; iteration++) {
209205
progressTracker.beginSubTask("Iteration");
210206

211207
// run forward + maybe backward for each Batch
212208
ParallelUtil.runWithConcurrency(concurrency, batchTasks, executor);
213-
totalLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
209+
var avgLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
210+
iterationLosses.add(avgLoss);
214211

215-
var converged = batchTasks.stream().allMatch(task -> task.converged);
212+
converged = batchTasks.stream().allMatch(task -> task.converged);
216213
if (converged) {
217214
progressTracker.endSubTask();
218215
break;
@@ -227,12 +224,18 @@ private double trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Ten
227224

228225
updater.update(meanGradients);
229226

230-
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", totalLoss));
231-
227+
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
232228
progressTracker.endSubTask("Iteration");
233229
}
234230

235-
return totalLoss;
231+
return ImmutableEpochResult.of(converged, iterationLosses);
232+
}
233+
234+
@ValueClass
235+
interface EpochResult {
236+
boolean converged();
237+
238+
List<Double> losses();
236239
}
237240

238241
static class BatchTask implements Runnable {
@@ -359,14 +362,27 @@ static GraphSageTrainMetrics empty() {
359362
return ImmutableGraphSageTrainMetrics.of(List.of(), false);
360363
}
361364

362-
List<Double> epochLosses();
365+
@Value.Derived
366+
default List<Double> epochLosses() {
367+
return iterationLossPerEpoch().stream()
368+
.map(iterationLosses -> iterationLosses.get(iterationLosses.size() - 1))
369+
.collect(Collectors.toList());
370+
}
371+
372+
List<List<Double>> iterationLossPerEpoch();
373+
363374
boolean didConverge();
364375

365376
@Value.Derived
366377
default int ranEpochs() {
367-
return epochLosses().isEmpty()
378+
return iterationLossPerEpoch().isEmpty()
368379
? 0
369-
: epochLosses().size();
380+
: iterationLossPerEpoch().size();
381+
}
382+
383+
@Value.Derived
384+
default List<Integer> ranIterationsPerEpoch() {
385+
return iterationLossPerEpoch().stream().map(List::size).collect(Collectors.toList());
370386
}
371387

372388
@Override
@@ -376,8 +392,10 @@ default Map<String, Object> toMap() {
376392
return Map.of(
377393
"metrics", Map.of(
378394
"epochLosses", epochLosses(),
395+
"iterationLossesPerEpoch", iterationLossPerEpoch(),
379396
"didConverge", didConverge(),
380-
"ranEpochs", ranEpochs()
397+
"ranEpochs", ranEpochs(),
398+
"ranIterationsPerEpoch", ranIterationsPerEpoch()
381399
));
382400
}
383401
}
@@ -390,13 +408,13 @@ public interface ModelTrainResult {
390408
Layer[] layers();
391409

392410
static ModelTrainResult of(
393-
List<Double> epochLosses,
411+
List<List<Double>> iterationLossesPerEpoch,
394412
boolean converged,
395413
Layer[] layers
396414
) {
397415
return ImmutableModelTrainResult.builder()
398416
.layers(layers)
399-
.metrics(ImmutableGraphSageTrainMetrics.of(epochLosses, converged))
417+
.metrics(ImmutableGraphSageTrainMetrics.of(iterationLossesPerEpoch, converged))
400418
.build();
401419
}
402420
}

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ void testLosses() {
219219
var metrics = trainResult.metrics();
220220
assertThat(metrics.didConverge()).isFalse();
221221
assertThat(metrics.ranEpochs()).isEqualTo(10);
222+
assertThat(metrics.ranIterationsPerEpoch()).containsExactly(100, 100, 100, 100, 100, 100, 100, 100, 100, 100);
222223

223224
var metricsMap = metrics.toMap().get("metrics");
224225
assertThat(metricsMap).isInstanceOf(Map.class);
@@ -266,6 +267,7 @@ void testLossesWithPoolAggregator() {
266267
var metrics = trainResult.metrics();
267268
assertThat(metrics.didConverge()).isFalse();
268269
assertThat(metrics.ranEpochs()).isEqualTo(10);
270+
assertThat(metrics.ranIterationsPerEpoch()).containsExactly(100, 100, 100, 100, 100, 100, 100, 100, 100, 100);
269271

270272
var metricsMap = metrics.toMap().get("metrics");
271273
assertThat(metricsMap).isInstanceOf(Map.class);
@@ -301,6 +303,7 @@ void testConvergence() {
301303
var trainMetrics = trainResult.metrics();
302304
assertThat(trainMetrics.didConverge()).isTrue();
303305
assertThat(trainMetrics.ranEpochs()).isEqualTo(1);
306+
assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2);
304307
}
305308

306309
@ParameterizedTest

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainAlgorithmFactoryTest.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import static org.eclipse.collections.impl.tuple.primitive.PrimitiveTuples.pair;
5959
import static org.junit.jupiter.api.Assertions.assertEquals;
6060
import static org.junit.jupiter.params.provider.Arguments.arguments;
61+
import static org.neo4j.gds.assertj.Extractors.keepingFixedNumberOfDecimals;
6162
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
6263
import static org.neo4j.gds.compat.TestLog.INFO;
6364
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.RESIDENT_MEMORY;
@@ -461,13 +462,15 @@ void memoryEstimationTreeStructure(boolean isMultiLabel) {
461462
void testLogging() {
462463
var config = ImmutableGraphSageTrainConfig.builder()
463464
.addFeatureProperties(DUMMY_PROPERTY)
464-
.embeddingDimension(64)
465+
.embeddingDimension(12)
466+
.aggregator(Aggregator.AggregatorType.POOL)
467+
.tolerance(1e-10)
468+
.sampleSizes(List.of(5, 3))
469+
.batchSize(5)
470+
.randomSeed(42L)
465471
.modelName("model")
466472
.epochs(2)
467473
.maxIterations(2)
468-
.tolerance(1e-10)
469-
.learningRate(0.001)
470-
.randomSeed(42L)
471474
.build();
472475

473476
var log = Neo4jProxy.testLog();
@@ -485,6 +488,7 @@ void testLogging() {
485488
AssertionsForInterfaceTypes.assertThat(messagesInOrder)
486489
// avoid asserting on the thread id
487490
.extracting(removingThreadId())
491+
.extracting(keepingFixedNumberOfDecimals(2))
488492
.containsExactly(
489493
"GraphSageTrain :: Start",
490494
"GraphSageTrain :: Prepare batches :: Start",
@@ -493,17 +497,23 @@ void testLogging() {
493497
"GraphSageTrain :: Train model :: Start",
494498
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Start",
495499
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Start",
496-
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: LOSS: 531.5699087433",
500+
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: LOSS: 132.63",
497501
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 100%",
498502
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Finished",
499503
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Start",
504+
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: LOSS: 129.13",
500505
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 100%",
501506
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Finished",
502507
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Finished",
503508
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Start",
504509
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Start",
510+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: LOSS: 123.38",
505511
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 100%",
506512
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Finished",
513+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Start",
514+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: LOSS: 116.06",
515+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 100%",
516+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Finished",
507517
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Finished",
508518
"GraphSageTrain :: Train model :: Finished",
509519
"GraphSageTrain :: Finished"

doc/asciidoc/machine-learning/node-embeddings/graph-sage/graph-sage.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ RETURN
278278
.Results
279279
|===
280280
| modelName | didConverge | ranEpochs | epochLosses
281-
| "exampleTrainModel" | false | 1 | [186.04946807210226]
281+
| "exampleTrainModel" | true | 1 | [186.04946807210226]
282282
|===
283283
--
284284

doc/asciidoc/machine-learning/node-embeddings/graph-sage/specific-train-configuration.adoc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
| sampleSizes | List of Integer | [25, 10] | yes | A list of Integer values, the size of the list determines the number of layers and the values determine how many nodes will be sampled by the layers.
1111
| projectedFeatureDimension | Integer | n/a | yes | The dimension of the projected `featureProperties`. This enables multi-label GraphSage, where each label can have a subset of the `featureProperties`.
1212
| batchSize | Integer | 100 | yes | The number of nodes per batch.
13-
| <<common-configuration-tolerance,tolerance>> | Float | 1e-4 | yes | Tolerance used for the early convergence of an epoch.
13+
| <<common-configuration-tolerance,tolerance>> | Float | 1e-4 | yes | Tolerance used for the early convergence of an epoch, which is checked after each iteration.
1414
| learningRate | Float | 0.1 | yes | The learning rate determines the step size at each iteration while moving toward a minimum of a loss function.
1515
| epochs | Integer | 1 | yes | Number of times to traverse the graph.
16-
| <<common-configuration-max-iterations,maxIterations>> | Integer | 10 | yes | Maximum number of weight updates per batch. Batches can also converge early based on `tolerance`.
16+
| <<common-configuration-max-iterations,maxIterations>> | Integer | 10 | yes | Maximum number of iterations per epoch. Each iteration the weights are updated.
1717
| searchDepth | Integer | 5 | yes | Maximum depth of the RandomWalks to sample nearby nodes for the training.
1818
| negativeSampleWeight | Integer | 20 | yes | The weight of the negative samples. Higher values increase the impact of negative samples in the loss.
1919
| <<common-configuration-relationship-weight-property,relationshipWeightProperty>> | String | null | yes | Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted.

0 commit comments

Comments
 (0)