Skip to content

Commit 2080199

Browse files
authored
Merge pull request #5249 from FlorentinD/move-model-creation-to-executor-lp
Move model creation to LinkPrediction Pipeline executor
2 parents 13d66d4 + 7006bfd commit 2080199

File tree

9 files changed

+53
-70
lines changed

9 files changed

+53
-70
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
public class LinkPredictionTrainingPipeline extends TrainingPipeline<LinkFeatureStep> {
3636

3737
public static final String PIPELINE_TYPE = "Link prediction training pipeline";
38+
public static final String MODEL_TYPE = "LinkPrediction";
3839

3940
private LinkPredictionSplitConfig splitConfig;
4041

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.jetbrains.annotations.NotNull;
2323
import org.neo4j.gds.RelationshipType;
2424
import org.neo4j.gds.api.Graph;
25-
import org.neo4j.gds.core.model.Model;
2625
import org.neo4j.gds.core.utils.TerminationFlag;
2726
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2827
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
@@ -48,16 +47,13 @@
4847
import org.neo4j.gds.ml.models.automl.RandomSearch;
4948
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
5049
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
51-
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
52-
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
5350
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
5451
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
5552
import org.neo4j.gds.ml.splitting.EdgeSplitter;
5653
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
5754
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
5855

5956
import java.util.List;
60-
import java.util.Map;
6157
import java.util.TreeSet;
6258
import java.util.function.BiConsumer;
6359
import java.util.stream.Collectors;
@@ -68,8 +64,6 @@
6864

6965
public final class LinkPredictionTrain {
7066

71-
public static final String MODEL_TYPE = "LinkPrediction";
72-
7367
private final Graph trainGraph;
7468
private final Graph validationGraph;
7569
private final LinkPredictionTrainingPipeline pipeline;
@@ -119,7 +113,6 @@ public static List<Task> progressTasks(int validationFolds, int numberOfModelSel
119113
}
120114

121115
public LinkPredictionTrainResult compute() {
122-
123116
progressTracker.beginSubTask("Extract train features");
124117
var trainData = extractFeaturesAndLabels(
125118
trainGraph,
@@ -169,13 +162,7 @@ public LinkPredictionTrainResult compute() {
169162
var testMetrics = trainingStatistics.winningModelTestMetrics();
170163
progressTracker.logMessage(formatWithLocale("Final model metrics on test set: %s", testMetrics));
171164

172-
var model = createModel(
173-
trainingStatistics.bestParameters(),
174-
classifier.data(),
175-
trainingStatistics.metricsForWinningModel()
176-
);
177-
178-
return LinkPredictionTrainResult.of(model, trainingStatistics);
165+
return ImmutableLinkPredictionTrainResult.of(classifier, trainingStatistics);
179166
}
180167

181168
@NotNull
@@ -348,26 +335,6 @@ private void computeTrainMetric(
348335
);
349336
}
350337

351-
private Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> createModel(
352-
TrainerConfig bestParameters,
353-
Classifier.ClassifierData classifierData,
354-
Map<Metric, BestMetricData> winnerMetrics
355-
) {
356-
return Model.of(
357-
config.username(),
358-
config.modelName(),
359-
MODEL_TYPE,
360-
trainGraph.schema(),
361-
classifierData,
362-
config,
363-
LinkPredictionModelInfo.of(
364-
bestParameters,
365-
winnerMetrics,
366-
LinkPredictionPredictPipeline.from(pipeline)
367-
)
368-
);
369-
}
370-
371338
public static MemoryEstimation estimate(
372339
LinkPredictionTrainingPipeline pipeline,
373340
LinkPredictionTrainConfig trainConfig

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainResult.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,12 @@
2020
package org.neo4j.gds.ml.pipeline.linkPipeline.train;
2121

2222
import org.neo4j.gds.annotation.ValueClass;
23-
import org.neo4j.gds.core.model.Model;
2423
import org.neo4j.gds.ml.models.Classifier;
2524
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
26-
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
2725

2826
@ValueClass
2927
public interface LinkPredictionTrainResult {
30-
Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model();
28+
Classifier classifier();
3129

3230
TrainingStatistics trainingStatistics();
33-
34-
static LinkPredictionTrainResult of(
35-
Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model,
36-
TrainingStatistics trainingStatistics
37-
) {
38-
return ImmutableLinkPredictionTrainResult.of(model, trainingStatistics);
39-
}
4031
}

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainTest.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,24 +221,18 @@ void trainsAModel() {
221221

222222
assertThat(result.trainingStatistics().getTrainStats(LinkMetric.AUCPR).size()).isEqualTo(MAX_TRIALS);
223223

224-
var actualModel = result.model();
224+
var trainedClassifier = result.classifier();
225225

226-
assertThat(actualModel.name()).isEqualTo(modelName);
227-
assertThat(actualModel.algoType()).isEqualTo(LinkPredictionTrain.MODEL_TYPE);
228-
assertThat(actualModel.trainConfig()).isEqualTo(trainConfig);
229-
// length of the linkFeatures
230-
231-
assertThat((LogisticRegressionData) actualModel.data())
226+
assertThat((LogisticRegressionData) trainedClassifier.data())
232227
.extracting(llrData -> llrData.weights().data().totalSize())
233228
.isEqualTo(6);
234229

235-
var customInfo = actualModel.customInfo();
236230
assertThat(result.trainingStatistics().getValidationStats(LinkMetric.AUCPR))
237231
.satisfies(scores ->
238232
assertThat(scores.get(0).avg()).isNotCloseTo(scores.get(1).avg(), Percentage.withPercentage(0.2))
239233
);
240234

241-
assertThat(customInfo.bestParameters())
235+
assertThat(result.trainingStatistics().bestParameters())
242236
.usingRecursiveComparison()
243237
.isEqualTo(LogisticRegressionTrainConfig.of(Map.of("penalty", 1, "patience", 5, "tolerance", 0.00001)));
244238
}
@@ -250,10 +244,10 @@ void seededTrain() {
250244
LinkPredictionTrainConfig trainConfig = trainingConfig(modelName);
251245

252246
var modelData = runLinkPrediction(trainConfig)
253-
.model()
247+
.classifier()
254248
.data();
255249
var modelDataRepeated = runLinkPrediction(trainConfig)
256-
.model()
250+
.classifier()
257251
.data();
258252

259253
assertThat(modelData)

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionPipelineTrainProc.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
import org.neo4j.gds.executor.GdsCallable;
2828
import org.neo4j.gds.ml.MLTrainResult;
2929
import org.neo4j.gds.ml.PipelineCompanion;
30-
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrain;
30+
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
3131
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
32-
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainResult;
3332
import org.neo4j.gds.results.MemoryEstimateResult;
3433
import org.neo4j.procedure.Description;
3534
import org.neo4j.procedure.Mode;
@@ -40,12 +39,13 @@
4039
import java.util.stream.Stream;
4140

4241
import static org.neo4j.gds.executor.ExecutionMode.TRAIN;
42+
import static org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainPipelineExecutor.LinkPredictionTrainPipelineResult;
4343
import static org.neo4j.procedure.Mode.READ;
4444

4545
@GdsCallable(name = "gds.beta.pipeline.linkPrediction.train", description = "Trains a link prediction model based on a pipeline", executionMode = TRAIN)
4646
public class LinkPredictionPipelineTrainProc extends TrainProc<
4747
LinkPredictionTrainPipelineExecutor,
48-
LinkPredictionTrainResult,
48+
LinkPredictionTrainPipelineResult,
4949
LinkPredictionTrainConfig,
5050
LinkPredictionPipelineTrainProc.LPTrainResult
5151
> {
@@ -82,17 +82,17 @@ public GraphStoreAlgorithmFactory<LinkPredictionTrainPipelineExecutor, LinkPredi
8282

8383
@Override
8484
protected String modelType() {
85-
return LinkPredictionTrain.MODEL_TYPE;
85+
return LinkPredictionTrainingPipeline.MODEL_TYPE;
8686
}
8787

8888
@Override
89-
protected Model<?, ?, ?> extractModel(LinkPredictionTrainResult algoResult) {
89+
protected Model<?, ?, ?> extractModel(LinkPredictionTrainPipelineResult algoResult) {
9090
return algoResult.model();
9191
}
9292

9393
@Override
9494
protected LPTrainResult constructProcResult(
95-
ComputationResult<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainResult, LinkPredictionTrainConfig> computationResult
95+
ComputationResult<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainPipelineResult, LinkPredictionTrainConfig> computationResult
9696
) {
9797
return new LPTrainResult(computationResult.result(), computationResult.computeMillis());
9898
}
@@ -102,7 +102,7 @@ public static class LPTrainResult extends MLTrainResult {
102102

103103
public final Map<String, Object> modelSelectionStats;
104104

105-
public LPTrainResult(LinkPredictionTrainResult algoResult, long trainMillis) {
105+
public LPTrainResult(LinkPredictionTrainPipelineResult algoResult, long trainMillis) {
106106
super(algoResult.model(), trainMillis);
107107

108108
this.modelSelectionStats = algoResult.trainingStatistics().toMap();

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutor.java

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,36 @@
2020
package org.neo4j.gds.ml.linkmodels.pipeline.train;
2121

2222
import org.neo4j.gds.RelationshipType;
23+
import org.neo4j.gds.annotation.ValueClass;
2324
import org.neo4j.gds.api.GraphStore;
25+
import org.neo4j.gds.core.model.Model;
2426
import org.neo4j.gds.core.model.ModelCatalog;
2527
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2628
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
2729
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2830
import org.neo4j.gds.executor.ExecutionContext;
31+
import org.neo4j.gds.ml.models.Classifier;
2932
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
3033
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
34+
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
35+
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
36+
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
3137
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
3238
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrain;
3339
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
34-
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainResult;
3540

3641
import java.util.List;
3742
import java.util.Map;
3843
import java.util.Optional;
3944
import java.util.stream.Collectors;
4045

46+
import static org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainPipelineExecutor.LinkPredictionTrainPipelineResult;
4147
import static org.neo4j.gds.ml.linkmodels.pipeline.train.RelationshipSplitter.splitEstimation;
48+
import static org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline.MODEL_TYPE;
4249
import static org.neo4j.gds.ml.util.TrainingSetWarnings.warnForSmallRelationshipSets;
4350

4451
public class LinkPredictionTrainPipelineExecutor extends PipelineExecutor
45-
<LinkPredictionTrainConfig, LinkPredictionTrainingPipeline, LinkPredictionTrainResult> {
52+
<LinkPredictionTrainConfig, LinkPredictionTrainingPipeline, LinkPredictionTrainPipelineResult> {
4653

4754
private final RelationshipSplitter relationshipSplitter;
4855

@@ -118,7 +125,7 @@ public Map<DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
118125
}
119126

120127
@Override
121-
protected LinkPredictionTrainResult execute(Map<DatasetSplits, GraphFilter> dataSplits) {
128+
protected LinkPredictionTrainPipelineResult execute(Map<DatasetSplits, GraphFilter> dataSplits) {
122129
PipelineExecutor.validateTrainingParameterSpace(pipeline);
123130

124131
var trainDataSplit = dataSplits.get(DatasetSplits.TRAIN);
@@ -141,14 +148,32 @@ protected LinkPredictionTrainResult execute(Map<DatasetSplits, GraphFilter> data
141148
pipeline.splitConfig().validationFolds(),
142149
progressTracker
143150
);
144-
return new LinkPredictionTrain(
151+
152+
var trainResult = new LinkPredictionTrain(
145153
trainGraph,
146154
testGraph,
147155
pipeline,
148156
config,
149157
progressTracker,
150158
terminationFlag
151159
).compute();
160+
161+
var model = Model.of(
162+
config.username(),
163+
config.modelName(),
164+
MODEL_TYPE,
165+
schemaBeforeSteps,
166+
trainResult.classifier().data(),
167+
config,
168+
LinkPredictionModelInfo.of(
169+
trainResult.trainingStatistics().bestParameters(),
170+
trainResult.trainingStatistics().metricsForWinningModel(),
171+
LinkPredictionPredictPipeline.from(pipeline)
172+
)
173+
);
174+
175+
176+
return ImmutableLinkPredictionTrainPipelineResult.of(model, trainResult.trainingStatistics());
152177
}
153178

154179
private void removeDataSplitRelationships(Map<DatasetSplits, GraphFilter> datasets) {
@@ -165,4 +190,10 @@ protected void cleanUpGraphStore(Map<DatasetSplits, GraphFilter> datasets) {
165190
removeDataSplitRelationships(datasets);
166191
super.cleanUpGraphStore(datasets);
167192
}
193+
194+
@ValueClass
195+
public interface LinkPredictionTrainPipelineResult {
196+
Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model();
197+
TrainingStatistics trainingStatistics();
198+
}
168199
}

proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineProcTestBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import java.util.Map;
4848
import java.util.stream.Stream;
4949

50-
import static org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrain.MODEL_TYPE;
50+
import static org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline.MODEL_TYPE;
5151

5252
@Neo4jModelCatalogExtension
5353
abstract class LinkPredictionPipelineProcTestBase extends BaseProcTest {

proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineExecutorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
7373
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
7474
import static org.neo4j.gds.compat.TestLog.INFO;
75-
import static org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrain.MODEL_TYPE;
75+
import static org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline.MODEL_TYPE;
7676

7777
@Neo4jModelCatalogExtension
7878
class LinkPredictionPredictPipelineExecutorTest extends BaseProcTest {

proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutorTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
5757
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.HadamardFeatureStep;
5858
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.L2FeatureStep;
59-
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrain;
6059
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
6160
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfigImpl;
6261
import org.neo4j.gds.test.TestMutateProc;
@@ -178,7 +177,7 @@ void testProcedureAndLinkFeatures() {
178177

179178
assertThat(actualModel.name()).isEqualTo("model");
180179

181-
assertThat(actualModel.algoType()).isEqualTo(LinkPredictionTrain.MODEL_TYPE);
180+
assertThat(actualModel.algoType()).isEqualTo(LinkPredictionTrainingPipeline.MODEL_TYPE);
182181
assertThat(actualModel.trainConfig()).isEqualTo(config);
183182
// length of the linkFeatures
184183
assertThat(logisticRegressionData.weights().data().totalSize()).isEqualTo(6);

0 commit comments

Comments
 (0)