Skip to content

Commit 91822bb

Browse files
breakanalysisadamnsch
authored andcommitted
Estimate relationship counts at creation of some progress tasks in LP
1 parent ebe3b1f commit 91822bb

File tree

5 files changed

+47
-13
lines changed

5 files changed

+47
-13
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,31 @@ public LinkPredictionTrain(
9595
this.classIdMap = makeClassIdMap();
9696
}
9797

98-
public static List<Task> progressTasks(int validationFolds, int numberOfModelSelectionTrials) {
98+
public static List<Task> progressTasks(
99+
long relationshipCount,
100+
LinkPredictionSplitConfig splitConfig,
101+
int numberOfModelSelectionTrials
102+
) {
103+
// the relationship count estimates depend on both UndirectedEdgeSplitter
104+
// and the volume set in extractFeaturesAndLabels
105+
var selectionRatio = (1 + splitConfig.negativeSamplingRatio());
106+
double nonTestRelationshipCount = relationshipCount * (1 - splitConfig.testFraction());
99107
return List.of(
100-
Tasks.leaf("Extract train features"),
108+
Tasks.leaf("Extract train features",
109+
(long) (nonTestRelationshipCount * splitConfig.trainFraction() * selectionRatio)
110+
),
101111
Tasks.iterativeFixed(
102112
"Select best model",
103-
() -> List.of(Tasks.leaf("Trial", validationFolds)),
113+
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds())),
104114
numberOfModelSelectionTrials
105115
),
106116
ClassifierTrainer.progressTask("Train best model"),
107117
Tasks.leaf("Compute train metrics"),
108118
Tasks.task(
109119
"Evaluate on test data",
110-
Tasks.leaf("Extract test features"),
120+
Tasks.leaf("Extract test features",
121+
(long) (relationshipCount * splitConfig.testFraction() * selectionRatio)
122+
),
111123
Tasks.leaf("Compute test metrics")
112124
)
113125
);

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainTest.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,8 @@ void logProgressRF() {
461461
var log = Neo4jProxy.testLog();
462462
var progressTracker = new TestProgressTracker(
463463
progressTask(
464-
pipeline.splitConfig().validationFolds(),
464+
trainGraph.relationshipCount(),
465+
pipeline.splitConfig(),
465466
pipeline.numberOfModelSelectionTrials()
466467
),
467468
log,
@@ -552,7 +553,8 @@ void logProgressLR() {
552553
var log = Neo4jProxy.testLog();
553554
var progressTracker = new TestProgressTracker(
554555
progressTask(
555-
pipeline.splitConfig().validationFolds(),
556+
trainGraph.relationshipCount(),
557+
pipeline.splitConfig(),
556558
pipeline.numberOfModelSelectionTrials()
557559
),
558560
log,
@@ -614,7 +616,8 @@ void logProgressLRWithRange() {
614616
var log = Neo4jProxy.testLog();
615617
var progressTracker = new TestProgressTracker(
616618
progressTask(
617-
pipeline.splitConfig().validationFolds(),
619+
trainGraph.relationshipCount(),
620+
pipeline.splitConfig(),
618621
pipeline.numberOfModelSelectionTrials()
619622
),
620623
log,
@@ -713,8 +716,11 @@ void logProgressLRWithRange() {
713716
);
714717
}
715718

716-
static Task progressTask(int validationFolds, int numberOfModelSelectionTrials) {
717-
return Tasks.task("MY TEST TASK", LinkPredictionTrain.progressTasks(validationFolds, numberOfModelSelectionTrials));
719+
static Task progressTask(long relationshipCount, LinkPredictionSplitConfig splitConfig, int numberOfModelSelectionTrials) {
720+
return Tasks.task(
721+
"MY TEST TASK",
722+
LinkPredictionTrain.progressTasks(relationshipCount, splitConfig, numberOfModelSelectionTrials)
723+
);
718724
}
719725

720726
private LinkPredictionTrainingPipeline linkPredictionPipeline() {

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineAlgorithmFactory.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,15 @@ public String taskName() {
7070

7171
@Override
7272
public Task progressTask(GraphStore graphStore, LinkPredictionTrainConfig config) {
73+
var relationshipCount = config
74+
.internalRelationshipTypes(graphStore)
75+
.stream()
76+
.mapToLong(graphStore::relationshipCount)
77+
.sum();
7378
return LinkPredictionTrainPipelineExecutor.progressTask(
7479
taskName(),
75-
PipelineCatalog.getTyped(config.username(), config.pipeline(), LinkPredictionTrainingPipeline.class)
80+
PipelineCatalog.getTyped(config.username(), config.pipeline(), LinkPredictionTrainingPipeline.class),
81+
relationshipCount
7682
);
7783
}
7884

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutor.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public LinkPredictionTrainPipelineExecutor(
8181
);
8282
}
8383

84-
public static Task progressTask(String taskName, LinkPredictionTrainingPipeline pipeline) {
84+
public static Task progressTask(String taskName, LinkPredictionTrainingPipeline pipeline, long relationshipCount) {
8585
return Tasks.task(taskName, new ArrayList<>() {{
8686
add(Tasks.leaf("Split relationships"));
8787
add(Tasks.iterativeFixed(
@@ -90,7 +90,8 @@ public static Task progressTask(String taskName, LinkPredictionTrainingPipeline
9090
pipeline.nodePropertySteps().size()
9191
));
9292
addAll(LinkPredictionTrain.progressTasks(
93-
pipeline.splitConfig().validationFolds(),
93+
relationshipCount,
94+
pipeline.splitConfig(),
9495
pipeline.numberOfModelSelectionTrials()
9596
));
9697
}});

proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutorTest.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,17 @@ void shouldLogProgress() {
386386

387387
TestProcedureRunner.applyOnProcedure(db, TestProc.class, caller -> {
388388
var log = Neo4jProxy.testLog();
389+
var relationshipCount = config
390+
.internalRelationshipTypes(graphStore)
391+
.stream()
392+
.mapToLong(graphStore::relationshipCount)
393+
.sum();
389394
var progressTracker = new TestProgressTracker(
390-
LinkPredictionTrainPipelineExecutor.progressTask("Link Prediction Train Pipeline", pipeline),
395+
LinkPredictionTrainPipelineExecutor.progressTask(
396+
"Link Prediction Train Pipeline",
397+
pipeline,
398+
relationshipCount
399+
),
391400
log,
392401
1,
393402
EmptyTaskRegistryFactory.INSTANCE

0 commit comments

Comments
 (0)