Skip to content

Commit d452ee5

Browse files
authored
Merge pull request #5231 from FlorentinD/config-trainer-method-cleanup
Simplify TrainerConfig::methodName
2 parents c7d2f16 + 8be1a99 commit d452ee5

File tree

14 files changed

+33
-32
lines changed

14 files changed

+33
-32
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/ClassifierFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public static MemoryEstimation dataMemoryEstimation(
7575
int featureDimension,
7676
boolean isReduced
7777
) {
78-
switch (TrainingMethod.valueOf(trainerConfig.methodName())) {
78+
switch (trainerConfig.method()) {
7979
case LogisticRegression:
8080
return LogisticRegressionData.memoryEstimation(isReduced, numberOfClasses, MemoryRange.of(featureDimension));
8181
case RandomForest:

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/ClassifierTrainerFactory.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import java.util.Optional;
3333
import java.util.function.LongUnaryOperator;
3434

35-
public class ClassifierTrainerFactory {
35+
public final class ClassifierTrainerFactory {
3636

3737
private ClassifierTrainerFactory() {}
3838

@@ -45,7 +45,7 @@ public static ClassifierTrainer create(
4545
Optional<Long> randomSeed,
4646
boolean reduceClassCount
4747
) {
48-
switch (TrainingMethod.valueOf(config.methodName())) {
48+
switch (config.method()) {
4949
case LogisticRegression: {
5050
return new LogisticRegressionTrainer(
5151
concurrency,
@@ -79,7 +79,7 @@ public static MemoryEstimation memoryEstimation(
7979
MemoryRange featureDimension,
8080
boolean isReduced
8181
) {
82-
switch (TrainingMethod.valueOf(config.methodName())) {
82+
switch (config.method()) {
8383
case LogisticRegression:
8484
return LogisticRegressionTrainer.memoryEstimation(
8585
isReduced,

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/RegressionTrainerFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public static RegressorTrainer create(
4141
int concurrency,
4242
Optional<Long> randomSeed
4343
) {
44-
switch (TrainingMethod.valueOf(config.methodName())) {
44+
switch (config.method()) {
4545
case LinearRegression: {
4646
return new LinearRegressionTrainer(
4747
concurrency,
@@ -59,7 +59,7 @@ public static RegressorTrainer create(
5959
);
6060
}
6161
default:
62-
throw new IllegalStateException(formatWithLocale("Method %s is not a regression method", config.methodName()));
62+
throw new IllegalStateException(formatWithLocale("Method %s is not a regression method", config.method()));
6363
}
6464
}
6565
}

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/TrainerConfig.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,18 @@
3030
public interface TrainerConfig extends ToMapConvertible {
3131

3232
@Configuration.Ignore
33-
String methodName();
33+
TrainingMethod method();
3434

3535
@Value.Derived
3636
@Configuration.Ignore
3737
default TunableTrainerConfig toTunableConfig() {
38-
return TunableTrainerConfig.of(toMap(), TrainingMethod.valueOf(methodName()));
38+
return TunableTrainerConfig.of(toMap(), method());
3939
}
4040

4141
@Configuration.Ignore
4242
default Map<String, Object> toMapWithTrainerMethod() {
4343
var mapWithTrainerMethod = new HashMap<>(toMap());
44-
mapWithTrainerMethod.put("methodName", methodName());
44+
mapWithTrainerMethod.put("methodName", method().name());
4545

4646
return mapWithTrainerMethod;
4747
}

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/linearregression/LinearRegressionTrainConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static LinearRegressionTrainConfig of(Map<String, Object> params) {
5252

5353
@Override
5454
@Configuration.Ignore
55-
default String methodName() {
56-
return TrainingMethod.LinearRegression.name();
55+
default TrainingMethod method() {
56+
return TrainingMethod.LinearRegression;
5757
}
5858
}

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/logisticregression/LogisticRegressionTrainConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static LogisticRegressionTrainConfig of(Map<String, Object> params) {
5252

5353
@Override
5454
@Configuration.Ignore
55-
default String methodName() {
56-
return TrainingMethod.LogisticRegression.name();
55+
default TrainingMethod method() {
56+
return TrainingMethod.LogisticRegression;
5757
}
5858
}

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/randomforest/RandomForestTrainerConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ default int numberOfDecisionTrees() {
5555

5656
@Override
5757
@Configuration.Ignore
58-
default String methodName() {
59-
return TrainingMethod.RandomForest.name();
58+
default TrainingMethod method() {
59+
return TrainingMethod.RandomForest;
6060
}
6161

6262
static RandomForestTrainerConfig of(Map<String, Object> params) {

ml/ml-algo/src/main/java/org/neo4j/gds/ml/nodeClassification/ClassificationMetricComputer.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.neo4j.gds.ml.models.ClassifierFactory;
3030
import org.neo4j.gds.ml.models.Features;
3131
import org.neo4j.gds.ml.models.TrainerConfig;
32-
import org.neo4j.gds.ml.models.TrainingMethod;
3332
import org.openjdk.jol.util.Multiset;
3433

3534
import java.util.function.LongUnaryOperator;
@@ -120,7 +119,7 @@ public static MemoryEstimation estimateEvaluation(
120119
.rangePerNode(
121120
"classifier runtime",
122121
nodeCount -> ClassifierFactory.runtimeOverheadMemoryEstimation(
123-
TrainingMethod.valueOf(config.methodName()),
122+
config.method(),
124123
batchSize,
125124
fudgedClassCount,
126125
fudgedFeatureCount,

ml/ml-algo/src/test/java/org/neo4j/gds/ml/metrics/StatsMapTest.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.neo4j.gds.core.GraphDimensions;
2424
import org.neo4j.gds.ml.models.TrainerConfig;
25+
import org.neo4j.gds.ml.models.TrainingMethod;
2526

2627
import java.util.List;
2728
import java.util.Map;
@@ -79,18 +80,18 @@ void toMap() {
7980

8081
private static final class TestTrainerConfig implements TrainerConfig {
8182

82-
private final String method;
83+
private final String name;
8384

84-
private TestTrainerConfig(String method) {this.method = method;}
85+
private TestTrainerConfig(String name) {this.name = name;}
8586

8687
@Override
8788
public Map<String, Object> toMap() {
88-
return Map.of("method", method);
89+
return Map.of("name", name);
8990
}
9091

9192
@Override
92-
public String methodName() {
93-
return method;
93+
public TrainingMethod method() {
94+
return TrainingMethod.RandomForest;
9495
}
9596
}
9697

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/TrainingPipeline.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public void addTrainerConfig(TunableTrainerConfig trainingConfig) {
129129
}
130130

131131
public void addTrainerConfig(TrainerConfig trainingConfig) {
132-
this.trainingParameterSpace.get(TrainingMethod.valueOf(trainingConfig.methodName())).add(trainingConfig.toTunableConfig());
132+
this.trainingParameterSpace.get(trainingConfig.method()).add(trainingConfig.toTunableConfig());
133133
}
134134

135135
public AutoTuningConfig autoTuningConfig() {

0 commit comments

Comments
 (0)