Skip to content

Commit 7b5d362

Browse files
Mats-SXFlorentinD
authored andcommitted
Remove Model dependency from NodeClassificationTrain
It returns a Classifier directly. Creating the model is the job of the pipeline.
1 parent a6d432e commit 7b5d362

File tree

4 files changed

+34
-52
lines changed

4 files changed

+34
-52
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.java

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import org.jetbrains.annotations.NotNull;
2323
import org.neo4j.gds.api.Graph;
24-
import org.neo4j.gds.core.model.Model;
2524
import org.neo4j.gds.core.utils.TerminationFlag;
2625
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2726
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
@@ -51,7 +50,6 @@
5150
import org.neo4j.gds.ml.nodeClassification.ClassificationMetricComputer;
5251
import org.neo4j.gds.ml.nodePropertyPrediction.NodeSplitter;
5352
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
54-
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPredictPipeline;
5553
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
5654
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
5755
import org.neo4j.gds.ml.splitting.FractionSplitter;
@@ -293,10 +291,7 @@ public NodeClassificationTrainResult compute() {
293291

294292
Classifier retrainedModelData = retrainBestModel(nodeSplits.allTrainingExamples(), trainingStatistics);
295293

296-
return ImmutableNodeClassificationTrainResult.of(
297-
createModel(retrainedModelData, trainingStatistics),
298-
trainingStatistics
299-
);
294+
return ImmutableNodeClassificationTrainResult.of(retrainedModelData, trainingStatistics);
300295
}
301296

302297
private void selectBestModel(List<TrainingExamplesSplit> nodeSplits, TrainingStatistics trainingStatistics) {
@@ -412,29 +407,6 @@ private Classifier retrainBestModel(HugeLongArray trainSet, TrainingStatistics t
412407
return retrainedClassifier;
413408
}
414409

415-
private Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> createModel(
416-
Classifier classifier,
417-
TrainingStatistics trainingStatistics
418-
) {
419-
420-
var modelInfo = NodeClassificationPipelineModelInfo.builder()
421-
.classes(classIdMap.originalIdsList())
422-
.bestParameters(trainingStatistics.bestParameters())
423-
.metrics(trainingStatistics.metricsForWinningModel())
424-
.pipeline(NodeClassificationPredictPipeline.from(pipeline))
425-
.build();
426-
427-
return Model.of(
428-
config.username(),
429-
config.modelName(),
430-
NodeClassificationTrainingPipeline.MODEL_TYPE,
431-
graph.schema(),
432-
classifier.data(),
433-
config,
434-
modelInfo
435-
);
436-
}
437-
438410
private Classifier trainModel(
439411
HugeLongArray trainSet,
440412
TrainerConfig trainerConfig,

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutor.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
3232
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
3333
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
34+
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPredictPipeline;
3435
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
3536

3637
import java.util.List;
@@ -108,7 +109,22 @@ protected NodeClassificationTrainPipelineResult execute(Map<DatasetSplits, Graph
108109
.create(graph, pipeline, config, progressTracker, terminationFlag)
109110
.compute();
110111

111-
return ImmutableNodeClassificationTrainPipelineResult.of(trainResult.model(), trainResult.trainingStatistics());
112+
var catalogModel = Model.of(
113+
config.username(),
114+
config.modelName(),
115+
NodeClassificationTrainingPipeline.MODEL_TYPE,
116+
schemaBeforeSteps,
117+
trainResult.classifier().data(),
118+
config,
119+
NodeClassificationPipelineModelInfo.builder()
120+
.classes(trainResult.classifier().classIdMap().originalIdsList())
121+
.bestParameters(trainResult.trainingStatistics().bestParameters())
122+
.metrics(trainResult.trainingStatistics().metricsForWinningModel())
123+
.pipeline(NodeClassificationPredictPipeline.from(pipeline))
124+
.build()
125+
);
126+
127+
return ImmutableNodeClassificationTrainPipelineResult.of(catalogModel, trainResult.trainingStatistics());
112128
}
113129

114130
@ValueClass

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainResult.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;
2121

2222
import org.neo4j.gds.annotation.ValueClass;
23-
import org.neo4j.gds.core.model.Model;
2423
import org.neo4j.gds.ml.models.Classifier;
2524
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
2625

2726
@ValueClass
2827
public interface NodeClassificationTrainResult {
29-
Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model();
28+
Classifier classifier();
3029
TrainingStatistics trainingStatistics();
3130
}

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainTest.java

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,7 @@ void selectsTheBestModel(ClassificationMetricSpecification metricSpecification)
149149
);
150150

151151
var result = ncTrain.compute();
152-
var model = result.model();
153152

154-
var customInfo = model.customInfo();
155153
List<ModelStats> validationScores = result.trainingStatistics().getValidationStats(metric);
156154

157155
assertThat(validationScores).hasSize(MAX_TRIALS);
@@ -162,7 +160,7 @@ void selectsTheBestModel(ClassificationMetricSpecification metricSpecification)
162160
.isNotCloseTo(validationScores.get(i).avg(), Percentage.withPercentage(0.2));
163161
}
164162

