Skip to content

Commit 4ee2878

Browse files
committed
Return loss per iteration
For more fine-tuning
1 parent 722b475 commit 4ee2878

File tree

2 files changed

+37
-16
lines changed

2 files changed

+37
-16
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,18 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
147147

148148
double previousLoss = Double.MAX_VALUE;
149149
boolean converged = false;
150-
var epochLosses = new ArrayList<Double>();
150+
var iterationLossesPerEpoch = new ArrayList<List<Double>>();
151151

152152
progressTracker.beginSubTask("Train model");
153153

154154
for (int epoch = 1; epoch <= epochs; epoch++) {
155155
progressTracker.beginSubTask("Epoch");
156156

157-
double newLoss = trainEpoch(batchTasks, weights);
158-
epochLosses.add(newLoss);
157+
158+
var iterationLosses = trainEpoch(batchTasks, weights);
159+
iterationLossesPerEpoch.add(iterationLosses);
160+
var newLoss = iterationLosses.get(iterationLosses.size() - 1);
161+
159162
progressTracker.endSubTask("Epoch");
160163
if (Math.abs((newLoss - previousLoss) / previousLoss) < tolerance) {
161164
converged = true;
@@ -166,7 +169,7 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
166169

167170
progressTracker.endSubTask("Train model");
168171

169-
return ModelTrainResult.of(epochLosses, converged, layers);
172+
return ModelTrainResult.of(iterationLossesPerEpoch, converged, layers);
170173
}
171174

172175
private BatchTask createBatchTask(
@@ -200,17 +203,18 @@ private BatchTask createBatchTask(
200203
return new BatchTask(lossFunction, weights, tolerance, progressTracker);
201204
}
202205

203-
private double trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
206+
private List<Double> trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
204207
var updater = new AdamOptimizer(weights, learningRate);
205208

206-
double totalLoss = Double.NaN;
207209
int iteration = 1;
210+
var iterationLosses = new ArrayList<Double>();
208211
for (;iteration <= maxIterations; iteration++) {
209212
progressTracker.beginSubTask("Iteration");
210213

211214
// run forward + maybe backward for each Batch
212215
ParallelUtil.runWithConcurrency(concurrency, batchTasks, executor);
213-
totalLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
216+
var avgLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
217+
iterationLosses.add(avgLoss);
214218

215219
var converged = batchTasks.stream().allMatch(task -> task.converged);
216220
if (converged) {
@@ -227,12 +231,11 @@ private double trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Ten
227231

228232
updater.update(meanGradients);
229233

230-
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", totalLoss));
231-
234+
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
232235
progressTracker.endSubTask("Iteration");
233236
}
234237

235-
return totalLoss;
238+
return iterationLosses;
236239
}
237240

238241
static class BatchTask implements Runnable {
@@ -359,14 +362,27 @@ static GraphSageTrainMetrics empty() {
359362
return ImmutableGraphSageTrainMetrics.of(List.of(), false);
360363
}
361364

362-
List<Double> epochLosses();
365+
@Value.Derived
366+
default List<Double> epochLosses() {
367+
return iterationLossPerEpoch().stream()
368+
.map(iterationLosses -> iterationLosses.get(iterationLosses.size() - 1))
369+
.collect(Collectors.toList());
370+
}
371+
372+
List<List<Double>> iterationLossPerEpoch();
373+
363374
boolean didConverge();
364375

365376
@Value.Derived
366377
default int ranEpochs() {
367-
return epochLosses().isEmpty()
378+
return iterationLossPerEpoch().isEmpty()
368379
? 0
369-
: epochLosses().size();
380+
: iterationLossPerEpoch().size();
381+
}
382+
383+
@Value.Derived
384+
default List<Integer> ranIterationsPerEpoch() {
385+
return iterationLossPerEpoch().stream().map(List::size).collect(Collectors.toList());
370386
}
371387

372388
@Override
@@ -376,8 +392,10 @@ default Map<String, Object> toMap() {
376392
return Map.of(
377393
"metrics", Map.of(
378394
"epochLosses", epochLosses(),
395+
"iterationLossesPerEpoch", iterationLossPerEpoch(),
379396
"didConverge", didConverge(),
380-
"ranEpochs", ranEpochs()
397+
"ranEpochs", ranEpochs(),
398+
"ranIterationsPerEpoch", ranIterationsPerEpoch()
381399
));
382400
}
383401
}
@@ -390,13 +408,13 @@ public interface ModelTrainResult {
390408
Layer[] layers();
391409

392410
static ModelTrainResult of(
393-
List<Double> epochLosses,
411+
List<List<Double>> iterationLossesPerEpoch,
394412
boolean converged,
395413
Layer[] layers
396414
) {
397415
return ImmutableModelTrainResult.builder()
398416
.layers(layers)
399-
.metrics(ImmutableGraphSageTrainMetrics.of(epochLosses, converged))
417+
.metrics(ImmutableGraphSageTrainMetrics.of(iterationLossesPerEpoch, converged))
400418
.build();
401419
}
402420
}

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ void testLosses() {
219219
var metrics = trainResult.metrics();
220220
assertThat(metrics.didConverge()).isFalse();
221221
assertThat(metrics.ranEpochs()).isEqualTo(10);
222+
assertThat(metrics.ranIterationsPerEpoch()).containsExactly(100, 100, 100, 100, 100, 100, 100, 100, 100, 100);
222223

223224
var metricsMap = metrics.toMap().get("metrics");
224225
assertThat(metricsMap).isInstanceOf(Map.class);
@@ -266,6 +267,7 @@ void testLossesWithPoolAggregator() {
266267
var metrics = trainResult.metrics();
267268
assertThat(metrics.didConverge()).isFalse();
268269
assertThat(metrics.ranEpochs()).isEqualTo(10);
270+
assertThat(metrics.ranIterationsPerEpoch()).containsExactly(100, 100, 100, 100, 100, 100, 100, 100, 100, 100);
269271

270272
var metricsMap = metrics.toMap().get("metrics");
271273
assertThat(metricsMap).isInstanceOf(Map.class);
@@ -301,6 +303,7 @@ void testConvergence() {
301303
var trainMetrics = trainResult.metrics();
302304
assertThat(trainMetrics.didConverge()).isTrue();
303305
assertThat(trainMetrics.ranEpochs()).isEqualTo(1);
306+
assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2);
304307
}
305308

306309
@ParameterizedTest

0 commit comments

Comments
 (0)