Skip to content

Commit 47ed9b4

Browse files
FlorentinDMats-SX
andcommitted
Add proc for creating node regression pipelines
Co-authored-by: Mats Rydberg <mats@neotechnology.com>
1 parent 6a75422 commit 47ed9b4

File tree

8 files changed

+122
-1
lines changed

8 files changed

+122
-1
lines changed

doc/antora/content-nav.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
**** xref:machine-learning/node-property-prediction/nodeclassification-pipelines/config/index.adoc[]
105105
**** xref:machine-learning/node-property-prediction/nodeclassification-pipelines/training/index.adoc[]
106106
**** xref:machine-learning/node-property-prediction/nodeclassification-pipelines/predict/index.adoc[]
107+
*** xref:machine-learning/node-property-prediction/noderegression-pipelines/index.adoc[]
107108
** xref:machine-learning/linkprediction-pipelines/index.adoc[]
108109
*** xref:machine-learning/linkprediction-pipelines/config/index.adoc[]
109110
*** xref:machine-learning/linkprediction-pipelines/training/index.adoc[]

doc/asciidoc/machine-learning/node-property-prediction/node-property-prediction.adoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@ The Neo4j Graph Data Science library support the following node property predict
1111

1212
* Beta
1313
** <<nodeclassification-pipelines>>
14+
* Alpha
15+
** <<noderegression-pipelines>>
1416

1517
include::nodeclassification-pipeline/nodeclassification.adoc[leveloffset=+1]
18+
include::noderegression-pipeline/noderegression.adoc[leveloffset=+1]
19+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[[noderegression-pipelines]]
2+
= Node regression pipelines
3+
4+
// TODO add config content

doc/asciidoc/operations-reference/appendix-a-machine-learning.adoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ include::model-operation-references.adoc[]
6262
.2+<.^| <<nodeclassification-pipelines, Node Classification Pipeline>>
6363
| `gds.alpha.pipeline.nodeClassification.addRandomForest`
6464
| `gds.alpha.pipeline.nodeClassification.configureAutoTuning`
65+
.1+<.^| <<noderegression-pipelines, Node Regression Pipeline>>
66+
| `gds.alpha.pipeline.nodeRegression.create`
6567
|===
6668

6769

doc/docbook/content-map.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@
304304
<d:tocentry linkend="nodeclassification-pipelines-predict"><?dbhtml filename="machine-learning/node-property-prediction/nodeclassification-pipelines/predict/index.html"?>
305305
</d:tocentry>
306306
</d:tocentry>
307+
<d:tocentry linkend="noderegression-pipelines"><?dbhtml filename="machine-learning/node-property-prediction/noderegression-pipelines/index.html"?>
308+
</d:tocentry>
307309
</d:tocentry>
308310

309311
<d:tocentry linkend="linkprediction-pipelines"><?dbhtml filename="machine-learning/linkprediction-pipelines/index.html"?>

open-packaging/src/test/java/org/neo4j/gds/OpenGdsProcedureSmokeTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class OpenGdsProcedureSmokeTest extends BaseProcTest {
9494
"gds.beta.pipeline.linkPrediction.train",
9595
"gds.beta.pipeline.linkPrediction.train.estimate",
9696

97+
"gds.alpha.pipeline.nodeRegression.create",
98+
9799
"gds.beta.pipeline.nodeClassification.selectFeatures",
98100
"gds.beta.pipeline.nodeClassification.addNodeProperty",
99101
"gds.beta.pipeline.nodeClassification.addLogisticRegression",
@@ -444,7 +446,7 @@ void countShouldMatch() {
444446
);
445447

446448
// If you find yourself updating this count, please also update the count in SmokeTest.kt
447-
int expectedCount = 304;
449+
int expectedCount = 305;
448450
assertEquals(
449451
expectedCount,
450452
registeredProcedures.size(),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.node.regression.configure;
21+
22+
import org.neo4j.gds.BaseProc;
23+
import org.neo4j.gds.core.StringIdentifierValidations;
24+
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
25+
import org.neo4j.gds.ml.pipeline.node.NodePipelineInfoResult;
26+
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainingPipeline;
27+
import org.neo4j.procedure.Description;
28+
import org.neo4j.procedure.Name;
29+
import org.neo4j.procedure.Procedure;
30+
31+
import java.util.stream.Stream;
32+
33+
import static org.neo4j.procedure.Mode.READ;
34+
35+
public class NodeRegressionPipelineCreateProc extends BaseProc {
36+
37+
@Procedure(name = "gds.alpha.pipeline.nodeRegression.create", mode = READ)
38+
@Description("Creates a node regression training pipeline in the pipeline catalog.")
39+
public Stream<NodePipelineInfoResult> create(@Name("pipelineName") String pipelineName) {
40+
StringIdentifierValidations.validateNoWhiteCharacter(pipelineName, "pipelineName");
41+
42+
var pipeline = new NodeRegressionTrainingPipeline();
43+
44+
PipelineCatalog.set(username(), pipelineName, pipeline);
45+
46+
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.node.regression.configure;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.BaseProcTest;
25+
import org.neo4j.gds.ml.models.TrainingMethod;
26+
import org.neo4j.gds.ml.pipeline.AutoTuningConfig;
27+
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
28+
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
29+
30+
import java.util.List;
31+
import java.util.Map;
32+
33+
import static org.assertj.core.api.Assertions.assertThat;
34+
35+
class NodeRegressionPipelineCreateProcTest extends BaseProcTest {
36+
37+
@BeforeEach
38+
void setUp() throws Exception {
39+
registerProcedures(NodeRegressionPipelineCreateProc.class);
40+
}
41+
42+
@Test
43+
void createPipeline() {
44+
assertCypherResult("CALL gds.alpha.pipeline.nodeRegression.create('p')", List.of(Map.of(
45+
"name", "p",
46+
"nodePropertySteps", List.of(),
47+
"featureProperties", List.of(),
48+
"splitConfig", NodePropertyPredictionSplitConfig.DEFAULT_CONFIG.toMap(),
49+
"autoTuningConfig", AutoTuningConfig.DEFAULT_CONFIG.toMap(),
50+
"parameterSpace", Map.of(
51+
TrainingMethod.LinearRegression.name(), List.of(),
52+
TrainingMethod.RandomForest.name(), List.of()
53+
)
54+
)));
55+
56+
assertThat(PipelineCatalog.exists(getUsername(), "p")).isTrue();
57+
}
58+
}

0 commit comments

Comments
 (0)