@@ -103,23 +103,21 @@ public static List<Task> progressTasks(
103103 // the relationship count estimates depend on both UndirectedEdgeSplitter
104104 // and the volume set in extractFeaturesAndLabels
105105 var selectionRatio = (1 + splitConfig .negativeSamplingRatio ());
106- double nonTestRelationshipCount = relationshipCount * (1 - splitConfig .testFraction ());
106+ long nonTestRelationshipCount = (long ) (relationshipCount * (1 - splitConfig .testFraction ()));
107+ long testRelationshipCount = (long ) (relationshipCount * splitConfig .testFraction () * selectionRatio );
108+ long trainRelationshipCount = (long ) (nonTestRelationshipCount * splitConfig .trainFraction () * selectionRatio );
107109 return List .of (
108- Tasks .leaf ("Extract train features" ,
109- (long ) (nonTestRelationshipCount * splitConfig .trainFraction () * selectionRatio )
110- ),
110+ Tasks .leaf ("Extract train features" , trainRelationshipCount ),
111111 Tasks .iterativeFixed (
112112 "Select best model" ,
113- () -> List .of (Tasks .leaf ("Trial" , splitConfig .validationFolds ())),
113+ () -> List .of (Tasks .leaf ("Trial" , splitConfig .validationFolds () * trainRelationshipCount )),
114114 numberOfModelSelectionTrials
115115 ),
116116 ClassifierTrainer .progressTask ("Train best model" ),
117117 Tasks .leaf ("Compute train metrics" ),
118118 Tasks .task (
119119 "Evaluate on test data" ,
120- Tasks .leaf ("Extract test features" ,
121- (long ) (relationshipCount * splitConfig .testFraction () * selectionRatio )
122- ),
120+ Tasks .leaf ("Extract test features" , testRelationshipCount ),
123121 Tasks .leaf ("Compute test metrics" )
124122 )
125123 );
@@ -217,9 +215,11 @@ private void modelSelect(
217215 config .randomSeed ()
218216 );
219217
218+ var trainRelationshipCount = trainGraph .relationshipCount ();
220219 int trial = 0 ;
221220 while (hyperParameterOptimizer .hasNext ()) {
222221 progressTracker .beginSubTask ();
222+ progressTracker .setVolume (pipeline .splitConfig ().validationFolds () * trainRelationshipCount );
223223 var modelParams = hyperParameterOptimizer .next ();
224224 progressTracker .logMessage (formatWithLocale ("Method: %s, Parameters: %s" , modelParams .method (), modelParams .toMap ()));
225225 var trainStatsBuilder = new ModelStatsBuilder (pipeline .splitConfig ().validationFolds ());
@@ -254,7 +254,7 @@ private void modelSelect(
254254 ProgressTracker .NULL_TRACKER
255255 );
256256
257- progressTracker .logProgress ();
257+ progressTracker .logProgress (trainRelationshipCount );
258258 }
259259
260260 // insert the candidates' metrics into trainStats and validationStats
0 commit comments