diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index e799db3626..7c9f0b6f7d 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -164,7 +164,10 @@ std::vector getInputs() 1234ULL, /* seed */ ML::CRITERION::MSE, /* split_criterion */ 8, /* n_streams */ - 128 /* max_batch_size */ + 128 /* max_batch_size */, + 0, /* minTreesPerGroupFold */ + 0, /* foldGroupSize */ + -1 /* group_col_idx */ ); using ML::fil::algo_t; diff --git a/cpp/bench/sg/filex.cu b/cpp/bench/sg/filex.cu index c090dcf360..ffc0d2b157 100644 --- a/cpp/bench/sg/filex.cu +++ b/cpp/bench/sg/filex.cu @@ -264,7 +264,10 @@ std::vector getInputs() 1234ULL, /* seed */ ML::CRITERION::MSE, /* split_criterion */ 8, /* n_streams */ - 128 /* max_batch_size */ + 128, /* max_batch_size */ + 0, /* minTreesPerGroupFold */ + 0, /* foldGroupSize */ + -1 /* group_col_idx */ ); using ML::fil::algo_t; diff --git a/cpp/bench/sg/rf_classifier.cu b/cpp/bench/sg/rf_classifier.cu index c9703066f9..f0936ee31e 100644 --- a/cpp/bench/sg/rf_classifier.cu +++ b/cpp/bench/sg/rf_classifier.cu @@ -108,7 +108,10 @@ std::vector getInputs() 1234ULL, /* seed */ ML::CRITERION::GINI, /* split_criterion */ 8, /* n_streams */ - 128 /* max_batch_size */ + 128, /* max_batch_size */ + 0, /* minTreesPerGroupFold */ + 0, /* foldGroupSize */ + -1 /* group_col_idx */ ); std::vector rowcols = { diff --git a/cpp/include/cuml/ensemble/randomforest.hpp b/cpp/include/cuml/ensemble/randomforest.hpp index b7c131f677..d978339891 100644 --- a/cpp/include/cuml/ensemble/randomforest.hpp +++ b/cpp/include/cuml/ensemble/randomforest.hpp @@ -48,6 +48,14 @@ struct RF_metrics { double median_abs_error; }; +void lower_bound_test( + const int* search_array, + const int* input_vals, + const int num_search_vals, + const int num_inputs, + int* output_array, + const cudaStream_t stream); + RF_metrics set_all_rf_metrics(RF_type rf_type, float accuracy, double mean_abs_error, @@ -99,6 +107,42 @@ struct RF_params { * Ratio of dataset rows used while fitting each tree. */ float max_samples; + + /** + * Comment from rforestry: + * The number of trees which we make sure have been created leaving + * out each fold (each fold is a set of randomly selected groups). + * This is 0 by default, so we will not give any special treatment to + * the groups when sampling observations, however if this is set to a positive integer, we + * modify the bootstrap sampling scheme to ensure that exactly that many trees + * have each group left out. We do this by, for each fold, creating minTreesPerGroupFold + * trees which are built on observations sampled from the set of training observations + * which are not in a group in the current fold. The folds form a random partition of + * all of the possible groups, each of size foldGroupSize. This means we create at + * least # folds * minTreesPerGroupFold trees for the forest. + * If ntree > # folds * minTreesPerGroupFold, we create + * max(# folds * minTreesPerGroupFold, ntree) total trees, in which at least minTreesPerGroupFold + * are created leaving out each fold. + */ + int minTreesPerGroupFold; + + /** + * Comment from rforestry: + * The number of groups that are selected randomly for each fold to be + * left out when using minTreesPerGroupFold. When minTreesPerGroupFold is set and foldGroupSize is + * set, all possible groups will be partitioned into folds, each containing foldGroupSize unique groups + * (if foldGroupSize doesn't evenly divide the number of groups, a single fold will be smaller, + * as it will contain the remaining groups). Then minTreesPerGroupFold are grown with each + * entire fold of groups left out. + */ + int foldGroupSize; + + /** + * group_col_idx + * The numeric index of the column to be used for group processing + */ + int group_col_idx; + /** * Decision tree training hyper parameter struct. */ @@ -225,7 +269,10 @@ RF_params set_rf_params(int max_depth, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, - int max_batch_size); + int max_batch_size, + int minTreesPerGroupFold, + int foldGroupSize, + int group_col_idx); // ----------------------------- Regression ----------------------------------- // diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index e5f300311d..3817bcef9a 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -594,7 +594,10 @@ RF_params set_rf_params(int max_depth, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, - int max_batch_size) + int max_batch_size, + int minTreesPerGroupFold, + int foldGroupSize, + int group_col_idx) { DT::DecisionTreeParams tree_params; DT::set_tree_params(tree_params, @@ -620,6 +623,9 @@ RF_params set_rf_params(int max_depth, rf_params.n_streams = min(cfg_n_streams, omp_get_max_threads()); if (n_trees < rf_params.n_streams) rf_params.n_streams = n_trees; rf_params.tree_params = tree_params; + rf_params.minTreesPerGroupFold = minTreesPerGroupFold; + rf_params.foldGroupSize = foldGroupSize; + rf_params.group_col_idx = group_col_idx; validity_check(rf_params); return rf_params; } diff --git a/cpp/src/randomforest/randomforest.cuh b/cpp/src/randomforest/randomforest.cuh index d5022b4769..18e140db9f 100644 --- a/cpp/src/randomforest/randomforest.cuh +++ b/cpp/src/randomforest/randomforest.cuh @@ -35,6 +35,7 @@ #include #include #include +#include #ifdef _OPENMP #include @@ -58,28 +59,331 @@ struct set_mask_functor { } }; +namespace { + +__global__ void log10(int* array) { + for (int ix = 0; ix < 10; ++ix) { + printf("array %d = %d\n", ix, array[ix]); + } +} + +__global__ void log10groups(const int* row_ids, const int* group_ids) { + for (int ix = 0; ix < 10; ++ix) { + printf("group ix %d, row %d = %d\n", ix, row_ids[ix], group_ids[row_ids[ix]]); + } +} + +void assign_groups_to_folds( + int n_groups, + int n_folds, + int fold_size, + std::vector> & fold_memberships, + std::mt19937& rng) +{ + std::vector group_indices(n_groups); + std::iota(group_indices.begin(), group_indices.end(), 0); + + std::shuffle(group_indices.begin(), group_indices.end(), rng); + + for (int ix_fold = 0; ix_fold < n_folds - 1; ++ix_fold) { + std::copy(group_indices.begin() + ix_fold*fold_size, + group_indices.begin() + (ix_fold+1)*fold_size, + fold_memberships[ix_fold].begin()); + // std::sort(fold_memberships[ix_fold].begin(), fold_memberships[ix_fold].end()); + } + + // Last fold could be smaller + const int last_fold_start = (n_folds - 1) * fold_size; + const int last_fold_size = n_groups - last_fold_start; + fold_memberships[n_folds - 1].resize(last_fold_size); + for (int ix = 0; ix < last_fold_size; ++ix) { + fold_memberships[n_folds - 1][ix] = group_indices[last_fold_start + ix]; + } +} + +template +__device__ int lower_bound(const T search_val, const U* array, int count) { + int it, step; + int first = 0; + while (count > 0) { + step = count / 2; + it = first + step; + if (array[it] < search_val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + return first; +} + +template +struct UniqueTransformFunctor { + const T* unique_groups; + const int num_groups; + UniqueTransformFunctor(const T* unique_groups, const int num_groups) + : unique_groups(unique_groups), + num_groups(num_groups) + {} + + __device__ int operator()(T group_val) { + int res = lower_bound(group_val, unique_groups, num_groups); + return res; + } +}; + +struct LeaveOutSamplesCopyIfFunctor { + + int* remaining_groups; + const int* sample_group_ids; + const int num_rem_groups; + LeaveOutSamplesCopyIfFunctor( + rmm::device_uvector* remaining_groups, + const int* sample_group_ids) + : remaining_groups(remaining_groups->data()), + sample_group_ids(sample_group_ids), + num_rem_groups(remaining_groups->size()) + {} + + __device__ bool operator()(const int ix_sample) { + // Do a quick lower_bound search + const int group_id = sample_group_ids[ix_sample]; + int it = lower_bound(group_id, remaining_groups, num_rem_groups); + return remaining_groups[it] == group_id; + } +}; + +void generate_row_indices_from_remaining_groups( + rmm::device_uvector* remaining_groups, + rmm::device_uvector* remaining_samples, + const int* sample_group_ids, + const size_t num_samples, + const cudaStream_t stream, + raft::random::Rng& rng) +{ + // From the remaining groups, we need to generate the remaining samples + // We're going to copy_if the indices to remaining_samples, only if the sample group id is part of the remaining groups + + // We want to sort remaining groups so that we can do the search in logN time. + thrust::sort(thrust::cuda::par.on(stream), remaining_groups->begin(), remaining_groups->end()); + + LeaveOutSamplesCopyIfFunctor predicate{remaining_groups, sample_group_ids}; + + auto output_it = thrust::copy_if( + thrust::cuda::par.on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_samples), + remaining_samples->begin(), + predicate); + auto dist = thrust::distance(remaining_samples->begin(), output_it); + remaining_samples->resize(dist, stream); +} + +void sample_rows_from_remaining_rows( + rmm::device_uvector* remaining_samples, + rmm::device_uvector* sample_output, + rmm::device_uvector* workspace, + const size_t num_samples, + const size_t sample_output_offset, + const cudaStream_t stream, + raft::random::Rng& rng) +{ + // This will do the actual permutation / shuffle operation + rng.uniformInt(workspace->data(), num_samples, 0, remaining_samples->size(), stream); + + auto index_iter = thrust::make_permutation_iterator(remaining_samples->begin(), workspace->begin()); + + thrust::copy(thrust::cuda::par.on(stream), index_iter, index_iter + num_samples, sample_output->begin() + sample_output_offset); +} + +void leave_groups_out_sample( + rmm::device_uvector* remaining_groups, + rmm::device_uvector* remaining_samples, + rmm::device_uvector* sample_output, + rmm::device_uvector* workspace, + const int* sample_group_ids, + std::vector& remaining_groups_host, + const size_t num_samples, + const size_t sample_output_offset, + const cudaStream_t stream, + raft::random::Rng& rng) +{ + raft::update_device(remaining_groups->data(), remaining_groups_host.data(), + remaining_groups_host.size(), stream); + remaining_groups->resize(remaining_groups_host.size(), stream); + + generate_row_indices_from_remaining_groups( + remaining_groups, remaining_samples, sample_group_ids, + num_samples, stream, rng); + sample_rows_from_remaining_rows( + remaining_samples, + sample_output, + workspace, + num_samples, + sample_output_offset, + stream, + rng); +} + +void update_averaging_mask( + rmm::device_uvector* split_row_mask, + const size_t n_sampled_rows, + const cudaStream_t stream) +{ + // First n_rows goes to training, num_avg_samples goes to averaging. + auto begin_zip_iterator = thrust::make_zip_iterator(thrust::make_tuple( + thrust::make_counting_iterator(0), split_row_mask->begin())); + auto end_zip_iterator = thrust::make_zip_iterator(thrust::make_tuple( + thrust::make_counting_iterator(2 * n_sampled_rows), split_row_mask->end())); + thrust::for_each(thrust::cuda::par.on(stream), begin_zip_iterator, end_zip_iterator, + thrust::make_zip_function(set_mask_functor(n_sampled_rows))); +} + +} // end anon namespace + namespace ML { + +void lower_bound_test( + const int* search_array, + const int* input_vals, + const int num_search_vals, + const int num_inputs, + int* output_array, + const cudaStream_t stream) +{ + UniqueTransformFunctor transform_fn{search_array, num_search_vals}; + thrust::transform( + thrust::cuda::par.on(stream), + input_vals, + input_vals + num_inputs, + output_array, + transform_fn); +} + template class RandomForest { protected: RF_params rf_params; // structure containing RF hyperparameters int rf_type; // 0 for classification 1 for regression - size_t get_row_sample(int tree_id, - int n_rows, + size_t get_row_sample(const int tree_id, + const int n_rows, int n_sampled_rows, - rmm::device_uvector* selected_rows, - rmm::device_uvector* split_row_mask, - rmm::device_uvector* tmp_row_vec, + rmm::device_uvector* selected_rows, // 2x n sampled rows + rmm::device_uvector* split_row_mask, // 2x n_sampled rows + rmm::device_uvector* remaining_groups, // 1x n_groups + rmm::device_uvector* remaining_samples, // 1x n_sampled_rows + rmm::device_uvector* workspace, // 1x n sampled rows + int* groups, + const int n_groups, + const int group_tree_count, + const std::vector>& fold_memberships, // each group belongs to one fold const cudaStream_t stream) { + // Todo: split group_fold_rng across threads raft::common::nvtx::range fun_scope("bootstrapping row IDs @randomforest.cuh"); // Hash these together so they are uncorrelated - auto rs = DT::fnv1a32_basis; - rs = DT::fnv1a32(rs, rf_params.seed); - rs = DT::fnv1a32(rs, tree_id); - raft::random::Rng rng(rs, raft::random::GenPhilox); + auto random_seed = DT::fnv1a32_basis; + random_seed = DT::fnv1a32(random_seed, rf_params.seed); + random_seed = DT::fnv1a32(random_seed, tree_id); + raft::random::Rng rng(random_seed, raft::random::GenPhilox); + + // generate the random state needed for cpu-side sampling + auto cpu_random_seed = DT::fnv1a32(random_seed, 1); + std::random_device rd; + std::mt19937 group_fold_rng(rd()); + group_fold_rng.seed(cpu_random_seed); + + std::vector> honest_group_assignments(2); + auto& splitting_groups = honest_group_assignments[0]; + auto& averaging_groups = honest_group_assignments[1]; + if (n_groups > 0) { + // Special handling for groups. We don't support split ratio honesty + const std::vector* current_fold_groups; + std::vector restricted_group_ixs; + std::vector restricted_group_ixs_diff; + int restricted_ix_size = n_groups; + if (rf_params.minTreesPerGroupFold > 0 and tree_id < group_tree_count) { + const int current_fold = tree_id / rf_params.minTreesPerGroupFold; + current_fold_groups = &fold_memberships[current_fold]; + restricted_group_ixs.resize(n_groups); + std::iota(restricted_group_ixs.begin(), restricted_group_ixs.end(), 0); + + restricted_ix_size = n_groups - current_fold_groups->size(); + restricted_group_ixs_diff.reserve(restricted_ix_size); + std::set_difference(restricted_group_ixs.begin(), + restricted_group_ixs.end(), + current_fold_groups->begin(), + current_fold_groups->end(), + std::inserter(restricted_group_ixs_diff, restricted_group_ixs_diff.begin())); + } + + if (rf_params.oob_honesty) { + const float split_ratio = 0.632; + + if (rf_params.minTreesPerGroupFold > 0 and tree_id < group_tree_count) { + // Doing group / fold "leave-out" logic + int honest_split_size = split_ratio * (n_groups - current_fold_groups->size()); + + splitting_groups.resize(honest_split_size); + averaging_groups.resize(restricted_ix_size - honest_split_size); + + assign_groups_to_folds( + restricted_ix_size, + 2, + honest_split_size, + honest_group_assignments, + group_fold_rng); + + // Replace indices with the actual groups + for (int ix_group = 0; ix_group < honest_split_size; ix_group++) { + honest_group_assignments[0][ix_group] = restricted_group_ixs_diff[honest_group_assignments[0][ix_group]]; + } + + for (int ix_group = 0; ix_group < restricted_ix_size - honest_split_size; ix_group++) { + honest_group_assignments[1][ix_group] = restricted_group_ixs_diff[honest_group_assignments[1][ix_group]]; + } + } else { + // Here we're not doing folds, we're partitioning the groups directly. We're also not leaving out groups? + // Easy enough so I'll add it, to match rforestry functionality + int honest_split_size = std::round(split_ratio * static_cast(n_groups)); + + // Avoid empty set + if (honest_split_size == n_groups) { + honest_split_size = n_groups - 1; + } else if (honest_split_size == 0) { + honest_split_size = 1; + } + + splitting_groups.resize(honest_split_size); + averaging_groups.resize(honest_split_size); + assign_groups_to_folds( + n_groups, + 2, + honest_split_size, + honest_group_assignments, + group_fold_rng); + } + + leave_groups_out_sample(remaining_groups, remaining_samples, selected_rows, workspace, + groups, splitting_groups, n_sampled_rows, 0, stream, rng); + + leave_groups_out_sample(remaining_groups, remaining_samples, selected_rows, workspace, + groups, averaging_groups, n_sampled_rows, n_sampled_rows, stream, rng); + + update_averaging_mask(split_row_mask, n_sampled_rows, stream); + + return n_sampled_rows; // averaging sample count + + } else if (rf_params.minTreesPerGroupFold > 0 and tree_id < group_tree_count) { + // Just don't use samples from the current fold for splitting. No averaging. + leave_groups_out_sample(remaining_groups, remaining_samples, selected_rows, workspace, + groups, restricted_group_ixs_diff, n_sampled_rows, 0, stream, rng); + return 0; // no averaging samples + } + } if (rf_params.bootstrap) { // Use bootstrapped sample set @@ -91,6 +395,7 @@ class RandomForest { size_t num_avg_samples = 0; if (rf_params.oob_honesty and rf_params.bootstrap) { + selected_rows->resize(n_sampled_rows, stream); // honesty doesn't make sense without bootstrapping -- all the obs were otherwise selected num_avg_samples = n_sampled_rows; assert(rf_params.bootstrap); @@ -100,35 +405,27 @@ class RandomForest { thrust::sort(thrust::cuda::par.on(stream), selected_rows->begin(), selected_rows->end()); // Get the set of observations that are not used for split - tmp_row_vec->resize(n_rows, stream); auto iter_end = thrust::set_difference( thrust::cuda::par.on(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(n_sampled_rows), selected_rows->begin(), selected_rows->end(), - tmp_row_vec->begin()); + remaining_samples->begin()); - // Now tmp_row_vec is the observations available for the averaging set - size_t num_remaining_samples = iter_end - tmp_row_vec->begin(); - tmp_row_vec->resize(num_remaining_samples, stream); + // Now remaining_samples is the observations available for the averaging set + size_t num_remaining_samples = iter_end - remaining_samples->begin(); + remaining_samples->resize(num_remaining_samples, stream); // Get the avg selected rows either as the remaining data, or bootstrapped again + selected_rows->resize(n_sampled_rows * 2, stream); if (rf_params.double_bootstrap) { - rng.uniformInt(selected_rows->data() + n_sampled_rows, num_avg_samples, 0, num_remaining_samples, stream); - auto index_iter = thrust::make_permutation_iterator(tmp_row_vec->begin(), selected_rows->begin() + n_sampled_rows); - - thrust::copy(thrust::cuda::par.on(stream), index_iter, index_iter + num_avg_samples, selected_rows->begin() + n_sampled_rows); + sample_rows_from_remaining_rows(remaining_samples, selected_rows, workspace, n_sampled_rows, n_sampled_rows, stream, rng); } else { - thrust::copy(thrust::cuda::par.on(stream), tmp_row_vec->begin(), tmp_row_vec->end(), selected_rows->begin() + n_sampled_rows); + thrust::copy(thrust::cuda::par.on(stream), remaining_samples->begin(), remaining_samples->end(), selected_rows->begin() + n_sampled_rows); } - // First n_rows goes to training, num_avg_samples goes to averaging. - auto begin_zip_iterator = thrust::make_zip_iterator(thrust::make_tuple( - thrust::make_counting_iterator(0), split_row_mask->begin())); - auto end_zip_iterator = thrust::make_zip_iterator(thrust::make_tuple( - thrust::make_counting_iterator(2 * n_sampled_rows), split_row_mask->end())); - thrust::for_each(thrust::cuda::par.on(stream), begin_zip_iterator, end_zip_iterator, thrust::make_zip_function(set_mask_functor(n_sampled_rows))); + update_averaging_mask(split_row_mask, n_sampled_rows, stream); } return num_avg_samples; @@ -206,13 +503,79 @@ class RandomForest { n_streams, handle.get_stream_pool_size()); + int n_trees = this->rf_params.n_trees; + // Compute the number of trees. This might change based on the "groups" logic + std::vector> foldMemberships; + const int foldGroupSize = this->rf_params.foldGroupSize; + const int minTreesPerGroupFold = this->rf_params.minTreesPerGroupFold; + std::unique_ptr> groups; + int n_groups = 0; + + if (this->rf_params.group_col_idx >= 0) { + // Here we'll going to do a unique, then build a vector of indices into the unique vector + cudaStream_t stream = handle.get_stream_from_stream_pool(0); + rmm::device_uvector input_groups(n_rows, stream); + rmm::device_uvector input_groups_unique(n_rows, stream); + groups = std::make_unique>(n_rows, stream); + cudaMemcpyAsync(input_groups.data(), + input + n_rows * this->rf_params.group_col_idx, + n_rows * sizeof(T), cudaMemcpyDefault, stream); + cudaMemcpyAsync(input_groups_unique.data(), + input + n_rows * this->rf_params.group_col_idx, + n_rows * sizeof(T), cudaMemcpyDefault, stream); + // Sadly we have to sort the entire array for unique to work. Is there + // a way to just unique the unsorted array? + thrust::sort(thrust::cuda::par.on(stream), + input_groups_unique.data(), + input_groups_unique.data() + n_rows); + T* new_end = thrust::unique(thrust::cuda::par.on(stream), + input_groups_unique.data(), + input_groups_unique.data() + n_rows); + // Now we'll have n_groups and can use some iterator to find the values for each group + n_groups = new_end - input_groups_unique.data(); + + UniqueTransformFunctor transform_fn{input_groups_unique.data(), n_groups}; + thrust::transform( + thrust::cuda::par.on(stream), + input_groups.data(), + input_groups.data() + n_rows, + groups->data(), + transform_fn); + } + + int group_tree_count = 0; + if (minTreesPerGroupFold > 0 and n_groups > 0) { + // Use a separate RNG and the std functions for group membership. + std::random_device rd; + std::mt19937 group_fold_rng(rd()); + auto random_seed = DT::fnv1a32_basis; + random_seed = DT::fnv1a32(random_seed, this->rf_params.seed); + random_seed = DT::fnv1a32(random_seed, std::numeric_limits::max()); + group_fold_rng.seed(random_seed); + + int n_folds = (n_groups + foldGroupSize - 1) / foldGroupSize; + group_tree_count = n_folds * minTreesPerGroupFold; + n_trees = std::max(n_trees, group_tree_count); + + forest->trees.resize(n_trees); + // TODO: Why are there 2 separate rf_params structs? + // I think it would be best for this class to hold a pointer to the one that's passed in to fit. + forest->rf_params.n_trees = n_trees; + this->rf_params.n_trees = n_trees; + + foldMemberships = std::vector>(n_folds, std::vector(minTreesPerGroupFold)); + + //assign group to fold + assign_groups_to_folds(n_groups, n_folds, foldGroupSize, foldMemberships, group_fold_rng); + } + // computing the quantiles: last two return values are shared pointers to device memory // encapsulated by quantiles struct auto [quantiles, quantiles_array, n_bins_array] = DT::computeQuantiles(handle, input, this->rf_params.tree_params.max_n_bins, n_rows, n_cols); // n_streams should not be less than n_trees - if (this->rf_params.n_trees < n_streams) n_streams = this->rf_params.n_trees; + if (n_trees < n_streams) n_streams = n_trees; // Select n_sampled_rows (with replacement) numbers from [0, n_rows) per tree. // selected_rows: randomly generated IDs for bootstrapped samples (w/ replacement); a device @@ -221,29 +584,46 @@ class RandomForest { // constructor std::deque> selected_rows; std::deque> split_row_masks; - std::deque> sampling_staging_vecs; + std::deque> workspaces; + std::deque> remaining_groups_vec; + std::deque> remaining_samples_vec; + const bool use_extra_vecs = this->rf_params.oob_honesty or this->rf_params.minTreesPerGroupFold > 0; size_t max_sample_row_size = this->rf_params.oob_honesty ? n_sampled_rows * 2 : n_sampled_rows; for (int i = 0; i < n_streams; i++) { auto s = handle.get_stream_from_stream_pool(i); selected_rows.emplace_back(max_sample_row_size, s); - if (this->rf_params.oob_honesty) { + if (use_extra_vecs) { split_row_masks.emplace_back(max_sample_row_size, s); - sampling_staging_vecs.emplace_back(n_rows, s); + workspaces.emplace_back(n_rows, s); + remaining_samples_vec.emplace_back(n_rows, s); + } + if (n_groups > 0) { + remaining_groups_vec.emplace_back(n_groups, s); } } #pragma omp parallel for num_threads(n_streams) - for (int i = 0; i < this->rf_params.n_trees; i++) { + for (int i = 0; i < n_trees; i++) { int stream_id = omp_get_thread_num(); auto s = handle.get_stream_from_stream_pool(stream_id); - rmm::device_uvector* sampling_staging_vec = this->rf_params.oob_honesty ? &sampling_staging_vecs[stream_id] : nullptr; - rmm::device_uvector* split_row_mask = this->rf_params.oob_honesty ? &split_row_masks[stream_id] : nullptr; + rmm::device_uvector* remaining_samples = use_extra_vecs ? &remaining_samples_vec[stream_id] : nullptr; + rmm::device_uvector* workspace = use_extra_vecs ? &workspaces[stream_id] : nullptr; + rmm::device_uvector* remaining_groups = n_groups > 0 ? &remaining_groups_vec[stream_id] : nullptr; + rmm::device_uvector* split_row_mask = use_extra_vecs ? &split_row_masks[stream_id] : nullptr; + int* this_groups = n_groups > 0 ? groups->data() : nullptr; auto n_avg_samples = this->get_row_sample( - i, n_rows, n_sampled_rows, &selected_rows[stream_id], + i, n_rows, n_sampled_rows, + &selected_rows[stream_id], split_row_mask, - sampling_staging_vec, + remaining_groups, + remaining_samples, + workspace, + this_groups, + n_groups, + group_tree_count, + foldMemberships, s); /* Build individual tree in the forest. @@ -286,6 +666,7 @@ class RandomForest { i); } } + // Cleanup handle.sync_stream_pool(); handle.sync_stream(); diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index 627fddc97c..e2f185d30f 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -43,6 +43,8 @@ #include #include #include +#include +#include #include @@ -237,7 +239,10 @@ auto TrainScore( 0, params.split_criterion, params.n_streams, - 128); + 128, + 0, + 0, + -1); auto forest = std::make_shared>(); auto forest_ptr = forest.get(); @@ -571,9 +576,62 @@ TEST(RfTests, IntegerOverflow) auto stream_pool = std::make_shared(4); raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); RF_params rf_params = - set_rf_params(3, 100, 1.0, 256, 1, 0, 2, 0, 0.0, false, false, false, 1, 1.0, 0, CRITERION::MSE, 4, 128); + set_rf_params(3, 100, 1.0, 256, 1, 0, 2, 0, 0.0, false, false, false, 1, 1.0, 0, CRITERION::MSE, 4, 128, 0, 0, -1); fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); + // Check we have actually learned something + EXPECT_GT(forest->trees[0]->leaf_counter, 1); + + // See if fil overflows + thrust::device_vector pred(m); + ModelHandle model; + build_treelite_forest(&model, forest_ptr, n); + + std::size_t num_outputs = 1; + fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO, + num_outputs > 1, + 1.f / num_outputs, + fil::storage_type_t::AUTO, + 8, + 1, + 0, + nullptr}; + + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); + fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); +} +namespace { + struct TransformFunctor { + __device__ float operator()(float input) { + return roundf(input); + } + }; +} + + +TEST(RfTests, Honesty) +{ + std::size_t m = 10000; + std::size_t n = 2150; + thrust::device_vector X(m * n); + thrust::device_vector y(m); + raft::random::Rng r(4); + r.normal(X.data().get(), X.size(), 0.0f, 2.0f, nullptr); + cudaStream_t stream; + cudaStreamCreate(&stream); + thrust::transform(thrust::cuda::par.on(stream), X.data(), + X.data() + m, X.data(), TransformFunctor{}); + // quantize the first column so that we can use it for meaningful groups + r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr); + auto forest = std::make_shared>(); + auto forest_ptr = forest.get(); + auto stream_pool = std::make_shared(4); + raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + RF_params rf_params = + set_rf_params(3, 100, 1.0, 256, 1, 1, 2, 2, 0.0, true, true, true, 5, 1.0, 0, CRITERION::MSE, 1, 128, 0, 0, -1); + fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); // Check we have actually learned something EXPECT_GT(forest->trees[0]->leaf_counter, 1); @@ -591,6 +649,278 @@ TEST(RfTests, IntegerOverflow) 1, 0, nullptr}; + + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); + fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); +} + +TEST(RfTests, SmallHonestFolds) +{ + std::size_t m = 1000; + std::size_t n = 10; + thrust::device_vector X(m * n); + thrust::device_vector y(m); + raft::random::Rng r(42); + r.normal(X.data().get(), X.size(), 0.0f, 1.0f, nullptr); + cudaStream_t stream; + cudaStreamCreate(&stream); + thrust::transform(thrust::cuda::par.on(stream), X.data(), + X.data() + m, X.data(), TransformFunctor{}); + // quantize the first column so that we can use it for meaningful groups + r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr); + auto forest = std::make_shared>(); + auto forest_ptr = forest.get(); + auto stream_pool = std::make_shared(4); + raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + RF_params rf_params = + set_rf_params(3, 100, 1.0, 256, 1, 1, 2, 2, 0.0, true, true, true, 1, 1.0, 0, CRITERION::MSE, 1, 128, 1, 5, 0); + fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); + // Check we have actually learned something + EXPECT_GT(forest->trees[0]->leaf_counter, 1); + + // See if fil overflows + thrust::device_vector pred(m); + ModelHandle model; + build_treelite_forest(&model, forest_ptr, n); + + std::size_t num_outputs = 1; + fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO, + num_outputs > 1, + 1.f / num_outputs, + fil::storage_type_t::AUTO, + 8, + 1, + 0, + nullptr}; + + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); + fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); +} + + +TEST(RfTests, SmallHonestFoldsWithFallback) +{ + std::size_t m = 1000; + std::size_t n = 10; + thrust::device_vector X(m * n); + thrust::device_vector y(m); + raft::random::Rng r(42); + r.normal(X.data().get(), X.size(), 0.0f, 1.0f, nullptr); + cudaStream_t stream; + cudaStreamCreate(&stream); + thrust::transform(thrust::cuda::par.on(stream), X.data(), + X.data() + m, X.data(), TransformFunctor{}); + // quantize the first column so that we can use it for meaningful groups + r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr); + auto forest = std::make_shared>(); + auto forest_ptr = forest.get(); + auto stream_pool = std::make_shared(4); + raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + RF_params rf_params = + set_rf_params(3, 100, 1.0, 256, 1, 1, 2, 2, 0.0, true, true, true, 100, 1.0, 0, CRITERION::MSE, 1, 128, 1, 5, 0); + fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); + // Check we have actually learned something + EXPECT_GT(forest->trees[0]->leaf_counter, 1); + + // See if fil overflows + thrust::device_vector pred(m); + ModelHandle model; + build_treelite_forest(&model, forest_ptr, n); + + std::size_t num_outputs = 1; + fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO, + num_outputs > 1, + 1.f / num_outputs, + fil::storage_type_t::AUTO, + 8, + 1, + 0, + nullptr}; + + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); + fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); +} + +TEST(RfTests, SmallDishonestFoldsWithFallback) +{ + std::size_t m = 1000; + std::size_t n = 10; + thrust::device_vector X(m * n); + thrust::device_vector y(m); + raft::random::Rng r(42); + r.normal(X.data().get(), X.size(), 0.0f, 1.0f, nullptr); + cudaStream_t stream; + cudaStreamCreate(&stream); + thrust::transform(thrust::cuda::par.on(stream), X.data(), + X.data() + m, X.data(), TransformFunctor{}); + // quantize the first column so that we can use it for meaningful groups + r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr); + auto forest = std::make_shared>(); + auto forest_ptr = forest.get(); + auto stream_pool = std::make_shared(4); + raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + RF_params rf_params = + set_rf_params(3, 100, 1.0, 256, 1, 1, 2, 2, 0.0, true, false, false, 100, 1.0, 0, CRITERION::MSE, 1, 128, 1, 5, 0); + fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); + // Check we have actually learned something + EXPECT_GT(forest->trees[0]->leaf_counter, 1); + + // See if fil overflows + thrust::device_vector pred(m); + ModelHandle model; + build_treelite_forest(&model, forest_ptr, n); + + std::size_t num_outputs = 1; + fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO, + num_outputs > 1, + 1.f / num_outputs, + fil::storage_type_t::AUTO, + 8, + 1, + 0, + nullptr}; + + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); + fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); +} + +TEST(RfTests, HonestFolds) +{ + std::size_t m = 10000; + std::size_t n = 2150; + thrust::device_vector X(m * n); + thrust::device_vector y(m); + raft::random::Rng r(4); + r.normal(X.data().get(), X.size(), 0.0f, 2.0f, nullptr); + cudaStream_t stream; + cudaStreamCreate(&stream); + thrust::transform(thrust::cuda::par.on(stream), X.data(), + X.data() + m, X.data(), TransformFunctor{}); + // quantize the first column so that we can use it for meaningful groups + r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr); + auto forest = std::make_shared>(); + auto forest_ptr = forest.get(); + auto stream_pool = std::make_shared(4); + raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + RF_params rf_params = + set_rf_params(3, 100, 1.0, 256, 1, 1, 2, 2, 0.0, true, true, true, 5, 1.0, 0, CRITERION::MSE, 4, 128, 2, 2, 0); + fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); + // Check we have actually learned something + EXPECT_GT(forest->trees[0]->leaf_counter, 1); + + // See if fil overflows + thrust::device_vector pred(m); + ModelHandle model; + build_treelite_forest(&model, forest_ptr, n); + + std::size_t num_outputs = 1; + fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO, + num_outputs > 1, + 1.f / num_outputs, + fil::storage_type_t::AUTO, + 8, + 1, + 0, + nullptr}; + + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); + fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); +} + +TEST(RfTests, HonestGroups) +{ + std::size_t m = 10000; + std::size_t n = 2150; + thrust::device_vector X(m * n); + thrust::device_vector y(m); + raft::random::Rng r(4); + r.normal(X.data().get(), X.size(), 0.0f, 2.0f, nullptr); + cudaStream_t stream; + cudaStreamCreate(&stream); + thrust::transform(thrust::cuda::par.on(stream), X.data(), + X.data() + m, X.data(), TransformFunctor{}); + // quantize the first column so that we can use it for meaningful groups + r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr); + auto forest = std::make_shared>(); + auto forest_ptr = forest.get(); + auto stream_pool = std::make_shared(4); + raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + RF_params rf_params = + set_rf_params(3, 100, 1.0, 256, 1, 1, 2, 2, 0.0, true, true, true, 5, 1.0, 0, CRITERION::MSE, 4, 128, 0, 0, 0); + fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); + // Check we have actually learned something + EXPECT_GT(forest->trees[0]->leaf_counter, 1); + + // See if fil overflows + thrust::device_vector pred(m); + ModelHandle model; + build_treelite_forest(&model, forest_ptr, n); + + std::size_t num_outputs = 1; + fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO, + num_outputs > 1, + 1.f / num_outputs, + fil::storage_type_t::AUTO, + 8, + 1, + 0, + nullptr}; + + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); + fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); +} + +TEST(RfTests, DishonestFolds) +{ + std::size_t m = 10000; + std::size_t n = 2150; + thrust::device_vector X(m * n); + thrust::device_vector y(m); + raft::random::Rng r(4); + r.normal(X.data().get(), X.size(), 0.0f, 2.0f, nullptr); + cudaStream_t stream; + cudaStreamCreate(&stream); + thrust::transform(thrust::cuda::par.on(stream), X.data(), + X.data() + m, X.data(), TransformFunctor{}); + // quantize the first column so that we can use it for meaningful groups + r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr); + auto forest = std::make_shared>(); + auto forest_ptr = forest.get(); + auto stream_pool = std::make_shared(4); + raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + RF_params rf_params = + set_rf_params(3, 100, 1.0, 256, 1, 1, 2, 2, 0.0, true, false, false, 5, 1.0, 0, CRITERION::MSE, 4, 128, 2, 2, 0); + fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); + // Check we have actually learned something + EXPECT_GT(forest->trees[0]->leaf_counter, 1); + + // See if fil overflows + thrust::device_vector pred(m); + ModelHandle model; + build_treelite_forest(&model, forest_ptr, n); + + std::size_t num_outputs = 1; + fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO, + num_outputs > 1, + 1.f / num_outputs, + fil::storage_type_t::AUTO, + 8, + 1, + 0, + nullptr}; + fil::forest_variant forest_variant; fil::from_treelite(handle, &forest_variant, model, &tl_params); fil::forest_t fil_forest = std::get>(forest_variant); @@ -813,7 +1143,7 @@ INSTANTIATE_TEST_CASE_P(RfTests, RFQuantileVariableBinsTestD, ::testing::ValuesI TEST(RfTest, TextDump) { - RF_params rf_params = set_rf_params(2, 2, 1.0, 2, 1, 0, 2, 0, 0.0, false, false, false, 1, 1.0, 0, GINI, 1, 128); + RF_params rf_params = set_rf_params(2, 2, 1.0, 2, 1, 0, 2, 0, 0.0, false, false, false, 1, 1.0, 0, GINI, 1, 128, 0, 0, -1); auto forest = std::make_shared>(); std::vector X_host = {1, 2, 3, 6, 7, 8}; @@ -954,7 +1284,6 @@ class ObjectiveTest : public ::testing::TestWithParam { (n_right / n) * right_mse); // gain in long form without proxy // edge cases - printf("gain %f\n", gain); if (n_left < params.min_samples_leaf_splitting or n_right < params.min_samples_leaf_splitting) return -std::numeric_limits::max(); else @@ -1339,5 +1668,54 @@ INSTANTIATE_TEST_CASE_P(RfTests, GiniObjectiveTestF, ::testing::ValuesIn(gini_objective_test_parameters)); +TEST(rf_unit_tests, lower_bound) { + // Initialize sorted array + cudaStream_t stream; + cudaStreamCreate(&stream); + int* sorted_array; + int* input_array; + int* output_array; + const int num_inputs = 200; + const int max_val = 100; + cudaMalloc(&sorted_array, num_inputs * sizeof(int)); + cudaMalloc(&input_array, num_inputs * sizeof(int)); + cudaMalloc(&output_array, num_inputs * sizeof(int)); + + // Generate random samples + raft::random::Rng rng(42, raft::random::GenPhilox); + rng.uniformInt(input_array, num_inputs, 0, max_val, stream); + cudaMemcpy(sorted_array, input_array, num_inputs * sizeof(int), cudaMemcpyDefault); + thrust::sort(thrust::cuda::par.on(stream), + sorted_array, + sorted_array + num_inputs); + + int* new_end = thrust::unique(thrust::cuda::par.on(stream), + sorted_array, + sorted_array + num_inputs); + + const int n_unique = new_end - sorted_array; + + lower_bound_test(sorted_array, input_array, n_unique, num_inputs, output_array, stream); + + std::vector input_vec(num_inputs); + std::vector sorted_vec(n_unique); + std::vector gpu_output_vec(num_inputs); + + cudaMemcpy(input_vec.data(), input_array, num_inputs * sizeof(int), cudaMemcpyDefault); + cudaMemcpy(gpu_output_vec.data(), output_array, num_inputs * sizeof(int), cudaMemcpyDefault); + cudaMemcpy(sorted_vec.data(), sorted_array, n_unique * sizeof(int), cudaMemcpyDefault); + + for (int ix = 0; ix < num_inputs; ++ix) { + int res = std::distance(sorted_vec.begin(), + std::lower_bound(sorted_vec.begin(), sorted_vec.end(), input_vec[ix])); + ASSERT_EQ(res, gpu_output_vec[ix]); + } + + // Compare against std::lower_bound + cudaFree(sorted_array); + cudaFree(input_array); + cudaFree(output_array); +} + } // end namespace DT } // end namespace ML diff --git a/honesty_test.py b/honesty_test.py index ea90c52f29..1d6c044cf9 100755 --- a/honesty_test.py +++ b/honesty_test.py @@ -47,6 +47,19 @@ n_trees = 100 +# Start group call -- note we're not using groups to specify OOB predictions +# Note, here we specify a column index to use for groups. Then the fit() function +# will use the GPU to compute unique group ids for every sample. +group_col_idx = x.columns.get_loc("state") +random_forest_regress = RFR(n_estimators=n_trees, oob_honesty=True, split_criterion=2, + random_state=42, minTreesPerGroupFold=5, foldGroupSize=1, group_col_idx=group_col_idx) +start = time.time() +trainedRFR = random_forest_regress.fit(x_train, y_train) +end = time.time() +pred_test_regress = trainedRFR.predict(x_test) +mse = cuml.metrics.mean_squared_error(y_test, pred_test_regress) +print(f"Group Honesty {mse} time {end-start}") + random_forest_regress = RFR(n_estimators=n_trees, split_criterion=2, random_state=42) start = time.time() trainedRFR = random_forest_regress.fit(x_train, y_train) @@ -62,3 +75,4 @@ pred_test_regress = trainedRFR.predict(x_test) mse = cuml.metrics.mean_squared_error(y_test, pred_test_regress) print(f"Honesty {mse} time {end-start}") + diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py index 1e1203d39c..6bb0c34a89 100755 --- a/python/cuml/dask/ensemble/randomforestclassifier.py +++ b/python/cuml/dask/ensemble/randomforestclassifier.py @@ -172,6 +172,35 @@ class RandomForestClassifier( This is an experimental parameter, and may be removed in the future. + minTreesPerGroupFold : int (default = 0) + Comment from rforestry: + The number of trees which we make sure have been created leaving + out each fold (each fold is a set of randomly selected groups). + This is 0 by default, so we will not give any special treatment to + the groups when sampling observations, however if this is set to a positive integer, we + modify the bootstrap sampling scheme to ensure that exactly that many trees + have each group left out. We do this by, for each fold, creating minTreesPerGroupFold + trees which are built on observations sampled from the set of training observations + which are not in a group in the current fold. The folds form a random partition of + all of the possible groups, each of size foldGroupSize. This means we create at + least # folds * minTreesPerGroupFold trees for the forest. + If ntree > # folds * minTreesPerGroupFold, we create + max(# folds * minTreesPerGroupFold, ntree) total trees, in which at least minTreesPerGroupFold + are created leaving out each fold. + + foldGroupSize : int (default = 0) + Comment from rforestry: + The number of groups that are selected randomly for each fold to be + left out when using minTreesPerGroupFold. When minTreesPerGroupFold is set and foldGroupSize is + set, all possible groups will be partitioned into folds, each containing foldGroupSize unique groups + (if foldGroupSize doesn't evenly divide the number of groups, a single fold will be smaller, + as it will contain the remaining groups). Then minTreesPerGroupFold are grown with each + entire fold of groups left out. + + group_col_idx : int (default = -1) + The numeric index of the column to be used for group processing + + Examples -------- For usage examples, please see the RAPIDS notebooks repository: diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py index a38ff0dcc9..1ea049b678 100755 --- a/python/cuml/dask/ensemble/randomforestregressor.py +++ b/python/cuml/dask/ensemble/randomforestregressor.py @@ -163,6 +163,33 @@ class RandomForestRegressor( n_estimators) When False, throws a RuntimeError. This is an experimental parameter, and may be removed in the future. + minTreesPerGroupFold : int (default = 0) + Comment from rforestry: + The number of trees which we make sure have been created leaving + out each fold (each fold is a set of randomly selected groups). + This is 0 by default, so we will not give any special treatment to + the groups when sampling observations, however if this is set to a positive integer, we + modify the bootstrap sampling scheme to ensure that exactly that many trees + have each group left out. We do this by, for each fold, creating minTreesPerGroupFold + trees which are built on observations sampled from the set of training observations + which are not in a group in the current fold. The folds form a random partition of + all of the possible groups, each of size foldGroupSize. This means we create at + least # folds * minTreesPerGroupFold trees for the forest. + If ntree > # folds * minTreesPerGroupFold, we create + max(# folds * minTreesPerGroupFold, ntree) total trees, in which at least minTreesPerGroupFold + are created leaving out each fold. + + foldGroupSize : int (default = 0) + Comment from rforestry: + The number of groups that are selected randomly for each fold to be + left out when using minTreesPerGroupFold. When minTreesPerGroupFold is set and foldGroupSize is + set, all possible groups will be partitioned into folds, each containing foldGroupSize unique groups + (if foldGroupSize doesn't evenly divide the number of groups, a single fold will be smaller, + as it will contain the remaining groups). Then minTreesPerGroupFold are grown with each + entire fold of groups left out. + + group_col_idx : int (default = -1) + The numeric index of the column to be used for group processing """ diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 88f559b097..52ca2f3ceb 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -44,8 +44,11 @@ from cuml.prims.label.classlabels import make_monotonic, check_labels class BaseRandomForestModel(Base): _param_names = ['n_estimators', 'max_depth', 'handle', 'max_features', 'n_bins', - 'split_criterion', 'min_samples_leaf', - 'min_samples_split', + 'split_criterion', + 'min_samples_leaf_splitting', + 'min_samples_leaf_averaging', + 'min_samples_split_splitting', + 'min_samples_split_averaging', 'min_impurity_decrease', 'bootstrap', 'verbose', 'max_samples', @@ -55,7 +58,8 @@ class BaseRandomForestModel(Base): 'output_type', 'min_weight_fraction_leaf', 'n_jobs', 'max_leaf_nodes', 'min_impurity_split', 'oob_score', 'random_state', 'warm_start', 'class_weight', - 'criterion'] + 'criterion', 'minTreesPerGroupFold', 'foldGroupSize', + 'group_col_idx'] criterion_dict = {'0': GINI, 'gini': GINI, '1': ENTROPY, 'entropy': ENTROPY, @@ -80,7 +84,11 @@ class BaseRandomForestModel(Base): min_impurity_split=None, oob_score=None, random_state=None, warm_start=None, class_weight=None, criterion=None, - max_batch_size=4096, **kwargs): + max_batch_size=4096, + minTreesPerGroupFold=0, + foldGroupSize=1, + group_col_idx=-1, + **kwargs): sklearn_params = {"criterion": criterion, "min_weight_fraction_leaf": min_weight_fraction_leaf, @@ -161,6 +169,9 @@ class BaseRandomForestModel(Base): self.model_pbuf_bytes = bytearray() self.treelite_handle = None self.treelite_serialized_model = None + self.minTreesPerGroupFold = minTreesPerGroupFold + self.foldGroupSize = foldGroupSize + self.group_col_idx = group_col_idx def _get_max_feat_val(self) -> float: if type(self.max_features) == int: diff --git a/python/cuml/ensemble/randomforest_shared.pxd b/python/cuml/ensemble/randomforest_shared.pxd index 0e8d80d563..f4d03d89ca 100644 --- a/python/cuml/ensemble/randomforest_shared.pxd +++ b/python/cuml/ensemble/randomforest_shared.pxd @@ -110,6 +110,9 @@ cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": uint64_t, CRITERION, int, + int, + int, + int, int) except + cdef vector[unsigned char] save_model(ModelHandle) diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 4c22cee406..284f811a7c 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -258,7 +258,34 @@ class RandomForestClassifier(BaseRandomForestModel, type. If None, the output type set at the module level (`cuml.global_settings.output_type`) will be used. See :ref:`output-data-type-configuration` for more info. - + minTreesPerGroupFold : int (default = 0) + Comment from rforestry: + The number of trees which we make sure have been created leaving + out each fold (each fold is a set of randomly selected groups). + This is 0 by default, so we will not give any special treatment to + the groups when sampling observations, however if this is set to a positive integer, we + modify the bootstrap sampling scheme to ensure that exactly that many trees + have each group left out. We do this by, for each fold, creating minTreesPerGroupFold + trees which are built on observations sampled from the set of training observations + which are not in a group in the current fold. The folds form a random partition of + all of the possible groups, each of size foldGroupSize. This means we create at + least # folds * minTreesPerGroupFold trees for the forest. + If ntree > # folds * minTreesPerGroupFold, we create + max(# folds * minTreesPerGroupFold, ntree) total trees, in which at least minTreesPerGroupFold + are created leaving out each fold. + + foldGroupSize : int (default = 1) + Comment from rforestry: + The number of groups that are selected randomly for each fold to be + left out when using minTreesPerGroupFold. When minTreesPerGroupFold is set and foldGroupSize is + set, all possible groups will be partitioned into folds, each containing foldGroupSize unique groups + (if foldGroupSize doesn't evenly divide the number of groups, a single fold will be smaller, + as it will contain the remaining groups). Then minTreesPerGroupFold are grown with each + entire fold of groups left out. + + group_col_idx : int (default = -1) + The numeric index of the column to be used for group processing + Notes ----- **Known Limitations**\n @@ -501,7 +528,10 @@ class RandomForestClassifier(BaseRandomForestModel, seed_val, self.split_criterion, self.n_streams, - self.max_batch_size) + self.max_batch_size, + self.minTreesPerGroupFold, + self.foldGroupSize, + self.group_col_idx) if self.dtype == np.float32: fit(handle_[0], diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 5a79b92d9b..85d25b1fde 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -261,6 +261,33 @@ class RandomForestRegressor(BaseRandomForestModel, type. If None, the output type set at the module level (`cuml.global_settings.output_type`) will be used. See :ref:`output-data-type-configuration` for more info. + minTreesPerGroupFold : int (default = 0) + Comment from rforestry: + The number of trees which we make sure have been created leaving + out each fold (each fold is a set of randomly selected groups). + This is 0 by default, so we will not give any special treatment to + the groups when sampling observations, however if this is set to a positive integer, we + modify the bootstrap sampling scheme to ensure that exactly that many trees + have each group left out. We do this by, for each fold, creating minTreesPerGroupFold + trees which are built on observations sampled from the set of training observations + which are not in a group in the current fold. The folds form a random partition of + all of the possible groups, each of size foldGroupSize. This means we create at + least # folds * minTreesPerGroupFold trees for the forest. + If ntree > # folds * minTreesPerGroupFold, we create + max(# folds * minTreesPerGroupFold, ntree) total trees, in which at least minTreesPerGroupFold + are created leaving out each fold. + + foldGroupSize : int (default = 1) + Comment from rforestry: + The number of groups that are selected randomly for each fold to be + left out when using minTreesPerGroupFold. When minTreesPerGroupFold is set and foldGroupSize is + set, all possible groups will be partitioned into folds, each containing foldGroupSize unique groups + (if foldGroupSize doesn't evenly divide the number of groups, a single fold will be smaller, + as it will contain the remaining groups). Then minTreesPerGroupFold are grown with each + entire fold of groups left out. + + group_col_idx : int (default = -1) + The numeric index of the column to be used for group processing Notes ----- @@ -486,7 +513,10 @@ class RandomForestRegressor(BaseRandomForestModel, seed_val, self.split_criterion, self.n_streams, - self.max_batch_size) + self.max_batch_size, + self.minTreesPerGroupFold, + self.foldGroupSize, + self.group_col_idx) if self.dtype == np.float32: fit(handle_[0],