165-
var actualWinnerParams = customInfo.bestParameters();
163+
var actualWinnerParams = result.trainingStatistics().bestParameters();
166164
assertThat(actualWinnerParams.toMap()).isEqualTo(expectedWinner.toMap());
167165
}
168166

@@ -226,30 +224,27 @@ void shouldProduceDifferentMetricsForDifferentTrainings(ClassificationMetricSpec
226224
);
227225

228226
var bananasModelTrainResult = bananasTrain.compute();
229-
var bananasModel = bananasModelTrainResult.model();
227+
var bananasClassifier = bananasModelTrainResult.classifier();
230228
var arrayModelTrainResult = arrayPropertyTrain.compute();
231-
var arrayPropertyModel = arrayModelTrainResult.model();
229+
var arrayPropertyClassifier = arrayModelTrainResult.classifier();
232230

233-
assertThat(arrayPropertyModel)
231+
assertThat(arrayPropertyClassifier)
234232
.usingRecursiveComparison()
235-
.withFailMessage("The trained models are exactly the same instance!")
236-
.isNotSameAs(bananasModel);
233+
.withFailMessage("The trained classifiers are exactly the same instance!")
234+
.isNotSameAs(bananasClassifier);
237235

238-
assertThat(arrayPropertyModel.data())
236+
assertThat(arrayPropertyClassifier.data())
239237
.usingRecursiveComparison()
240238
.withFailMessage("Should not produce the same trained `data`!")
241-
.isNotEqualTo(bananasModel.data());
239+
.isNotEqualTo(bananasClassifier.data());
242240

243-
var bananasCustomInfo = bananasModel.customInfo();
244-
var bananasValidationScore = bananasCustomInfo.metrics().get(metric);
241+
var bananasMetrics = bananasModelTrainResult.trainingStatistics().metricsForWinningModel().get(metric);
242+
var arrayPropertyMetrics = arrayModelTrainResult.trainingStatistics().metricsForWinningModel().get(metric);
245243

246-
var arrayPropertyCustomInfo = arrayPropertyModel.customInfo();
247-
var arrayPropertyValidationScores = arrayPropertyCustomInfo.metrics().get(metric);
248-
249-
assertThat(arrayPropertyValidationScores)
244+
assertThat(arrayPropertyMetrics)
250245
.usingRecursiveComparison()
251-
.isNotSameAs(bananasValidationScore)
252-
.isNotEqualTo(bananasValidationScore);
246+
.isNotSameAs(bananasMetrics)
247+
.isNotEqualTo(bananasMetrics);
253248
}
254249

255250
@Test
@@ -450,9 +445,9 @@ void seededNodeClassification(int concurrency) {
450445
var firstResult = algoSupplier.get().compute();
451446
var secondResult = algoSupplier.get().compute();
452447

453-
assertThat(((LogisticRegressionData) firstResult.model().data()).weights().data())
448+
assertThat(((LogisticRegressionData) firstResult.classifier().data()).weights().data())
454449
.matches(matrix -> matrix.equals(
455-
((LogisticRegressionData) secondResult.model().data()).weights().data(),
450+
((LogisticRegressionData) secondResult.classifier().data()).weights().data(),
456451
1e-10
457452
));
458453
}

0 commit comments

Comments
 (0)