|
44 | 44 | import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification; |
45 | 45 | import org.neo4j.gds.ml.models.TrainingMethod; |
46 | 46 | import org.neo4j.gds.ml.models.automl.TunableTrainerConfig; |
| 47 | +import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionData; |
47 | 48 | import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig; |
48 | 49 | import org.neo4j.gds.ml.models.randomforest.RandomForestTrainerConfig; |
49 | 50 | import org.neo4j.gds.ml.pipeline.AutoTuningConfigImpl; |
@@ -124,7 +125,8 @@ void trainsAModel() { |
124 | 125 | var metricSpecification = ClassificationMetricSpecification.parse("F1(class=1)"); |
125 | 126 | var metric = metricSpecification.createMetrics(List.of()).findFirst().orElseThrow(); |
126 | 127 |
|
127 | | - pipeline.addTrainerConfig(LogisticRegressionTrainConfig.of(Map.of("penalty", 1, "maxEpochs", 1))); |
| 128 | + var modelCandidate = LogisticRegressionTrainConfig.of(Map.of("penalty", 1, "maxEpochs", 1)); |
| 129 | + pipeline.addTrainerConfig(modelCandidate); |
128 | 130 |
|
129 | 131 | pipeline.setSplitConfig(NodePropertyPredictionSplitConfigImpl.builder() |
130 | 132 | .testFraction(0.3) |
@@ -152,6 +154,14 @@ void trainsAModel() { |
152 | 154 | var model = result.model(); |
153 | 155 |
|
154 | 156 | assertThat(model.creator()).isEqualTo(getUsername()); |
| 157 | + assertThat(model.algoType()).isEqualTo(NodeClassificationTrainingPipeline.MODEL_TYPE); |
| 158 | + assertThat(model.data()).isInstanceOf(LogisticRegressionData.class); |
| 159 | + assertThat(model.trainConfig()).isEqualTo(config); |
| 160 | + assertThat(model.graphSchema()).isEqualTo(graphStore.schema()); |
| 161 | + assertThat(model.name()).isEqualTo("model"); |
| 162 | + assertThat(model.stored()).isFalse(); |
| 163 | + assertThat(model.customInfo().bestParameters().toMap()).isEqualTo(modelCandidate.toMap()); |
| 164 | + assertThat(model.customInfo().metrics()).containsOnlyKeys(metric); |
155 | 165 |
|
156 | 166 | // using explicit type intentionally :) |
157 | 167 | NodeClassificationPipelineModelInfo customInfo = model.customInfo(); |
|
0 commit comments