@@ -145,26 +145,17 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
145145
146146 progressTracker .endSubTask ("Prepare batches" );
147147
148- double previousLoss = Double .MAX_VALUE ;
149148 boolean converged = false ;
150149 var iterationLossesPerEpoch = new ArrayList <List <Double >>();
151150
152151 progressTracker .beginSubTask ("Train model" );
153152
154- for (int epoch = 1 ; epoch <= epochs ; epoch ++) {
153+ for (int epoch = 1 ; epoch <= epochs && ! converged ; epoch ++) {
155154 progressTracker .beginSubTask ("Epoch" );
156-
157-
158- var iterationLosses = trainEpoch (batchTasks , weights );
159- iterationLossesPerEpoch .add (iterationLosses );
160- var newLoss = iterationLosses .get (iterationLosses .size () - 1 );
161-
155+ var epochResult = trainEpoch (batchTasks , weights );
156+ iterationLossesPerEpoch .add (epochResult .losses ());
157+ converged = epochResult .converged ();
162158 progressTracker .endSubTask ("Epoch" );
163- if (Math .abs ((newLoss - previousLoss ) / previousLoss ) < tolerance ) {
164- converged = true ;
165- break ;
166- }
167- previousLoss = newLoss ;
168159 }
169160
170161 progressTracker .endSubTask ("Train model" );
@@ -203,11 +194,13 @@ private BatchTask createBatchTask(
203194 return new BatchTask (lossFunction , weights , tolerance , progressTracker );
204195 }
205196
206- private List < Double > trainEpoch (List <BatchTask > batchTasks , List <Weights <? extends Tensor <?>>> weights ) {
197+ private EpochResult trainEpoch (List <BatchTask > batchTasks , List <Weights <? extends Tensor <?>>> weights ) {
207198 var updater = new AdamOptimizer (weights , learningRate );
208199
209200 int iteration = 1 ;
210201 var iterationLosses = new ArrayList <Double >();
202+ var converged = false ;
203+
211204 for (;iteration <= maxIterations ; iteration ++) {
212205 progressTracker .beginSubTask ("Iteration" );
213206
@@ -216,7 +209,7 @@ private List<Double> trainEpoch(List<BatchTask> batchTasks, List<Weights<? exten
216209 var avgLoss = batchTasks .stream ().mapToDouble (BatchTask ::loss ).average ().orElseThrow ();
217210 iterationLosses .add (avgLoss );
218211
219- var converged = batchTasks .stream ().allMatch (task -> task .converged );
212+ converged = batchTasks .stream ().allMatch (task -> task .converged );
220213 if (converged ) {
221214 progressTracker .endSubTask ();
222215 break ;
@@ -235,7 +228,14 @@ private List<Double> trainEpoch(List<BatchTask> batchTasks, List<Weights<? exten
235228 progressTracker .endSubTask ("Iteration" );
236229 }
237230
238- return iterationLosses ;
231+ return ImmutableEpochResult .of (converged , iterationLosses );
232+ }
233+
234+ @ ValueClass
235+ interface EpochResult {
236+ boolean converged ();
237+
238+ List <Double > losses ();
239239 }
240240
241241 static class BatchTask implements Runnable {
0 commit comments