Skip to content

Commit 54ef30a

Browse files
breakanalysisadamnsch
authored andcommitted
Increase volume for model select task in LP
1 parent 91822bb commit 54ef30a

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,21 @@ public static List<Task> progressTasks(
103103
// the relationship count estimates depend on both UndirectedEdgeSplitter
104104
// and the volume set in extractFeaturesAndLabels
105105
var selectionRatio = (1 + splitConfig.negativeSamplingRatio());
106-
double nonTestRelationshipCount = relationshipCount * (1 - splitConfig.testFraction());
106+
long nonTestRelationshipCount = (long) (relationshipCount * (1 - splitConfig.testFraction()));
107+
long testRelationshipCount = (long) (relationshipCount * splitConfig.testFraction() * selectionRatio);
108+
long trainRelationshipCount = (long) (nonTestRelationshipCount * splitConfig.trainFraction() * selectionRatio);
107109
return List.of(
108-
Tasks.leaf("Extract train features",
109-
(long) (nonTestRelationshipCount * splitConfig.trainFraction() * selectionRatio)
110-
),
110+
Tasks.leaf("Extract train features", trainRelationshipCount),
111111
Tasks.iterativeFixed(
112112
"Select best model",
113-
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds())),
113+
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds() * trainRelationshipCount)),
114114
numberOfModelSelectionTrials
115115
),
116116
ClassifierTrainer.progressTask("Train best model"),
117117
Tasks.leaf("Compute train metrics"),
118118
Tasks.task(
119119
"Evaluate on test data",
120-
Tasks.leaf("Extract test features",
121-
(long) (relationshipCount * splitConfig.testFraction() * selectionRatio)
122-
),
120+
Tasks.leaf("Extract test features", testRelationshipCount),
123121
Tasks.leaf("Compute test metrics")
124122
)
125123
);
@@ -217,9 +215,11 @@ private void modelSelect(
217215
config.randomSeed()
218216
);
219217

218+
var trainRelationshipCount = trainGraph.relationshipCount();
220219
int trial = 0;
221220
while (hyperParameterOptimizer.hasNext()) {
222221
progressTracker.beginSubTask();
222+
progressTracker.setVolume(pipeline.splitConfig().validationFolds() * trainRelationshipCount);
223223
var modelParams = hyperParameterOptimizer.next();
224224
progressTracker.logMessage(formatWithLocale("Method: %s, Parameters: %s", modelParams.method(), modelParams.toMap()));
225225
var trainStatsBuilder = new ModelStatsBuilder(pipeline.splitConfig().validationFolds());
@@ -254,7 +254,7 @@ private void modelSelect(
254254
ProgressTracker.NULL_TRACKER
255255
);
256256

257-
progressTracker.logProgress();
257+
progressTracker.logProgress(trainRelationshipCount);
258258
}
259259

260260
// insert the candidates' metrics into trainStats and validationStats

0 commit comments

Comments
 (0)