Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions .idea/deployment.xml

This file was deleted.

24 changes: 15 additions & 9 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

72 changes: 0 additions & 72 deletions .idea/markdown-navigator.xml

This file was deleted.

3 changes: 0 additions & 3 deletions .idea/markdown-navigator/profiles_settings.xml

This file was deleted.

1 change: 1 addition & 0 deletions bartpy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sklearnmodel import BART, ShrunkBART, ShrunkBARTCV
28 changes: 21 additions & 7 deletions bartpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ class Model:
def __init__(self,
data: Optional[Data],
sigma: Sigma,
trees: Optional[List[Tree]]=None,
n_trees: int=50,
alpha: float=0.95,
beta: float=2.,
k: int=2.,
initializer: Initializer=SklearnTreeInitializer()):
trees: Optional[List[Tree]] = None,
n_trees: int = 50,
alpha: float = 0.95,
beta: float = 2.,
k: int = 2.,
initializer: Initializer = SklearnTreeInitializer(),
classification: bool = False):

self.data = deepcopy(data)
self.alpha = float(alpha)
Expand All @@ -31,6 +32,7 @@ def __init__(self,
self._sigma = sigma
self._prediction = None
self._initializer = initializer
self.classification = classification

if trees is None:
self.n_trees = n_trees
Expand All @@ -53,7 +55,7 @@ def residuals(self) -> np.ndarray:
def unnormalized_residuals(self) -> np.ndarray:
return self.data.y.unnormalized_y - self.data.y.unnormalize_y(self.predict())

def predict(self, X: np.ndarray=None) -> np.ndarray:
def predict(self, X: np.ndarray = None) -> np.ndarray:
if X is not None:
return self._out_of_sample_predict(X)
return np.sum([tree.predict() for tree in self.trees], axis=0)
Expand All @@ -77,8 +79,20 @@ def refreshed_trees(self) -> Generator[Tree, None, None]:
yield tree
self._prediction += tree.predict()

def update_z_values(self, y):
if not self.classification:
return
z = np.random.normal(loc=self.predict(self.data.X.values))
one_label = np.maximum(z[y == 1], 0)
zero_label = np.minimum(z[y == 0], 0)
z[y == 1] = one_label
z[y == 0] = zero_label
self.data.update_y(z)

@property
def sigma_m(self) -> float:
if self.classification:
return 3 / (self.k * np.power(self.n_trees, 0.5))
return 0.5 / (self.k * np.power(self.n_trees, 0.5))

@property
Expand Down
31 changes: 30 additions & 1 deletion bartpy/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class TreeNode(object):
- Data relevant for the node
- Links to children nodes
"""
def __init__(self, split: Split, depth: int, left_child: 'TreeNode'=None, right_child: 'TreeNode'=None):

def __init__(self, split: Split, depth: int, left_child: 'TreeNode' = None, right_child: 'TreeNode' = None):
self.depth = depth
self._split = split
self._left_child = left_child
Expand Down Expand Up @@ -55,6 +56,9 @@ def __init__(self, split: Split, depth=0, value=0.0):
def set_value(self, value: float) -> None:
self._value = value

def set_mean_response(self, value: float) -> None:
self._mean_response = value

@property
def current_value(self):
return self._value
Expand All @@ -65,6 +69,14 @@ def predict(self) -> float:
def is_splittable(self) -> bool:
return self.data.X.is_at_least_one_splittable_variable()

@property
def n_obs(self):
return self.data.X.n_obsv

@property
def mean_response(self):
return self._mean_response


class DecisionNode(TreeNode):
"""
Expand All @@ -81,6 +93,23 @@ def is_prunable(self) -> bool:
def most_recent_split_condition(self) -> SplitCondition:
return self.left_child.split.most_recent_split_condition()

@property
def n_obs(self):
n_l = self.left_child.n_obs
n_r = self.right_child.n_obs
return n_l + n_r


@property
def current_value(self):
n_l = self.left_child.n_obs
n_r = self.right_child.n_obs
l_val = self.left_child.current_value if type(self.left_child) == DecisionNode else self.left_child.mean_response
r_val = self.right_child.current_value if type(self.right_child) == DecisionNode else self.right_child.mean_response
l_sum = l_val * n_l
r_sum = r_val * n_r
return (l_sum + r_sum) / self.n_obs


def split_node(node: LeafNode, split_conditions: Tuple[SplitCondition, SplitCondition]) -> DecisionNode:
"""
Expand Down
7 changes: 3 additions & 4 deletions bartpy/samplers/leafnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def sample(self, model: Model, node: LeafNode) -> float:
n = node.data.X.n_obsv
likihood_var = (model.sigma.current_value() ** 2) / n
likihood_mean = node.data.y.summed_y() / n
node.set_mean_response(likihood_mean)
posterior_variance = 1. / (1. / prior_var + 1. / likihood_var)
posterior_mean = likihood_mean * (prior_var / (likihood_var + prior_var))
return posterior_mean + (self._scalar_sampler.sample() * np.power(posterior_variance / model.n_trees, 0.5))

val = posterior_mean + (self._scalar_sampler.sample() * np.power(posterior_variance / model.n_trees, 0.5))
return val

# class VectorizedLeafNodeSampler(Sampler):

Expand All @@ -45,5 +46,3 @@ def sample(self, model: Model, node: LeafNode) -> float:
# prior_var = model.sigma_m ** 2
# n_s = []
# sum_s = []


19 changes: 14 additions & 5 deletions bartpy/samplers/modelsampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from collections import defaultdict
from typing import List, Mapping, Union, Any, Type

Expand All @@ -16,9 +17,11 @@ class ModelSampler(Sampler):

def __init__(self,
schedule: SampleSchedule,
trace_logger_class: Type[TraceLogger]=TraceLogger):
trace_logger_class: Type[TraceLogger]=TraceLogger,
n_rules: int=None):
self.schedule = schedule
self.trace_logger_class = trace_logger_class
self.n_rules = n_rules

def step(self, model: Model, trace_logger: TraceLogger):
step_result = defaultdict(list)
Expand All @@ -35,21 +38,27 @@ def samples(self, model: Model,
thin: float=0.1,
store_in_sample_predictions: bool=True,
store_acceptance: bool=True) -> Chain:
print("Starting burn")
# print("Starting burn")

trace_logger = self.trace_logger_class()
y = copy.deepcopy(model.data.y.unnormalized_y)

for _ in tqdm(range(n_burn)):
for _ in range(n_burn):
model.update_z_values(y)
self.step(model, trace_logger)

trace = []
model_trace = []
acceptance_trace = []
print("Starting sampling")
# print("Starting sampling")

thin_inverse = 1. / thin

for ss in tqdm(range(n_samples)):
for ss in range(n_samples):
model.update_z_values(y)
step_trace_dict = self.step(model, trace_logger)
# print(step_trace_dict)

if ss % thin_inverse == 0:
if store_in_sample_predictions:
in_sample_log = trace_logger["In Sample Prediction"](model.predict())
Expand Down
6 changes: 4 additions & 2 deletions bartpy/sigma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


class Sigma:
"""
A representation of the sigma term in the model.
Expand All @@ -18,16 +17,19 @@ class Sigma:

"""

def __init__(self, alpha: float, beta: float, scaling_factor: float):
def __init__(self, alpha: float, beta: float, scaling_factor: float, classification :bool = False):
self.alpha = alpha
self.beta = beta
self._current_value = 1.0
self.scaling_factor = scaling_factor
self._classification = classification

def set_value(self, value: float) -> None:
self._current_value = value

def current_value(self) -> float:
if self._classification:
return 1
return self._current_value

def current_unnormalized_value(self) -> float:
Expand Down
Loading