Skip to content

Commit 8d5293d

Browse files
Mats-SXFlorentinD
andcommitted
Implement NodeRegressionTrainPipelineExecutor
- It owns creating a Model instance for the catalog - Schema in Model is the schema before node property steps have run Co-authored-by: Florentin Dörre <florentin.dorre@neotechnology.com>
1 parent 72f6846 commit 8d5293d

File tree

5 files changed

+435
-0
lines changed

5 files changed

+435
-0
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/PipelineExecutor.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.neo4j.gds.RelationshipType;
2525
import org.neo4j.gds.annotation.ValueClass;
2626
import org.neo4j.gds.api.GraphStore;
27+
import org.neo4j.gds.api.schema.GraphSchema;
2728
import org.neo4j.gds.config.AlgoBaseConfig;
2829
import org.neo4j.gds.core.model.ModelCatalog;
2930
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
@@ -35,6 +36,7 @@
3536
import java.util.Collection;
3637
import java.util.List;
3738
import java.util.Map;
39+
import java.util.Set;
3840
import java.util.stream.Collectors;
3941

4042
import static org.neo4j.gds.config.MutatePropertyConfig.MUTATE_PROPERTY_KEY;
@@ -56,6 +58,7 @@ public enum DatasetSplits {
5658
protected final PIPELINE_CONFIG config;
5759
protected final ExecutionContext executionContext;
5860
protected final GraphStore graphStore;
61+
protected final GraphSchema schemaBeforeSteps;
5962
protected final String graphName;
6063

6164
protected PipelineExecutor(
@@ -72,6 +75,10 @@ protected PipelineExecutor(
7275
this.executionContext = executionContext;
7376
this.graphStore = graphStore;
7477
this.graphName = graphName;
78+
this.schemaBeforeSteps = graphStore
79+
.schema()
80+
.filterNodeLabels(Set.copyOf(config.nodeLabelIdentifiers(graphStore)))
81+
.filterRelationshipTypes(Set.copyOf(config.internalRelationshipTypes(graphStore)));
7582
}
7683

7784
public static MemoryEstimation estimateNodePropertySteps(ModelCatalog modelCatalog, List<ExecutableNodePropertyStep> nodePropertySteps, List<String> nodeLabels, List<String> relationshipTypes) {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.nodePipeline.regression;
21+
22+
import org.immutables.value.Value;
23+
import org.neo4j.gds.annotation.ValueClass;
24+
import org.neo4j.gds.config.ToMapConvertible;
25+
import org.neo4j.gds.ml.metrics.BestMetricData;
26+
import org.neo4j.gds.ml.metrics.Metric;
27+
import org.neo4j.gds.ml.models.TrainerConfig;
28+
29+
import java.util.Map;
30+
import java.util.stream.Collectors;
31+
32+
@ValueClass
33+
public interface NodeRegressionPipelineModelInfo extends ToMapConvertible {
34+
35+
/**
36+
* The parameters that yielded the best fold-averaged validation score
37+
* for the selection metric.
38+
* @return
39+
*/
40+
TrainerConfig bestParameters();
41+
42+
Map<Metric, BestMetricData> metrics();
43+
44+
@Override
45+
@Value.Auxiliary
46+
@Value.Derived
47+
default Map<String, Object> toMap() {
48+
return Map.of(
49+
"bestParameters", bestParameters().toMapWithTrainerMethod(),
50+
"metrics", metrics().entrySet().stream().collect(Collectors.toMap(
51+
entry -> entry.getKey().toString(),
52+
entry -> entry.getValue().toMap()
53+
))
54+
);
55+
}
56+
57+
static ImmutableNodeRegressionPipelineModelInfo.Builder builder() {
58+
return ImmutableNodeRegressionPipelineModelInfo.builder();
59+
}
60+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.nodePipeline.regression;
21+
22+
import org.neo4j.gds.annotation.ValueClass;
23+
import org.neo4j.gds.api.GraphStore;
24+
import org.neo4j.gds.core.model.Model;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
26+
import org.neo4j.gds.core.utils.progress.tasks.Task;
27+
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
28+
import org.neo4j.gds.executor.ExecutionContext;
29+
import org.neo4j.gds.ml.models.Regressor;
30+
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
31+
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
32+
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
33+
34+
import java.util.ArrayList;
35+
import java.util.List;
36+
import java.util.Map;
37+
import java.util.Optional;
38+
39+
import static org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainPipelineExecutor.NodeRegressionTrainPipelineResult;
40+
41+
public class NodeRegressionTrainPipelineExecutor extends PipelineExecutor<
42+
NodeRegressionPipelineTrainConfig,
43+
NodeRegressionTrainingPipeline,
44+
NodeRegressionTrainPipelineResult
45+
> {
46+
47+
public static Task progressTask(NodeRegressionTrainingPipeline pipeline) {
48+
return Tasks.task(
49+
"Node Regression Train Pipeline",
50+
new ArrayList<>() {{
51+
add(Tasks.iterativeFixed(
52+
"Execute node property steps",
53+
() -> List.of(Tasks.leaf("Step")),
54+
pipeline.nodePropertySteps().size()
55+
));
56+
addAll(NodeRegressionTrain.progressTasks(
57+
pipeline.splitConfig().validationFolds(),
58+
pipeline.numberOfModelSelectionTrials()
59+
));
60+
61+
}}
62+
);
63+
}
64+
65+
public NodeRegressionTrainPipelineExecutor(
66+
NodeRegressionTrainingPipeline pipeline,
67+
NodeRegressionPipelineTrainConfig config,
68+
ExecutionContext executionContext,
69+
GraphStore graphStore,
70+
ProgressTracker progressTracker
71+
) {
72+
super(pipeline, config, executionContext, graphStore, config.graphName(), progressTracker);
73+
}
74+
75+
@Override
76+
public Map<DatasetSplits, GraphFilter> splitDataset() {
77+
// we don't split the input graph but generate the features and predict over the whole graph.
78+
// Inside the training algo we split the nodes into multiple sets.
79+
return Map.of(
80+
DatasetSplits.FEATURE_INPUT,
81+
ImmutableGraphFilter.of(
82+
config.nodeLabelIdentifiers(graphStore),
83+
config.internalRelationshipTypes(graphStore)
84+
)
85+
);
86+
}
87+
88+
@Override
89+
protected NodeRegressionTrainPipelineResult execute(Map<DatasetSplits, GraphFilter> dataSplits) {
90+
PipelineExecutor.validateTrainingParameterSpace(pipeline);
91+
92+
var nodeLabels = config.nodeLabelIdentifiers(graphStore);
93+
var relationshipTypes = config.internalRelationshipTypes(graphStore);
94+
var graph = graphStore.getGraph(nodeLabels, relationshipTypes, Optional.empty());
95+
96+
this.pipeline.splitConfig().validateMinNumNodesInSplitSets(graph);
97+
98+
NodeRegressionTrainResult trainResult = NodeRegressionTrain
99+
.create(graph, pipeline, config, progressTracker, terminationFlag)
100+
.compute();
101+
102+
var catalogModel = Model.of(
103+
config.username(),
104+
config.modelName(),
105+
NodeRegressionTrainingPipeline.MODEL_TYPE,
106+
schemaBeforeSteps,
107+
trainResult.regressor().data(),
108+
config,
109+
NodeRegressionPipelineModelInfo.builder()
110+
.bestParameters(trainResult.trainingStatistics().bestParameters())
111+
.metrics(trainResult.trainingStatistics().metricsForWinningModel())
112+
.build()
113+
);
114+
115+
return ImmutableNodeRegressionTrainPipelineResult.of(catalogModel, trainResult.trainingStatistics());
116+
}
117+
118+
@ValueClass
119+
interface NodeRegressionTrainPipelineResult {
120+
Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo> model();
121+
TrainingStatistics trainingStatistics();
122+
}
123+
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrainingPipeline.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
public class NodeRegressionTrainingPipeline extends NodePropertyTrainingPipeline {
2727

2828
public static final String PIPELINE_TYPE = "Node regression training pipeline";
29+
public static final String MODEL_TYPE = "NodeRegression";
2930

3031
public NodeRegressionTrainingPipeline() {
3132
super(TrainingType.REGRESSION);

0 commit comments

Comments
 (0)