Skip to content

Commit 54a3f9d

Browse files
authored
Merge pull request #5356 from adamnsch/dt-opt
Improve decision tree training performance on datasets with many classes
2 parents dcc7b98 + 8f69730 commit 54a3f9d

File tree

8 files changed

+37
-50
lines changed

8 files changed

+37
-50
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/decisiontree/Splitter.java

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ public class Splitter {
3232
private final FeatureBagger featureBagger;
3333
private final int minLeafSize;
3434
private final HugeLongArray sortCache;
35-
private final ImpurityCriterion.ImpurityData bestLeftImpurityDataForIdx;
36-
private final ImpurityCriterion.ImpurityData bestRightImpurityDataForIdx;
3735
private final ImpurityCriterion.ImpurityData rightImpurityData;
3836

3937
Splitter(long trainSetSize, ImpurityCriterion impurityCriterion, FeatureBagger featureBagger, Features features, int minLeafSize) {
@@ -42,8 +40,6 @@ public class Splitter {
4240
this.features = features;
4341
this.minLeafSize = minLeafSize;
4442
this.sortCache = HugeLongArray.newArray(trainSetSize);
45-
this.bestLeftImpurityDataForIdx = impurityCriterion.groupImpurity(HugeLongArray.of(), 0, 0);
46-
this.bestRightImpurityDataForIdx = impurityCriterion.groupImpurity(HugeLongArray.of(), 0, 0);
4743
this.rightImpurityData = impurityCriterion.groupImpurity(HugeLongArray.of(), 0, 0);
4844
}
4945

@@ -52,7 +48,7 @@ static long memoryEstimation(long numberOfTrainingSamples, long sizeOfImpurityDa
5248
// sort cache
5349
+ HugeLongArray.memoryEstimation(numberOfTrainingSamples)
5450
// impurity data cache
55-
+ 6 * sizeOfImpurityData
51+
+ 4 * sizeOfImpurityData
5652
// group cache
5753
+ 4 * HugeLongArray.memoryEstimation(numberOfTrainingSamples);
5854
}
@@ -77,10 +73,6 @@ DecisionTreeTrainer.Split findBestSplit(Group group) {
7773
int[] featureBag = featureBagger.sample();
7874

7975
for (int featureIdx : featureBag) {
80-
double bestImpurityForIdx = Double.MAX_VALUE;
81-
double bestValueForIdx = Double.MAX_VALUE;
82-
long bestLeftGroupSizeForIdx = -1;
83-
8476
// By doing a sort of the group by this particular feature, all possible splits will simply be represented
8577
// by each index in the ordered group.
8678
HugeSerialIndirectMergeSort.sort(rightChildArray, group.size(), (long l) -> features.get(l)[featureIdx], sortCache);
@@ -99,6 +91,7 @@ DecisionTreeTrainer.Split findBestSplit(Group group) {
9991
}
10092

10193
var leftImpurityData = impurityCriterion.groupImpurity(leftChildArray, 0, minLeafSize - 1L);
94+
boolean foundImprovementWithIdx = false;
10295

10396
// Continue moving feature vectors, but now actually compute combined impurity since left group is large enough.
10497
for (long leftGroupSize = minLeafSize; leftGroupSize <= group.size() - minLeafSize; leftGroupSize++) {
@@ -112,24 +105,18 @@ DecisionTreeTrainer.Split findBestSplit(Group group) {
112105

113106
// We track best split for a single feature idx in order to keep using `leftChildArray` and `rightChildArray`
114107
// throughout search for splits for this particular idx.
115-
if (combinedImpurity < bestImpurityForIdx) {
116-
bestValueForIdx = features.get(splittingFeatureVectorIdx)[featureIdx];
117-
bestImpurityForIdx = combinedImpurity;
118-
leftImpurityData.copyTo(bestLeftImpurityDataForIdx);
119-
rightImpurityData.copyTo(bestRightImpurityDataForIdx);
120-
bestLeftGroupSizeForIdx = leftGroupSize;
108+
if (combinedImpurity < bestImpurity) {
109+
foundImprovementWithIdx = true;
110+
bestIdx = featureIdx;
111+
bestValue = features.get(splittingFeatureVectorIdx)[featureIdx];
112+
bestImpurity = combinedImpurity;
113+
bestLeftGroupSize = leftGroupSize;
114+
leftImpurityData.copyTo(bestLeftImpurityData);
115+
rightImpurityData.copyTo(bestRightImpurityData);
121116
}
122117
}
123118

124-
if (bestImpurityForIdx < bestImpurity) {
125-
bestIdx = featureIdx;
126-
bestValue = bestValueForIdx;
127-
bestImpurity = bestImpurityForIdx;
128-
bestLeftGroupSize = bestLeftGroupSizeForIdx;
129-
130-
bestLeftImpurityDataForIdx.copyTo(bestLeftImpurityData);
131-
bestRightImpurityDataForIdx.copyTo(bestRightImpurityData);
132-
119+
if (foundImprovementWithIdx) {
133120
// At this time it's fine to swap array pointers since we will have to do a resort for the next feature
134121
// anyway.
135122
var tmpChildArray = bestRightChildArray;

ml/ml-algo/src/test/java/org/neo4j/gds/ml/decisiontree/DecisionTreeClassifierTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,10 @@ void estimateDecisionTree(
244244
@ParameterizedTest
245245
@CsvSource(value = {
246246
// Scales with training set size even if maxDepth limits tree size.
247-
" 6, 1_000, 41_272, 56_032",
248-
" 6, 10_000, 401_272, 488_032",
247+
" 6, 1_000, 41_008, 55_768",
248+
" 6, 10_000, 401_008, 487_768",
249249
// Scales with maxDepth when maxDepth is limiting tree size.
250-
" 20, 10_000, 401_272, 1_443_712",
250+
" 20, 10_000, 401_008, 1_443_448",
251251
})
252252
void trainMemoryEstimation(int maxDepth, long numberOfTrainingSamples, long expectedMin, long expectedMax) {
253253
var config = DecisionTreeTrainerConfigImpl.builder()

ml/ml-algo/src/test/java/org/neo4j/gds/ml/decisiontree/DecisionTreeRegressorTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,10 @@ void considersMinLeafSize() {
227227
@ParameterizedTest
228228
@CsvSource(value = {
229229
// Scales with training set size even if maxDepth limits tree size.
230-
" 6, 1_000, 40_704, 55_968",
231-
" 6, 10_000, 400_704, 487_968",
230+
" 6, 1_000, 40_600, 55_864",
231+
" 6, 10_000, 400_600, 487_864",
232232
// Scales with maxDepth when maxDepth is limiting tree size.
233-
" 20, 10_000, 400_704, 1_523_136",
233+
" 20, 10_000, 400_600, 1_523_032",
234234
})
235235
void trainMemoryEstimation(int maxDepth, long numberOfTrainingSamples, long expectedMin, long expectedMax) {
236236
var config = DecisionTreeTrainerConfigImpl.builder()

ml/ml-algo/src/test/java/org/neo4j/gds/ml/decisiontree/SplitterTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ void shouldFindBestSplit(
195195
@ParameterizedTest
196196
@CsvSource(value = {
197197
// Scales with training set size.
198-
" 1_000, 20, 40_368",
199-
" 10_000, 20, 400_368",
198+
" 1_000, 20, 40_320",
199+
" 10_000, 20, 400_320",
200200
// Changes a little with impurity data size.
201-
" 1_000, 100, 40_848",
201+
" 1_000, 100, 40_640",
202202
})
203203
void memoryEstimation(long numberOfTrainingSamples, long sizeOfImpurityData, long expectedSize) {
204204
long size = Splitter.memoryEstimation(numberOfTrainingSamples, sizeOfImpurityData);

ml/ml-algo/src/test/java/org/neo4j/gds/ml/models/randomforest/RandomForestClassifierTest.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,19 @@ void predictOverheadMemoryEstimation(
270270

271271
@ParameterizedTest
272272
@CsvSource(value = {
273-
" 6, 100_000, 10, 10, 1, 1, 0.1, 1.0, 5_214_370, 6_027_194",
273+
" 6, 100_000, 10, 10, 1, 1, 0.1, 1.0, 5_214_106, 6_026_930",
274274
// Should increase fairly little with more trees if training set big.
275-
" 10, 100_000, 10, 10, 1, 10, 0.1, 1.0, 5_215_018, 7_096_578",
275+
" 10, 100_000, 10, 10, 1, 10, 0.1, 1.0, 5_214_754, 7_096_314",
276276
// Should be capped by number of training examples, despite high max depth.
277-
" 8_000, 500, 10, 10, 1, 1, 0.1, 1.0, 27_930, 187_730",
277+
" 8_000, 500, 10, 10, 1, 1, 0.1, 1.0, 27_666, 187_466",
278278
// Should increase very little when having more classes.
279-
" 10, 100_000, 100, 10, 1, 10, 0.1, 1.0, 5_220_058, 7_101_618",
279+
" 10, 100_000, 100, 10, 1, 10, 0.1, 1.0, 5_218_354, 7_099_914",
280280
// Should increase very little when using more features for splits.
281-
" 10, 100_000, 100, 10, 1, 10, 0.9, 1.0, 5_220_098, 7_101_750",
281+
" 10, 100_000, 100, 10, 1, 10, 0.9, 1.0, 5_218_394, 7_100_046",
282282
// Should decrease a lot when sampling fewer training examples per tree.
283-
" 10, 100_000, 100, 10, 1, 10, 0.1, 0.2, 1_370_058, 2_611_618",
283+
" 10, 100_000, 100, 10, 1, 10, 0.1, 0.2, 1_368_354, 2_609_914",
284284
// Should almost be x4 when concurrency * 4.
285-
" 10, 100_000, 100, 10, 4, 10, 0.1, 1.0, 19_677_808, 24_257_808",
285+
" 10, 100_000, 100, 10, 4, 10, 0.1, 1.0, 19_670_992, 24_250_992",
286286
})
287287
void trainMemoryEstimation(
288288
int maxDepth,

ml/ml-algo/src/test/java/org/neo4j/gds/ml/models/randomforest/RandomForestRegressorTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,17 @@ void predictOverheadMemoryEstimation() {
166166

167167
@ParameterizedTest
168168
@CsvSource(value = {
169-
" 6, 100_000, 10, 1, 1, 0.1, 1.0, 4_813_754, 5_627_586",
169+
" 6, 100_000, 10, 1, 1, 0.1, 1.0, 4_813_650, 5_627_482",
170170
// Should increase fairly little with more trees if training set big.
171-
" 10, 100_000, 10, 1, 10, 0.1, 1.0, 4_814_474, 6_786_058",
171+
" 10, 100_000, 10, 1, 10, 0.1, 1.0, 4_814_370, 6_785_954",
172172
// Should be capped by number of training examples, despite high max depth.
173-
" 8_000, 500, 10, 1, 1, 0.1, 1.0, 25_314, 193_098",
173+
" 8_000, 500, 10, 1, 1, 0.1, 1.0, 25_210, 192_994",
174174
// Should increase very little when using more features for splits.
175-
" 10, 100_000, 10, 1, 10, 0.9, 1.0, 4_814_514, 6_786_190",
175+
" 10, 100_000, 10, 1, 10, 0.9, 1.0, 4_814_410, 6_786_086",
176176
// Should decrease a lot when sampling fewer training examples per tree.
177-
" 10, 100_000, 10, 1, 10, 0.1, 0.2, 964_474, 2_296_058",
177+
" 10, 100_000, 10, 1, 10, 0.1, 0.2, 964_370, 2_295_954",
178178
// Should almost be x4 when concurrency * 4.
179-
" 10, 100_000, 10, 4, 10, 0.1, 1.0, 19_255_376, 23_949_952",
179+
" 10, 100_000, 10, 4, 10, 0.1, 1.0, 19_254_960, 23_949_536",
180180
})
181181
void trainMemoryEstimation(
182182
int maxDepth,

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,15 @@ static Stream<Arguments> paramsForEstimationsWithParamSpace() {
180180
.build()
181181
.toTunableConfig()
182182
),
183-
MemoryRange.of(66_896, 899_376)
183+
MemoryRange.of(66_352, 898_832)
184184
),
185185
Arguments.of(
186186
"Default RF and default LR",
187187
List.of(
188188
LogisticRegressionTrainConfig.DEFAULT.toTunableConfig(),
189189
RandomForestClassifierTrainerConfig.DEFAULT.toTunableConfig()
190190
),
191-
MemoryRange.of(73_976, 2_738_824)
191+
MemoryRange.of(73_432, 2_738_280)
192192
),
193193
Arguments.of(
194194
"Default RF and default LR with range",
@@ -199,7 +199,7 @@ static Stream<Arguments> paramsForEstimationsWithParamSpace() {
199199
),
200200
RandomForestClassifierTrainerConfig.DEFAULT.toTunableConfig()
201201
),
202-
MemoryRange.of(73_976, 2_738_824)
202+
MemoryRange.of(73_432, 2_738_280)
203203
),
204204
Arguments.of(
205205
"Default RF and default LR with batch size range",

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ public static Stream<Arguments> trainerMethodConfigs() {
323323
),
324324
Arguments.of(
325325
List.of(RandomForestClassifierTrainerConfig.DEFAULT.toTunableConfig()),
326-
MemoryRange.of(91_186, 207_958)
326+
MemoryRange.of(90_938, 207_710)
327327
),
328328
Arguments.of(
329329
List.of(LogisticRegressionTrainConfig.DEFAULT.toTunableConfig(), RandomForestClassifierTrainerConfig.DEFAULT.toTunableConfig()),

0 commit comments

Comments
 (0)