@@ -2088,9 +2088,11 @@ TEST_CASE("stress test, boosting") {
20882088 // terms.push_back({0, 1, 2, 3}); // TODO: enable when fast enough
20892089 }
20902090 const size_t cRounds = 200 ;
2091- std::vector<IntEbm> boostFlagsAny{TermBoostFlags_PurifyGain,
2091+ std::vector<IntEbm> boostFlagsAny{// TermBoostFlags_PurifyGain,
20922092 TermBoostFlags_DisableNewtonGain,
20932093 TermBoostFlags_DisableCategorical,
2094+ // TermBoostFlags_PurifyUpdate,
2095+ // TermBoostFlags_GradientSums, // does not return a metric
20942096 TermBoostFlags_DisableNewtonUpdate,
20952097 TermBoostFlags_RandomSplits};
20962098 std::vector<IntEbm> boostFlagsChoose{TermBoostFlags_Default,
@@ -2099,10 +2101,10 @@ TEST_CASE("stress test, boosting") {
20992101 TermBoostFlags_MissingSeparate,
21002102 TermBoostFlags_MissingDrop};
21012103
2102- double validationMetric = 0 .0 ;
2104+ double validationMetric = 1 .0 ;
21032105
21042106 for (IntEbm classesCount = Task_Regression; classesCount < 5 ; ++classesCount) {
2105- if (classesCount != Task_Regression && classesCount < 2 ) {
2107+ if (classesCount != Task_Regression && classesCount < 1 ) {
21062108 continue ;
21072109 }
21082110 const auto train = MakeRandomDataset (rng, classesCount, cTrainSamples, features);
@@ -2159,9 +2161,13 @@ TEST_CASE("stress test, boosting") {
21592161 .validationMetric ;
21602162 }
21612163 }
2162- validationMetric += validationMetricIteration;
2164+ if (classesCount == 1 ) {
2165+ CHECK (std::numeric_limits<double >::infinity () == validationMetricIteration);
2166+ } else {
2167+ validationMetric *= validationMetricIteration;
2168+ }
21632169 }
21642170 }
21652171
2166- CHECK (validationMetric == 42031.143270308334 );
2172+ CHECK (validationMetric == 62013566170252.117 );
21672173}
0 commit comments