diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index 20fd32a801..4914844555 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -14,6 +14,8 @@ import numpy as np +from pandas import DataFrame, Series + from pymc3.distributions.distribution import NoDistribution from pymc3.distributions.tree import LeafNode, SplitNode, Tree @@ -21,9 +23,10 @@ class BaseBART(NoDistribution): - def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs): - self.X = X - self.Y = Y + def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs): + + self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y) + super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs) if self.X.ndim != 2: @@ -48,12 +51,24 @@ def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs): self.num_observations = X.shape[0] self.num_variates = X.shape[1] + self.available_predictors = list(range(self.num_variates)) + self.ssv = SampleSplittingVariable(split_prior, self.num_variates) self.m = m self.alpha = alpha self.trees = self.init_list_of_trees() + self.all_trees = [] self.mean = fast_mean() self.prior_prob_leaf_node = compute_prior_probability(alpha) + def preprocess_XY(self, X, Y): + if isinstance(Y, (Series, DataFrame)): + Y = Y.to_numpy() + if isinstance(X, (Series, DataFrame)): + X = X.to_numpy() + missing_data = np.any(np.isnan(X)) + X = np.random.normal(X, np.std(X, 0) / 100) + return X, Y, missing_data + def init_list_of_trees(self): initial_value_leaf_nodes = self.Y.mean() / self.m initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32") @@ -79,39 +94,26 @@ def __iter__(self): def __repr_latex(self): raise NotImplementedError - def get_available_predictors(self, idx_data_points_split_node): - possible_splitting_variables = [] - for j in range(self.num_variates): - x_j = self.X[idx_data_points_split_node, j] - x_j = x_j[~np.isnan(x_j)] - for i in range(1, len(x_j)): - if x_j[i - 1] != x_j[i]: - possible_splitting_variables.append(j) - break - return possible_splitting_variables - def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable): x_j = self.X[idx_data_points_split_node, idx_split_variable] - x_j = x_j[~np.isnan(x_j)] - values, indices = np.unique(x_j, return_index=True) - # The last value is not consider since if we choose it as the value of - # the splitting rule assignment, it would leave the right subtree empty. - return values[:-1], indices[:-1] + if self.missing_data: + x_j = x_j[~np.isnan(x_j)] + values = np.unique(x_j) + # The last value is never available as it would leave the right subtree empty. + return values[:-1] def grow_tree(self, tree, index_leaf_node): - # This can be unsuccessful when there are not available predictors current_node = tree.get_node(index_leaf_node) - available_predictors = self.get_available_predictors(current_node.idx_data_points) - - if not available_predictors: - return False, None - - index_selected_predictor = discrete_uniform_sampler(len(available_predictors)) - selected_predictor = available_predictors[index_selected_predictor] - available_splitting_rules, _ = self.get_available_splitting_rules( + index_selected_predictor = self.ssv.rvs() + selected_predictor = self.available_predictors[index_selected_predictor] + available_splitting_rules = self.get_available_splitting_rules( current_node.idx_data_points, selected_predictor ) + # This can be unsuccessful when there are not available splitting rules + if available_splitting_rules.size == 0: + return False, None + index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules)) selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule] new_split_node = SplitNode( @@ -167,6 +169,19 @@ def draw_leaf_value(self, idx_data_points): draw = self.mean(R_j) return draw + def predict(self, X_new): + """Compute out of sample predictions evaluated at X_new""" + trees = self.all_trees + num_observations = X_new.shape[0] + pred = np.zeros((len(trees), num_observations)) + np.random.randint(len(trees)) + for draw, trees_to_sum in enumerate(trees): + new_Y = np.zeros(num_observations) + for tree in trees_to_sum: + new_Y += [tree.predict_out_of_sample(x) for x in X_new] + pred[draw] = new_Y + return pred + def compute_prior_probability(alpha): """ @@ -217,6 +232,31 @@ def discrete_uniform_sampler(upper_value): return int(np.random.random() * upper_value) +class SampleSplittingVariable: + def __init__(self, prior, num_variates): + self.prior = prior + self.num_variates = num_variates + + if self.prior is not None: + self.prior = np.asarray(self.prior) + self.prior = self.prior / self.prior.sum() + if self.prior.size != self.num_variates: + raise ValueError( + f"The size of split_prior ({self.prior.size}) should be the " + f"same as the number of covariates ({self.num_variates})" + ) + self.enu = list(enumerate(np.cumsum(self.prior))) + + def rvs(self): + if self.prior is None: + return int(np.random.random() * self.num_variates) + else: + r = np.random.random() + for i, v in self.enu: + if r <= v: + return i + + class BART(BaseBART): """ BART distribution. @@ -225,19 +265,23 @@ class BART(BaseBART): Parameters ---------- - X : + X : array-like The design matrix. - Y : + Y : array-like The response vector. m : int Number of trees alpha : float Control the prior probability over the depth of the trees. Must be in the interval (0, 1), altought it is recomenned to be in the interval (0, 0.5]. + split_prior : array-like + Each element of split_prior should be in the [0, 1] interval and the elements should sum + to 1. Otherwise they will be normalized. + Defaults to None, all variable have the same a prior probability """ - def __init__(self, X, Y, m=200, alpha=0.25): - super().__init__(X, Y, m, alpha) + def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None): + super().__init__(X, Y, m, alpha, split_prior) def _str_repr(self, name=None, dist=None, formatting="plain"): if dist is None: diff --git a/pymc3/distributions/tree.py b/pymc3/distributions/tree.py index 81c727232a..8e84bd9a7c 100644 --- a/pymc3/distributions/tree.py +++ b/pymc3/distributions/tree.py @@ -84,6 +84,22 @@ def predict_output(self, num_observations): output[current_node.idx_data_points] = current_node.value return output + def predict_out_of_sample(self, x): + """ + Predict output of tree for an unobserved point x. + + Parameters + ---------- + x : numpy array + + Returns + ------- + float + Value of the leaf value where the unobserved point lies. + """ + leaf_node = self._traverse_tree(x=x, node_index=0) + return leaf_node.value + def _traverse_tree(self, x, node_index=0): """ Traverse the tree starting from a particular node given an unobserved point. @@ -99,15 +115,13 @@ def _traverse_tree(self, x, node_index=0): """ current_node = self.get_node(node_index) if isinstance(current_node, SplitNode): - if x is not np.NaN: + if x[current_node.idx_split_variable] <= current_node.split_value: left_child = current_node.get_idx_left_child() - final_node = self._traverse_tree(x, left_child) + current_node = self._traverse_tree(x, left_child) else: right_child = current_node.get_idx_right_child() - final_node = self._traverse_tree(x, right_child) - else: - final_node = current_node - return final_node + current_node = self._traverse_tree(x, right_child) + return current_node def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node): """ diff --git a/pymc3/step_methods/pgbart.py b/pymc3/step_methods/pgbart.py index ca0c61aa25..c3bac3ade9 100644 --- a/pymc3/step_methods/pgbart.py +++ b/pymc3/step_methods/pgbart.py @@ -64,8 +64,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m self.tune = True self.idx = 0 + self.iter = 0 + self.sum_trees = [] + self.chunk = chunk + if chunk == "auto": self.chunk = max(1, int(self.bart.m * 0.1)) + self.bart.chunk = self.chunk self.num_particles = num_particles self.log_num_particles = np.log(num_particles) self.indices = list(range(1, num_particles)) @@ -96,14 +101,14 @@ def astep(self, _): self.idx = 0 for idx in range(self.idx, self.idx + self.chunk): - if idx > bart.m: + if idx >= bart.m: break self.idx += 1 tree = bart.trees[idx] R_j = bart.get_residuals_loo(tree) # Generate an initial set of SMC particles # at the end of the algorithm we return one of these particles as the new tree - particles = self.init_particles(tree.tree_id, R_j, bart.num_observations) + particles = self.init_particles(tree.tree_id, R_j, num_observations) for t in range(1, max_stages): # Get old particle at stage t @@ -147,6 +152,11 @@ def astep(self, _): bart.sum_trees_output = bart.Y - R_j + new_prediction if not self.tune: + self.iter += 1 + self.sum_trees.append(new_tree.tree) + if not self.iter % bart.m: + bart.all_trees.append(self.sum_trees) + self.sum_trees = [] for index in new_tree.used_variates: variable_inclusion[index] += 1