@@ -147,15 +147,18 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
147147
148148 double previousLoss = Double .MAX_VALUE ;
149149 boolean converged = false ;
150- var epochLosses = new ArrayList <Double >();
150+ var iterationLossesPerEpoch = new ArrayList <List < Double > >();
151151
152152 progressTracker .beginSubTask ("Train model" );
153153
154154 for (int epoch = 1 ; epoch <= epochs ; epoch ++) {
155155 progressTracker .beginSubTask ("Epoch" );
156156
157- double newLoss = trainEpoch (batchTasks , weights );
158- epochLosses .add (newLoss );
157+
158+ var iterationLosses = trainEpoch (batchTasks , weights );
159+ iterationLossesPerEpoch .add (iterationLosses );
160+ var newLoss = iterationLosses .get (iterationLosses .size () - 1 );
161+
159162 progressTracker .endSubTask ("Epoch" );
160163 if (Math .abs ((newLoss - previousLoss ) / previousLoss ) < tolerance ) {
161164 converged = true ;
@@ -166,7 +169,7 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
166169
167170 progressTracker .endSubTask ("Train model" );
168171
169- return ModelTrainResult .of (epochLosses , converged , layers );
172+ return ModelTrainResult .of (iterationLossesPerEpoch , converged , layers );
170173 }
171174
172175 private BatchTask createBatchTask (
@@ -200,17 +203,18 @@ private BatchTask createBatchTask(
200203 return new BatchTask (lossFunction , weights , tolerance , progressTracker );
201204 }
202205
203- private double trainEpoch (List <BatchTask > batchTasks , List <Weights <? extends Tensor <?>>> weights ) {
206+ private List < Double > trainEpoch (List <BatchTask > batchTasks , List <Weights <? extends Tensor <?>>> weights ) {
204207 var updater = new AdamOptimizer (weights , learningRate );
205208
206- double totalLoss = Double .NaN ;
207209 int iteration = 1 ;
210+ var iterationLosses = new ArrayList <Double >();
208211 for (;iteration <= maxIterations ; iteration ++) {
209212 progressTracker .beginSubTask ("Iteration" );
210213
211214 // run forward + maybe backward for each Batch
212215 ParallelUtil .runWithConcurrency (concurrency , batchTasks , executor );
213- totalLoss = batchTasks .stream ().mapToDouble (BatchTask ::loss ).average ().orElseThrow ();
216+ var avgLoss = batchTasks .stream ().mapToDouble (BatchTask ::loss ).average ().orElseThrow ();
217+ iterationLosses .add (avgLoss );
214218
215219 var converged = batchTasks .stream ().allMatch (task -> task .converged );
216220 if (converged ) {
@@ -227,12 +231,11 @@ private double trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Ten
227231
228232 updater .update (meanGradients );
229233
230- progressTracker .logMessage (formatWithLocale ("LOSS: %.10f" , totalLoss ));
231-
234+ progressTracker .logMessage (formatWithLocale ("LOSS: %.10f" , avgLoss ));
232235 progressTracker .endSubTask ("Iteration" );
233236 }
234237
235- return totalLoss ;
238+ return iterationLosses ;
236239 }
237240
238241 static class BatchTask implements Runnable {
@@ -359,14 +362,27 @@ static GraphSageTrainMetrics empty() {
359362 return ImmutableGraphSageTrainMetrics .of (List .of (), false );
360363 }
361364
362- List <Double > epochLosses ();
365+ @ Value .Derived
366+ default List <Double > epochLosses () {
367+ return iterationLossPerEpoch ().stream ()
368+ .map (iterationLosses -> iterationLosses .get (iterationLosses .size () - 1 ))
369+ .collect (Collectors .toList ());
370+ }
371+
372+ List <List <Double >> iterationLossPerEpoch ();
373+
363374 boolean didConverge ();
364375
365376 @ Value .Derived
366377 default int ranEpochs () {
367- return epochLosses ().isEmpty ()
378+ return iterationLossPerEpoch ().isEmpty ()
368379 ? 0
369- : epochLosses ().size ();
380+ : iterationLossPerEpoch ().size ();
381+ }
382+
383+ @ Value .Derived
384+ default List <Integer > ranIterationsPerEpoch () {
385+ return iterationLossPerEpoch ().stream ().map (List ::size ).collect (Collectors .toList ());
370386 }
371387
372388 @ Override
@@ -376,8 +392,10 @@ default Map<String, Object> toMap() {
376392 return Map .of (
377393 "metrics" , Map .of (
378394 "epochLosses" , epochLosses (),
395+ "iterationLossesPerEpoch" , iterationLossPerEpoch (),
379396 "didConverge" , didConverge (),
380- "ranEpochs" , ranEpochs ()
397+ "ranEpochs" , ranEpochs (),
398+ "ranIterationsPerEpoch" , ranIterationsPerEpoch ()
381399 ));
382400 }
383401 }
@@ -390,13 +408,13 @@ public interface ModelTrainResult {
390408 Layer [] layers ();
391409
392410 static ModelTrainResult of (
393- List <Double > epochLosses ,
411+ List <List < Double >> iterationLossesPerEpoch ,
394412 boolean converged ,
395413 Layer [] layers
396414 ) {
397415 return ImmutableModelTrainResult .builder ()
398416 .layers (layers )
399- .metrics (ImmutableGraphSageTrainMetrics .of (epochLosses , converged ))
417+ .metrics (ImmutableGraphSageTrainMetrics .of (iterationLossesPerEpoch , converged ))
400418 .build ();
401419 }
402420 }
0 commit comments