Skip to content

Commit 9ea772b

Browse files
authored
Merge pull request #5240 from FlorentinD/regression-pipeline-executor
Implement NodeRegressionPipelineExecutor
2 parents 9c5c51c + 0c34dfc commit 9ea772b

File tree

15 files changed

+522
-97
lines changed

15 files changed

+522
-97
lines changed

executor/src/main/java/org/neo4j/gds/executor/GdsCallableFinder.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ public final class GdsCallableFinder {
4242
"org.neo4j.gds"
4343
);
4444

45-
private static final List<String> DEFAULT_PACKAGE_BLACKLIST = List.of(
46-
"org.neo4j.gds.pregel",
47-
"org.neo4j.gds.test"
48-
);
45+
private static final List<String> DEFAULT_PACKAGE_BLACKLIST = List.of("org.neo4j.gds.pregel");
4946

5047
public static Stream<GdsCallableDefinition> findAll() {
5148
return findAll(DEFAULT_PACKAGE_BLACKLIST);

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/PipelineExecutor.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.neo4j.gds.RelationshipType;
2525
import org.neo4j.gds.annotation.ValueClass;
2626
import org.neo4j.gds.api.GraphStore;
27+
import org.neo4j.gds.api.schema.GraphSchema;
2728
import org.neo4j.gds.config.AlgoBaseConfig;
2829
import org.neo4j.gds.core.model.ModelCatalog;
2930
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
@@ -35,6 +36,7 @@
3536
import java.util.Collection;
3637
import java.util.List;
3738
import java.util.Map;
39+
import java.util.Set;
3840
import java.util.stream.Collectors;
3941

4042
import static org.neo4j.gds.config.MutatePropertyConfig.MUTATE_PROPERTY_KEY;
@@ -56,6 +58,7 @@ public enum DatasetSplits {
5658
protected final PIPELINE_CONFIG config;
5759
protected final ExecutionContext executionContext;
5860
protected final GraphStore graphStore;
61+
protected final GraphSchema schemaBeforeSteps;
5962
protected final String graphName;
6063

6164
protected PipelineExecutor(
@@ -72,6 +75,10 @@ protected PipelineExecutor(
7275
this.executionContext = executionContext;
7376
this.graphStore = graphStore;
7477
this.graphName = graphName;
78+
this.schemaBeforeSteps = graphStore
79+
.schema()
80+
.filterNodeLabels(Set.copyOf(config.nodeLabelIdentifiers(graphStore)))
81+
.filterRelationshipTypes(Set.copyOf(config.internalRelationshipTypes(graphStore)));
7582
}
7683

7784
public static MemoryEstimation estimateNodePropertySteps(ModelCatalog modelCatalog, List<ExecutableNodePropertyStep> nodePropertySteps, List<String> nodeLabels, List<String> relationshipTypes) {

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.java

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import org.jetbrains.annotations.NotNull;
2323
import org.neo4j.gds.api.Graph;
24-
import org.neo4j.gds.core.model.Model;
2524
import org.neo4j.gds.core.utils.TerminationFlag;
2625
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2726
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
@@ -51,7 +50,6 @@
5150
import org.neo4j.gds.ml.nodeClassification.ClassificationMetricComputer;
5251
import org.neo4j.gds.ml.nodePropertyPrediction.NodeSplitter;
5352
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
54-
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPredictPipeline;
5553
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
5654
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
5755
import org.neo4j.gds.ml.splitting.FractionSplitter;
@@ -73,7 +71,6 @@
7371

7472
public final class NodeClassificationTrain {
7573

76-
private final Graph graph;
7774
private final NodeClassificationPipelineTrainConfig config;
7875
private final NodeClassificationTrainingPipeline pipeline;
7976
private final Features features;
@@ -232,7 +229,6 @@ public static NodeClassificationTrain create(
232229
}
233230

234231
return new NodeClassificationTrain(
235-
graph,
236232
pipeline,
237233
config,
238234
features,
@@ -246,7 +242,6 @@ public static NodeClassificationTrain create(
246242
}
247243

248244
private NodeClassificationTrain(
249-
Graph graph,
250245
NodeClassificationTrainingPipeline pipeline,
251246
NodeClassificationPipelineTrainConfig config,
252247
Features features,
@@ -259,7 +254,6 @@ private NodeClassificationTrain(
259254
) {
260255
this.progressTracker = progressTracker;
261256
this.terminationFlag = terminationFlag;
262-
this.graph = graph;
263257
this.pipeline = pipeline;
264258
this.config = config;
265259
this.features = features;
@@ -293,10 +287,7 @@ public NodeClassificationTrainResult compute() {
293287

294288
Classifier retrainedModelData = retrainBestModel(nodeSplits.allTrainingExamples(), trainingStatistics);
295289

296-
return ImmutableNodeClassificationTrainResult.of(
297-
createModel(retrainedModelData, trainingStatistics),
298-
trainingStatistics
299-
);
290+
return ImmutableNodeClassificationTrainResult.of(retrainedModelData, trainingStatistics);
300291
}
301292

302293
private void selectBestModel(List<TrainingExamplesSplit> nodeSplits, TrainingStatistics trainingStatistics) {
@@ -412,29 +403,6 @@ private Classifier retrainBestModel(HugeLongArray trainSet, TrainingStatistics t
412403
return retrainedClassifier;
413404
}
414405

415-
private Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> createModel(
416-
Classifier classifier,
417-
TrainingStatistics trainingStatistics
418-
) {
419-
420-
var modelInfo = NodeClassificationPipelineModelInfo.builder()
421-
.classes(classIdMap.originalIdsList())
422-
.bestParameters(trainingStatistics.bestParameters())
423-
.metrics(trainingStatistics.metricsForWinningModel())
424-
.pipeline(NodeClassificationPredictPipeline.from(pipeline))
425-
.build();
426-
427-
return Model.of(
428-
config.username(),
429-
config.modelName(),
430-
NodeClassificationTrainingPipeline.MODEL_TYPE,
431-
graph.schema(),
432-
classifier.data(),
433-
config,
434-
modelInfo
435-
);
436-
}
437-
438406
private Classifier trainModel(
439407
HugeLongArray trainSet,
440408
TrainerConfig trainerConfig,

proc/machine-learning/src/main/java/org/neo4j/gds/ml/nodemodels/NodeClassificationTrainPipelineAlgorithmFactory.java renamed to pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineAlgorithmFactory.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.ml.nodemodels;
20+
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;
2121

2222
import org.neo4j.gds.GraphStoreAlgorithmFactory;
2323
import org.neo4j.gds.api.GraphStore;
@@ -27,11 +27,8 @@
2727
import org.neo4j.gds.core.utils.progress.tasks.Task;
2828
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
2929
import org.neo4j.gds.executor.ExecutionContext;
30-
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationTrainPipelineExecutor;
3130
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
3231
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
33-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
34-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain;
3532

3633
import java.util.ArrayList;
3734
import java.util.List;

proc/machine-learning/src/main/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationTrainPipelineExecutor.java renamed to pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutor.java

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,33 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.ml.nodemodels.pipeline;
20+
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;
2121

22+
import org.neo4j.gds.annotation.ValueClass;
2223
import org.neo4j.gds.api.GraphStore;
24+
import org.neo4j.gds.core.model.Model;
2325
import org.neo4j.gds.core.model.ModelCatalog;
2426
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2527
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
2628
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2729
import org.neo4j.gds.executor.ExecutionContext;
30+
import org.neo4j.gds.ml.models.Classifier;
2831
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
2932
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
33+
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
34+
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPredictPipeline;
3035
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
31-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
32-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain;
33-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainResult;
3436

3537
import java.util.List;
3638
import java.util.Map;
3739
import java.util.Optional;
3840

41+
import static org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainPipelineExecutor.NodeClassificationTrainPipelineResult;
42+
3943
public class NodeClassificationTrainPipelineExecutor extends PipelineExecutor<
4044
NodeClassificationPipelineTrainConfig,
4145
NodeClassificationTrainingPipeline,
42-
NodeClassificationTrainResult
46+
NodeClassificationTrainPipelineResult
4347
> {
4448

4549
public NodeClassificationTrainPipelineExecutor(
@@ -92,7 +96,7 @@ public Map<DatasetSplits, GraphFilter> splitDataset() {
9296
}
9397

9498
@Override
95-
protected NodeClassificationTrainResult execute(Map<DatasetSplits, GraphFilter> dataSplits) {
99+
protected NodeClassificationTrainPipelineResult execute(Map<DatasetSplits, GraphFilter> dataSplits) {
96100
PipelineExecutor.validateTrainingParameterSpace(pipeline);
97101

98102
var nodeLabels = config.nodeLabelIdentifiers(graphStore);
@@ -101,8 +105,31 @@ protected NodeClassificationTrainResult execute(Map<DatasetSplits, GraphFilter>
101105

102106
this.pipeline.splitConfig().validateMinNumNodesInSplitSets(graph);
103107

104-
return NodeClassificationTrain
108+
var trainResult = NodeClassificationTrain
105109
.create(graph, pipeline, config, progressTracker, terminationFlag)
106110
.compute();
111+
112+
var catalogModel = Model.of(
113+
config.username(),
114+
config.modelName(),
115+
NodeClassificationTrainingPipeline.MODEL_TYPE,
116+
schemaBeforeSteps,
117+
trainResult.classifier().data(),
118+
config,
119+
NodeClassificationPipelineModelInfo.builder()
120+
.classes(trainResult.classifier().classIdMap().originalIdsList())
121+
.bestParameters(trainResult.trainingStatistics().bestParameters())
122+
.metrics(trainResult.trainingStatistics().metricsForWinningModel())
123+
.pipeline(NodeClassificationPredictPipeline.from(pipeline))
124+
.build()
125+
);
126+
127+
return ImmutableNodeClassificationTrainPipelineResult.of(catalogModel, trainResult.trainingStatistics());
128+
}
129+
130+
@ValueClass
131+
public interface NodeClassificationTrainPipelineResult {
132+
Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model();
133+
TrainingStatistics trainingStatistics();
107134
}
108135
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainResult.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;
2121

2222
import org.neo4j.gds.annotation.ValueClass;
23-
import org.neo4j.gds.core.model.Model;
2423
import org.neo4j.gds.ml.models.Classifier;
2524
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
2625

2726
@ValueClass
2827
public interface NodeClassificationTrainResult {
29-
Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model();
28+
Classifier classifier();
3029
TrainingStatistics trainingStatistics();
3130
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.nodePipeline.regression;
21+
22+
import org.immutables.value.Value;
23+
import org.neo4j.gds.annotation.ValueClass;
24+
import org.neo4j.gds.config.ToMapConvertible;
25+
import org.neo4j.gds.ml.metrics.BestMetricData;
26+
import org.neo4j.gds.ml.metrics.Metric;
27+
import org.neo4j.gds.ml.models.TrainerConfig;
28+
29+
import java.util.Map;
30+
import java.util.stream.Collectors;
31+
32+
@ValueClass
33+
public interface NodeRegressionPipelineModelInfo extends ToMapConvertible {
34+
35+
/**
36+
* The parameters that yielded the best fold-averaged validation score
37+
* for the selection metric.
38+
* @return
39+
*/
40+
TrainerConfig bestParameters();
41+
42+
Map<Metric, BestMetricData> metrics();
43+
44+
@Override
45+
@Value.Auxiliary
46+
@Value.Derived
47+
default Map<String, Object> toMap() {
48+
return Map.of(
49+
"bestParameters", bestParameters().toMapWithTrainerMethod(),
50+
"metrics", metrics().entrySet().stream().collect(Collectors.toMap(
51+
entry -> entry.getKey().toString(),
52+
entry -> entry.getValue().toMap()
53+
))
54+
);
55+
}
56+
57+
static ImmutableNodeRegressionPipelineModelInfo.Builder builder() {
58+
return ImmutableNodeRegressionPipelineModelInfo.builder();
59+
}
60+
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrain.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public final class NodeRegressionTrain {
5656
private final ProgressTracker progressTracker;
5757
private final TerminationFlag terminationFlag;
5858

59-
public static List<Task> progressTask(int validationFolds, int numberOfModelSelectionTrials) {
59+
public static List<Task> progressTasks(int validationFolds, int numberOfModelSelectionTrials) {
6060
return List.of(
6161
Tasks.leaf("Shuffle and Split"),
6262
Tasks.iterativeFixed(

0 commit comments

Comments
 (0)