Skip to content

Commit 6b6e94c

Browse files
committed
Retrieve nodes-only graph for NC/NR
1 parent 6bc2b86 commit 6b6e94c

File tree

4 files changed

+10
-25
lines changed

4 files changed

+10
-25
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutor.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import java.util.ArrayList;
4040
import java.util.List;
4141
import java.util.Map;
42-
import java.util.Optional;
4342

4443
import static org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainPipelineExecutor.NodeClassificationTrainPipelineResult;
4544

@@ -121,13 +120,12 @@ protected NodeClassificationTrainPipelineResult execute(Map<DatasetSplits, Graph
121120
PipelineExecutor.validateTrainingParameterSpace(pipeline);
122121

123122
var nodeLabels = config.nodeLabelIdentifiers(graphStore);
124-
var relationshipTypes = config.internalRelationshipTypes(graphStore);
125-
var graph = graphStore.getGraph(nodeLabels, relationshipTypes, Optional.empty());
123+
var nodesGraph = graphStore.getGraph(nodeLabels);
126124

127-
this.pipeline.splitConfig().validateMinNumNodesInSplitSets(graph);
125+
this.pipeline.splitConfig().validateMinNumNodesInSplitSets(nodesGraph);
128126

129127
var trainResult = NodeClassificationTrain
130-
.create(graph, pipeline, config, progressTracker, terminationFlag)
128+
.create(nodesGraph, pipeline, config, progressTracker, terminationFlag)
131129
.compute();
132130

133131
var catalogModel = Model.of(

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrainPipelineExecutor.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import java.util.ArrayList;
3636
import java.util.List;
3737
import java.util.Map;
38-
import java.util.Optional;
3938

4039
import static org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainPipelineExecutor.NodeRegressionTrainPipelineResult;
4140

@@ -91,13 +90,12 @@ protected NodeRegressionTrainPipelineResult execute(Map<DatasetSplits, GraphFilt
9190
PipelineExecutor.validateTrainingParameterSpace(pipeline);
9291

9392
var nodeLabels = config.nodeLabelIdentifiers(graphStore);
94-
var relationshipTypes = config.internalRelationshipTypes(graphStore);
95-
var graph = graphStore.getGraph(nodeLabels, relationshipTypes, Optional.empty());
93+
var nodesGraph = graphStore.getGraph(nodeLabels);
9694

97-
this.pipeline.splitConfig().validateMinNumNodesInSplitSets(graph);
95+
this.pipeline.splitConfig().validateMinNumNodesInSplitSets(nodesGraph);
9896

9997
NodeRegressionTrainResult trainResult = NodeRegressionTrain
100-
.create(graph, pipeline, config, progressTracker, terminationFlag)
98+
.create(nodesGraph, pipeline, config, progressTracker, terminationFlag)
10199
.compute();
102100

103101
var catalogModel = Model.of(

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineExecutor.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
import java.util.List;
4444
import java.util.Map;
45-
import java.util.Optional;
4645

4746
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
4847

@@ -124,13 +123,8 @@ public Map<DatasetSplits, GraphFilter> splitDataset() {
124123

125124
@Override
126125
protected NodeClassificationPredict.NodeClassificationResult execute(Map<DatasetSplits, GraphFilter> dataSplits) {
127-
var graph = graphStore.getGraph(
128-
config.nodeLabelIdentifiers(graphStore),
129-
config.internalRelationshipTypes(graphStore),
130-
Optional.empty()
131-
);
132-
133-
var features = FeaturesFactory.extractLazyFeatures(graph, pipeline.featureProperties());
126+
var nodesGraph = graphStore.getGraph(config.nodeLabelIdentifiers(graphStore));
127+
var features = FeaturesFactory.extractLazyFeatures(nodesGraph, pipeline.featureProperties());
134128

135129
if (features.featureDimension() != modelData.featureDimension()) {
136130
throw new IllegalArgumentException(formatWithLocale(

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPredictPipelineExecutor.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
import java.util.List;
3939
import java.util.Map;
40-
import java.util.Optional;
4140

4241
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
4342

@@ -87,12 +86,8 @@ public Map<DatasetSplits, GraphFilter> splitDataset() {
8786

8887
@Override
8988
protected HugeDoubleArray execute(Map<DatasetSplits, GraphFilter> dataSplits) {
90-
var graph = graphStore.getGraph(
91-
config.nodeLabelIdentifiers(graphStore),
92-
config.internalRelationshipTypes(graphStore),
93-
Optional.empty()
94-
);
95-
Features features = FeaturesFactory.extractLazyFeatures(graph, pipeline.featureProperties());
89+
var nodesGraph = graphStore.getGraph(config.nodeLabelIdentifiers(graphStore));
90+
Features features = FeaturesFactory.extractLazyFeatures(nodesGraph, pipeline.featureProperties());
9691

9792
if (features.featureDimension() != regressor.data().featureDimension()) {
9893
throw new IllegalArgumentException(formatWithLocale(

0 commit comments

Comments
 (0)