@@ -100,25 +100,20 @@ public static List<Task> progressTasks(
100100 LinkPredictionSplitConfig splitConfig ,
101101 int numberOfModelSelectionTrials
102102 ) {
103- // the relationship count estimates depend on both UndirectedEdgeSplitter
104- // and the volume set in extractFeaturesAndLabels
105- var selectionRatio = (1 + splitConfig .negativeSamplingRatio ());
106- long nonTestRelationshipCount = (long ) (relationshipCount * (1 - splitConfig .testFraction ()));
107- long testRelationshipCount = (long ) (relationshipCount * splitConfig .testFraction () * selectionRatio );
108- long trainRelationshipCount = (long ) (nonTestRelationshipCount * splitConfig .trainFraction () * selectionRatio );
103+ var sizes = splitConfig .expectedSetSizes (relationshipCount );
109104 return List .of (
110- Tasks .leaf ("Extract train features" , trainRelationshipCount ),
105+ Tasks .leaf ("Extract train features" , sizes . trainSize () ),
111106 Tasks .iterativeFixed (
112107 "Select best model" ,
113- () -> List .of (Tasks .leaf ("Trial" , splitConfig .validationFolds () * trainRelationshipCount )),
108+ () -> List .of (Tasks .leaf ("Trial" , splitConfig .validationFolds () * sizes . trainSize () )),
114109 numberOfModelSelectionTrials
115110 ),
116- ClassifierTrainer .progressTask ("Train best model" , trainRelationshipCount ),
117- Tasks .leaf ("Compute train metrics" ),
111+ ClassifierTrainer .progressTask ("Train best model" , sizes . trainSize () ),
112+ Tasks .leaf ("Compute train metrics" , sizes . trainSize () ),
118113 Tasks .task (
119114 "Evaluate on test data" ,
120- Tasks .leaf ("Extract test features" , testRelationshipCount ),
121- Tasks .leaf ("Compute test metrics" )
115+ Tasks .leaf ("Extract test features" , sizes . testSize () ),
116+ Tasks .leaf ("Compute test metrics" , sizes . testSize () )
122117 )
123118 );
124119 }
0 commit comments