Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ std::vector<Params> getInputs()
1234ULL, /* seed */
ML::CRITERION::MSE, /* split_criterion */
8, /* n_streams */
128 /* max_batch_size */
128 /* max_batch_size */,
0, /* minTreesPerGroupFold */
0, /* foldGroupSize */
-1 /* group_col_idx */
);

using ML::fil::algo_t;
Expand Down
5 changes: 4 additions & 1 deletion cpp/bench/sg/filex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ std::vector<Params> getInputs()
1234ULL, /* seed */
ML::CRITERION::MSE, /* split_criterion */
8, /* n_streams */
128 /* max_batch_size */
128, /* max_batch_size */
0, /* minTreesPerGroupFold */
0, /* foldGroupSize */
-1 /* group_col_idx */
);

using ML::fil::algo_t;
Expand Down
5 changes: 4 additions & 1 deletion cpp/bench/sg/rf_classifier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ std::vector<Params> getInputs()
1234ULL, /* seed */
ML::CRITERION::GINI, /* split_criterion */
8, /* n_streams */
128 /* max_batch_size */
128, /* max_batch_size */
0, /* minTreesPerGroupFold */
0, /* foldGroupSize */
-1 /* group_col_idx */
);

std::vector<Triplets> rowcols = {
Expand Down
49 changes: 48 additions & 1 deletion cpp/include/cuml/ensemble/randomforest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ struct RF_metrics {
double median_abs_error;
};

void lower_bound_test(
const int* search_array,
const int* input_vals,
const int num_search_vals,
const int num_inputs,
int* output_array,
const cudaStream_t stream);

RF_metrics set_all_rf_metrics(RF_type rf_type,
float accuracy,
double mean_abs_error,
Expand Down Expand Up @@ -99,6 +107,42 @@ struct RF_params {
* Ratio of dataset rows used while fitting each tree.
*/
float max_samples;

/**
* Comment from rforestry:
* The number of trees which we make sure have been created leaving
* out each fold (each fold is a set of randomly selected groups).
* This is 0 by default, so we will not give any special treatment to
* the groups when sampling observations, however if this is set to a positive integer, we
* modify the bootstrap sampling scheme to ensure that exactly that many trees
* have each group left out. We do this by, for each fold, creating minTreesPerGroupFold
* trees which are built on observations sampled from the set of training observations
* which are not in a group in the current fold. The folds form a random partition of
* all of the possible groups, each of size foldGroupSize. This means we create at
* least # folds * minTreesPerGroupFold trees for the forest.
* If ntree > # folds * minTreesPerGroupFold, we create
* max(# folds * minTreesPerGroupFold, ntree) total trees, in which at least minTreesPerGroupFold
* are created leaving out each fold.
*/
int minTreesPerGroupFold;

/**
* Comment from rforestry:
* The number of groups that are selected randomly for each fold to be
* left out when using minTreesPerGroupFold. When minTreesPerGroupFold is set and foldGroupSize is
* set, all possible groups will be partitioned into folds, each containing foldGroupSize unique groups
* (if foldGroupSize doesn't evenly divide the number of groups, a single fold will be smaller,
* as it will contain the remaining groups). Then minTreesPerGroupFold are grown with each
* entire fold of groups left out.
*/
int foldGroupSize;

/**
* group_col_idx
* The numeric index of the column to be used for group processing
*/
int group_col_idx;

/**
* Decision tree training hyper parameter struct.
*/
Expand Down Expand Up @@ -225,7 +269,10 @@ RF_params set_rf_params(int max_depth,
uint64_t seed,
CRITERION split_criterion,
int cfg_n_streams,
int max_batch_size);
int max_batch_size,
int minTreesPerGroupFold,
int foldGroupSize,
int group_col_idx);

// ----------------------------- Regression ----------------------------------- //

Expand Down
8 changes: 7 additions & 1 deletion cpp/src/randomforest/randomforest.cu
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,10 @@ RF_params set_rf_params(int max_depth,
uint64_t seed,
CRITERION split_criterion,
int cfg_n_streams,
int max_batch_size)
int max_batch_size,
int minTreesPerGroupFold,
int foldGroupSize,
int group_col_idx)
{
DT::DecisionTreeParams tree_params;
DT::set_tree_params(tree_params,
Expand All @@ -620,6 +623,9 @@ RF_params set_rf_params(int max_depth,
rf_params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (n_trees < rf_params.n_streams) rf_params.n_streams = n_trees;
rf_params.tree_params = tree_params;
rf_params.minTreesPerGroupFold = minTreesPerGroupFold;
rf_params.foldGroupSize = foldGroupSize;
rf_params.group_col_idx = group_col_idx;
validity_check(rf_params);
return rf_params;
}
Expand Down
Loading