Skip to content

Commit e948920

Browse files
committed
Fix didConverge
Actually as we dont resample the neighbors per epoch, the previous convergence logic did not make an actual difference. Also, before it was ood to check the tolerance twice (once inside the iterations loop and the epochs loop)
1 parent 4ee2878 commit e948920

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,26 +145,17 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
145145

146146
progressTracker.endSubTask("Prepare batches");
147147

148-
double previousLoss = Double.MAX_VALUE;
149148
boolean converged = false;
150149
var iterationLossesPerEpoch = new ArrayList<List<Double>>();
151150

152151
progressTracker.beginSubTask("Train model");
153152

154-
for (int epoch = 1; epoch <= epochs; epoch++) {
153+
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
155154
progressTracker.beginSubTask("Epoch");
156-
157-
158-
var iterationLosses = trainEpoch(batchTasks, weights);
159-
iterationLossesPerEpoch.add(iterationLosses);
160-
var newLoss = iterationLosses.get(iterationLosses.size() - 1);
161-
155+
var epochResult = trainEpoch(batchTasks, weights);
156+
iterationLossesPerEpoch.add(epochResult.losses());
157+
converged = epochResult.converged();
162158
progressTracker.endSubTask("Epoch");
163-
if (Math.abs((newLoss - previousLoss) / previousLoss) < tolerance) {
164-
converged = true;
165-
break;
166-
}
167-
previousLoss = newLoss;
168159
}
169160

170161
progressTracker.endSubTask("Train model");
@@ -203,11 +194,13 @@ private BatchTask createBatchTask(
203194
return new BatchTask(lossFunction, weights, tolerance, progressTracker);
204195
}
205196

206-
private List<Double> trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
197+
private EpochResult trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
207198
var updater = new AdamOptimizer(weights, learningRate);
208199

209200
int iteration = 1;
210201
var iterationLosses = new ArrayList<Double>();
202+
var converged = false;
203+
211204
for (;iteration <= maxIterations; iteration++) {
212205
progressTracker.beginSubTask("Iteration");
213206

@@ -216,7 +209,7 @@ private List<Double> trainEpoch(List<BatchTask> batchTasks, List<Weights<? exten
216209
var avgLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
217210
iterationLosses.add(avgLoss);
218211

219-
var converged = batchTasks.stream().allMatch(task -> task.converged);
212+
converged = batchTasks.stream().allMatch(task -> task.converged);
220213
if (converged) {
221214
progressTracker.endSubTask();
222215
break;
@@ -235,7 +228,14 @@ private List<Double> trainEpoch(List<BatchTask> batchTasks, List<Weights<? exten
235228
progressTracker.endSubTask("Iteration");
236229
}
237230

238-
return iterationLosses;
231+
return ImmutableEpochResult.of(converged, iterationLosses);
232+
}
233+
234+
@ValueClass
235+
interface EpochResult {
236+
boolean converged();
237+
238+
List<Double> losses();
239239
}
240240

241241
static class BatchTask implements Runnable {

doc/asciidoc/machine-learning/node-embeddings/graph-sage/graph-sage.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ RETURN
278278
.Results
279279
|===
280280
| modelName | didConverge | ranEpochs | epochLosses
281-
| "exampleTrainModel" | false | 1 | [186.04946807210226]
281+
| "exampleTrainModel" | true | 1 | [186.04946807210226]
282282
|===
283283
--
284284

0 commit comments

Comments
 (0)