Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bart: Refactor splitting variables and predictions #4310

Merged
merged 13 commits into from
Jan 19, 2021
Merged
107 changes: 75 additions & 32 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@

import numpy as np

from pandas import DataFrame, Series

from .distribution import NoDistribution
from .tree import LeafNode, SplitNode, Tree

__all__ = ["BART"]


class BaseBART(NoDistribution):
def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
self.X = X
self.Y = Y
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):

self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)

super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)

if self.X.ndim != 2:
Expand All @@ -48,12 +51,24 @@ def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):

self.num_observations = X.shape[0]
self.num_variates = X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.ssv = sample_splitting_variable(split_prior, self.num_variates)
self.m = m
self.alpha = alpha
self.trees = self.init_list_of_trees()
self.all_trees = []
self.mean = fast_mean()
self.prior_prob_leaf_node = compute_prior_probability(alpha)

def preprocess_XY(self, X, Y):
if isinstance(Y, (Series, DataFrame)):
Y = Y.values
if isinstance(X, (Series, DataFrame)):
X = X.values
missing_data = np.any(np.isnan(X))
X = np.random.normal(X, np.std(X, 0) / 100)
return X, Y, missing_data

def init_list_of_trees(self):
initial_value_leaf_nodes = self.Y.mean() / self.m
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
Expand All @@ -79,39 +94,26 @@ def __iter__(self):
def __repr_latex(self):
raise NotImplementedError

def get_available_predictors(self, idx_data_points_split_node):
possible_splitting_variables = []
for j in range(self.num_variates):
x_j = self.X[idx_data_points_split_node, j]
x_j = x_j[~np.isnan(x_j)]
for i in range(1, len(x_j)):
if x_j[i - 1] != x_j[i]:
possible_splitting_variables.append(j)
break
return possible_splitting_variables

def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
x_j = self.X[idx_data_points_split_node, idx_split_variable]
x_j = x_j[~np.isnan(x_j)]
values, indices = np.unique(x_j, return_index=True)
# The last value is not consider since if we choose it as the value of
# the splitting rule assignment, it would leave the right subtree empty.
return values[:-1], indices[:-1]
if self.missing_data:
x_j = x_j[~np.isnan(x_j)]
values = np.unique(x_j)
# The last value is never available as it would leave the right subtree empty.
return values[:-1]

def grow_tree(self, tree, index_leaf_node):
# This can be unsuccessful when there are not available predictors
current_node = tree.get_node(index_leaf_node)

available_predictors = self.get_available_predictors(current_node.idx_data_points)

if not available_predictors:
return False, None

