From 5d27ba7fc8d059caae7751f5a8ad6e524aabe999 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Thu, 8 Jun 2023 16:50:54 -0600 Subject: [PATCH] Update RF algorithm to allow "honesty" feature which is useful for Causal inferencing --- cpp/CMakeLists.txt | 14 + cpp/bench/sg/fil.cu | 8 +- cpp/bench/sg/filex.cu | 8 +- cpp/bench/sg/rf_classifier.cu | 8 +- cpp/include/cuml/ensemble/randomforest.hpp | 28 +- cpp/include/cuml/tree/decisiontree.hpp | 38 +- .../decisiontree/batched-levelalgo/bins.cuh | 79 ++- .../batched-levelalgo/builder.cuh | 68 ++- .../decisiontree/batched-levelalgo/dataset.h | 7 + .../kernels/builder_kernels.cuh | 17 +- .../kernels/builder_kernels_impl.cuh | 89 +++- .../kernels/entropy-double-honest.cu | 34 ++ .../kernels/entropy-double.cu | 2 +- .../kernels/entropy-float-honest.cu | 34 ++ .../kernels/entropy-float.cu | 2 +- .../kernels/gamma-double-honest.cu | 34 ++ .../batched-levelalgo/kernels/gamma-double.cu | 2 +- .../kernels/gamma-float-honest.cu | 34 ++ .../batched-levelalgo/kernels/gamma-float.cu | 2 +- .../kernels/gini-double-honest.cu | 34 ++ .../batched-levelalgo/kernels/gini-double.cu | 2 +- .../kernels/gini-float-honest.cu | 34 ++ .../batched-levelalgo/kernels/gini-float.cu | 2 +- .../kernels/inverse_gaussian-double-honest.cu | 34 ++ .../kernels/inverse_gaussian-double.cu | 2 +- .../kernels/inverse_gaussian-float-honest.cu | 34 ++ .../kernels/inverse_gaussian-float.cu | 2 +- .../kernels/mse-double-honest.cu | 34 ++ .../batched-levelalgo/kernels/mse-double.cu | 2 +- .../kernels/mse-float-honest.cu | 34 ++ .../batched-levelalgo/kernels/mse-float.cu | 2 +- .../kernels/poisson-double-honest.cu | 34 ++ .../kernels/poisson-double.cu | 2 +- .../kernels/poisson-float-honest.cu | 34 ++ .../kernels/poisson-float.cu | 2 +- .../batched-levelalgo/objectives.cuh | 464 ++++++++++++------ .../decisiontree/batched-levelalgo/split.cuh | 60 ++- cpp/src/decisiontree/decisiontree.cu | 49 +- cpp/src/decisiontree/decisiontree.cuh | 130 +++-- cpp/src/randomforest/randomforest.cu | 19 +- cpp/src/randomforest/randomforest.cuh | 141 +++++- cpp/test/sg/rf_test.cu | 123 +++-- honesty_test.py | 64 +++ python/cuml/benchmark/algorithms.py | 1 - .../dask/ensemble/randomforestclassifier.py | 47 +- .../dask/ensemble/randomforestregressor.py | 44 +- python/cuml/ensemble/randomforest_common.pyx | 31 +- python/cuml/ensemble/randomforest_shared.pxd | 4 + .../cuml/ensemble/randomforestclassifier.pyx | 51 +- .../cuml/ensemble/randomforestregressor.pyx | 53 +- python/cuml/tests/test_random_forest.py | 18 +- quick_build.sh | 4 + 52 files changed, 1606 insertions(+), 493 deletions(-) create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/gini-double-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/gini-float-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/mse-double-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/mse-float-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double-honest.cu create mode 100644 cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float-honest.cu create mode 100755 honesty_test.py create mode 100755 quick_build.sh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index fcac774aed..8ac988cd7d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -322,6 +322,20 @@ if(BUILD_CUML_CPP_LIBRARY) src/decisiontree/batched-levelalgo/kernels/mse-float.cu src/decisiontree/batched-levelalgo/kernels/poisson-double.cu src/decisiontree/batched-levelalgo/kernels/poisson-float.cu + + src/decisiontree/batched-levelalgo/kernels/entropy-double-honest.cu + src/decisiontree/batched-levelalgo/kernels/entropy-float-honest.cu + src/decisiontree/batched-levelalgo/kernels/gamma-double-honest.cu + src/decisiontree/batched-levelalgo/kernels/gamma-float-honest.cu + src/decisiontree/batched-levelalgo/kernels/gini-double-honest.cu + src/decisiontree/batched-levelalgo/kernels/gini-float-honest.cu + src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double-honest.cu + src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float-honest.cu + src/decisiontree/batched-levelalgo/kernels/mse-double-honest.cu + src/decisiontree/batched-levelalgo/kernels/mse-float-honest.cu + src/decisiontree/batched-levelalgo/kernels/poisson-double-honest.cu + src/decisiontree/batched-levelalgo/kernels/poisson-float-honest.cu + src/decisiontree/batched-levelalgo/kernels/quantiles.cu src/decisiontree/decisiontree.cu) endif() diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index 09efc1dfa3..e799db3626 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -151,10 +151,14 @@ std::vector getInputs() (1 << 20), /* max_leaves */ 1.f, /* max_features */ 32, /* max_n_bins */ - 3, /* min_samples_leaf */ - 3, /* min_samples_split */ + 3, /* min_samples_leaf_splitting */ + 3, /* min_samples_leaf_averaging */ + 3, /* min_samples_split_splitting */ + 3, /* min_samples_split_averaging */ 0.0f, /* min_impurity_decrease */ true, /* bootstrap */ + false, /* oob_honesty */ + true, /* double_bootstrap */ 1, /* n_trees */ 1.f, /* max_samples */ 1234ULL, /* seed */ diff --git a/cpp/bench/sg/filex.cu b/cpp/bench/sg/filex.cu index 048d89c3d9..c090dcf360 100644 --- a/cpp/bench/sg/filex.cu +++ b/cpp/bench/sg/filex.cu @@ -251,10 +251,14 @@ std::vector getInputs() (1 << 20), /* max_leaves */ 1.f, /* max_features */ 32, /* max_n_bins */ - 3, /* min_samples_leaf */ - 3, /* min_samples_split */ + 3, /* min_samples_leaf_splitting */ + 3, /* min_samples_leaf_averaging */ + 3, /* min_samples_split_splitting */ + 3, /* min_samples_split_averaging */ 0.0f, /* min_impurity_decrease */ true, /* bootstrap */ + false, /* oob_honesty */ + true, /* double_bootstrap */ 1, /* n_trees */ 1.f, /* max_samples */ 1234ULL, /* seed */ diff --git a/cpp/bench/sg/rf_classifier.cu b/cpp/bench/sg/rf_classifier.cu index 9f3a9b1a7d..c9703066f9 100644 --- a/cpp/bench/sg/rf_classifier.cu +++ b/cpp/bench/sg/rf_classifier.cu @@ -95,10 +95,14 @@ std::vector getInputs() (1 << 20), /* max_leaves */ 0.3, /* max_features */ 32, /* max_n_bins */ - 3, /* min_samples_leaf */ - 3, /* min_samples_split */ + 3, /* min_samples_leaf_splitting */ + 3, /* min_samples_leaf_averaging */ + 3, /* min_samples_split_splitting */ + 3, /* min_samples_split_averaging */ 0.0f, /* min_impurity_decrease */ true, /* bootstrap */ + false, /* oob_honesty */ + true, /* double_bootstrap */ 500, /* n_trees */ 1.f, /* max_samples */ 1234ULL, /* seed */ diff --git a/cpp/include/cuml/ensemble/randomforest.hpp b/cpp/include/cuml/ensemble/randomforest.hpp index bccc02bac2..b7c131f677 100644 --- a/cpp/include/cuml/ensemble/randomforest.hpp +++ b/cpp/include/cuml/ensemble/randomforest.hpp @@ -75,6 +75,26 @@ struct RF_params { * tree. */ bool bootstrap; + + /** + * Control whether to use honesty features to allow causal inferencing + * + * This indicates that the values used for averaging in the leaf node predictions + * should be a disjoint set with the labels used for splits during training. + * See this issue for more detail: https://github.com/rapidsai/cuml/issues/5253 + */ + bool oob_honesty; + + /** + * Honesty double bootstrapping + * + * With double bootstrapping, the set of samples that was not sampled for training + * is again sampled with replacement. This leaves some samples that could be used + * for double OOB prediction. TODO: how can we make the user aware of which + * samples could be used for double OOB prediction? + */ + bool double_bootstrap; + /** * Ratio of dataset rows used while fitting each tree. */ @@ -192,10 +212,14 @@ RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, - int min_samples_leaf, - int min_samples_split, + int min_samples_leaf_splitting, + int min_samples_leaf_averaging, + int min_samples_split_splitting, + int min_samples_split_averaging, float min_impurity_decrease, bool bootstrap, + bool oob_honesty, + bool double_bootstrap, int n_trees, float max_samples, uint64_t seed, diff --git a/cpp/include/cuml/tree/decisiontree.hpp b/cpp/include/cuml/tree/decisiontree.hpp index b6ccdb21c8..cfedb1e057 100644 --- a/cpp/include/cuml/tree/decisiontree.hpp +++ b/cpp/include/cuml/tree/decisiontree.hpp @@ -44,13 +44,21 @@ struct DecisionTreeParams { */ int max_n_bins; /** - * The minimum number of samples (rows) in each leaf node. + * The minimum number of splitting samples (rows) in each leaf node. */ - int min_samples_leaf; + int min_samples_leaf_splitting; + /** + * The minimum number of averaging samples (rows) in each leaf node. + */ + int min_samples_leaf_averaging; + /** + * The minimum number of splitting samples (rows) needed to split an internal node. + */ + int min_samples_split_splitting; /** - * The minimum number of samples (rows) needed to split an internal node. + * The minimum number of averaging samples (rows) needed to split an internal node. */ - int min_samples_split; + int min_samples_split_averaging; /** * Node split criterion. GINI and Entropy for classification, MSE for regression. */ @@ -66,6 +74,11 @@ struct DecisionTreeParams { * used only for batched-level algo */ int max_batch_size; + + /** + * Whether to use oob honesty features + */ + bool oob_honesty; }; /** @@ -75,9 +88,13 @@ struct DecisionTreeParams { * @param[in] cfg_max_leaves: maximum leaves; default -1 * @param[in] cfg_max_features: maximum number of features; default 1.0f * @param[in] cfg_max_n_bins: maximum number of bins; default 128 - * @param[in] cfg_min_samples_leaf: min. rows in each leaf node; default 1 - * @param[in] cfg_min_samples_split: min. rows needed to split an internal node; + * @param[in] cfg_min_samples_leaf_splitting: min. splitting rows in each leaf node; default 1 + * @param[in] cfg_min_samples_leaf_averaging: min. averaging rows in each leaf node when oobhonesty enabled; + * default 1 + * @param[in] cfg_min_samples_split_splitting: min. splitting rows needed to split an internal node; * default 2 + * @param[in] cfg_min_samples_split_averaging: min. averaging rows needed to split an internal + * node when oobhonest enabled; default 2 * @param[in] cfg_min_impurity_decrease: split a node only if its reduction in * impurity is more than this value * @param[in] cfg_split_criterion: split criterion; default CRITERION_END, @@ -91,11 +108,14 @@ void set_tree_params(DecisionTreeParams& params, int cfg_max_leaves = -1, float cfg_max_features = 1.0f, int cfg_max_n_bins = 128, - int cfg_min_samples_leaf = 1, - int cfg_min_samples_split = 2, + int cfg_min_samples_leaf_splitting = 1, + int cfg_min_samples_leaf_averaging = 1, + int cfg_min_samples_split_splitting = 2, + int cfg_min_samples_split_averaging = 2, float cfg_min_impurity_decrease = 0.0f, CRITERION cfg_split_criterion = CRITERION_END, - int cfg_max_batch_size = 4096); + int cfg_max_batch_size = 4096, + bool cfg_oob_honesty = false); template struct TreeMetaDataNode { diff --git a/cpp/src/decisiontree/batched-levelalgo/bins.cuh b/cpp/src/decisiontree/batched-levelalgo/bins.cuh index 312c4f2b51..8f3a3440c4 100644 --- a/cpp/src/decisiontree/batched-levelalgo/bins.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/bins.cuh @@ -25,7 +25,7 @@ struct CountBin { HDI CountBin(int x_) : x(x_) {} HDI CountBin() : x(0) {} - DI static void IncrementHistogram(CountBin* hist, int n_bins, int b, int label) + DI static void IncrementHistogram(CountBin* hist, int n_bins, int b, int label, bool /*is_split_row*/, bool /*split_op*/) { auto offset = label * n_bins + b; CountBin::AtomicAdd(hist + offset, {1}); @@ -43,6 +43,41 @@ struct CountBin { } }; +struct HonestCountBin : CountBin { + int x_averaging; + + HonestCountBin(HonestCountBin const&) = default; + HDI HonestCountBin(int x_train, int x_averaging) : CountBin(x_train), x_averaging(x_averaging) {} + HDI HonestCountBin() : CountBin(), x_averaging(0) {} + + DI static void IncrementHistogram(HonestCountBin* hist, int n_bins, int b, int label, bool is_split_row, bool /*is_split_op*/) + { + auto offset = label * n_bins + b; + if (is_split_row) { + atomicAdd(&(hist + offset)->x, {1}); + } else { + atomicAdd(&(hist + offset)->x_averaging, {1}); + } + } + + DI static void AtomicAdd(HonestCountBin* address, HonestCountBin val) + { + atomicAdd(&address->x, val.x); + atomicAdd(&address->x_averaging, val.x_averaging); + } + HDI HonestCountBin& operator+=(const HonestCountBin& b) + { + CountBin::operator+=(b); + x_averaging += b.x_averaging; + return *this; + } + HDI HonestCountBin operator+(HonestCountBin b) const + { + b += *this; + return b; + } +}; + struct AggregateBin { double label_sum; int count; @@ -51,7 +86,7 @@ struct AggregateBin { HDI AggregateBin() : label_sum(0.0), count(0) {} HDI AggregateBin(double label_sum, int count) : label_sum(label_sum), count(count) {} - DI static void IncrementHistogram(AggregateBin* hist, int n_bins, int b, double label) + DI static void IncrementHistogram(AggregateBin* hist, int n_bins, int b, double label, bool, bool) { AggregateBin::AtomicAdd(hist + b, {label, 1}); } @@ -72,5 +107,45 @@ struct AggregateBin { return b; } }; + +struct HonestAggregateBin : AggregateBin { + int count_averaging; + + HonestAggregateBin(HonestAggregateBin const&) = default; + HDI HonestAggregateBin() : AggregateBin(), count_averaging(0) {} + HDI HonestAggregateBin(double label_sum, int count, int count_averaging) + : AggregateBin(label_sum, count), count_averaging(count_averaging) + {} + + DI static void IncrementHistogram( + HonestAggregateBin* hist, int n_bins, int b, double label, bool is_split_row, bool is_split_op) + { + HonestAggregateBin* address = hist + b; + const int train_incr = static_cast(is_split_row); + const int avg_incr = 1 - train_incr; + + // Either split row and split op, or neither. Otherwise no increment of the label. + const double label_incr = not (is_split_row xor is_split_op) ? label : 0.0; + HonestAggregateBin::AtomicAdd(address, {label_incr, train_incr, avg_incr}); + } + DI static void AtomicAdd(HonestAggregateBin* address, HonestAggregateBin val) + { + atomicAdd(&address->label_sum, val.label_sum); + atomicAdd(&address->count, val.count); + atomicAdd(&address->count_averaging, val.count_averaging); + } + HDI HonestAggregateBin& operator+=(const HonestAggregateBin& b) + { + AggregateBin::operator+=(b); + count_averaging += b.count_averaging; + return *this; + } + HDI HonestAggregateBin operator+(HonestAggregateBin b) const + { + b += *this; + return b; + } +}; + } // namespace DT } // namespace ML \ No newline at end of file diff --git a/cpp/src/decisiontree/batched-levelalgo/builder.cuh b/cpp/src/decisiontree/batched-levelalgo/builder.cuh index 453110c099..0cc1b6f51f 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder.cuh @@ -47,7 +47,7 @@ class NodeQueue { std::deque work_items_; public: - NodeQueue(DecisionTreeParams params, size_t max_nodes, size_t sampled_rows, int num_outputs) + NodeQueue(DecisionTreeParams params, size_t max_nodes, size_t sampled_rows, size_t avg_row_count, int num_outputs) : params(params), tree(std::make_shared>()) { tree->num_outputs = num_outputs; @@ -56,8 +56,8 @@ class NodeQueue { tree->leaf_counter = 1; tree->depth_counter = 0; node_instances_.reserve(max_nodes); - node_instances_.emplace_back(InstanceRange{0, sampled_rows}); - if (this->IsExpandable(tree->sparsetree.back(), 0)) { + node_instances_.emplace_back(InstanceRange{0, sampled_rows, avg_row_count}); + if (this->IsExpandable(node_instances_.back(), 0)) { work_items_.emplace_back(NodeWorkItem{0, 0, node_instances_.back()}); } } @@ -79,16 +79,19 @@ class NodeQueue { } // This node is allowed to be expanded further (if its split gain is high enough) - bool IsExpandable(const NodeT& n, int depth) + bool IsExpandable(const InstanceRange& instance, int depth) { if (depth >= params.max_depth) return false; - if (int(n.InstanceCount()) < params.min_samples_split) return false; + + int nTrain = int(instance.count - instance.avg_count); + if (nTrain < params.min_samples_split_splitting) return false; + if (params.oob_honesty and instance.avg_count < params.min_samples_split_averaging) return false; if (params.max_leaves != -1 && tree->leaf_counter >= params.max_leaves) return false; return true; } - template - void Push(const std::vector& work_items, SplitT* h_splits) + template + void Push(const std::vector& work_items, SplitT* h_splits, DatasetT& dataset) { // Update node queue based on splits for (std::size_t i = 0; i < work_items.size(); i++) { @@ -96,12 +99,11 @@ class NodeQueue { auto item = work_items[i]; auto parent_range = node_instances_.at(item.idx); if (SplitNotValid( - split, params.min_impurity_decrease, params.min_samples_leaf, parent_range.count)) { + split, params.min_impurity_decrease, params.min_samples_leaf_splitting, params.min_samples_leaf_averaging, parent_range.count, params.oob_honesty)) { continue; } if (params.max_leaves != -1 && tree->leaf_counter >= params.max_leaves) break; - // parent tree->sparsetree.at(item.idx) = NodeT::CreateSplitNode(split.colid, split.quesval, @@ -111,10 +113,10 @@ class NodeQueue { tree->leaf_counter++; // left tree->sparsetree.emplace_back(NodeT::CreateLeafNode(split.nLeft)); - node_instances_.emplace_back(InstanceRange{parent_range.begin, std::size_t(split.nLeft)}); - + node_instances_.emplace_back(InstanceRange{parent_range.begin, std::size_t(split.nLeft), std::size_t(split.nLeftAveraging)}); + // Do not add a work item if this child is definitely a leaf - if (this->IsExpandable(tree->sparsetree.back(), item.depth + 1)) { + if (this->IsExpandable(node_instances_.back(), item.depth + 1)) { work_items_.emplace_back( NodeWorkItem{tree->sparsetree.size() - 1, item.depth + 1, node_instances_.back()}); } @@ -122,13 +124,13 @@ class NodeQueue { // right tree->sparsetree.emplace_back(NodeT::CreateLeafNode(parent_range.count - split.nLeft)); node_instances_.emplace_back( - InstanceRange{parent_range.begin + split.nLeft, parent_range.count - split.nLeft}); - + InstanceRange{parent_range.begin + split.nLeft, parent_range.count - split.nLeft, std::size_t(split.nRightAveraging)}); + // Do not add a work item if this child is definitely a leaf - if (this->IsExpandable(tree->sparsetree.back(), item.depth + 1)) { + if (this->IsExpandable(node_instances_.back(), item.depth + 1)) { work_items_.emplace_back( NodeWorkItem{tree->sparsetree.size() - 1, item.depth + 1, node_instances_.back()}); - } + } // update depth tree->depth_counter = max(tree->depth_counter, item.depth + 1); @@ -201,33 +203,21 @@ struct Builder { IdxT treeid, uint64_t seed, const DecisionTreeParams& p, - const DataT* data, - const LabelT* labels, - IdxT n_rows, - IdxT n_cols, - rmm::device_uvector* row_ids, - IdxT n_classes, + DatasetT& dataset, const QuantilesT& q) : handle(handle), builder_stream(s), treeid(treeid), seed(seed), params(p), - dataset{data, - labels, - n_rows, - n_cols, - int(row_ids->size()), - max(1, IdxT(params.max_features * n_cols)), - row_ids->data(), - n_classes}, + dataset(dataset), quantiles(q), d_buff(0, builder_stream) { max_blocks_dimx = 1 + params.max_batch_size + dataset.n_sampled_rows / TPB_DEFAULT; ASSERT(q.quantiles_array != nullptr && q.n_bins_array != nullptr, "Currently quantiles need to be computed before this call!"); - ASSERT(n_classes >= 1, "n_classes should be at least 1"); + ASSERT(dataset.num_outputs >= 1, "num outputs should be at least 1"); auto [device_workspace_size, host_workspace_size] = workspaceSize(); d_buff.resize(device_workspace_size, builder_stream); @@ -346,11 +336,11 @@ struct Builder { raft::common::nvtx::range fun_scope("Builder::train @builder.cuh [batched-levelalgo]"); MLCommon::TimerCPU timer; NodeQueue queue( - params, this->maxNodes(), dataset.n_sampled_rows, dataset.num_outputs); + params, this->maxNodes(), dataset.n_sampled_rows, dataset.n_avg_rows, dataset.num_outputs); while (queue.HasWork()) { auto work_items = queue.Pop(); auto [splits_host_ptr, splits_count] = doSplit(work_items); - queue.Push(work_items, splits_host_ptr); + queue.Push(work_items, splits_host_ptr, dataset); } auto tree = queue.GetTree(); this->SetLeafPredictions(tree, queue.GetInstanceRanges()); @@ -473,14 +463,13 @@ struct Builder { computeSplit(c, n_blocks_dimx, n_large_nodes); RAFT_CUDA_TRY(cudaPeekAtLastError()); } - // create child nodes (or make the current ones leaf) auto smem_size = 2 * sizeof(IdxT) * TPB_DEFAULT; raft::common::nvtx::push_range("nodeSplitKernel @builder.cuh [batched-levelalgo]"); - nodeSplitKernel + nodeSplitKernel <<>>(params.max_depth, - params.min_samples_leaf, - params.min_samples_split, + params.min_samples_leaf_splitting, + params.min_samples_leaf_averaging, params.max_leaves, params.min_impurity_decrease, dataset, @@ -528,14 +517,13 @@ struct Builder { int len_histograms = n_bins * n_classes * n_blocks_dimy * n_large_nodes; RAFT_CUDA_TRY(cudaMemsetAsync(histograms, 0, sizeof(BinT) * len_histograms, builder_stream)); // create the objective function object - ObjectiveT objective(dataset.num_outputs, params.min_samples_leaf); + ObjectiveT objective(dataset.num_outputs, params.min_samples_leaf_splitting, params.min_samples_leaf_averaging); // call the computeSplitKernel raft::common::nvtx::range kernel_scope("computeSplitKernel @builder.cuh [batched-levelalgo]"); computeSplitKernel <<>>(histograms, params.max_n_bins, params.max_depth, - params.min_samples_split, params.max_leaves, dataset, quantiles, @@ -564,7 +552,7 @@ struct Builder { rmm::device_uvector d_instance_ranges(max_batch_size, builder_stream); rmm::device_uvector d_leaves(max_batch_size * dataset.num_outputs, builder_stream); - ObjectiveT objective(dataset.num_outputs, params.min_samples_leaf); + ObjectiveT objective(dataset.num_outputs, params.min_samples_leaf_splitting, params.min_samples_leaf_averaging); for (std::size_t batch_begin = 0; batch_begin < tree->sparsetree.size(); batch_begin += max_batch_size) { std::size_t batch_end = min(batch_begin + max_batch_size, tree->sparsetree.size()); diff --git a/cpp/src/decisiontree/batched-levelalgo/dataset.h b/cpp/src/decisiontree/batched-levelalgo/dataset.h index aad9c8c3c8..b3a2f13276 100644 --- a/cpp/src/decisiontree/batched-levelalgo/dataset.h +++ b/cpp/src/decisiontree/batched-levelalgo/dataset.h @@ -35,6 +35,13 @@ struct Dataset { IdxT n_sampled_cols; /** indices of sampled rows */ IdxT* row_ids; + /** n_avg_rows */ + IdxT n_avg_rows; + + /* boolean mask indicating whether a given row id should be used for splitting. + * Unused / nullptr if oob_honesty is disabled.*/ + bool* split_row_mask; + /** Number of classes or regression outputs*/ IdxT num_outputs; }; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels.cuh index 7daf5341b7..0b524df4f8 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels.cuh @@ -32,6 +32,7 @@ namespace DT { struct InstanceRange { std::size_t begin; std::size_t count; + std::size_t avg_count; }; struct NodeWorkItem { @@ -57,11 +58,16 @@ struct WorkloadInfo { template HDI bool SplitNotValid(const SplitT& split, DataT min_impurity_decrease, - IdxT min_samples_leaf, - std::size_t num_rows) + IdxT min_samples_leaf_splitting, + IdxT min_samples_leaf_averaging, + std::size_t num_rows, + const bool oob_honesty) { - return split.best_metric_val <= min_impurity_decrease || split.nLeft < min_samples_leaf || - (IdxT(num_rows) - split.nLeft) < min_samples_leaf; + const int nLeftSplitting = split.nLeft - split.nLeftAveraging; + const int nRightSplitting = num_rows - split.nLeft - split.nRightAveraging; + return split.best_metric_val <= min_impurity_decrease || + nLeftSplitting < min_samples_leaf_splitting || nRightSplitting < min_samples_leaf_splitting || + (oob_honesty and (split.nLeftAveraging < min_samples_leaf_averaging || split.nRightAveraging < min_samples_leaf_averaging)); } /* Returns 'dataset' rounded up to a correctly-aligned pointer of type OutT* */ @@ -71,7 +77,7 @@ DI OutT* alignPointer(InT dataset) return reinterpret_cast(raft::alignTo(reinterpret_cast(dataset), sizeof(OutT))); } -template +template __global__ void nodeSplitKernel(const IdxT max_depth, const IdxT min_samples_leaf, const IdxT min_samples_split, @@ -347,7 +353,6 @@ template dataset, const Quantiles quantiles, diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels_impl.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels_impl.cuh index 6e6e526c78..775cbcaf5f 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels_impl.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels_impl.cuh @@ -36,15 +36,19 @@ static constexpr int TPB_DEFAULT = 128; * @note this should be called by only one block from all participating blocks * 'smem' should be at least of size `sizeof(IdxT) * TPB * 2` */ -template +template DI void partitionSamples(const Dataset& dataset, const Split& split, const NodeWorkItem& work_item, - char* smem) + char* smem, + int nLeftAvg, + int nRightAvg) { typedef cub::BlockScan BlockScanT; __shared__ typename BlockScanT::TempStorage temp1, temp2; - volatile auto* row_ids = reinterpret_cast(dataset.row_ids); + IdxT* row_ids = dataset.row_ids; + bool* split_row_mask = dataset.split_row_mask; + // for compaction size_t smemSize = sizeof(IdxT) * TPB; auto* lcomp = reinterpret_cast(smem); @@ -56,39 +60,64 @@ DI void partitionSamples(const Dataset& dataset, auto end = range_start + range_len; int lflag = 0, rflag = 0, llen = 0, rlen = 0, minlen = 0; auto tid = threadIdx.x; + int lidx, ridx; + while (loffset < part && roffset < end) { // find the samples in the left that belong to right and vice-versa + // Also scan to compute the locations for each 'misfit' in the two partitions auto loff = loffset + tid, roff = roffset + tid; - if (llen == minlen) lflag = loff < part ? col[row_ids[loff]] > split.quesval : 0; - if (rlen == minlen) rflag = roff < end ? col[row_ids[roff]] <= split.quesval : 0; - // scan to compute the locations for each 'misfit' in the two partitions - int lidx, ridx; - BlockScanT(temp1).ExclusiveSum(lflag, lidx, llen); - BlockScanT(temp2).ExclusiveSum(rflag, ridx, rlen); - __syncthreads(); + if (llen == minlen) { + lflag = loff < part ? col[row_ids[loff]] > split.quesval : 0; + BlockScanT(temp1).ExclusiveSum(lflag, lidx, llen); + } else { + lidx -= minlen; + llen -= minlen; + } + if (rlen == minlen) { + rflag = roff < end ? col[row_ids[roff]] <= split.quesval : 0; + BlockScanT(temp2).ExclusiveSum(rflag, ridx, rlen); + } else { + ridx -= minlen; + rlen -= minlen; + } + minlen = llen < rlen ? llen : rlen; // compaction to figure out the right locations to swap if (lflag) lcomp[lidx] = loff; if (rflag) rcomp[ridx] = roff; - __syncthreads(); // reset the appropriate flags for the longer of the two if (lidx < minlen) lflag = 0; if (ridx < minlen) rflag = 0; if (llen == minlen) loffset += TPB; if (rlen == minlen) roffset += TPB; + + __syncthreads(); + // swap the 'misfit's if (tid < minlen) { auto a = row_ids[lcomp[tid]]; auto b = row_ids[rcomp[tid]]; + bool c,d; + if (oob_honesty) { + // Also swap the associated training row mask flags + // This is the same dimension as row ids + c = split_row_mask[lcomp[tid]]; + d = split_row_mask[rcomp[tid]]; + } + row_ids[lcomp[tid]] = b; row_ids[rcomp[tid]] = a; + if (oob_honesty) { + split_row_mask[lcomp[tid]] = d; + split_row_mask[rcomp[tid]] = c; + } } - } + } } -template +template __global__ void nodeSplitKernel(const IdxT max_depth, - const IdxT min_samples_leaf, - const IdxT min_samples_split, + const IdxT min_samples_leaf_splitting, + const IdxT min_samples_leaf_averaging, const IdxT max_leaves, const DataT min_impurity_decrease, const Dataset dataset, @@ -99,10 +128,10 @@ __global__ void nodeSplitKernel(const IdxT max_depth, const auto work_item = work_items[blockIdx.x]; const auto split = splits[blockIdx.x]; if (SplitNotValid( - split, min_impurity_decrease, min_samples_leaf, IdxT(work_item.instances.count))) { + split, min_impurity_decrease, min_samples_leaf_splitting, min_samples_leaf_averaging, IdxT(work_item.instances.count), oob_honesty)) { return; } - partitionSamples(dataset, split, work_item, (char*)smem); + partitionSamples(dataset, split, work_item, (char*)smem, split.nLeftAveraging, split.nRightAveraging); } template @@ -124,9 +153,12 @@ __global__ void leafKernel(ObjectiveT objective, histogram[i] = BinT(); } __syncthreads(); + for (auto i = range.begin + tid; i < range.begin + range.count; i += blockDim.x) { auto label = dataset.labels[dataset.row_ids[i]]; - BinT::IncrementHistogram(histogram, 1, 0, label); + auto is_split_row = dataset.split_row_mask == nullptr ? true : dataset.split_row_mask[i]; + + BinT::IncrementHistogram(histogram, 1, 0, label, is_split_row, false /* is split op*/); } __syncthreads(); if (tid == 0) { @@ -174,7 +206,6 @@ template dataset, const Quantiles quantiles, @@ -202,6 +233,7 @@ __global__ void computeSplitKernel(BinT* histograms, IdxT offset_blockid = workload_info_cta.offset_blockid; IdxT num_blocks = workload_info_cta.num_blocks; + // obtaining the feature to test split on IdxT col; @@ -236,8 +268,11 @@ __global__ void computeSplitKernel(BinT* histograms, // Must be 64 bit - can easily grow larger than a 32 bit int std::size_t col_offset = std::size_t(col) * dataset.M; + for (auto i = range_start + tid; i < end; i += stride) { + // each thread works over a data point and strides to the next + auto is_split_row = dataset.split_row_mask == nullptr ? true : dataset.split_row_mask[i]; auto row = dataset.row_ids[i]; auto data = dataset.data[row + col_offset]; auto label = dataset.labels[row]; @@ -245,7 +280,7 @@ __global__ void computeSplitKernel(BinT* histograms, // `start` is lowest index such that data <= shared_quantiles[start] IdxT start = lower_bound(shared_quantiles, n_bins, data); // ++shared_histogram[start] - BinT::IncrementHistogram(shared_histogram, n_bins, start, label); + BinT::IncrementHistogram(shared_histogram, n_bins, start, label, is_split_row, true /*is split op*/); } // synchronizing above changes across block @@ -288,7 +323,7 @@ __global__ void computeSplitKernel(BinT* histograms, // calculate the best candidate bins (one for each thread in the block) in current feature and // corresponding information gain for splitting Split sp = - objective.Gain(shared_histogram, shared_quantiles, col, range_len, n_bins); + objective.Gain(shared_histogram, shared_quantiles, col, range_len, work_item.instances.avg_count, n_bins); __syncthreads(); @@ -299,7 +334,7 @@ __global__ void computeSplitKernel(BinT* histograms, } // partial template instantiation to avoid code-duplication -template __global__ void nodeSplitKernel<_DataT, _LabelT, _IdxT, TPB_DEFAULT>( +template __global__ void nodeSplitKernel<_DataT, _LabelT, _IdxT, TPB_DEFAULT, false>( const _IdxT max_depth, const _IdxT min_samples_leaf, const _IdxT min_samples_split, @@ -308,6 +343,15 @@ template __global__ void nodeSplitKernel<_DataT, _LabelT, _IdxT, TPB_DEFAULT>( const Dataset<_DataT, _LabelT, _IdxT> dataset, const NodeWorkItem* work_items, const Split<_DataT, _IdxT>* splits); +template __global__ void nodeSplitKernel<_DataT, _LabelT, _IdxT, TPB_DEFAULT, true>( + const _IdxT max_depth, + const _IdxT min_samples_leaf, + const _IdxT min_samples_split, + const _IdxT max_leaves, + const _DataT min_impurity_decrease, + const Dataset<_DataT, _LabelT, _IdxT> dataset, + const NodeWorkItem* work_items, + const Split<_DataT, _IdxT>* splits); template __global__ void leafKernel<_DatasetT, _NodeT, _ObjectiveT, _DataT>( _ObjectiveT objective, @@ -320,7 +364,6 @@ computeSplitKernel<_DataT, _LabelT, _IdxT, TPB_DEFAULT, _ObjectiveT, _BinT>( _BinT* histograms, _IdxT n_bins, _IdxT max_depth, - _IdxT min_samples_split, _IdxT max_leaves, const Dataset<_DataT, _LabelT, _IdxT> dataset, const Quantiles<_DataT, _IdxT> quantiles, diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double-honest.cu new file mode 100644 index 0000000000..0356a31232 --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = double; +using _LabelT = int; +using _IdxT = int; +using _ObjectiveT = EntropyObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestCountBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double.cu index 1b37a03503..ec59dda0ef 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-double.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = double; using _LabelT = int; using _IdxT = int; -using _ObjectiveT = EntropyObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = EntropyObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = CountBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float-honest.cu new file mode 100644 index 0000000000..31bf7393b7 --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = float; +using _LabelT = int; +using _IdxT = int; +using _ObjectiveT = EntropyObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestCountBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float.cu index 618e03fef5..562f6345c1 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/entropy-float.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = float; using _LabelT = int; using _IdxT = int; -using _ObjectiveT = EntropyObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = EntropyObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = CountBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double-honest.cu new file mode 100644 index 0000000000..cdc793d60a --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = double; +using _LabelT = double; +using _IdxT = int; +using _ObjectiveT = GammaObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double.cu index 60c35bb9b4..0a69ebfd6a 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-double.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = double; using _LabelT = double; using _IdxT = int; -using _ObjectiveT = GammaObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = GammaObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float-honest.cu new file mode 100644 index 0000000000..9a0394661e --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = float; +using _LabelT = float; +using _IdxT = int; +using _ObjectiveT = GammaObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float.cu index e3c6630ede..d7adace0e8 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gamma-float.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = float; using _LabelT = float; using _IdxT = int; -using _ObjectiveT = GammaObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = GammaObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gini-double-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-double-honest.cu new file mode 100644 index 0000000000..40db2e4ee7 --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-double-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = double; +using _LabelT = int; +using _IdxT = int; +using _ObjectiveT = GiniObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestCountBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gini-double.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-double.cu index 3666e48f7d..344013b6e2 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/gini-double.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-double.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = double; using _LabelT = int; using _IdxT = int; -using _ObjectiveT = GiniObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = GiniObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = CountBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gini-float-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-float-honest.cu new file mode 100644 index 0000000000..0233ea7acf --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-float-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = float; +using _LabelT = int; +using _IdxT = int; +using _ObjectiveT = GiniObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestCountBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/gini-float.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-float.cu index 0abd98756f..a50c70dd92 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/gini-float.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/gini-float.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = float; using _LabelT = int; using _IdxT = int; -using _ObjectiveT = GiniObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = GiniObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = CountBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double-honest.cu new file mode 100644 index 0000000000..1c0014e462 --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = double; +using _LabelT = double; +using _IdxT = int; +using _ObjectiveT = InverseGaussianObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double.cu index a89f27a265..ecad5ed7a6 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-double.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = double; using _LabelT = double; using _IdxT = int; -using _ObjectiveT = InverseGaussianObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = InverseGaussianObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float-honest.cu new file mode 100644 index 0000000000..7c329f498a --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = float; +using _LabelT = float; +using _IdxT = int; +using _ObjectiveT = InverseGaussianObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float.cu index b2582ef14a..c3043e722d 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/inverse_gaussian-float.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = float; using _LabelT = float; using _IdxT = int; -using _ObjectiveT = InverseGaussianObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = InverseGaussianObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/mse-double-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-double-honest.cu new file mode 100644 index 0000000000..567d852d80 --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-double-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = double; +using _LabelT = double; +using _IdxT = int; +using _ObjectiveT = MSEObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/mse-double.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-double.cu index b9b5ec461b..51043409d0 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/mse-double.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-double.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = double; using _LabelT = double; using _IdxT = int; -using _ObjectiveT = MSEObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = MSEObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/mse-float-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-float-honest.cu new file mode 100644 index 0000000000..c1d9469be2 --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-float-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = float; +using _LabelT = float; +using _IdxT = int; +using _ObjectiveT = MSEObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/mse-float.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-float.cu index ca1f36425d..27a43a1fab 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/mse-float.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/mse-float.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = float; using _LabelT = float; using _IdxT = int; -using _ObjectiveT = MSEObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = MSEObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double-honest.cu new file mode 100644 index 0000000000..2072e0ebcb --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = double; +using _LabelT = double; +using _IdxT = int; +using _ObjectiveT = PoissonObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double.cu index fd15ceffda..5041f574ee 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-double.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = double; using _LabelT = double; using _IdxT = int; -using _ObjectiveT = PoissonObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = PoissonObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float-honest.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float-honest.cu new file mode 100644 index 0000000000..c1147e5735 --- /dev/null +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float-honest.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../bins.cuh" +#include "../objectives.cuh" + +namespace ML { +namespace DT { +using _DataT = float; +using _LabelT = float; +using _IdxT = int; +using _ObjectiveT = PoissonObjectiveFunction<_DataT, _LabelT, _IdxT, true /*oob_honesty*/>; +using _BinT = HonestAggregateBin; +using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; +using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; +} // namespace DT +} // namespace ML + +#include "builder_kernels_impl.cuh" diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float.cu b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float.cu index db3f9114ce..d5bb073c76 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float.cu +++ b/cpp/src/decisiontree/batched-levelalgo/kernels/poisson-float.cu @@ -24,7 +24,7 @@ namespace DT { using _DataT = float; using _LabelT = float; using _IdxT = int; -using _ObjectiveT = PoissonObjectiveFunction<_DataT, _LabelT, _IdxT>; +using _ObjectiveT = PoissonObjectiveFunction<_DataT, _LabelT, _IdxT, false /*oob_honesty*/>; using _BinT = AggregateBin; using _DatasetT = Dataset<_DataT, _LabelT, _IdxT>; using _NodeT = SparseTreeNode<_DataT, _LabelT, _IdxT>; diff --git a/cpp/src/decisiontree/batched-levelalgo/objectives.cuh b/cpp/src/decisiontree/batched-levelalgo/objectives.cuh index 59b44b3619..d5f2797df6 100644 --- a/cpp/src/decisiontree/batched-levelalgo/objectives.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/objectives.cuh @@ -24,40 +24,50 @@ namespace ML { namespace DT { -template +template class GiniObjectiveFunction { public: using DataT = DataT_; using LabelT = LabelT_; using IdxT = IdxT_; + static const bool oob_honesty = oob_honesty_; private: IdxT nclasses; - IdxT min_samples_leaf; + IdxT min_samples_leaf_splitting; + IdxT min_samples_leaf_averaging; public: - using BinT = CountBin; - GiniObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) - : nclasses(nclasses), min_samples_leaf(min_samples_leaf) - { - } + using BinT = std::conditional_t; + GiniObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf_splitting, IdxT min_samples_leaf_averaging) + : nclasses(nclasses), + min_samples_leaf_splitting(min_samples_leaf_splitting), + min_samples_leaf_averaging(min_samples_leaf_averaging) + {} DI IdxT NumClasses() const { return nclasses; } /** * @brief compute the gini impurity reduction for each split */ - HDI DataT GainPerSplit(BinT* hist, IdxT i, IdxT n_bins, IdxT len, IdxT nLeft) + HDI DataT GainPerSplit( + BinT* hist, IdxT i, IdxT n_bins, IdxT nLeftSplitting, + IdxT nLeftAveraging, IdxT trainingLen, IdxT nRightAveraging) { - IdxT nRight = len - nLeft; - constexpr DataT One = DataT(1.0); - auto invLen = One / len; - auto invLeft = One / nLeft; - auto invRight = One / nRight; - auto gain = DataT(0.0); + constexpr DataT One = DataT(1.0); + auto invLen = One / trainingLen; + auto invLeft = One / nLeftSplitting; + auto nRightSplitting = trainingLen - nLeftSplitting; + auto invRight = One / nRightSplitting; + auto gain = DataT(0.0); // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) + if constexpr (oob_honesty) { + if (nLeftAveraging < min_samples_leaf_averaging || nRightAveraging < min_samples_leaf_averaging) + return -std::numeric_limits::max(); + } + + if (nLeftSplitting < min_samples_leaf_splitting || nRightSplitting < min_samples_leaf_splitting) return -std::numeric_limits::max(); for (IdxT j = 0; j < nclasses; ++j) { @@ -80,100 +90,147 @@ class GiniObjectiveFunction { return gain; } - DI Split Gain(BinT* shist, DataT* squantiles, IdxT col, IdxT len, IdxT n_bins) + DI Split Gain(BinT* shist, DataT* squantiles, IdxT col, IdxT len, IdxT avg_len, IdxT n_bins) { Split sp; for (IdxT i = threadIdx.x; i < n_bins; i += blockDim.x) { - IdxT nLeft = 0; + IdxT nLeftSplitting = 0; + IdxT nLeftAveraging = 0; + IdxT nRightAveraging = 0; for (IdxT j = 0; j < nclasses; ++j) { - nLeft += shist[n_bins * j + i].x; + nLeftSplitting += shist[n_bins * j + i].x; + if constexpr (oob_honesty) { + nLeftAveraging += shist[n_bins * j + i].x_averaging; + } } - sp.update({squantiles[i], col, GainPerSplit(shist, i, n_bins, len, nLeft), nLeft}); + + if constexpr (oob_honesty) { + nRightAveraging = avg_len - nLeftAveraging; + } + + sp.update({ + squantiles[i], col, + GainPerSplit(shist, i, n_bins, nLeftSplitting, nLeftAveraging, len - avg_len, nRightAveraging), + nLeftAveraging + nLeftSplitting, nLeftAveraging, nRightAveraging}); } return sp; } + static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { // Output probability int total = 0; for (int i = 0; i < nclasses; i++) { - total += shist[i].x; + if constexpr (oob_honesty) { + total += shist[i].x_averaging; + } else { + total += shist[i].x; + } } for (int i = 0; i < nclasses; i++) { - out[i] = DataT(shist[i].x) / total; + if constexpr (oob_honesty) { + out[i] = DataT(shist[i].x_averaging) / total; + } else { + out[i] = DataT(shist[i].x) / total; + } } } }; -template +template class EntropyObjectiveFunction { public: using DataT = DataT_; using LabelT = LabelT_; using IdxT = IdxT_; - + static const bool oob_honesty = oob_honesty_; + private: IdxT nclasses; - IdxT min_samples_leaf; + IdxT min_samples_leaf_splitting; + IdxT min_samples_leaf_averaging; public: - using BinT = CountBin; - EntropyObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) - : nclasses(nclasses), min_samples_leaf(min_samples_leaf) - { - } + using BinT = std::conditional_t; + + EntropyObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf_splitting, IdxT min_samples_leaf_averaging) + : nclasses(nclasses), + min_samples_leaf_splitting(min_samples_leaf_splitting), + min_samples_leaf_averaging(min_samples_leaf_averaging) + {} + DI IdxT NumClasses() const { return nclasses; } /** * @brief compute the Entropy (or information gain) for each split */ - HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT n_bins, IdxT len, IdxT nLeft) + HDI DataT GainPerSplit( + BinT* hist, IdxT i, IdxT n_bins, IdxT nLeftSplitting, + IdxT nLeftAveraging, IdxT lenTraining, IdxT nRightAveraging) { - IdxT nRight{len - nLeft}; - auto gain{DataT(0.0)}; + const auto nRightSplitting = lenTraining - nLeftSplitting; + // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { - return -std::numeric_limits::max(); - } else { - auto invLeft{DataT(1.0) / nLeft}; - auto invRight{DataT(1.0) / nRight}; - auto invLen{DataT(1.0) / len}; - for (IdxT c = 0; c < nclasses; ++c) { - int val_i = 0; - auto lval_i = hist[n_bins * c + i].x; - if (lval_i != 0) { - auto lval = DataT(lval_i); - gain += raft::myLog(lval * invLeft) / raft::myLog(DataT(2)) * lval * invLen; - } + if constexpr (oob_honesty) { + if (nLeftAveraging < min_samples_leaf_averaging || nRightAveraging < min_samples_leaf_averaging) + return -std::numeric_limits::max(); + } - val_i += lval_i; - auto total_sum = hist[n_bins * c + n_bins - 1].x; - auto rval_i = total_sum - lval_i; - if (rval_i != 0) { - auto rval = DataT(rval_i); - gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval * invLen; - } + if (nLeftSplitting < min_samples_leaf_splitting || nRightSplitting < min_samples_leaf_splitting) + return -std::numeric_limits::max(); + + auto gain{DataT(0.0)}; + auto invLeft{DataT(1.0) / nLeftSplitting}; + auto invRight{DataT(1.0) / nRightSplitting}; + auto invLen{DataT(1.0) / lenTraining}; + for (IdxT c = 0; c < nclasses; ++c) { + int val_i = 0; + auto lval_i = hist[n_bins * c + i].x; + if (lval_i != 0) { + auto lval = DataT(lval_i); + gain += raft::myLog(lval * invLeft) / raft::myLog(DataT(2)) * lval * invLen; + } - val_i += rval_i; - if (val_i != 0) { - auto val = DataT(val_i) * invLen; - gain -= val * raft::myLog(val) / raft::myLog(DataT(2)); - } + val_i += lval_i; + auto total_sum = hist[n_bins * c + n_bins - 1].x; + auto rval_i = total_sum - lval_i; + if (rval_i != 0) { + auto rval = DataT(rval_i); + gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval * invLen; } - return gain; + val_i += rval_i; + if (val_i != 0) { + auto val = DataT(val_i) * invLen; + gain -= val * raft::myLog(val) / raft::myLog(DataT(2)); + } } + + return gain; } - DI Split Gain(BinT* scdf_labels, DataT* squantiles, IdxT col, IdxT len, IdxT n_bins) - { + DI Split Gain(BinT* shist, DataT* squantiles, IdxT col, IdxT len, IdxT avg_len, IdxT n_bins) + { Split sp; for (IdxT i = threadIdx.x; i < n_bins; i += blockDim.x) { - IdxT nLeft = 0; + IdxT nLeftSplitting = 0; + IdxT nLeftAveraging = 0; + IdxT nRightAveraging = 0; for (IdxT j = 0; j < nclasses; ++j) { - nLeft += scdf_labels[n_bins * j + i].x; + nLeftSplitting += shist[n_bins * j + i].x; + if constexpr (oob_honesty) { + nLeftAveraging += shist[n_bins * j + i].x_averaging; + } } - sp.update({squantiles[i], col, GainPerSplit(scdf_labels, i, n_bins, len, nLeft), nLeft}); + + if constexpr (oob_honesty) { + nRightAveraging = avg_len - nLeftAveraging; + } + + sp.update({ + squantiles[i], col, + GainPerSplit(shist, i, n_bins, nLeftSplitting, nLeftAveraging, len - avg_len, nRightAveraging), + nLeftAveraging + nLeftSplitting, nLeftAveraging, nRightAveraging}); } return sp; } @@ -182,30 +239,42 @@ class EntropyObjectiveFunction { // Output probability int total = 0; for (int i = 0; i < nclasses; i++) { - total += shist[i].x; + if constexpr (oob_honesty) { + total += shist[i].x_averaging; + } else { + total += shist[i].x; + } } for (int i = 0; i < nclasses; i++) { - out[i] = DataT(shist[i].x) / total; + if constexpr (oob_honesty) { + out[i] = DataT(shist[i].x_averaging) / total; + } else { + out[i] = DataT(shist[i].x) / total; + + } } } }; -template +template class MSEObjectiveFunction { public: using DataT = DataT_; using LabelT = LabelT_; using IdxT = IdxT_; - using BinT = AggregateBin; + static const bool oob_honesty = oob_honesty_; + + using BinT = std::conditional_t; private: - IdxT min_samples_leaf; + IdxT min_samples_leaf_splitting; + IdxT min_samples_leaf_averaging; public: - HDI MSEObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) - : min_samples_leaf(min_samples_leaf) - { - } + HDI MSEObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf_splitting, IdxT min_samples_leaf_averaging) + : min_samples_leaf_splitting(min_samples_leaf_splitting), + min_samples_leaf_averaging(min_samples_leaf_averaging) + {} /** * @brief compute the Mean squared error impurity reduction (or purity gain) for each split @@ -220,34 +289,51 @@ class MSEObjectiveFunction { * and is mathematically equivalent to the respective differences of * mean-squared errors. */ - HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT n_bins, IdxT len, IdxT nLeft) const + HDI DataT GainPerSplit( + BinT const* hist, IdxT i, IdxT n_bins, IdxT nLeftSplitting, + IdxT nLeftAveraging, IdxT trainingLen, IdxT nRightAveraging) const { auto gain{DataT(0)}; - IdxT nRight{len - nLeft}; - auto invLen = DataT(1.0) / len; - // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { - return -std::numeric_limits::max(); - } else { - auto label_sum = hist[n_bins - 1].label_sum; - DataT parent_obj = -label_sum * label_sum * invLen; - DataT left_obj = -(hist[i].label_sum * hist[i].label_sum) / nLeft; - DataT right_label_sum = hist[i].label_sum - label_sum; - DataT right_obj = -(right_label_sum * right_label_sum) / nRight; - gain = parent_obj - (left_obj + right_obj); - gain *= DataT(0.5) * invLen; - - return gain; + IdxT nRightSplitting{trainingLen - nLeftSplitting}; + auto invLen = DataT(1.0) / trainingLen; + + if constexpr (oob_honesty) { + if (nLeftAveraging < min_samples_leaf_averaging || nRightAveraging < min_samples_leaf_averaging) + return -std::numeric_limits::max(); } + + if (nLeftSplitting < min_samples_leaf_splitting || nRightSplitting < min_samples_leaf_splitting) + return -std::numeric_limits::max(); + + auto label_sum = hist[n_bins - 1].label_sum; + DataT parent_obj = -label_sum * label_sum * invLen; + DataT left_obj = -(hist[i].label_sum * hist[i].label_sum) / nLeftSplitting; + DataT right_label_sum = hist[i].label_sum - label_sum; + DataT right_obj = -(right_label_sum * right_label_sum) / nRightSplitting; + gain = parent_obj - (left_obj + right_obj); + gain *= DataT(0.5) * invLen; + + return gain; } DI Split Gain( - BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT n_bins) const + BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT avg_len, IdxT n_bins) const { Split sp; for (IdxT i = threadIdx.x; i < n_bins; i += blockDim.x) { - auto nLeft = shist[i].count; - sp.update({squantiles[i], col, GainPerSplit(shist, i, n_bins, len, nLeft), nLeft}); + auto nLeftSplitting = shist[i].count; + + int nLeftAveraging = 0; + int nRightAveraging = 0; + if constexpr (oob_honesty) { + nLeftAveraging = shist[i].count_averaging; + nRightAveraging = avg_len - nLeftAveraging; + } + + sp.update({ + squantiles[i], col, + GainPerSplit(shist, i, n_bins, nLeftSplitting, nLeftAveraging, len - avg_len, nRightAveraging), + nLeftSplitting + nLeftAveraging, nLeftAveraging, nRightAveraging}); } return sp; } @@ -257,29 +343,36 @@ class MSEObjectiveFunction { static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { for (int i = 0; i < nclasses; i++) { - out[i] = shist[i].label_sum / shist[i].count; + if constexpr (oob_honesty) { + out[i] = shist[i].label_sum / shist[i].count_averaging; + } else { + out[i] = shist[i].label_sum / shist[i].count; + } } } }; -template +template class PoissonObjectiveFunction { public: using DataT = DataT_; using LabelT = LabelT_; using IdxT = IdxT_; - using BinT = AggregateBin; + static const bool oob_honesty = oob_honesty_; + + using BinT = std::conditional_t; private: - IdxT min_samples_leaf; + IdxT min_samples_leaf_splitting; + IdxT min_samples_leaf_averaging; public: static constexpr auto eps_ = 10 * std::numeric_limits::epsilon(); - HDI PoissonObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) - : min_samples_leaf(min_samples_leaf) - { - } + HDI PoissonObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf_splitting, IdxT min_samples_leaf_averaging) + : min_samples_leaf_splitting(min_samples_leaf_splitting), + min_samples_leaf_averaging(min_samples_leaf_averaging) + {} /** * @brief compute the poisson impurity reduction (or purity gain) for each split @@ -294,14 +387,21 @@ class PoissonObjectiveFunction { * and is mathematically equivalent to the respective differences of * poisson half deviances. */ - HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT n_bins, IdxT len, IdxT nLeft) const + HDI DataT GainPerSplit( + BinT const* hist, IdxT i, IdxT n_bins, IdxT nLeftSplitting, + IdxT nLeftAveraging, IdxT trainingLen, IdxT nRightAveraging) const { // get the lens' - IdxT nRight = len - nLeft; - auto invLen = DataT(1) / len; + IdxT nRightSplitting = trainingLen - nLeftSplitting; + auto invLen = DataT(1) / trainingLen; // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) + if constexpr (oob_honesty) { + if (nLeftAveraging < min_samples_leaf_averaging || nRightAveraging < min_samples_leaf_averaging) + return -std::numeric_limits::max(); + } + + if (nLeftSplitting < min_samples_leaf_splitting || nRightSplitting < min_samples_leaf_splitting) return -std::numeric_limits::max(); auto label_sum = hist[n_bins - 1].label_sum; @@ -314,8 +414,8 @@ class PoissonObjectiveFunction { // compute the gain to be DataT parent_obj = -label_sum * raft::myLog(label_sum * invLen); - DataT left_obj = -left_label_sum * raft::myLog(left_label_sum / nLeft); - DataT right_obj = -right_label_sum * raft::myLog(right_label_sum / nRight); + DataT left_obj = -left_label_sum * raft::myLog(left_label_sum / nLeftSplitting); + DataT right_obj = -right_label_sum * raft::myLog(right_label_sum / nRightSplitting); DataT gain = parent_obj - (left_obj + right_obj); gain = gain * invLen; @@ -323,12 +423,23 @@ class PoissonObjectiveFunction { } DI Split Gain( - BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT n_bins) const + BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT avg_len, IdxT n_bins) const { Split sp; for (IdxT i = threadIdx.x; i < n_bins; i += blockDim.x) { - auto nLeft = shist[i].count; - sp.update({squantiles[i], col, GainPerSplit(shist, i, n_bins, len, nLeft), nLeft}); + auto nLeftSplitting = shist[i].count; + + int nLeftAveraging = 0; + int nRightAveraging = 0; + if constexpr (oob_honesty) { + nLeftAveraging = shist[i].count_averaging; + nRightAveraging = avg_len - nLeftAveraging; + } + + sp.update({ + squantiles[i], col, + GainPerSplit(shist, i, n_bins, nLeftSplitting, nLeftAveraging, len - avg_len, nRightAveraging), + nLeftSplitting + nLeftAveraging, nLeftAveraging, nRightAveraging}); } return sp; } @@ -338,28 +449,35 @@ class PoissonObjectiveFunction { static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { for (int i = 0; i < nclasses; i++) { - out[i] = shist[i].label_sum / shist[i].count; + if constexpr (oob_honesty) { + out[i] = shist[i].label_sum / shist[i].count_averaging; + } else { + out[i] = shist[i].label_sum / shist[i].count; + } } } }; -template +template class GammaObjectiveFunction { public: using DataT = DataT_; using LabelT = LabelT_; using IdxT = IdxT_; - using BinT = AggregateBin; + static const bool oob_honesty = oob_honesty_; + + + using BinT = std::conditional_t; static constexpr auto eps_ = 10 * std::numeric_limits::epsilon(); - private: - IdxT min_samples_leaf; + IdxT min_samples_leaf_splitting; + IdxT min_samples_leaf_averaging; public: - HDI GammaObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) - : min_samples_leaf{min_samples_leaf} - { - } + HDI GammaObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf_splitting, IdxT min_samples_leaf_averaging) + : min_samples_leaf_splitting{min_samples_leaf_splitting}, + min_samples_leaf_averaging{min_samples_leaf_averaging} + {} /** * @brief compute the gamma impurity reduction (or purity gain) for each split @@ -374,15 +492,22 @@ class GammaObjectiveFunction { * and is mathematically equivalent to the respective differences of * gamma half deviances. */ - HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT n_bins, IdxT len, IdxT nLeft) const + HDI DataT GainPerSplit( + BinT const* hist, IdxT i, IdxT n_bins, IdxT nLeftSplitting, + IdxT nLeftAveraging, IdxT trainingLen, IdxT nRightAveraging) const { - IdxT nRight = len - nLeft; - auto invLen = DataT(1) / len; + IdxT nRightSplitting = trainingLen - nLeftSplitting; // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) + if constexpr (oob_honesty) { + if (nLeftAveraging < min_samples_leaf_averaging || nRightAveraging < min_samples_leaf_averaging) + return -std::numeric_limits::max(); + } + + if (nLeftSplitting < min_samples_leaf_splitting || nRightSplitting < min_samples_leaf_splitting) return -std::numeric_limits::max(); + auto invLen = DataT(1) / trainingLen; DataT label_sum = hist[n_bins - 1].label_sum; DataT left_label_sum = (hist[i].label_sum); DataT right_label_sum = (hist[n_bins - 1].label_sum - hist[i].label_sum); @@ -392,9 +517,9 @@ class GammaObjectiveFunction { return -std::numeric_limits::max(); // compute the gain to be - DataT parent_obj = len * raft::myLog(label_sum * invLen); - DataT left_obj = nLeft * raft::myLog(left_label_sum / nLeft); - DataT right_obj = nRight * raft::myLog(right_label_sum / nRight); + DataT parent_obj = trainingLen * raft::myLog(label_sum * invLen); + DataT left_obj = nLeftSplitting * raft::myLog(left_label_sum / nLeftSplitting); + DataT right_obj = nRightSplitting * raft::myLog(right_label_sum / nRightSplitting); DataT gain = parent_obj - (left_obj + right_obj); gain = gain * invLen; @@ -402,42 +527,62 @@ class GammaObjectiveFunction { } DI Split Gain( - BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT n_bins) const + BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT avg_len, IdxT n_bins) const { Split sp; for (IdxT i = threadIdx.x; i < n_bins; i += blockDim.x) { - auto nLeft = shist[i].count; - sp.update({squantiles[i], col, GainPerSplit(shist, i, n_bins, len, nLeft), nLeft}); + auto nLeftSplitting = shist[i].count; + + int nLeftAveraging = 0; + int nRightAveraging = 0; + if constexpr (oob_honesty) { + nLeftAveraging = shist[i].count_averaging; + nRightAveraging = avg_len - nLeftAveraging; + } + + sp.update({ + squantiles[i], col, + GainPerSplit(shist, i, n_bins, nLeftSplitting, nLeftAveraging, len - avg_len, nRightAveraging), + nLeftSplitting + nLeftAveraging, nLeftAveraging, nRightAveraging}); } return sp; } + DI IdxT NumClasses() const { return 1; } static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { for (int i = 0; i < nclasses; i++) { - out[i] = shist[i].label_sum / shist[i].count; + if constexpr (oob_honesty) { + out[i] = shist[i].label_sum / shist[i].count_averaging; + } else { + out[i] = shist[i].label_sum / shist[i].count; + } } } }; -template +template class InverseGaussianObjectiveFunction { public: using DataT = DataT_; using LabelT = LabelT_; using IdxT = IdxT_; - using BinT = AggregateBin; + + static const bool oob_honesty = oob_honesty_; + + using BinT = std::conditional_t; + static constexpr auto eps_ = 10 * std::numeric_limits::epsilon(); - private: - IdxT min_samples_leaf; + IdxT min_samples_leaf_splitting; + IdxT min_samples_leaf_averaging; public: - HDI InverseGaussianObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) - : min_samples_leaf{min_samples_leaf} - { - } + HDI InverseGaussianObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf_splitting, IdxT min_samples_leaf_averaging) + : min_samples_leaf_splitting{min_samples_leaf_splitting}, + min_samples_leaf_averaging{min_samples_leaf_averaging} + {} /** * @brief compute the inverse gaussian impurity reduction (or purity gain) for each split @@ -452,13 +597,20 @@ class InverseGaussianObjectiveFunction { * and is mathematically equivalent to the respective differences of * inverse gaussian deviances. */ - HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT n_bins, IdxT len, IdxT nLeft) const + HDI DataT GainPerSplit( + const BinT* hist, IdxT i, IdxT n_bins, IdxT nLeftSplitting, + IdxT nLeftAveraging, IdxT trainingLen, IdxT nRightAveraging) const { // get the lens' - IdxT nRight = len - nLeft; - + IdxT nRightSplitting = trainingLen - nLeftSplitting; + // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) + if constexpr (oob_honesty) { + if (nLeftAveraging < min_samples_leaf_averaging || nRightAveraging < min_samples_leaf_averaging) + return -std::numeric_limits::max(); + } + + if (nLeftSplitting < min_samples_leaf_splitting || nRightSplitting < min_samples_leaf_splitting) return -std::numeric_limits::max(); auto label_sum = hist[n_bins - 1].label_sum; @@ -470,31 +622,47 @@ class InverseGaussianObjectiveFunction { return -std::numeric_limits::max(); // compute the gain to be - DataT parent_obj = -DataT(len) * DataT(len) / label_sum; - DataT left_obj = -DataT(nLeft) * DataT(nLeft) / left_label_sum; - DataT right_obj = -DataT(nRight) * DataT(nRight) / right_label_sum; + DataT parent_obj = -DataT(trainingLen) * DataT(trainingLen) / label_sum; + DataT left_obj = -DataT(nLeftSplitting) * DataT(nLeftSplitting) / left_label_sum; + DataT right_obj = -DataT(nRightSplitting) * DataT(nRightSplitting) / right_label_sum; DataT gain = parent_obj - (left_obj + right_obj); - gain = gain / (2 * len); + gain = gain / (2 * trainingLen); return gain; } DI Split Gain( - BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT n_bins) const + BinT const* shist, DataT const* squantiles, IdxT col, IdxT len, IdxT avg_len, IdxT n_bins) const { Split sp; for (IdxT i = threadIdx.x; i < n_bins; i += blockDim.x) { - auto nLeft = shist[i].count; - sp.update({squantiles[i], col, GainPerSplit(shist, i, n_bins, len, nLeft), nLeft}); + auto nLeftSplitting = shist[i].count; + + int nLeftAveraging = 0; + int nRightAveraging = 0; + if constexpr (oob_honesty) { + nLeftAveraging = shist[i].count_averaging; + nRightAveraging = avg_len - nLeftAveraging; + } + + sp.update({ + squantiles[i], col, + GainPerSplit(shist, i, n_bins, nLeftSplitting, nLeftAveraging, len - avg_len, nRightAveraging), + nLeftSplitting + nLeftAveraging, nLeftAveraging, nRightAveraging}); } return sp; } + DI IdxT NumClasses() const { return 1; } static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { for (int i = 0; i < nclasses; i++) { - out[i] = shist[i].label_sum / shist[i].count; + if constexpr (oob_honesty) { + out[i] = shist[i].label_sum / shist[i].count_averaging; + } else { + out[i] = shist[i].label_sum / shist[i].count; + } } } }; diff --git a/cpp/src/decisiontree/batched-levelalgo/split.cuh b/cpp/src/decisiontree/batched-levelalgo/split.cuh index bb4bd5408a..ac71cadf80 100644 --- a/cpp/src/decisiontree/batched-levelalgo/split.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/split.cuh @@ -18,6 +18,7 @@ #include #include +#include namespace ML { namespace DT { @@ -43,8 +44,15 @@ struct Split { /** number of samples in the left child */ int nLeft; - DI Split(DataT quesval, IdxT colid, DataT best_metric_val, IdxT nLeft) - : quesval(quesval), colid(colid), best_metric_val(best_metric_val), nLeft(nLeft) + /** number of training samples in the left child */ + int nLeftAveraging; + + /** number of training samples in the right child */ + int nRightAveraging; + + DI Split(DataT quesval, IdxT colid, DataT best_metric_val, IdxT nLeft, IdxT nLeftAveraging, IdxT nRightAveraging) + : quesval(quesval), colid(colid), best_metric_val(best_metric_val), + nLeft(nLeft), nLeftAveraging(nLeftAveraging), nRightAveraging(nRightAveraging) { } @@ -53,6 +61,8 @@ struct Split { quesval = best_metric_val = Min; colid = -1; nLeft = 0; + nLeftAveraging = 0; + nRightAveraging = 0; } /** @@ -68,6 +78,8 @@ struct Split { colid = other.colid; best_metric_val = other.best_metric_val; nLeft = other.nLeft; + nLeftAveraging = other.nLeftAveraging; + nRightAveraging = other.nRightAveraging; return *this; } @@ -103,7 +115,9 @@ struct Split { auto co = raft::shfl(colid, id); auto be = raft::shfl(best_metric_val, id); auto nl = raft::shfl(nLeft, id); - update({qu, co, be, nl}); + auto nlAvg = raft::shfl(nLeftAveraging, id); + auto nrAvg = raft::shfl(nRightAveraging, id); + update({qu, co, be, nl, nlAvg, nrAvg}); } } @@ -142,13 +156,19 @@ struct Split { split_reg.colid = split->colid; split_reg.best_metric_val = split->best_metric_val; split_reg.nLeft = split->nLeft; + split_reg.nLeftAveraging = split->nLeftAveraging; + split_reg.nRightAveraging = split->nRightAveraging; bool update_result = - split_reg.update({this->quesval, this->colid, this->best_metric_val, this->nLeft}); + split_reg.update({ + this->quesval, this->colid, this->best_metric_val, this->nLeft, + this->nLeftAveraging, this->nRightAveraging}); if (update_result) { split->quesval = split_reg.quesval; split->colid = split_reg.colid; split->best_metric_val = split_reg.best_metric_val; split->nLeft = split_reg.nLeft; + split->nLeftAveraging = split_reg.nLeftAveraging; + split->nRightAveraging = split_reg.nRightAveraging; } __threadfence(); atomicExch(mutex, 0); @@ -164,22 +184,46 @@ struct Split { * @param[in] len length of this array * @param[in] s cuda stream where to schedule work */ +template +__global__ void init_splits_kernel(Split* splits, IdxT len) +{ + const int ix_thread = threadIdx.x + blockDim.x * blockIdx.x; + if (ix_thread < len) { + splits[ix_thread] = Split{}; + } +} + template void initSplit(Split* splits, IdxT len, cudaStream_t s) { - auto op = [] __device__(Split * ptr, IdxT idx) { *ptr = Split(); }; - raft::linalg::writeOnlyUnaryOp, decltype(op), IdxT, TPB>(splits, len, op, s); + // The below is a potential replacement for "writeOnlyUnaryOp", which stopped working + // during my testing + // old version: + // auto op = [] __device__(Split * ptr, IdxT idx) { *ptr = Split(); }; + // raft::linalg::writeOnlyUnaryOp, decltype(op), IdxT, TPB>(splits, len, op, s); + + // potential new version + // auto op = [] __device__(auto idx) { return Split{}; }; + // raft::handle_t handle{s}; + // auto split_view = raft::make_device_vector_view(splits, len); + // raft::linalg::map_offset(handle, split_view, op); + + // custom kernel version + const int grid_dim = (len + TPB - 1) / TPB; + init_splits_kernel<<>>(splits, len); } template void printSplits(Split* splits, IdxT len, cudaStream_t s) { auto op = [] __device__(Split * ptr, IdxT idx) { - printf("quesval = %e, colid = %d, best_metric_val = %e, nLeft = %d\n", + printf("quesval = %e, colid = %d, best_metric_val = %e, nLeft = %d, nLeftAveraging = %d, nRightAveraging = %d\n", ptr->quesval, ptr->colid, ptr->best_metric_val, - ptr->nLeft); + ptr->nLeft, + ptr->nLeftAveraging, + ptr->nRightAveraging); }; raft::linalg::writeOnlyUnaryOp, decltype(op), IdxT, TPB>(splits, len, op, s); RAFT_CUDA_TRY(cudaDeviceSynchronize()); diff --git a/cpp/src/decisiontree/decisiontree.cu b/cpp/src/decisiontree/decisiontree.cu index 46516aaa4a..349f559e33 100644 --- a/cpp/src/decisiontree/decisiontree.cu +++ b/cpp/src/decisiontree/decisiontree.cu @@ -35,12 +35,18 @@ void validity_check(const DecisionTreeParams params) ASSERT((params.max_n_bins > 0), "Invalid max_n_bins %d", params.max_n_bins); ASSERT((params.max_n_bins <= 1024), "max_n_bins should not be larger than 1024"); ASSERT((params.split_criterion != 3), "MAE not supported."); - ASSERT((params.min_samples_leaf >= 1), - "Invalid value for min_samples_leaf %d. Should be >= 1.", - params.min_samples_leaf); - ASSERT((params.min_samples_split >= 2), - "Invalid value for min_samples_split: %d. Should be >= 2.", - params.min_samples_split); + ASSERT((params.min_samples_leaf_splitting >= 1), + "Invalid value for min_samples_leaf_splitting %d. Should be >= 1.", + params.min_samples_leaf_splitting); + ASSERT((not params.oob_honesty or params.min_samples_leaf_averaging >= 1), + "Invalid value for min_samples_leaf_averaging %d. Should be >= 1 if honesty enabled.", + params.min_samples_leaf_averaging); + ASSERT((params.min_samples_split_splitting >= 2), + "Invalid value for min_samples_split_splitting: %d. Should be >= 2.", + params.min_samples_split_splitting); + ASSERT((not params.oob_honesty or params.min_samples_split_averaging >= 2), + "Invalid value for min_samples_split_averaging: %d. Should be >= 2 if honesty enabled.", + params.min_samples_split_averaging); } /** @@ -56,27 +62,34 @@ void validity_check(const DecisionTreeParams params) * @param[in] cfg_split_criterion: split criterion; default CRITERION_END, * i.e., GINI for classification or MSE for regression * @param[in] cfg_max_batch_size: batch size for experimental backend + * @param[in] cfg_oob_honesty: Whether to use oob_honesty features */ void set_tree_params(DecisionTreeParams& params, int cfg_max_depth, int cfg_max_leaves, float cfg_max_features, int cfg_max_n_bins, - int cfg_min_samples_leaf, - int cfg_min_samples_split, + int cfg_min_samples_leaf_splitting, + int cfg_min_samples_leaf_averaging, + int cfg_min_samples_split_splitting, + int cfg_min_samples_split_averaging, float cfg_min_impurity_decrease, CRITERION cfg_split_criterion, - int cfg_max_batch_size) + int cfg_max_batch_size, + bool cfg_oob_honesty) { - params.max_depth = cfg_max_depth; - params.max_leaves = cfg_max_leaves; - params.max_features = cfg_max_features; - params.max_n_bins = cfg_max_n_bins; - params.min_samples_leaf = cfg_min_samples_leaf; - params.min_samples_split = cfg_min_samples_split; - params.split_criterion = cfg_split_criterion; - params.min_impurity_decrease = cfg_min_impurity_decrease; - params.max_batch_size = cfg_max_batch_size; + params.max_depth = cfg_max_depth; + params.max_leaves = cfg_max_leaves; + params.max_features = cfg_max_features; + params.max_n_bins = cfg_max_n_bins; + params.min_samples_leaf_splitting = cfg_min_samples_leaf_splitting; + params.min_samples_leaf_averaging = cfg_min_samples_leaf_averaging; + params.min_samples_split_splitting = cfg_min_samples_split_splitting; + params.min_samples_split_averaging = cfg_min_samples_split_averaging; + params.split_criterion = cfg_split_criterion; + params.min_impurity_decrease = cfg_min_impurity_decrease; + params.max_batch_size = cfg_max_batch_size; + params.oob_honesty = cfg_oob_honesty; validity_check(params); } diff --git a/cpp/src/decisiontree/decisiontree.cuh b/cpp/src/decisiontree/decisiontree.cuh index eac66f1e16..6a27fcf6c8 100644 --- a/cpp/src/decisiontree/decisiontree.cuh +++ b/cpp/src/decisiontree/decisiontree.cuh @@ -222,7 +222,7 @@ tl::Tree build_treelite_tree(const DT::TreeMetaDataNode& rf_tree, class DecisionTree { public: - template + template static std::shared_ptr> fit( const raft::handle_t& handle, const cudaStream_t s, @@ -231,6 +231,8 @@ class DecisionTree { const int nrows, const LabelT* labels, rmm::device_uvector* row_ids, + bool* split_row_mask, + const size_t n_avg_samples, int unique_labels, DecisionTreeParams params, uint64_t seed, @@ -244,93 +246,75 @@ class DecisionTree { params.split_criterion = default_criterion; } using IdxT = int; + Dataset dataset{ + data, + labels, + nrows, + ncols, + int(row_ids->size()), + max(1, IdxT(params.max_features * ncols)), + row_ids->data(), + int(n_avg_samples), + split_row_mask, + unique_labels}; + // Dispatch objective if (not std::is_same::value and params.split_criterion == CRITERION::GINI) { - return Builder>(handle, - s, - treeid, - seed, - params, - data, - labels, - nrows, - ncols, - row_ids, - unique_labels, - quantiles) + return Builder>(handle, + s, + treeid, + seed, + params, + dataset, + quantiles) .train(); } else if (not std::is_same::value and params.split_criterion == CRITERION::ENTROPY) { - return Builder>(handle, - s, - treeid, - seed, - params, - data, - labels, - nrows, - ncols, - row_ids, - unique_labels, - quantiles) + return Builder>(handle, + s, + treeid, + seed, + params, + dataset, + quantiles) .train(); } else if (std::is_same::value and params.split_criterion == CRITERION::MSE) { - return Builder>(handle, - s, - treeid, - seed, - params, - data, - labels, - nrows, - ncols, - row_ids, - unique_labels, - quantiles) + return Builder>(handle, + s, + treeid, + seed, + params, + dataset, + quantiles) .train(); } else if (std::is_same::value and params.split_criterion == CRITERION::POISSON) { - return Builder>(handle, - s, - treeid, - seed, - params, - data, - labels, - nrows, - ncols, - row_ids, - unique_labels, - quantiles) + return Builder>(handle, + s, + treeid, + seed, + params, + dataset, + quantiles) .train(); } else if (std::is_same::value and params.split_criterion == CRITERION::GAMMA) { - return Builder>(handle, - s, - treeid, - seed, - params, - data, - labels, - nrows, - ncols, - row_ids, - unique_labels, - quantiles) + return Builder>(handle, + s, + treeid, + seed, + params, + dataset, + quantiles) .train(); } else if (std::is_same::value and params.split_criterion == CRITERION::INVERSE_GAUSSIAN) { - return Builder>(handle, - s, - treeid, - seed, - params, - data, - labels, - nrows, - ncols, - row_ids, - unique_labels, - quantiles) + return Builder>(handle, + s, + treeid, + seed, + params, + dataset, + quantiles) .train(); } else { ASSERT(false, "Unknown split criterion."); diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index a97422d841..e5f300311d 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -581,10 +581,14 @@ RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, - int min_samples_leaf, - int min_samples_split, + int min_samples_leaf_splitting, + int min_samples_leaf_averaging, + int min_samples_split_splitting, + int min_samples_split_averaging, float min_impurity_decrease, bool bootstrap, + bool oob_honesty, + bool double_bootstrap, int n_trees, float max_samples, uint64_t seed, @@ -598,14 +602,19 @@ RF_params set_rf_params(int max_depth, max_leaves, max_features, max_n_bins, - min_samples_leaf, - min_samples_split, + min_samples_leaf_splitting, + min_samples_leaf_averaging, + min_samples_split_splitting, + min_samples_split_averaging, min_impurity_decrease, split_criterion, - max_batch_size); + max_batch_size, + oob_honesty); RF_params rf_params; rf_params.n_trees = n_trees; rf_params.bootstrap = bootstrap; + rf_params.oob_honesty = oob_honesty; + rf_params.double_bootstrap = double_bootstrap; rf_params.max_samples = max_samples; rf_params.seed = seed; rf_params.n_streams = min(cfg_n_streams, omp_get_max_threads()); diff --git a/cpp/src/randomforest/randomforest.cuh b/cpp/src/randomforest/randomforest.cuh index 9b1d14bb00..d5022b4769 100644 --- a/cpp/src/randomforest/randomforest.cuh +++ b/cpp/src/randomforest/randomforest.cuh @@ -31,6 +31,10 @@ #include #include +#include +#include +#include +#include #ifdef _OPENMP #include @@ -41,6 +45,19 @@ #include +struct set_mask_functor { + const int n_rows; + set_mask_functor(const int n_rows) + : n_rows(n_rows) + {} + + __host__ __device__ + void operator()(const int& index, bool& output) + { + output = index < n_rows; + } +}; + namespace ML { template class RandomForest { @@ -48,10 +65,13 @@ class RandomForest { RF_params rf_params; // structure containing RF hyperparameters int rf_type; // 0 for classification 1 for regression - void get_row_sample(int tree_id, - int n_rows, - rmm::device_uvector* selected_rows, - const cudaStream_t stream) + size_t get_row_sample(int tree_id, + int n_rows, + int n_sampled_rows, + rmm::device_uvector* selected_rows, + rmm::device_uvector* split_row_mask, + rmm::device_uvector* tmp_row_vec, + const cudaStream_t stream) { raft::common::nvtx::range fun_scope("bootstrapping row IDs @randomforest.cuh"); @@ -60,14 +80,58 @@ class RandomForest { rs = DT::fnv1a32(rs, rf_params.seed); rs = DT::fnv1a32(rs, tree_id); raft::random::Rng rng(rs, raft::random::GenPhilox); + if (rf_params.bootstrap) { // Use bootstrapped sample set - rng.uniformInt(selected_rows->data(), selected_rows->size(), 0, n_rows, stream); - + rng.uniformInt(selected_rows->data(), n_sampled_rows, 0, n_rows, stream); } else { // Use all the samples from the dataset thrust::sequence(thrust::cuda::par.on(stream), selected_rows->begin(), selected_rows->end()); } + size_t num_avg_samples = 0; + + if (rf_params.oob_honesty and rf_params.bootstrap) { + // honesty doesn't make sense without bootstrapping -- all the obs were otherwise selected + num_avg_samples = n_sampled_rows; + assert(rf_params.bootstrap); + + // We'll have n_rows samples for splitting. + // Need to sort the selected rows to be able to use thrust set difference + thrust::sort(thrust::cuda::par.on(stream), selected_rows->begin(), selected_rows->end()); + + // Get the set of observations that are not used for split + tmp_row_vec->resize(n_rows, stream); + auto iter_end = thrust::set_difference( + thrust::cuda::par.on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(n_sampled_rows), + selected_rows->begin(), + selected_rows->end(), + tmp_row_vec->begin()); + + // Now tmp_row_vec is the observations available for the averaging set + size_t num_remaining_samples = iter_end - tmp_row_vec->begin(); + tmp_row_vec->resize(num_remaining_samples, stream); + + // Get the avg selected rows either as the remaining data, or bootstrapped again + if (rf_params.double_bootstrap) { + rng.uniformInt(selected_rows->data() + n_sampled_rows, num_avg_samples, 0, num_remaining_samples, stream); + auto index_iter = thrust::make_permutation_iterator(tmp_row_vec->begin(), selected_rows->begin() + n_sampled_rows); + + thrust::copy(thrust::cuda::par.on(stream), index_iter, index_iter + num_avg_samples, selected_rows->begin() + n_sampled_rows); + } else { + thrust::copy(thrust::cuda::par.on(stream), tmp_row_vec->begin(), tmp_row_vec->end(), selected_rows->begin() + n_sampled_rows); + } + + // First n_rows goes to training, num_avg_samples goes to averaging. + auto begin_zip_iterator = thrust::make_zip_iterator(thrust::make_tuple( + thrust::make_counting_iterator(0), split_row_mask->begin())); + auto end_zip_iterator = thrust::make_zip_iterator(thrust::make_tuple( + thrust::make_counting_iterator(2 * n_sampled_rows), split_row_mask->end())); + thrust::for_each(thrust::cuda::par.on(stream), begin_zip_iterator, end_zip_iterator, thrust::make_zip_function(set_mask_functor(n_sampled_rows))); + } + + return num_avg_samples; } void error_checking(const T* input, L* predictions, int n_rows, int n_cols, bool predict) const @@ -156,16 +220,31 @@ class RandomForest { // Use a deque instead of vector because it can be used on objects with a deleted copy // constructor std::deque> selected_rows; + std::deque> split_row_masks; + std::deque> sampling_staging_vecs; + + size_t max_sample_row_size = this->rf_params.oob_honesty ? n_sampled_rows * 2 : n_sampled_rows; for (int i = 0; i < n_streams; i++) { - selected_rows.emplace_back(n_sampled_rows, handle.get_stream_from_stream_pool(i)); + auto s = handle.get_stream_from_stream_pool(i); + selected_rows.emplace_back(max_sample_row_size, s); + if (this->rf_params.oob_honesty) { + split_row_masks.emplace_back(max_sample_row_size, s); + sampling_staging_vecs.emplace_back(n_rows, s); + } } #pragma omp parallel for num_threads(n_streams) for (int i = 0; i < this->rf_params.n_trees; i++) { int stream_id = omp_get_thread_num(); auto s = handle.get_stream_from_stream_pool(stream_id); - - this->get_row_sample(i, n_rows, &selected_rows[stream_id], s); + + rmm::device_uvector* sampling_staging_vec = this->rf_params.oob_honesty ? &sampling_staging_vecs[stream_id] : nullptr; + rmm::device_uvector* split_row_mask = this->rf_params.oob_honesty ? &split_row_masks[stream_id] : nullptr; + auto n_avg_samples = this->get_row_sample( + i, n_rows, n_sampled_rows, &selected_rows[stream_id], + split_row_mask, + sampling_staging_vec, + s); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -175,19 +254,37 @@ class RandomForest { Expectation: Each tree node will contain (a) # n_sampled_rows and (b) a pointer to a list of row numbers w.r.t original data. */ - - forest->trees[i] = DT::DecisionTree::fit(handle, - s, - input, - n_cols, - n_rows, - labels, - &selected_rows[stream_id], - n_unique_labels, - this->rf_params.tree_params, - this->rf_params.seed, - quantiles, - i); + if (this->rf_params.oob_honesty) { + forest->trees[i] = DT::DecisionTree::fit(handle, + s, + input, + n_cols, + n_rows, + labels, + &selected_rows[stream_id], + split_row_masks[stream_id].data(), + n_avg_samples, + n_unique_labels, + this->rf_params.tree_params, + this->rf_params.seed, + quantiles, + i); + } else { + forest->trees[i] = DT::DecisionTree::fit(handle, + s, + input, + n_cols, + n_rows, + labels, + &selected_rows[stream_id], + nullptr, + n_avg_samples, + n_unique_labels, + this->rf_params.tree_params, + this->rf_params.seed, + quantiles, + i); + } } // Cleanup handle.sync_stream_pool(); diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index de8ef17010..627fddc97c 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -117,14 +117,18 @@ struct RfTestParams { int max_leaves; bool bootstrap; int max_n_bins; - int min_samples_leaf; - int min_samples_split; + int min_samples_leaf_splitting; + int min_samples_leaf_averaging; + int min_samples_split_splitting; + int min_samples_split_averaging; float min_impurity_decrease; int n_streams; CRITERION split_criterion; int seed; int n_labels; bool double_precision; + bool double_bootstrap; + bool oob_honesty; // c++ has no reflection, so we enumerate the types here // This must be updated if new fields are added using types = std::tuple; }; @@ -152,12 +160,14 @@ std::ostream& operator<<(std::ostream& os, const RfTestParams& ps) os << ", n_trees = " << ps.n_trees << ", max_features = " << ps.max_features; os << ", max_samples = " << ps.max_samples << ", max_depth = " << ps.max_depth; os << ", max_leaves = " << ps.max_leaves << ", bootstrap = " << ps.bootstrap; - os << ", max_n_bins = " << ps.max_n_bins << ", min_samples_leaf = " << ps.min_samples_leaf; - os << ", min_samples_split = " << ps.min_samples_split; + os << ", max_n_bins = " << ps.max_n_bins; + os << ", min_samples_leaf_splitting = " << ps.min_samples_leaf_splitting << ", min_samples_split_splitting = " << ps.min_samples_split_splitting; + os << ", min_samples_leaf_averaging = " << ps.min_samples_leaf_averaging << ", min_samples_split_averaging = " << ps.min_samples_split_averaging; os << ", min_impurity_decrease = " << ps.min_impurity_decrease << ", n_streams = " << ps.n_streams; os << ", split_criterion = " << ps.split_criterion << ", seed = " << ps.seed; os << ", n_labels = " << ps.n_labels << ", double_precision = " << ps.double_precision; + os << ", oob_honesty = " << ps.oob_honesty << ", double_bootstrap = " << ps.double_bootstrap; return os; } @@ -214,10 +224,14 @@ auto TrainScore( params.max_leaves, params.max_features, params.max_n_bins, - params.min_samples_leaf, - params.min_samples_split, + params.min_samples_leaf_splitting, + params.min_samples_leaf_averaging, + params.min_samples_split_splitting, + params.min_samples_split_averaging, params.min_impurity_decrease, params.bootstrap, + params.oob_honesty, + params.double_bootstrap, params.n_trees, params.max_samples, 0, @@ -341,8 +355,9 @@ class RfSpecialisedTest { if (params.max_leaves > 0) { EXPECT_LE(forest->trees[i]->leaf_counter, params.max_leaves); } EXPECT_LE(forest->trees[i]->depth_counter, params.max_depth); + const int exp_min_avging = params.oob_honesty ? params.min_samples_leaf_averaging : 0; EXPECT_LE(forest->trees[i]->leaf_counter, - raft::ceildiv(int(params.n_rows), params.min_samples_leaf)); + raft::ceildiv(int(params.n_rows), params.min_samples_leaf_splitting + exp_min_avging)); } } @@ -465,6 +480,7 @@ class RfTest : public ::testing::TestWithParam { void SetUp() override { RfTestParams params = ::testing::TestWithParam::GetParam(); + std::cout << "Params " << params << std::endl; bool is_regression = params.split_criterion != GINI and params.split_criterion != ENTROPY; if (params.double_precision) { if (is_regression) { @@ -494,8 +510,10 @@ std::vector max_depth = {1, 10, 30}; std::vector max_leaves = {-1, 16, 50}; std::vector bootstrap = {false, true}; std::vector max_n_bins = {2, 57, 128, 256}; -std::vector min_samples_leaf = {1, 10, 30}; -std::vector min_samples_split = {2, 10}; +std::vector min_samples_leaf_splitting = {1, 10, 30}; +std::vector min_samples_leaf_averaging = {1, 10, 30}; +std::vector min_samples_split_splitting = {2, 10}; +std::vector min_samples_split_averaging = {2, 10}; std::vector min_impurity_decrease = {0.0f, 1.0f, 10.0f}; std::vector n_streams = {1, 2, 10}; std::vector split_criterion = {CRITERION::INVERSE_GAUSSIAN, @@ -507,6 +525,8 @@ std::vector split_criterion = {CRITERION::INVERSE_GAUSSIAN, std::vector seed = {0, 17}; std::vector n_labels = {2, 10, 20}; std::vector double_precision = {false, true}; +std::vector double_bootstrap = {true}; +std::vector oob_honesty = {false}; int n_tests = 100; @@ -523,14 +543,18 @@ INSTANTIATE_TEST_CASE_P(RfTests, max_leaves, bootstrap, max_n_bins, - min_samples_leaf, - min_samples_split, + min_samples_leaf_splitting, + min_samples_leaf_averaging, + min_samples_split_splitting, + min_samples_split_averaging, min_impurity_decrease, n_streams, split_criterion, seed, n_labels, - double_precision))); + double_precision, + double_bootstrap, + oob_honesty))); TEST(RfTests, IntegerOverflow) { @@ -547,7 +571,7 @@ TEST(RfTests, IntegerOverflow) auto stream_pool = std::make_shared(4); raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool); RF_params rf_params = - set_rf_params(3, 100, 1.0, 256, 1, 2, 0.0, false, 1, 1.0, 0, CRITERION::MSE, 4, 128); + set_rf_params(3, 100, 1.0, 256, 1, 0, 2, 0, 0.0, false, false, false, 1, 1.0, 0, CRITERION::MSE, 4, 128); fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params); // Check we have actually learned something @@ -789,7 +813,7 @@ INSTANTIATE_TEST_CASE_P(RfTests, RFQuantileVariableBinsTestD, ::testing::ValuesI TEST(RfTest, TextDump) { - RF_params rf_params = set_rf_params(2, 2, 1.0, 2, 1, 2, 0.0, false, 1, 1.0, 0, GINI, 1, 128); + RF_params rf_params = set_rf_params(2, 2, 1.0, 2, 1, 0, 2, 0, 0.0, false, false, false, 1, 1.0, 0, GINI, 1, 128); auto forest = std::make_shared>(); std::vector X_host = {1, 2, 3, 6, 7, 8}; @@ -831,7 +855,8 @@ struct ObjectiveTestParameters { int n_rows; int max_n_bins; int n_classes; - int min_samples_leaf; + int min_samples_leaf_splitting; + int min_samples_leaf_averaging; double tolerance; }; @@ -929,7 +954,8 @@ class ObjectiveTest : public ::testing::TestWithParam { (n_right / n) * right_mse); // gain in long form without proxy // edge cases - if (n_left < params.min_samples_leaf or n_right < params.min_samples_leaf) + printf("gain %f\n", gain); + if (n_left < params.min_samples_leaf_splitting or n_right < params.min_samples_leaf_splitting) return -std::numeric_limits::max(); else return gain; @@ -966,7 +992,7 @@ class ObjectiveTest : public ::testing::TestWithParam { (n_right / n) * right_ighd); // gain in long form without proxy // edge cases - if (n_left < params.min_samples_leaf or n_right < params.min_samples_leaf or + if (n_left < params.min_samples_leaf_splitting or n_right < params.min_samples_leaf_splitting or label_sum < ObjectiveT::eps_ or label_sum_right < ObjectiveT::eps_ or label_sum_left < ObjectiveT::eps_) return -std::numeric_limits::max(); @@ -1006,7 +1032,7 @@ class ObjectiveTest : public ::testing::TestWithParam { (n_right / n) * right_ghd); // gain in long form without proxy // edge cases - if (n_left < params.min_samples_leaf or n_right < params.min_samples_leaf or + if (n_left < params.min_samples_leaf_splitting or n_right < params.min_samples_leaf_splitting or label_sum < ObjectiveT::eps_ or label_sum_right < ObjectiveT::eps_ or label_sum_left < ObjectiveT::eps_) return -std::numeric_limits::max(); @@ -1044,7 +1070,7 @@ class ObjectiveTest : public ::testing::TestWithParam { (n_right / n) * right_phd); // gain in long form without proxy // edge cases - if (n_left < params.min_samples_leaf or n_right < params.min_samples_leaf or + if (n_left < params.min_samples_leaf_splitting or n_right < params.min_samples_leaf_splitting or label_sum < ObjectiveT::eps_ or label_sum_right < ObjectiveT::eps_ or label_sum_left < ObjectiveT::eps_) return -std::numeric_limits::max(); @@ -1083,7 +1109,7 @@ class ObjectiveTest : public ::testing::TestWithParam { auto gain = parent_entropy - ((left_n / n) * left_entropy + (right_n / n) * right_entropy); // edge cases - if (left_n < params.min_samples_leaf or right_n < params.min_samples_leaf) { + if (left_n < params.min_samples_leaf_splitting or right_n < params.min_samples_leaf_splitting) { return -std::numeric_limits::max(); } else { return gain; @@ -1120,7 +1146,7 @@ class ObjectiveTest : public ::testing::TestWithParam { auto gain = parent_gini - ((left_n / n) * left_gini + (right_n / n) * right_gini); // edge cases - if (left_n < params.min_samples_leaf or right_n < params.min_samples_leaf) { + if (left_n < params.min_samples_leaf_splitting or right_n < params.min_samples_leaf_splitting) { return -std::numeric_limits::max(); } else { return gain; @@ -1177,7 +1203,7 @@ class ObjectiveTest : public ::testing::TestWithParam { { srand(params.seed); params = ::testing::TestWithParam::GetParam(); - ObjectiveT objective(params.n_classes, params.min_samples_leaf); + ObjectiveT objective(params.n_classes, params.min_samples_leaf_splitting, params.min_samples_leaf_averaging); auto data = GenRandomData(); auto [cdf_hist, pdf_hist] = GenHist(data); @@ -1187,53 +1213,56 @@ class ObjectiveTest : public ::testing::TestWithParam { auto hypothesis_gain = objective.GainPerSplit(&cdf_hist[0], split_bin_index, params.max_n_bins, + NumLeftOfBin(cdf_hist, split_bin_index), + 0, // nLeftAveraging NumLeftOfBin(cdf_hist, params.max_n_bins - 1), - NumLeftOfBin(cdf_hist, split_bin_index)); + 0 // nRightAveraging + ); ASSERT_NEAR(ground_truth_gain, hypothesis_gain, params.tolerance); } }; const std::vector mse_objective_test_parameters = { - {9507819643927052255LLU, 2048, 64, 1, 0, 0.00001}, - {9507819643927052259LLU, 2048, 128, 1, 1, 0.00001}, - {9507819643927052251LLU, 2048, 256, 1, 1, 0.00001}, - {9507819643927052258LLU, 2048, 512, 1, 5, 0.00001}, + {9507819643927052255LLU, 2048, 64, 1, 1, 0, 0.00001}, + {9507819643927052259LLU, 2048, 128, 1, 1, 0, 0.00001}, + {9507819643927052251LLU, 2048, 256, 1, 1, 0, 0.00001}, + {9507819643927052258LLU, 2048, 512, 1, 5, 0, 0.00001}, }; const std::vector poisson_objective_test_parameters = { - {9507819643927052255LLU, 2048, 64, 1, 0, 0.00001}, - {9507819643927052259LLU, 2048, 128, 1, 1, 0.00001}, - {9507819643927052251LLU, 2048, 256, 1, 1, 0.00001}, - {9507819643927052258LLU, 2048, 512, 1, 5, 0.00001}, + {9507819643927052255LLU, 2048, 64, 1, 1, 0, 0.00001}, + {9507819643927052259LLU, 2048, 128, 1, 1, 0, 0.00001}, + {9507819643927052251LLU, 2048, 256, 1, 1, 0, 0.00001}, + {9507819643927052258LLU, 2048, 512, 1, 5, 0, 0.00001}, }; const std::vector gamma_objective_test_parameters = { - {9507819643927052255LLU, 2048, 64, 1, 0, 0.00001}, - {9507819643927052259LLU, 2048, 128, 1, 1, 0.00001}, - {9507819643927052251LLU, 2048, 256, 1, 1, 0.00001}, - {9507819643927052258LLU, 2048, 512, 1, 5, 0.00001}, + {9507819643927052255LLU, 2048, 64, 1, 1, 0, 0.00001}, + {9507819643927052259LLU, 2048, 128, 1, 1, 0, 0.00001}, + {9507819643927052251LLU, 2048, 256, 1, 1, 0, 0.00001}, + {9507819643927052258LLU, 2048, 512, 1, 5, 0, 0.00001}, }; const std::vector invgauss_objective_test_parameters = { - {9507819643927052255LLU, 2048, 64, 1, 0, 0.00001}, - {9507819643927052259LLU, 2048, 128, 1, 1, 0.00001}, - {9507819643927052251LLU, 2048, 256, 1, 1, 0.00001}, - {9507819643927052258LLU, 2048, 512, 1, 5, 0.00001}, + {9507819643927052255LLU, 2048, 64, 1, 1, 0, 0.00001}, + {9507819643927052259LLU, 2048, 128, 1, 1, 0, 0.00001}, + {9507819643927052251LLU, 2048, 256, 1, 1, 0, 0.00001}, + {9507819643927052258LLU, 2048, 512, 1, 5, 0, 0.00001}, }; const std::vector entropy_objective_test_parameters = { - {9507819643927052255LLU, 2048, 64, 2, 0, 0.00001}, - {9507819643927052256LLU, 2048, 128, 10, 1, 0.00001}, - {9507819643927052257LLU, 2048, 256, 100, 1, 0.00001}, - {9507819643927052258LLU, 2048, 512, 100, 5, 0.00001}, + {9507819643927052255LLU, 2048, 64, 2, 0, 0, 0.00001}, + {9507819643927052256LLU, 2048, 128, 10, 1, 0, 0.00001}, + {9507819643927052257LLU, 2048, 256, 100, 1, 0, 0.00001}, + {9507819643927052258LLU, 2048, 512, 100, 5, 0, 0.00001}, }; const std::vector gini_objective_test_parameters = { - {9507819643927052255LLU, 2048, 64, 2, 0, 0.00001}, - {9507819643927052256LLU, 2048, 128, 10, 1, 0.00001}, - {9507819643927052257LLU, 2048, 256, 100, 1, 0.00001}, - {9507819643927052258LLU, 2048, 512, 100, 5, 0.00001}, + {9507819643927052255LLU, 2048, 64, 2, 0, 0, 0.00001}, + {9507819643927052256LLU, 2048, 128, 10, 1, 0, 0.00001}, + {9507819643927052257LLU, 2048, 256, 100, 1, 0, 0.00001}, + {9507819643927052258LLU, 2048, 512, 100, 5, 0, 0.00001}, }; // mse objective test diff --git a/honesty_test.py b/honesty_test.py new file mode 100755 index 0000000000..ea90c52f29 --- /dev/null +++ b/honesty_test.py @@ -0,0 +1,64 @@ +#! /usr/bin/env python + +import pandas as pd +import numpy as np +import cuml + +from cuml.ensemble import RandomForestClassifier as RFC +from cuml.ensemble import RandomForestRegressor as RFR + +import time + + +input = pd.read_parquet("/home/scratch.eschmidt_sw/gotvBIG.parquet") + +# input = input.iloc[:50000, :] + +# cuML doesn't handle string inputs +input["vh_stratum"] = input["vh_stratum"].replace({"below": -1.0, "average":0.0, "above":1.0, "":np.nan}).astype(float) + +states = input['state'].unique() +states_map = {} +for ix_state,state in enumerate(states): + states_map[state] = float(ix_state) + +input["state"] = input["state"].map(states_map) + +# state fields are redundant with state value +# state_fields = ["d_st_AK","d_st_AR","d_st_AZ","d_st_CO","d_st_FL","d_st_GA","d_st_IA","d_st_KS","d_st_KY","d_st_LA","d_st_ME","d_st_MI","d_st_NC","d_st_NH","d_st_SD","d_st_TX","d_st_WI"] +# input = input.drop(labels=state_fields, axis=1) + +input = input.astype('float32') +input = input.dropna() + +# Choose how many index include for random selection +num_rows = input.shape[0] +ix_train = np.random.choice(num_rows, replace=False, size=int(num_rows*0.7)) +ix_test = np.setdiff1d(np.arange(num_rows), ix_train) + +x = input.drop(labels=["voted14"], axis=1) +y = input["voted14"] + +x_train = x.iloc[ix_train, :] +y_train = y.iloc[ix_train] + +x_test = x.iloc[ix_test, :] +y_test = y.iloc[ix_test] + +n_trees = 100 + +random_forest_regress = RFR(n_estimators=n_trees, split_criterion=2, random_state=42) +start = time.time() +trainedRFR = random_forest_regress.fit(x_train, y_train) +end = time.time() +pred_test_regress = trainedRFR.predict(x_test) +mse = cuml.metrics.mean_squared_error(y_test, pred_test_regress) +print(f"No honesty {mse} time {end-start}") + +random_forest_regress = RFR(n_estimators=n_trees, oob_honesty=True, split_criterion=2, random_state=42) +start = time.time() +trainedRFR = random_forest_regress.fit(x_train, y_train) +end = time.time() +pred_test_regress = trainedRFR.predict(x_test) +mse = cuml.metrics.mean_squared_error(y_test, pred_test_regress) +print(f"Honesty {mse} time {end-start}") diff --git a/python/cuml/benchmark/algorithms.py b/python/cuml/benchmark/algorithms.py index 2b83aad1ee..18ed19b116 100644 --- a/python/cuml/benchmark/algorithms.py +++ b/python/cuml/benchmark/algorithms.py @@ -171,7 +171,6 @@ def run_cuml(self, data, bench_args={}, **override_setup_args): """Runs the cuml-based algorithm's fit method on specified data""" all_args = {**self.shared_args, **self.cuml_args} all_args = {**all_args, **override_setup_args} - if "cuml_setup_result" not in all_args: cuml_obj = self.cuml_class(**all_args) else: diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py index 4dfd7c3ddb..1e1203d39c 100755 --- a/python/cuml/dask/ensemble/randomforestclassifier.py +++ b/python/cuml/dask/ensemble/randomforestclassifier.py @@ -96,6 +96,14 @@ class RandomForestClassifier( * If ``True``, each tree in the forest is built on a bootstrapped sample with replacement. * If ``False``, the whole dataset is used to build each tree. + oob_honesty : boolean (default = True) + Control oob_honesty.\n + * If ``True``, eachtree in the forest is built using disjoint sets for splitting and averaging + * If ``False``, the whole dataset is used to build each tree. + double_bootstrap : boolean (default = True) + Control bootstrapping in the averaging set. Only applies if oob_honesty is truer.\n + * If ``True``, each tree uses an averaging set which is sampled with replacement from the samples not used for splitting + * If ``False``, each tree uses an averaging set which is the set difference of all samples presented to the tree and the splitting set max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = 16) @@ -119,22 +127,34 @@ class RandomForestClassifier( n_bins : int (default = 128) Maximum number of bins used by the split algorithm per feature. - min_samples_leaf : int or float (default = 1) - The minimum number of samples (rows) in each leaf node.\n - * If type ``int``, then ``min_samples_leaf`` represents the minimum + min_samples_leaf_splitting : int or float (default = 1) + The minimum number of training samples (rows) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_splitting`` represents the minimum number. - * If ``float``, then ``min_samples_leaf`` represents a fraction - and ``ceil(min_samples_leaf * n_rows)`` is the minimum number of + * If ``float``, then ``min_samples_leaf_splitting`` represents a fraction and + ``ceil(min_samples_leaf_splitting * n_rows)`` is the minimum number of samples for each leaf node. - - min_samples_split : int or float (default = 2) - The minimum number of samples required to split an internal - node.\n - * If type ``int``, then ``min_samples_split`` represents the minimum + min_samples_leaf_averaging : int or float (default = 2) + The minimum number of averaging samples (rows, oob_honesty) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_averaging`` represents the minimum + number. + * If ``float``, then ``min_samples_leaf_averaging`` represents a fraction and + ``ceil(min_samples_leaf_averaging * n_rows)`` is the minimum number of + samples for each leaf node. + min_samples_split_splitting : int or float (default = 2) + The minimum number of trainingsamples required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum number. - * If type ``float``, then ``min_samples_split`` represents a fraction - and ``ceil(min_samples_split * n_rows)`` is the minimum number of - samples for each split. + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum + number of samples for each split. + min_samples_split_averaging : int or float (default = 2) + The minimum number of averaging samples (oob_honesty) required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum + number. + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum + number of samples for each split. n_streams : int (default = 4 ) Number of parallel streams used for forest building @@ -169,7 +189,6 @@ def __init__( ignore_empty_partitions=False, **kwargs, ): - super().__init__(client=client, verbose=verbose, **kwargs) self._create_model( model_func=RandomForestClassifier._construct_rf, diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py index f2c7d283eb..a38ff0dcc9 100755 --- a/python/cuml/dask/ensemble/randomforestregressor.py +++ b/python/cuml/dask/ensemble/randomforestregressor.py @@ -82,6 +82,14 @@ class RandomForestRegressor( * If ``True``, each tree in the forest is built on a bootstrapped sample with replacement. * If ``False``, the whole dataset is used to build each tree. + oob_honesty : boolean (default = True) + Control oob_honesty.\n + * If ``True``, eachtree in the forest is built using disjoint sets for splitting and averaging + * If ``False``, the whole dataset is used to build each tree. + double_bootstrap : boolean (default = True) + Control bootstrapping in the averaging set. Only applies if oob_honesty is truer.\n + * If ``True``, each tree uses an averaging set which is sampled with replacement from the samples not used for splitting + * If ``False``, each tree uses an averaging set which is the set difference of all samples presented to the tree and the splitting set max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = 16) @@ -104,20 +112,34 @@ class RandomForestRegressor( * If ``None``, then ``max_features = 1.0``. n_bins : int (default = 128) Maximum number of bins used by the split algorithm per feature. - min_samples_leaf : int or float (default = 1) - The minimum number of samples (rows) in each leaf node.\n - * If type ``int``, then ``min_samples_leaf`` represents the minimum + min_samples_leaf_splitting : int or float (default = 1) + The minimum number of training samples (rows) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_splitting`` represents the minimum number. - * If ``float``, then ``min_samples_leaf`` represents a fraction and - ``ceil(min_samples_leaf * n_rows)`` is the minimum number of + * If ``float``, then ``min_samples_leaf_splitting`` represents a fraction and + ``ceil(min_samples_leaf_splitting * n_rows)`` is the minimum number of samples for each leaf node. - min_samples_split : int or float (default = 2) - The minimum number of samples required to split an internal node.\n - * If type ``int``, then ``min_samples_split`` represents the minimum + min_samples_leaf_averaging : int or float (default = 2) + The minimum number of averaging samples (rows, oob_honesty) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_averaging`` represents the minimum number. - * If type ``float``, then ``min_samples_split`` represents a fraction - and ``ceil(min_samples_split * n_rows)`` is the minimum number of - samples for each split. + * If ``float``, then ``min_samples_leaf_averaging`` represents a fraction and + ``ceil(min_samples_leaf_averaging * n_rows)`` is the minimum number of + samples for each leaf node. + min_samples_split_splitting : int or float (default = 2) + The minimum number of trainingsamples required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum + number. + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum + number of samples for each split. + min_samples_split_averaging : int or float (default = 2) + The minimum number of averaging samples (oob_honesty) required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum + number. + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum + number of samples for each split. accuracy_metric : string (default = 'r2') Decides the metric used to evaluate the performance of the model. In the 0.16 release, the default scoring metric was changed diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 77fe304d92..88f559b097 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -71,8 +71,9 @@ class BaseRandomForestModel(Base): def __init__(self, *, split_criterion, n_streams=4, n_estimators=100, max_depth=16, handle=None, max_features='auto', n_bins=128, - bootstrap=True, - verbose=False, min_samples_leaf=1, min_samples_split=2, + bootstrap=True, oob_honesty=False, double_bootstrap=True, + verbose=False, min_samples_leaf_splitting=1, min_samples_leaf_averaging=2, + min_samples_split_splitting=2, min_samples_split_averaging=2, max_samples=1.0, max_leaves=-1, accuracy_metric=None, dtype=None, output_type=None, min_weight_fraction_leaf=None, n_jobs=None, max_leaf_nodes=None, min_impurity_decrease=0.0, @@ -135,8 +136,10 @@ class BaseRandomForestModel(Base): self.split_criterion = \ BaseRandomForestModel.criterion_dict[str(split_criterion)] - self.min_samples_leaf = min_samples_leaf - self.min_samples_split = min_samples_split + self.min_samples_leaf_averaging = min_samples_leaf_averaging + self.min_samples_split_averaging = min_samples_split_averaging + self.min_samples_leaf_splitting = min_samples_leaf_splitting + self.min_samples_split_splitting = min_samples_split_splitting self.min_impurity_decrease = min_impurity_decrease self.max_samples = max_samples self.max_leaves = max_leaves @@ -144,6 +147,8 @@ class BaseRandomForestModel(Base): self.max_depth = max_depth self.max_features = max_features self.bootstrap = bootstrap + self.oob_honesty = oob_honesty + self.double_bootstrap = double_bootstrap self.n_bins = n_bins self.n_cols = None self.dtype = dtype @@ -304,12 +309,18 @@ class BaseRandomForestModel(Base): "to fit the estimator") max_feature_val = self._get_max_feat_val() - if type(self.min_samples_leaf) == float: - self.min_samples_leaf = \ - math.ceil(self.min_samples_leaf * self.n_rows) - if type(self.min_samples_split) == float: - self.min_samples_split = \ - max(2, math.ceil(self.min_samples_split * self.n_rows)) + if type(self.min_samples_leaf_splitting) == float: + self.min_samples_leaf_splitting = \ + math.ceil(self.min_samples_leaf_splitting * self.n_rows) + if type(self.min_samples_split_splitting) == float: + self.min_samples_split_splitting = \ + max(2, math.ceil(self.min_samples_split_splitting * self.n_rows)) + if type(self.min_samples_leaf_averaging) == float: + self.min_samples_leaf_averaging = \ + math.ceil(self.min_samples_leaf_averaging * self.n_rows) + if type(self.min_samples_split_averaging) == float: + self.min_samples_split_averaging = \ + max(2, math.ceil(self.min_samples_split_averaging * self.n_rows)) return X_m, y_m, max_feature_val def _tl_handle_from_bytes(self, treelite_serialized_model): diff --git a/python/cuml/ensemble/randomforest_shared.pxd b/python/cuml/ensemble/randomforest_shared.pxd index bd4e8ca0b0..0e8d80d563 100644 --- a/python/cuml/ensemble/randomforest_shared.pxd +++ b/python/cuml/ensemble/randomforest_shared.pxd @@ -99,8 +99,12 @@ cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": int, int, int, + int, + int, float, bool, + bool, + bool, int, float, uint64_t, diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 2afdfffbc6..4c22cee406 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -172,6 +172,14 @@ class RandomForestClassifier(BaseRandomForestModel, * If ``True``, eachtree in the forest is built on a bootstrapped sample with replacement. * If ``False``, the whole dataset is used to build each tree. + oob_honesty : boolean (default = True) + Control oob_honesty.\n + * If ``True``, eachtree in the forest is built using disjoint sets for splitting and averaging + * If ``False``, the whole dataset is used to build each tree. + double_bootstrap : boolean (default = True) + Control bootstrapping in the averaging set. Only applies if oob_honesty is truer.\n + * If ``True``, each tree uses an averaging set which is sampled with replacement from the samples not used for splitting + * If ``False``, each tree uses an averaging set which is the set difference of all samples presented to the tree and the splitting set max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = 16) @@ -198,19 +206,33 @@ class RandomForestClassifier(BaseRandomForestModel, increasing the number of bins may improve accuracy. n_streams : int (default = 4) Number of parallel streams used for forest building. - min_samples_leaf : int or float (default = 1) - The minimum number of samples (rows) in each leaf node.\n - * If type ``int``, then ``min_samples_leaf`` represents the minimum + min_samples_leaf_splitting : int or float (default = 1) + The minimum number of training samples (rows) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_splitting`` represents the minimum + number. + * If ``float``, then ``min_samples_leaf_splitting`` represents a fraction and + ``ceil(min_samples_leaf_splitting * n_rows)`` is the minimum number of + samples for each leaf node. + min_samples_leaf_averaging : int or float (default = 0) + The minimum number of averaging samples (rows, oob_honesty) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_averaging`` represents the minimum number. - * If ``float``, then ``min_samples_leaf`` represents a fraction and - ``ceil(min_samples_leaf * n_rows)`` is the minimum number of + * If ``float``, then ``min_samples_leaf_averaging`` represents a fraction and + ``ceil(min_samples_leaf_averaging * n_rows)`` is the minimum number of samples for each leaf node. - min_samples_split : int or float (default = 2) - The minimum number of samples required to split an internal node.\n - * If type ``int``, then min_samples_split represents the minimum + min_samples_split_splitting : int or float (default = 2) + The minimum number of trainingsamples required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum number. - * If type ``float``, then ``min_samples_split`` represents a fraction - and ``max(2, ceil(min_samples_split * n_rows))`` is the minimum + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum + number of samples for each split. + min_samples_split_averaging : int or float (default = 0) + The minimum number of averaging samples (oob_honesty) required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum + number. + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum number of samples for each split. min_impurity_decrease : float (default = 0.0) Minimum decrease in impurity required for @@ -257,7 +279,6 @@ class RandomForestClassifier(BaseRandomForestModel, def __init__(self, *, split_criterion=0, handle=None, verbose=False, output_type=None, **kwargs): - self.RF_type = CLASSIFICATION self.num_classes = 2 super().__init__( @@ -467,10 +488,14 @@ class RandomForestClassifier(BaseRandomForestModel, self.max_leaves, max_feature_val, self.n_bins, - self.min_samples_leaf, - self.min_samples_split, + self.min_samples_leaf_splitting, + self.min_samples_leaf_averaging, + self.min_samples_split_splitting, + self.min_samples_split_averaging, self.min_impurity_decrease, self.bootstrap, + self.oob_honesty, + self.double_bootstrap, self.n_estimators, self.max_samples, seed_val, diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index e88c1f7325..5a79b92d9b 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -168,6 +168,14 @@ class RandomForestRegressor(BaseRandomForestModel, * If ``True``, eachtree in the forest is built on a bootstrapped sample with replacement. * If ``False``, the whole dataset is used to build each tree. + oob_honesty : boolean (default = True) + Control oob_honesty.\n + * If ``True``, eachtree in the forest is built using disjoint sets for splitting and averaging + * If ``False``, the whole dataset is used to build each tree. + double_bootstrap : boolean (default = True) + Control bootstrapping in the averaging set. Only applies if oob_honesty is truer.\n + * If ``True``, each tree uses an averaging set which is sampled with replacement from the samples not used for splitting + * If ``False``, each tree uses an averaging set which is the set difference of all samples presented to the tree and the splitting set max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = 16) @@ -194,20 +202,33 @@ class RandomForestRegressor(BaseRandomForestModel, increasing the number of bins may improve accuracy. n_streams : int (default = 4 ) Number of parallel streams used for forest building - min_samples_leaf : int or float (default = 1) - The minimum number of samples (rows) in each leaf node.\n - * If type ``int``, then ``min_samples_leaf`` represents the minimum - number.\n - * If ``float``, then ``min_samples_leaf`` represents a fraction and - ``ceil(min_samples_leaf * n_rows)`` is the minimum number of + min_samples_leaf_splitting : int or float (default = 1) + The minimum number of training samples (rows) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_splitting`` represents the minimum + number. + * If ``float``, then ``min_samples_leaf_splitting`` represents a fraction and + ``ceil(min_samples_leaf_splitting * n_rows)`` is the minimum number of + samples for each leaf node. + min_samples_leaf_averaging : int or float (default = 2) + The minimum number of averaging samples (rows, oob_honesty) in each leaf node.\n + * If type ``int``, then ``min_samples_leaf_averaging`` represents the minimum + number. + * If ``float``, then ``min_samples_leaf_averaging`` represents a fraction and + ``ceil(min_samples_leaf_averaging * n_rows)`` is the minimum number of samples for each leaf node. - min_samples_split : int or float (default = 2) - The minimum number of samples required to split an internal - node.\n - * If type ``int``, then min_samples_split represents the minimum + min_samples_split_splitting : int or float (default = 2) + The minimum number of trainingsamples required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum + number. + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum + number of samples for each split. + min_samples_split_averaging : int or float (default = 2) + The minimum number of averaging samples (oob_honesty) required to split an internal node.\n + * If type ``int``, then min_samples_split_splitting represents the minimum number. - * If type ``float``, then ``min_samples_split`` represents a fraction - and ``max(2, ceil(min_samples_split * n_rows))`` is the minimum + * If type ``float``, then ``min_samples_split_splitting`` represents a fraction + and ``max(2, ceil(min_samples_split_splitting * n_rows))`` is the minimum number of samples for each split. min_impurity_decrease : float (default = 0.0) The minimum decrease in impurity required for node to be split @@ -452,10 +473,14 @@ class RandomForestRegressor(BaseRandomForestModel, self.max_leaves, max_feature_val, self.n_bins, - self.min_samples_leaf, - self.min_samples_split, + self.min_samples_leaf_splitting, + self.min_samples_leaf_averaging, + self.min_samples_split_splitting, + self.min_samples_split_averaging, self.min_impurity_decrease, self.bootstrap, + self.oob_honesty, + self.double_bootstrap, self.n_estimators, self.max_samples, seed_val, diff --git a/python/cuml/tests/test_random_forest.py b/python/cuml/tests/test_random_forest.py index 591dc515ac..0345ebff98 100644 --- a/python/cuml/tests/test_random_forest.py +++ b/python/cuml/tests/test_random_forest.py @@ -295,7 +295,7 @@ def test_rf_classification(small_clf, datatype, max_samples, max_features): max_samples=max_samples, n_bins=16, split_criterion=0, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=123, n_streams=1, n_estimators=40, @@ -356,7 +356,7 @@ def test_rf_classification_unorder( max_samples=max_samples, n_bins=16, split_criterion=0, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=123, n_streams=1, n_estimators=40, @@ -428,7 +428,7 @@ def test_rf_regression( max_samples=max_samples, n_bins=n_bins, split_criterion=2, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=123, n_streams=1, n_estimators=50, @@ -620,7 +620,7 @@ def rf_classification( max_samples=max_samples, n_bins=16, split_criterion=0, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=999, n_estimators=40, handle=handle, @@ -716,7 +716,7 @@ def test_rf_classification_sparse( cuml_model = curfc( n_bins=16, split_criterion=0, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=123, n_streams=1, n_estimators=num_treees, @@ -802,7 +802,7 @@ def test_rf_regression_sparse(special_reg, datatype, fil_sparse_format, algo): cuml_model = curfr( n_bins=16, split_criterion=2, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=123, n_streams=1, n_estimators=num_treees, @@ -1026,7 +1026,7 @@ def test_rf_get_text(n_estimators, detailed_text): max_samples=1.0, n_bins=16, split_criterion=0, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=23707, n_streams=1, n_estimators=n_estimators, @@ -1074,7 +1074,7 @@ def test_rf_get_json(estimator_type, max_depth, n_estimators): max_samples=1.0, n_bins=16, split_criterion=0, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=23707, n_streams=1, n_estimators=n_estimators, @@ -1087,7 +1087,7 @@ def test_rf_get_json(estimator_type, max_depth, n_estimators): max_features=1.0, max_samples=1.0, n_bins=16, - min_samples_leaf=2, + min_samples_leaf_splitting=2, random_state=23707, n_streams=1, n_estimators=n_estimators, diff --git a/quick_build.sh b/quick_build.sh new file mode 100755 index 0000000000..ec5c383383 --- /dev/null +++ b/quick_build.sh @@ -0,0 +1,4 @@ +cd cpp/build +ninja -j 12 +cp libcuml*.so /home/scratch.eschmidt_sw/miniconda3/envs/all_cuda-118_arch-x86_64/lib +cd ../.. \ No newline at end of file