Skip to content

Commit 18aa09f

Browse files
Mats-SXFlorentinD
authored andcommitted
Add additional assertions on Model metadata
1 parent 7b5d362 commit 18aa09f

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutorTest.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
4545
import org.neo4j.gds.ml.models.TrainingMethod;
4646
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
47+
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionData;
4748
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
4849
import org.neo4j.gds.ml.models.randomforest.RandomForestTrainerConfig;
4950
import org.neo4j.gds.ml.pipeline.AutoTuningConfigImpl;
@@ -124,7 +125,8 @@ void trainsAModel() {
124125
var metricSpecification = ClassificationMetricSpecification.parse("F1(class=1)");
125126
var metric = metricSpecification.createMetrics(List.of()).findFirst().orElseThrow();
126127

127-
pipeline.addTrainerConfig(LogisticRegressionTrainConfig.of(Map.of("penalty", 1, "maxEpochs", 1)));
128+
var modelCandidate = LogisticRegressionTrainConfig.of(Map.of("penalty", 1, "maxEpochs", 1));
129+
pipeline.addTrainerConfig(modelCandidate);
128130

129131
pipeline.setSplitConfig(NodePropertyPredictionSplitConfigImpl.builder()
130132
.testFraction(0.3)
@@ -152,6 +154,14 @@ void trainsAModel() {
152154
var model = result.model();
153155

154156
assertThat(model.creator()).isEqualTo(getUsername());
157+
assertThat(model.algoType()).isEqualTo(NodeClassificationTrainingPipeline.MODEL_TYPE);
158+
assertThat(model.data()).isInstanceOf(LogisticRegressionData.class);
159+
assertThat(model.trainConfig()).isEqualTo(config);
160+
assertThat(model.graphSchema()).isEqualTo(graphStore.schema());
161+
assertThat(model.name()).isEqualTo("model");
162+
assertThat(model.stored()).isFalse();
163+
assertThat(model.customInfo().bestParameters().toMap()).isEqualTo(modelCandidate.toMap());
164+
assertThat(model.customInfo().metrics()).containsOnlyKeys(metric);
155165

156166
// using explicit type intentionally :)
157167
NodeClassificationPipelineModelInfo customInfo = model.customInfo();

0 commit comments

Comments
 (0)