index_selected_predictor = discrete_uniform_sampler(len(available_predictors))
selected_predictor = available_predictors[index_selected_predictor]
available_splitting_rules, _ = self.get_available_splitting_rules(
index_selected_predictor = self.ssv.rvs()
selected_predictor = self.available_predictors[index_selected_predictor]
available_splitting_rules = self.get_available_splitting_rules(
current_node.idx_data_points, selected_predictor
)
# This can be unsuccessful when there are not available splitting rules
if available_splitting_rules.size == 0:
return False, None

index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
new_split_node = SplitNode(
Expand Down Expand Up @@ -167,6 +169,18 @@ def draw_leaf_value(self, idx_data_points):
draw = self.mean(R_j)
return draw

def predict(self, X_new):
"""Compute out of sample predictions evaluated at X_new"""
trees = self.all_trees
num_observations = X_new.shape[0]
pred = np.zeros((len(trees), num_observations))
for draw, trees_to_sum in enumerate(trees):
new_Y = np.zeros(X_new.shape[0])
for tree in trees_to_sum:
new_Y += [tree.predict_out_of_sample(x) for x in X_new]
pred[draw] = new_Y
return pred


def compute_prior_probability(alpha):
"""
Expand Down Expand Up @@ -217,6 +231,31 @@ def discrete_uniform_sampler(upper_value):
return int(np.random.random() * upper_value)


class sample_splitting_variable:
def __init__(self, prior, num_variates):
self.prior = prior
self.num_variates = num_variates

if self.prior is not None:
self.prior = np.asarray(self.prior)
self.prior = self.prior / self.prior.sum()
if self.prior.size != self.num_variates:
raise ValueError(
f"The size of split_prior ({self.prior.size}) should be the "
f"same as the number of covariates ({self.num_variates})"
)
self.enu = list(enumerate(np.cumsum(self.prior)))

def rvs(self):
if self.prior is None:
return int(np.random.random() * self.num_variates)
else:
r = np.random.random()
for i, v in self.enu:
if r <= v:
return i


class BART(BaseBART):
"""
BART distribution.
Expand All @@ -225,19 +264,23 @@ class BART(BaseBART):

Parameters
----------
X :
X : array-like
The design matrix.
Y :
Y : array-like
The response vector.
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
altought it is recomenned to be in the interval (0, 0.5].
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum
to 1. Otherwise they will be normalized.
Defaults to None, all variable have the same a prior probability
"""

def __init__(self, X, Y, m=200, alpha=0.25):
super().__init__(X, Y, m, alpha)
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
super().__init__(X, Y, m, alpha, split_prior)

def _str_repr(self, name=None, dist=None, formatting="plain"):
if dist is None:
Expand Down
26 changes: 20 additions & 6 deletions pymc3/distributions/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ def predict_output(self, num_observations):
output[current_node.idx_data_points] = current_node.value
return output

def predict_out_of_sample(self, x):
"""
Predict output of tree for an unobserved point x.

Parameters
----------
x : numpy array

Returns
-------
float
Value of the leaf value where the unobserved point lies.
"""
leaf_node = self._traverse_tree(x=x, node_index=0)
return leaf_node.value

def _traverse_tree(self, x, node_index=0):
"""
Traverse the tree starting from a particular node given an unobserved point.
Expand All @@ -99,15 +115,13 @@ def _traverse_tree(self, x, node_index=0):
"""
current_node = self.get_node(node_index)
if isinstance(current_node, SplitNode):
if x is not np.NaN:
if x[current_node.idx_split_variable] <= current_node.split_value:
left_child = current_node.get_idx_left_child()
final_node = self._traverse_tree(x, left_child)
current_node = self._traverse_tree(x, left_child)
else:
right_child = current_node.get_idx_right_child()
final_node = self._traverse_tree(x, right_child)
else:
final_node = current_node
return final_node
current_node = self._traverse_tree(x, right_child)
return current_node

def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node):
"""
Expand Down
14 changes: 12 additions & 2 deletions pymc3/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m

self.tune = True
self.idx = 0
self.iter = 0
self.sum_trees = []
self.chunk = chunk

if chunk == "auto":
self.chunk = max(1, int(self.bart.m * 0.1))
self.bart.chunk = self.chunk
self.num_particles = num_particles
self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
Expand Down Expand Up @@ -96,14 +101,14 @@ def astep(self, _):
self.idx = 0

for idx in range(self.idx, self.idx + self.chunk):
if idx > bart.m:
if idx >= bart.m:
break
self.idx += 1
tree = bart.trees[idx]
R_j = bart.get_residuals_loo(tree)
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree.tree_id, R_j, bart.num_observations)
particles = self.init_particles(tree.tree_id, R_j, num_observations)

for t in range(1, max_stages):
# Get old particle at stage t
Expand Down Expand Up @@ -147,6 +152,11 @@ def astep(self, _):
bart.sum_trees_output = bart.Y - R_j + new_prediction

if not self.tune:
self.iter += 1
self.sum_trees.append(new_tree.tree)
if not self.iter % bart.m:
bart.all_trees.append(self.sum_trees)
self.sum_trees = []
for index in new_tree.used_variates:
variable_inclusion[index] += 1

Expand Down