Skip to content

Commit c611b8a

Browse files
breakanalysisadamnsch
authored andcommitted
Finish modifying LP train tasks
Co-Authored-By: Adam Schill Collberg<adam.schill.collberg@protonmail.com>
1 parent c62220f commit c611b8a

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,25 +100,20 @@ public static List<Task> progressTasks(
100100
LinkPredictionSplitConfig splitConfig,
101101
int numberOfModelSelectionTrials
102102
) {
103-
// the relationship count estimates depend on both UndirectedEdgeSplitter
104-
// and the volume set in extractFeaturesAndLabels
105-
var selectionRatio = (1 + splitConfig.negativeSamplingRatio());
106-
long nonTestRelationshipCount = (long) (relationshipCount * (1 - splitConfig.testFraction()));
107-
long testRelationshipCount = (long) (relationshipCount * splitConfig.testFraction() * selectionRatio);
108-
long trainRelationshipCount = (long) (nonTestRelationshipCount * splitConfig.trainFraction() * selectionRatio);
103+
var sizes = splitConfig.expectedSetSizes(relationshipCount);
109104
return List.of(
110-
Tasks.leaf("Extract train features", trainRelationshipCount),
105+
Tasks.leaf("Extract train features", sizes.trainSize()),
111106
Tasks.iterativeFixed(
112107
"Select best model",
113-
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds() * trainRelationshipCount)),
108+
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds() * sizes.trainSize())),
114109
numberOfModelSelectionTrials
115110
),
116-
ClassifierTrainer.progressTask("Train best model", trainRelationshipCount),
117-
Tasks.leaf("Compute train metrics"),
111+
ClassifierTrainer.progressTask("Train best model", sizes.trainSize()),
112+
Tasks.leaf("Compute train metrics", sizes.trainSize()),
118113
Tasks.task(
119114
"Evaluate on test data",
120-
Tasks.leaf("Extract test features", testRelationshipCount),
121-
Tasks.leaf("Compute test metrics")
115+
Tasks.leaf("Extract test features", sizes.testSize()),
116+
Tasks.leaf("Compute test metrics", sizes.testSize())
122117
)
123118
);
124119
}

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ void logProgressLRWithRange() {
616616
var log = Neo4jProxy.testLog();
617617
var progressTracker = new TestProgressTracker(
618618
progressTask(
619-
trainGraph.relationshipCount(),
619+
2 * trainGraph.relationshipCount(),
620620
pipeline.splitConfig(),
621621
pipeline.numberOfModelSelectionTrials()
622622
),

0 commit comments

Comments
 (0)