|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 | 16 |
|
17 |
| -from pandas import DataFrame, Series |
| 17 | +from aesara.tensor.random.op import RandomVariable, default_shape_from_params |
18 | 18 |
|
19 | 19 | from pymc3.distributions.distribution import NoDistribution
|
20 |
| -from pymc3.distributions.tree import LeafNode, SplitNode, Tree |
21 | 20 |
|
22 | 21 | __all__ = ["BART"]
|
23 | 22 |
|
24 | 23 |
|
25 |
| -class BaseBART(NoDistribution): |
26 |
| - def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs): |
27 |
| - |
28 |
| - self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y) |
29 |
| - |
30 |
| - super().__init__(shape=X.shape[0], dtype="float64", initval=0, *args, **kwargs) |
31 |
| - |
32 |
| - if self.X.ndim != 2: |
33 |
| - raise ValueError("The design matrix X must have two dimensions") |
34 |
| - |
35 |
| - if self.Y.ndim != 1: |
36 |
| - raise ValueError("The response matrix Y must have one dimension") |
37 |
| - if self.X.shape[0] != self.Y.shape[0]: |
38 |
| - raise ValueError( |
39 |
| - "The design matrix X and the response matrix Y must have the same number of elements" |
40 |
| - ) |
41 |
| - if not isinstance(m, int): |
42 |
| - raise ValueError("The number of trees m type must be int") |
43 |
| - if m < 1: |
44 |
| - raise ValueError("The number of trees m must be greater than zero") |
45 |
| - |
46 |
| - if alpha <= 0 or 1 <= alpha: |
47 |
| - raise ValueError( |
48 |
| - "The value for the alpha parameter for the tree structure " |
49 |
| - "must be in the interval (0, 1)" |
50 |
| - ) |
51 |
| - |
52 |
| - self.num_observations = X.shape[0] |
53 |
| - self.num_variates = X.shape[1] |
54 |
| - self.available_predictors = list(range(self.num_variates)) |
55 |
| - self.ssv = SampleSplittingVariable(split_prior, self.num_variates) |
56 |
| - self.m = m |
57 |
| - self.alpha = alpha |
58 |
| - self.trees = self.init_list_of_trees() |
59 |
| - self.all_trees = [] |
60 |
| - self.mean = fast_mean() |
61 |
| - self.prior_prob_leaf_node = compute_prior_probability(alpha) |
62 |
| - |
63 |
| - def preprocess_XY(self, X, Y): |
64 |
| - if isinstance(Y, (Series, DataFrame)): |
65 |
| - Y = Y.to_numpy() |
66 |
| - if isinstance(X, (Series, DataFrame)): |
67 |
| - X = X.to_numpy() |
68 |
| - missing_data = np.any(np.isnan(X)) |
69 |
| - X = np.random.normal(X, np.std(X, 0) / 100) |
70 |
| - return X, Y, missing_data |
71 |
| - |
72 |
| - def init_list_of_trees(self): |
73 |
| - initial_value_leaf_nodes = self.Y.mean() / self.m |
74 |
| - initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32") |
75 |
| - list_of_trees = [] |
76 |
| - for i in range(self.m): |
77 |
| - new_tree = Tree.init_tree( |
78 |
| - tree_id=i, |
79 |
| - leaf_node_value=initial_value_leaf_nodes, |
80 |
| - idx_data_points=initial_idx_data_points_leaf_nodes, |
81 |
| - ) |
82 |
| - list_of_trees.append(new_tree) |
83 |
| - # Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J. |
84 |
| - # bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013 |
85 |
| - # The sum_trees_output will contain the sum of the predicted output for all trees. |
86 |
| - # When R_j is needed we subtract the current predicted output for tree T_j. |
87 |
| - self.sum_trees_output = np.full_like(self.Y, self.Y.mean()) |
88 |
| - |
89 |
| - return list_of_trees |
90 |
| - |
91 |
| - def __iter__(self): |
92 |
| - return iter(self.trees) |
93 |
| - |
94 |
| - def __repr_latex(self): |
95 |
| - raise NotImplementedError |
96 |
| - |
97 |
| - def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable): |
98 |
| - x_j = self.X[idx_data_points_split_node, idx_split_variable] |
99 |
| - if self.missing_data: |
100 |
| - x_j = x_j[~np.isnan(x_j)] |
101 |
| - values = np.unique(x_j) |
102 |
| - # The last value is never available as it would leave the right subtree empty. |
103 |
| - return values[:-1] |
104 |
| - |
105 |
| - def grow_tree(self, tree, index_leaf_node): |
106 |
| - current_node = tree.get_node(index_leaf_node) |
107 |
| - |
108 |
| - index_selected_predictor = self.ssv.rvs() |
109 |
| - selected_predictor = self.available_predictors[index_selected_predictor] |
110 |
| - available_splitting_rules = self.get_available_splitting_rules( |
111 |
| - current_node.idx_data_points, selected_predictor |
112 |
| - ) |
113 |
| - # This can be unsuccessful when there are not available splitting rules |
114 |
| - if available_splitting_rules.size == 0: |
115 |
| - return False, None |
116 |
| - |
117 |
| - index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules)) |
118 |
| - selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule] |
119 |
| - new_split_node = SplitNode( |
120 |
| - index=index_leaf_node, |
121 |
| - idx_split_variable=selected_predictor, |
122 |
| - split_value=selected_splitting_rule, |
123 |
| - ) |
124 |
| - |
125 |
| - left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points( |
126 |
| - new_split_node, current_node.idx_data_points |
127 |
| - ) |
128 |
| - |
129 |
| - left_node_value = self.draw_leaf_value(left_node_idx_data_points) |
130 |
| - right_node_value = self.draw_leaf_value(right_node_idx_data_points) |
131 |
| - |
132 |
| - new_left_node = LeafNode( |
133 |
| - index=current_node.get_idx_left_child(), |
134 |
| - value=left_node_value, |
135 |
| - idx_data_points=left_node_idx_data_points, |
136 |
| - ) |
137 |
| - new_right_node = LeafNode( |
138 |
| - index=current_node.get_idx_right_child(), |
139 |
| - value=right_node_value, |
140 |
| - idx_data_points=right_node_idx_data_points, |
141 |
| - ) |
142 |
| - tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) |
143 |
| - |
144 |
| - return True, index_selected_predictor |
145 |
| - |
146 |
| - def get_new_idx_data_points(self, current_split_node, idx_data_points): |
147 |
| - idx_split_variable = current_split_node.idx_split_variable |
148 |
| - split_value = current_split_node.split_value |
149 |
| - |
150 |
| - left_idx = self.X[idx_data_points, idx_split_variable] <= split_value |
151 |
| - left_node_idx_data_points = idx_data_points[left_idx] |
152 |
| - right_node_idx_data_points = idx_data_points[~left_idx] |
153 |
| - |
154 |
| - return left_node_idx_data_points, right_node_idx_data_points |
155 |
| - |
156 |
| - def get_residuals(self): |
157 |
| - """Compute the residuals.""" |
158 |
| - R_j = self.Y - self.sum_trees_output |
159 |
| - return R_j |
160 |
| - |
161 |
| - def get_residuals_loo(self, tree): |
162 |
| - """Compute the residuals without leaving the passed tree out.""" |
163 |
| - R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations)) |
164 |
| - return R_j |
165 |
| - |
166 |
| - def draw_leaf_value(self, idx_data_points): |
167 |
| - """Draw the residual mean.""" |
168 |
| - R_j = self.get_residuals()[idx_data_points] |
169 |
| - draw = self.mean(R_j) |
170 |
| - return draw |
171 |
| - |
172 |
| - def predict(self, X_new): |
173 |
| - """Compute out of sample predictions evaluated at X_new""" |
174 |
| - trees = self.all_trees |
175 |
| - num_observations = X_new.shape[0] |
176 |
| - pred = np.zeros((len(trees), num_observations)) |
177 |
| - np.random.randint(len(trees)) |
178 |
| - for draw, trees_to_sum in enumerate(trees): |
179 |
| - new_Y = np.zeros(num_observations) |
180 |
| - for tree in trees_to_sum: |
181 |
| - new_Y += [tree.predict_out_of_sample(x) for x in X_new] |
182 |
| - pred[draw] = new_Y |
183 |
| - return pred |
184 |
| - |
185 |
| - |
186 |
| -def compute_prior_probability(alpha): |
| 24 | +class BARTRV(RandomVariable): |
187 | 25 | """
|
188 |
| - Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). |
189 |
| - Taken from equation 19 in [Rockova2018]. |
190 |
| -
|
191 |
| - Parameters |
192 |
| - ---------- |
193 |
| - alpha : float |
194 |
| -
|
195 |
| - Returns |
196 |
| - ------- |
197 |
| - list with probabilities for leaf nodes |
198 |
| -
|
199 |
| - References |
200 |
| - ---------- |
201 |
| - .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. |
202 |
| - arXiv, `link <https://arxiv.org/abs/1810.00787>`__ |
| 26 | + Base class for BART |
203 | 27 | """
|
204 |
| - prior_leaf_prob = [0] |
205 |
| - depth = 1 |
206 |
| - while prior_leaf_prob[-1] < 1: |
207 |
| - prior_leaf_prob.append(1 - alpha ** depth) |
208 |
| - depth += 1 |
209 |
| - return prior_leaf_prob |
210 |
| - |
211 |
| - |
212 |
| -def fast_mean(): |
213 |
| - """If available use Numba to speed up the computation of the mean.""" |
214 |
| - try: |
215 |
| - from numba import jit |
216 |
| - except ImportError: |
217 |
| - return np.mean |
218 |
| - |
219 |
| - @jit |
220 |
| - def mean(a): |
221 |
| - count = a.shape[0] |
222 |
| - suma = 0 |
223 |
| - for i in range(count): |
224 |
| - suma += a[i] |
225 |
| - return suma / count |
226 |
| - |
227 |
| - return mean |
228 | 28 |
|
| 29 | + name = "BART" |
| 30 | + ndim_supp = 1 |
| 31 | + ndims_params = [2, 1, 0, 0, 0, 1] |
| 32 | + dtype = "floatX" |
| 33 | + _print_name = ("BART", "\\operatorname{BART}") |
| 34 | + all_trees = None |
| 35 | + |
| 36 | + def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
| 37 | + return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) |
| 38 | + |
| 39 | + @classmethod |
| 40 | + def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs): |
| 41 | + size = kwargs.pop("size", None) |
| 42 | + X_new = kwargs.pop("X_new", None) |
| 43 | + all_trees = cls.all_trees |
| 44 | + if all_trees: |
| 45 | + |
| 46 | + if size is None: |
| 47 | + size = () |
| 48 | + elif isinstance(size, int): |
| 49 | + size = [size] |
| 50 | + |
| 51 | + flatten_size = 1 |
| 52 | + for s in size: |
| 53 | + flatten_size *= s |
| 54 | + |
| 55 | + idx = rng.randint(len(all_trees), size=flatten_size) |
| 56 | + |
| 57 | + if X_new is None: |
| 58 | + pred = np.zeros((flatten_size, all_trees[0][0].num_observations)) |
| 59 | + for ind, p in enumerate(pred): |
| 60 | + for tree in all_trees[idx[ind]]: |
| 61 | + p += tree.predict_output() |
| 62 | + else: |
| 63 | + pred = np.zeros((flatten_size, X_new.shape[0])) |
| 64 | + for ind, p in enumerate(pred): |
| 65 | + for tree in all_trees[idx[ind]]: |
| 66 | + p += np.array([tree.predict_out_of_sample(x) for x in X_new]) |
| 67 | + return pred.reshape((*size, -1)) |
| 68 | + else: |
| 69 | + return np.full_like(cls.Y, cls.Y.mean()) |
229 | 70 |
|
230 |
| -def discrete_uniform_sampler(upper_value): |
231 |
| - """Draw from the uniform distribution with bounds [0, upper_value).""" |
232 |
| - return int(np.random.random() * upper_value) |
233 |
| - |
234 |
| - |
235 |
| -class SampleSplittingVariable: |
236 |
| - def __init__(self, prior, num_variates): |
237 |
| - self.prior = prior |
238 |
| - self.num_variates = num_variates |
239 |
| - |
240 |
| - if self.prior is not None: |
241 |
| - self.prior = np.asarray(self.prior) |
242 |
| - self.prior = self.prior / self.prior.sum() |
243 |
| - if self.prior.size != self.num_variates: |
244 |
| - raise ValueError( |
245 |
| - f"The size of split_prior ({self.prior.size}) should be the " |
246 |
| - f"same as the number of covariates ({self.num_variates})" |
247 |
| - ) |
248 |
| - self.enu = list(enumerate(np.cumsum(self.prior))) |
249 | 71 |
|
250 |
| - def rvs(self): |
251 |
| - if self.prior is None: |
252 |
| - return int(np.random.random() * self.num_variates) |
253 |
| - else: |
254 |
| - r = np.random.random() |
255 |
| - for i, v in self.enu: |
256 |
| - if r <= v: |
257 |
| - return i |
| 72 | +bart = BARTRV() |
258 | 73 |
|
259 | 74 |
|
260 |
| -class BART(BaseBART): |
| 75 | +class BART(NoDistribution): |
261 | 76 | """
|
262 |
| - BART distribution. |
| 77 | + Bayesian Additive Regression Tree distribution. |
263 | 78 |
|
264 | 79 | Distribution representing a sum over trees
|
265 | 80 |
|
266 | 81 | Parameters
|
267 | 82 | ----------
|
268 | 83 | X : array-like
|
269 |
| - The design matrix. |
| 84 | + The covariate matrix. |
270 | 85 | Y : array-like
|
271 | 86 | The response vector.
|
272 | 87 | m : int
|
273 | 88 | Number of trees
|
274 | 89 | alpha : float
|
275 |
| - Control the prior probability over the depth of the trees. Must be in the interval (0, 1), |
276 |
| - altought it is recomenned to be in the interval (0, 0.5]. |
| 90 | + Control the prior probability over the depth of the trees. Even when it can takes values in |
| 91 | + the interval (0, 1), it is recommended to be in the interval (0, 0.5]. |
| 92 | + k : float |
| 93 | + Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1 |
| 94 | + and 3. |
277 | 95 | split_prior : array-like
|
278 |
| - Each element of split_prior should be in the [0, 1] interval and the elements should sum |
279 |
| - to 1. Otherwise they will be normalized. |
280 |
| - Defaults to None, all variable have the same a prior probability |
| 96 | + Each element of split_prior should be in the [0, 1] interval and the elements should sum to |
| 97 | + 1. Otherwise they will be normalized. |
| 98 | + Defaults to None, i.e. all covariates have the same prior probability to be selected. |
281 | 99 | """
|
282 | 100 |
|
283 |
| - def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None): |
284 |
| - super().__init__(X, Y, m, alpha, split_prior) |
| 101 | + def __new__( |
| 102 | + cls, |
| 103 | + name, |
| 104 | + X, |
| 105 | + Y, |
| 106 | + m=50, |
| 107 | + alpha=0.25, |
| 108 | + k=2, |
| 109 | + split_prior=None, |
| 110 | + **kwargs, |
| 111 | + ): |
| 112 | + |
| 113 | + cls.all_trees = [] |
| 114 | + |
| 115 | + bart_op = type( |
| 116 | + f"BART_{name}", |
| 117 | + (BARTRV,), |
| 118 | + dict( |
| 119 | + name="BART", |
| 120 | + all_trees=cls.all_trees, |
| 121 | + inplace=False, |
| 122 | + initval=Y.mean(), |
| 123 | + X=X, |
| 124 | + Y=Y, |
| 125 | + m=m, |
| 126 | + alpha=alpha, |
| 127 | + k=k, |
| 128 | + split_prior=split_prior, |
| 129 | + ), |
| 130 | + )() |
| 131 | + |
| 132 | + NoDistribution.register(BARTRV) |
| 133 | + |
| 134 | + cls.rv_op = bart_op |
| 135 | + params = [X, Y, m, alpha, k] |
| 136 | + return super().__new__(cls, name, *params, **kwargs) |
| 137 | + |
| 138 | + @classmethod |
| 139 | + def dist(cls, *params, **kwargs): |
| 140 | + return super().dist(params, **kwargs) |
0 commit comments