diff --git a/.idea/deployment.xml b/.idea/deployment.xml
deleted file mode 100644
index 9aeac29..0000000
--- a/.idea/deployment.xml
+++ /dev/null
@@ -1,21 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index 4f208a9..09d1f7e 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -1,18 +1,24 @@
+
+
+
-
-
-
-
-
+
diff --git a/.idea/markdown-navigator.xml b/.idea/markdown-navigator.xml
deleted file mode 100644
index 20f3a39..0000000
--- a/.idea/markdown-navigator.xml
+++ /dev/null
@@ -1,72 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/markdown-navigator/profiles_settings.xml b/.idea/markdown-navigator/profiles_settings.xml
deleted file mode 100644
index 57927c5..0000000
--- a/.idea/markdown-navigator/profiles_settings.xml
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
\ No newline at end of file
diff --git a/bartpy/__init__.py b/bartpy/__init__.py
index e69de29..90c624a 100644
--- a/bartpy/__init__.py
+++ b/bartpy/__init__.py
@@ -0,0 +1 @@
+from .sklearnmodel import BART, ShrunkBART, ShrunkBARTCV
\ No newline at end of file
diff --git a/bartpy/model.py b/bartpy/model.py
index b1f1742..7b57e88 100644
--- a/bartpy/model.py
+++ b/bartpy/model.py
@@ -17,12 +17,13 @@ class Model:
def __init__(self,
data: Optional[Data],
sigma: Sigma,
- trees: Optional[List[Tree]]=None,
- n_trees: int=50,
- alpha: float=0.95,
- beta: float=2.,
- k: int=2.,
- initializer: Initializer=SklearnTreeInitializer()):
+ trees: Optional[List[Tree]] = None,
+ n_trees: int = 50,
+ alpha: float = 0.95,
+ beta: float = 2.,
+ k: int = 2.,
+ initializer: Initializer = SklearnTreeInitializer(),
+ classification: bool = False):
self.data = deepcopy(data)
self.alpha = float(alpha)
@@ -31,6 +32,7 @@ def __init__(self,
self._sigma = sigma
self._prediction = None
self._initializer = initializer
+ self.classification = classification
if trees is None:
self.n_trees = n_trees
@@ -53,7 +55,7 @@ def residuals(self) -> np.ndarray:
def unnormalized_residuals(self) -> np.ndarray:
return self.data.y.unnormalized_y - self.data.y.unnormalize_y(self.predict())
- def predict(self, X: np.ndarray=None) -> np.ndarray:
+ def predict(self, X: np.ndarray = None) -> np.ndarray:
if X is not None:
return self._out_of_sample_predict(X)
return np.sum([tree.predict() for tree in self.trees], axis=0)
@@ -77,8 +79,20 @@ def refreshed_trees(self) -> Generator[Tree, None, None]:
yield tree
self._prediction += tree.predict()
+ def update_z_values(self, y):
+ if not self.classification:
+ return
+ z = np.random.normal(loc=self.predict(self.data.X.values))
+ one_label = np.maximum(z[y == 1], 0)
+ zero_label = np.minimum(z[y == 0], 0)
+ z[y == 1] = one_label
+ z[y == 0] = zero_label
+ self.data.update_y(z)
+
@property
def sigma_m(self) -> float:
+ if self.classification:
+ return 3 / (self.k * np.power(self.n_trees, 0.5))
return 0.5 / (self.k * np.power(self.n_trees, 0.5))
@property
diff --git a/bartpy/node.py b/bartpy/node.py
index 80bd841..253a70a 100644
--- a/bartpy/node.py
+++ b/bartpy/node.py
@@ -11,7 +11,8 @@ class TreeNode(object):
- Data relevant for the node
- Links to children nodes
"""
- def __init__(self, split: Split, depth: int, left_child: 'TreeNode'=None, right_child: 'TreeNode'=None):
+
+ def __init__(self, split: Split, depth: int, left_child: 'TreeNode' = None, right_child: 'TreeNode' = None):
self.depth = depth
self._split = split
self._left_child = left_child
@@ -55,6 +56,9 @@ def __init__(self, split: Split, depth=0, value=0.0):
def set_value(self, value: float) -> None:
self._value = value
+ def set_mean_response(self, value: float) -> None:
+ self._mean_response = value
+
@property
def current_value(self):
return self._value
@@ -65,6 +69,14 @@ def predict(self) -> float:
def is_splittable(self) -> bool:
return self.data.X.is_at_least_one_splittable_variable()
+ @property
+ def n_obs(self):
+ return self.data.X.n_obsv
+
+ @property
+ def mean_response(self):
+ return self._mean_response
+
class DecisionNode(TreeNode):
"""
@@ -81,6 +93,23 @@ def is_prunable(self) -> bool:
def most_recent_split_condition(self) -> SplitCondition:
return self.left_child.split.most_recent_split_condition()
+ @property
+ def n_obs(self):
+ n_l = self.left_child.n_obs
+ n_r = self.right_child.n_obs
+ return n_l + n_r
+
+
+ @property
+ def current_value(self):
+ n_l = self.left_child.n_obs
+ n_r = self.right_child.n_obs
+ l_val = self.left_child.current_value if type(self.left_child) == DecisionNode else self.left_child.mean_response
+ r_val = self.right_child.current_value if type(self.right_child) == DecisionNode else self.right_child.mean_response
+ l_sum = l_val * n_l
+ r_sum = r_val * n_r
+ return (l_sum + r_sum) / self.n_obs
+
def split_node(node: LeafNode, split_conditions: Tuple[SplitCondition, SplitCondition]) -> DecisionNode:
"""
diff --git a/bartpy/samplers/leafnode.py b/bartpy/samplers/leafnode.py
index d70b702..c853453 100644
--- a/bartpy/samplers/leafnode.py
+++ b/bartpy/samplers/leafnode.py
@@ -28,10 +28,11 @@ def sample(self, model: Model, node: LeafNode) -> float:
n = node.data.X.n_obsv
likihood_var = (model.sigma.current_value() ** 2) / n
likihood_mean = node.data.y.summed_y() / n
+ node.set_mean_response(likihood_mean)
posterior_variance = 1. / (1. / prior_var + 1. / likihood_var)
posterior_mean = likihood_mean * (prior_var / (likihood_var + prior_var))
- return posterior_mean + (self._scalar_sampler.sample() * np.power(posterior_variance / model.n_trees, 0.5))
-
+ val = posterior_mean + (self._scalar_sampler.sample() * np.power(posterior_variance / model.n_trees, 0.5))
+ return val
# class VectorizedLeafNodeSampler(Sampler):
@@ -45,5 +46,3 @@ def sample(self, model: Model, node: LeafNode) -> float:
# prior_var = model.sigma_m ** 2
# n_s = []
# sum_s = []
-
-
diff --git a/bartpy/samplers/modelsampler.py b/bartpy/samplers/modelsampler.py
index 4d8b997..37d09a3 100644
--- a/bartpy/samplers/modelsampler.py
+++ b/bartpy/samplers/modelsampler.py
@@ -1,3 +1,4 @@
+import copy
from collections import defaultdict
from typing import List, Mapping, Union, Any, Type
@@ -16,9 +17,11 @@ class ModelSampler(Sampler):
def __init__(self,
schedule: SampleSchedule,
- trace_logger_class: Type[TraceLogger]=TraceLogger):
+ trace_logger_class: Type[TraceLogger]=TraceLogger,
+ n_rules: int=None):
self.schedule = schedule
self.trace_logger_class = trace_logger_class
+ self.n_rules = n_rules
def step(self, model: Model, trace_logger: TraceLogger):
step_result = defaultdict(list)
@@ -35,21 +38,27 @@ def samples(self, model: Model,
thin: float=0.1,
store_in_sample_predictions: bool=True,
store_acceptance: bool=True) -> Chain:
- print("Starting burn")
+ # print("Starting burn")
trace_logger = self.trace_logger_class()
+ y = copy.deepcopy(model.data.y.unnormalized_y)
- for _ in tqdm(range(n_burn)):
+ for _ in range(n_burn):
+ model.update_z_values(y)
self.step(model, trace_logger)
+
trace = []
model_trace = []
acceptance_trace = []
- print("Starting sampling")
+ # print("Starting sampling")
thin_inverse = 1. / thin
- for ss in tqdm(range(n_samples)):
+ for ss in range(n_samples):
+ model.update_z_values(y)
step_trace_dict = self.step(model, trace_logger)
+ # print(step_trace_dict)
+
if ss % thin_inverse == 0:
if store_in_sample_predictions:
in_sample_log = trace_logger["In Sample Prediction"](model.predict())
diff --git a/bartpy/sigma.py b/bartpy/sigma.py
index b289527..4a254df 100644
--- a/bartpy/sigma.py
+++ b/bartpy/sigma.py
@@ -1,5 +1,4 @@
-
class Sigma:
"""
A representation of the sigma term in the model.
@@ -18,16 +17,19 @@ class Sigma:
"""
- def __init__(self, alpha: float, beta: float, scaling_factor: float):
+ def __init__(self, alpha: float, beta: float, scaling_factor: float, classification :bool = False):
self.alpha = alpha
self.beta = beta
self._current_value = 1.0
self.scaling_factor = scaling_factor
+ self._classification = classification
def set_value(self, value: float) -> None:
self._current_value = value
def current_value(self) -> float:
+ if self._classification:
+ return 1
return self._current_value
def current_unnormalized_value(self) -> float:
diff --git a/bartpy/sklearnmodel.py b/bartpy/sklearnmodel.py
index 7e92b1a..3876401 100644
--- a/bartpy/sklearnmodel.py
+++ b/bartpy/sklearnmodel.py
@@ -1,15 +1,22 @@
+import copy
from copy import deepcopy
from typing import List, Callable, Mapping, Union, Optional
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
+import scipy.stats
from joblib import Parallel, delayed
from sklearn.base import RegressorMixin, BaseEstimator
+from sklearn.model_selection import cross_val_score
+from sklearn.tree import DecisionTreeClassifier
+from sklearn import datasets, model_selection
from bartpy.data import Data
from bartpy.initializers.initializer import Initializer
from bartpy.initializers.sklearntreeinitializer import SklearnTreeInitializer
from bartpy.model import Model
+from bartpy.node import LeafNode, DecisionNode
from bartpy.samplers.leafnode import LeafNodeSampler
from bartpy.samplers.modelsampler import ModelSampler, Chain
from bartpy.samplers.schedule import SampleSchedule
@@ -24,7 +31,17 @@ def run_chain(model: 'SklearnModel', X: np.ndarray, y: np.ndarray):
Run a single chain for a model
Primarily used as a building block for constructing a parallel run of multiple chains
"""
+
+ # TODO: support for classification F^{-1} (y) ~ N(G(x), 1)
+
+ # if model.classification:
+ # z = np.random.normal(loc=model.predict(X))
+ # z[y == 1] = np.maximum(z[y == 1], 0)
+ # z[y == 0] = np.minimum(z[y == 0], 0)
+ # y = z
+
model.model = model._construct_model(X, y)
+
return model.sampler.samples(model.model,
model.n_samples,
model.n_burn,
@@ -37,6 +54,92 @@ def delayed_run_chain():
return run_chain
+def get_nodes(root_node):
+ decision_nodes = []
+ leaf_nodes = []
+
+ def _add_nodes(n):
+ if type(n) == LeafNode:
+ leaf_nodes.append(n)
+ elif type(n) == DecisionNode:
+ decision_nodes.append(n)
+ if n.left_child:
+ _add_nodes(n.left_child)
+ if n.right_child:
+ _add_nodes(n.right_child)
+
+ _add_nodes(root_node)
+ return decision_nodes, leaf_nodes
+
+
+def get_root_node(tree):
+ for n in tree.decision_nodes:
+ if n.depth == 0:
+ return n
+
+
+def shrink_tree(tree, reg_param):
+ root = get_root_node(tree)
+ tree_d_node = shrink_node(root, reg_param)
+ d, l = get_nodes(tree_d_node)
+ tree._nodes = d + l
+ return tree
+
+
+def expand_node(node):
+ mask_int = 1 - node.data.mask.astype(int)
+ # y = node.data.y.values
+ val = np.sum(node.data.y.values
+ * mask_int) / np.sum(mask_int)
+ node.set_value(val)
+ return node
+
+
+def expand_tree(tree):
+ decision, leaves = get_nodes(get_root_node(tree))
+ leaves_new = [expand_node(n) for n in leaves]
+ tree._nodes = decision + leaves_new
+ return tree
+
+
+def shrink_node(node, reg_param, parent_val, parent_num, cum_sum, scheme, constant):
+ """Shrink the tree
+ """
+
+ # node.set_value(node.mean_response)
+
+ left = node.left_child
+ right = node.right_child
+ is_leaf = type(node) == LeafNode
+ # if self.prediction_task == 'regression':
+ val = node.current_value
+ is_root = parent_val is None and parent_num is None
+ n_samples = node.n_obs if (scheme != "leaf_based" or is_root) else parent_num
+
+ if is_root:
+ val_new = val
+
+ else:
+ reg_term = reg_param if scheme == "constant" else reg_param / parent_num
+
+ val_new = (val - parent_val) / (1 + reg_term)
+
+ cum_sum += val_new
+
+ if is_leaf:
+ if scheme == "leaf_based":
+ v = constant + (val - constant) / (1 + reg_param / node.n_obs)
+ node.set_value(v)
+ else:
+ node.set_value(cum_sum)
+
+ else:
+ shrink_node(left, reg_param, val, parent_num=n_samples, cum_sum=cum_sum, scheme=scheme, constant=constant)
+ shrink_node(right, reg_param, val, parent_num=n_samples, cum_sum=cum_sum, scheme=scheme, constant=constant)
+
+ return node
+
+
class SklearnModel(BaseEstimator, RegressorMixin):
"""
The main access point to building BART models in BartPy
@@ -94,11 +197,12 @@ def __init__(self,
thin: float = 0.1,
alpha: float = 0.95,
beta: float = 2.,
- store_in_sample_predictions: bool=False,
- store_acceptance_trace: bool=False,
- tree_sampler: TreeMutationSampler=get_tree_sampler(0.5, 0.5),
- initializer: Optional[Initializer]=None,
- n_jobs=-1):
+ store_in_sample_predictions: bool = False,
+ store_acceptance_trace: bool = False,
+ tree_sampler: TreeMutationSampler = get_tree_sampler(0.5, 0.5),
+ initializer: Optional[Initializer] = None,
+ n_jobs=-1,
+ classification: bool = False):
self.n_trees = n_trees
self.n_chains = n_chains
self.sigma_a = sigma_a
@@ -118,6 +222,7 @@ def __init__(self,
self.initializer = initializer
self.schedule = SampleSchedule(self.tree_sampler, LeafNodeSampler(), SigmaSampler())
self.sampler = ModelSampler(self.schedule)
+ self.classification = classification
self.sigma, self.data, self.model, self._prediction_samples, self._model_samples, self.extract = [None] * 6
@@ -140,10 +245,18 @@ def fit(self, X: Union[np.ndarray, pd.DataFrame], y: np.ndarray) -> 'SklearnMode
self.model = self._construct_model(X, y)
self.extract = Parallel(n_jobs=self.n_jobs)(self.f_delayed_chains(X, y))
self.combined_chains = self._combine_chains(self.extract)
- self._model_samples, self._prediction_samples = self.combined_chains["model"], self.combined_chains["in_sample_predictions"]
+ self._model_samples, self._prediction_samples = self.combined_chains["model"], self.combined_chains[
+ "in_sample_predictions"]
self._acceptance_trace = self.combined_chains["acceptance"]
+ self.fitted_ = True
return self
+ @property
+ def fitted(self):
+ if hasattr(self, "fitted_"):
+ return self.fitted_
+ return False
+
@staticmethod
def _combine_chains(extract: List[Chain]) -> Chain:
keys = list(extract[0].keys())
@@ -164,13 +277,14 @@ def _construct_model(self, X: np.ndarray, y: np.ndarray) -> Model:
if len(X) == 0 or X.shape[1] == 0:
raise ValueError("Empty covariate matrix passed")
self.data = self._convert_covariates_to_data(X, y)
- self.sigma = Sigma(self.sigma_a, self.sigma_b, self.data.y.normalizing_scale)
+ self.sigma = Sigma(self.sigma_a, self.sigma_b, self.data.y.normalizing_scale, self.classification)
self.model = Model(self.data,
self.sigma,
n_trees=self.n_trees,
alpha=self.alpha,
beta=self.beta,
- initializer=self.initializer)
+ initializer=self.initializer,
+ classification=self.classification)
return self.model
def f_delayed_chains(self, X: np.ndarray, y: np.ndarray):
@@ -205,7 +319,7 @@ def f_chains(self) -> List[Callable[[], Chain]]:
"""
return [delayed_run_chain() for _ in range(self.n_chains)]
- def predict(self, X: np.ndarray=None) -> np.ndarray:
+ def predict(self, X: np.ndarray = None) -> np.ndarray:
"""
Predict the target corresponding to the provided covariate matrix
If X is None, will predict based on training covariates
@@ -228,7 +342,14 @@ def predict(self, X: np.ndarray=None) -> np.ndarray:
raise ValueError(
"In sample predictions only possible if model.store_in_sample_predictions is `True`. Either set the parameter to True or pass a non-None X parameter")
else:
- return self._out_of_sample_predict(X)
+ predictions = self._out_of_sample_predict(X)
+ if self.classification:
+ return np.round(predictions, 0)
+ return predictions
+
+ def predict_proba(self, X: np.ndarray = None) -> np.ndarray:
+ return self._out_of_sample_predict(X)
+
def residuals(self, X=None, y=None) -> np.ndarray:
"""
@@ -288,7 +409,12 @@ def rmse(self, X, y) -> float:
return np.sqrt(np.sum(self.l2_error(X, y)))
def _out_of_sample_predict(self, X):
- return self.data.y.unnormalize_y(np.mean([x.predict(X) for x in self._model_samples], axis=0))
+ samples = self._model_samples
+ predictions_transformed = [x.predict(X) for x in samples]
+ predictions = self.data.y.unnormalize_y(np.mean(predictions_transformed, axis=0))
+ if self.classification:
+ predictions = scipy.stats.norm.cdf(predictions)
+ return predictions
def fit_predict(self, X, y):
self.fit(X, y)
@@ -364,3 +490,209 @@ def from_extract(self, extract: List[Chain], X: np.ndarray, y: np.ndarray) -> 'S
self._acceptance_trace = combined_chain["acceptance"]
new_model.data = self._convert_covariates_to_data(X, y)
return new_model
+
+
+class BART(SklearnModel):
+
+ @staticmethod
+ def _get_n_nodes(trees):
+ nodes = 0
+ for tree in trees:
+ nodes += len(tree.decision_nodes)
+ return nodes
+
+ @property
+ def sample_complexity(self):
+ # samples = self._model_samples
+ # trees = [s.trees for s in samples]
+ complexities = [self._get_n_nodes(t) for t in self.trees]
+ return np.sum(complexities)
+
+ @staticmethod
+ def sub_forest(trees, n_nodes):
+ nodes = 0
+ for i, tree in enumerate(trees):
+ nodes += len(tree.decision_nodes)
+ if nodes >= n_nodes:
+ return trees[0:i + 1]
+
+ @property
+ def trees(self):
+ trs = [s.trees for s in self._model_samples]
+ return trs
+
+ def update_complexity(self, i):
+ samples_complexity = [self._get_n_nodes(t) for t in self.trees]
+
+ # complexity_sum = 0
+ arg_sort_complexity = np.argsort(samples_complexity)
+ self._model_samples = self._model_samples[arg_sort_complexity[:i + 1]]
+
+ return self
+
+
+class ImputedBART(BaseEstimator):
+ def __init__(self, estimator_):
+ # super(ShrunkBARTRegressor, self).__init__()
+ self.estimator_ = estimator_
+
+ def predict(self, *args, **kwargs):
+ return self.estimator_.predict(*args, **kwargs)
+
+ def predict_proba(self, *args, **kwargs):
+ if hasattr(self.estimator_, 'predict_proba'):
+ return self.estimator_.predict_proba(*args, **kwargs)
+ else:
+ return NotImplemented
+
+ def score(self, *args, **kwargs):
+ if hasattr(self.estimator_, 'score'):
+ return self.estimator_.score(*args, **kwargs)
+ else:
+ return NotImplemented
+
+
+class ShrunkBART(ImputedBART):
+
+ def __init__(self, estimator_, reg_param, scheme):
+ super(ShrunkBART, self).__init__(estimator_)
+ self.reg_param = reg_param
+ self.scheme = scheme
+
+ def shrink_tree(self, tree):
+ root = get_root_node(tree)
+ tree_d_node = shrink_node(root, self.reg_param, parent_val=None, parent_num=None, cum_sum=0, scheme=self.scheme,
+ constant=np.mean(self.estimator_.data.y.values))
+ d, l = get_nodes(tree_d_node)
+ tree._nodes = d + l
+ return tree
+
+ def fit(self, *args, **kwargs):
+ if not self.estimator_.fitted:
+ self.estimator_.fit(*args, **kwargs)
+ samples = []
+ for s in self.estimator_.model_samples:
+ for i, tree in enumerate(s._trees):
+ s_tree = self.shrink_tree(expand_tree(copy.deepcopy(tree)))
+ s._trees[i] = s_tree
+ samples.append(s)
+ self.estimator_._model_samples = samples
+ self.fitted_ = True
+
+
+class ExpandedBART(ImputedBART):
+
+ def fit(self, *args, **kwargs):
+ if not self.estimator_.fitted:
+ self.estimator_.fit(*args, **kwargs)
+ samples = []
+ for s in self.estimator_.model_samples:
+ for i, tree in enumerate(s._trees):
+ s_tree = expand_tree(copy.deepcopy(tree))
+ s._trees[i] = s_tree
+ samples.append(s)
+ self.estimator_._model_samples = samples
+ self.fitted_ = True
+
+
+# class ExpandedBARTRegressor(ImputedBARTRegressor):
+#
+# def fit(self, *args, **kwargs):
+# if not self.estimator_.fitted:
+# self.estimator_.fit(*args, **kwargs)
+# samples = []
+# for s in self.estimator_.model_samples:
+# for i, tree in enumerate(s._trees):
+# s_tree = expend_tree(copy.deepcopy(tree), args[1])
+# s._trees[i] = s_tree
+# samples.append(s)
+# self.estimator_._model_samples = samples
+# self.fitted_ = True
+
+
+class ShrunkBARTCV(ShrunkBART):
+ def __init__(self, estimator_: BaseEstimator, scheme: str,
+ reg_param_list: List[float] = [0.1, 1, 10, 50, 100, 500],
+ cv: int = 3, scoring=None):
+ super(ShrunkBARTCV, self).__init__(estimator_, None, scheme)
+ self.reg_param_list = np.array(reg_param_list)
+ self.cv = cv
+ self.scoring = scoring
+
+ def fit(self, X, y, *args, **kwargs):
+ self.scores_ = []
+ for reg_param in self.reg_param_list:
+ est = ShrunkBART(deepcopy(self.estimator_), reg_param, self.scheme)
+ cv_scores = cross_val_score(est, X, y, cv=self.cv, scoring=self.scoring)
+ self.scores_.append(np.mean(cv_scores))
+ self.reg_param = self.reg_param_list[np.argmax(self.scores_)]
+ super().fit(X=X, y=y)
+
+
+def main():
+ # iris = datasets.load_iris()
+ # idx = np.logical_or(iris.target == 0, iris.target == 1)
+ # X, y = iris.data[idx, ...], iris.target[idx]
+ X, y = datasets.load_diabetes(return_X_y=True)
+
+ X_train, X_test, y_train, y_test = model_selection.train_test_split(
+ X, y, test_size=0.3, random_state=1)
+ bart = BART(classification=False)
+ bart.fit(X_train, y_train)
+ preds_org = bart.predict(X_test)
+ mse = np.linalg.norm(preds_org - y_test)
+ print(mse)
+ # tree = DecisionTreeClassifier()
+ # tree.fit(X, y)
+ # preds_tree = tree.predict_proba(X)
+ # bart_s = ShrunkBARTCV(copy.deepcopy(bart), scheme="node_based")
+ # bart_s.fit(X, y)
+ #
+ # # bart_s_c = ShrunkBART(copy.deepcopy(bart), reg_param=2, scheme="constant")
+ # # bart_s_c.fit(X, y)
+ # #
+ # # bart_s_l = ShrunkBART(copy.deepcopy(bart), reg_param=2, scheme="leaf_based")
+ # # bart_s_l.fit(X, y)
+ # # bart_s_cv = ShrunkBARTRegressorCV(estimator_=copy.deepcopy(bart))
+ # # bart_s_cv.fit(X, y)
+ # e_bart = ExpandedBART(estimator_=copy.deepcopy(bart))
+ # e_bart.fit(X, y)
+ #
+ # preds = bart_s.predict(X)
+ #
+ # # preds_c = bart_s_c.predict(X)
+ # # preds_l = bart_s_l.predict(X)
+ # # # preds_cv = bart_s_cv.predict(X)
+ # preds_bart_e = e_bart.predict(X)
+ # fig, ax = plt.subplots(1)
+ #
+ # ax.scatter(np.arange(len(y)), preds_org, c="orange", label="bart")
+ # ax.scatter(np.arange(len(y)), preds, c="purple", alpha=0.3, label="shrunk node")
+ # # ax.scatter(np.arange(len(y)), preds_c, c="blue", alpha=0.3, label="shrunk constant")
+ # # ax.scatter(np.arange(len(y)), preds_l, c="red", alpha=0.3, label="shrunk leaf")
+ # ax.scatter(np.arange(len(y)), preds_bart_e, c="green", alpha=0.3, label="average")
+ # # preds_all = [preds_org, preds_c, preds_l, preds_bart_e, preds]
+ # # shift = 0.5
+ # # rng = (np.min([np.min(p) for p in preds_all]) - shift, np.max([np.max(p) for p in preds_all]) + shift)
+ # # n_bins = 200
+ # # alpha = 0.8
+ # # ax.hist(preds_org, color="orange", alpha=alpha, label="bart", bins=n_bins, range=rng)
+ # # ax.hist(preds, color="purple", alpha=alpha, label="shrunk node", bins=n_bins, range=rng)
+ # # ax.hist(preds_c, color="blue", alpha=alpha, label="shrunk constant", bins=n_bins, range=rng)
+ # # ax.hist(preds_l, color="red", alpha=alpha, label="shrunk leaf", bins=n_bins, range=rng)
+ # # ax.hist(preds_bart_e, color="green", alpha=alpha, label="average", bins=n_bins, range=rng)
+ # #
+ # # ax.set_xlabel("Predicted Value")
+ # # ax.set_ylabel("Count")
+ #
+ # plt.title(np.mean(y))
+ #
+ # plt.legend(loc="upper left")
+ # plt.savefig("bart_shrink.png")
+ # # plt.show()
+ # #
+ # # plt.close()
+ #
+
+if __name__ == '__main__':
+ main()