Skip to content

Commit ff45994

Browse files
authored
Bring back BART to V4 and make it more general (#4914)
* frowardporting from unreleased v3 plus generalization * aesarize * improve docstrings * small fix docstring and variable names * fix format variable importance * fix broadcasting issue and other minor fixes * add test and fix pylint * fix float32 * sample splitting variables non-uniformly * remove xfail * add back xfail on windows * add back xfail on windows and for float32 * fix test * clean rnd * add size argument and check for NoDistribution * stop updating split_prior after tuning * clean code and small speed-up * clean code and small speed-up * revert xfail * add tests * fix number of chains * revert test * clean code, refactor and small speed-up * test random * test random * add missing data test
1 parent 61fa834 commit ff45994

File tree

8 files changed

+563
-379
lines changed

8 files changed

+563
-379
lines changed

Diff for: pymc3/distributions/bart.py

+96-240
Original file line numberDiff line numberDiff line change
@@ -14,271 +14,127 @@
1414

1515
import numpy as np
1616

17-
from pandas import DataFrame, Series
17+
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
1818

1919
from pymc3.distributions.distribution import NoDistribution
20-
from pymc3.distributions.tree import LeafNode, SplitNode, Tree
2120

2221
__all__ = ["BART"]
2322

2423

25-
class BaseBART(NoDistribution):
26-
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):
27-
28-
self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)
29-
30-
super().__init__(shape=X.shape[0], dtype="float64", initval=0, *args, **kwargs)
31-
32-
if self.X.ndim != 2:
33-
raise ValueError("The design matrix X must have two dimensions")
34-
35-
if self.Y.ndim != 1:
36-
raise ValueError("The response matrix Y must have one dimension")
37-
if self.X.shape[0] != self.Y.shape[0]:
38-
raise ValueError(
39-
"The design matrix X and the response matrix Y must have the same number of elements"
40-
)
41-
if not isinstance(m, int):
42-
raise ValueError("The number of trees m type must be int")
43-
if m < 1:
44-
raise ValueError("The number of trees m must be greater than zero")
45-
46-
if alpha <= 0 or 1 <= alpha:
47-
raise ValueError(
48-
"The value for the alpha parameter for the tree structure "
49-
"must be in the interval (0, 1)"
50-
)
51-
52-
self.num_observations = X.shape[0]
53-
self.num_variates = X.shape[1]
54-
self.available_predictors = list(range(self.num_variates))
55-
self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
56-
self.m = m
57-
self.alpha = alpha
58-
self.trees = self.init_list_of_trees()
59-
self.all_trees = []
60-
self.mean = fast_mean()
61-
self.prior_prob_leaf_node = compute_prior_probability(alpha)
62-
63-
def preprocess_XY(self, X, Y):
64-
if isinstance(Y, (Series, DataFrame)):
65-
Y = Y.to_numpy()
66-
if isinstance(X, (Series, DataFrame)):
67-
X = X.to_numpy()
68-
missing_data = np.any(np.isnan(X))
69-
X = np.random.normal(X, np.std(X, 0) / 100)
70-
return X, Y, missing_data
71-
72-
def init_list_of_trees(self):
73-
initial_value_leaf_nodes = self.Y.mean() / self.m
74-
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
75-
list_of_trees = []
76-
for i in range(self.m):
77-
new_tree = Tree.init_tree(
78-
tree_id=i,
79-
leaf_node_value=initial_value_leaf_nodes,
80-
idx_data_points=initial_idx_data_points_leaf_nodes,
81-
)
82-
list_of_trees.append(new_tree)
83-
# Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J.
84-
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
85-
# The sum_trees_output will contain the sum of the predicted output for all trees.
86-
# When R_j is needed we subtract the current predicted output for tree T_j.
87-
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())
88-
89-
return list_of_trees
90-
91-
def __iter__(self):
92-
return iter(self.trees)
93-
94-
def __repr_latex(self):
95-
raise NotImplementedError
96-
97-
def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
98-
x_j = self.X[idx_data_points_split_node, idx_split_variable]
99-
if self.missing_data:
100-
x_j = x_j[~np.isnan(x_j)]
101-
values = np.unique(x_j)
102-
# The last value is never available as it would leave the right subtree empty.
103-
return values[:-1]
104-
105-
def grow_tree(self, tree, index_leaf_node):
106-
current_node = tree.get_node(index_leaf_node)
107-
108-
index_selected_predictor = self.ssv.rvs()
109-
selected_predictor = self.available_predictors[index_selected_predictor]
110-
available_splitting_rules = self.get_available_splitting_rules(
111-
current_node.idx_data_points, selected_predictor
112-
)
113-
# This can be unsuccessful when there are not available splitting rules
114-
if available_splitting_rules.size == 0:
115-
return False, None
116-
117-
index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
118-
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
119-
new_split_node = SplitNode(
120-
index=index_leaf_node,
121-
idx_split_variable=selected_predictor,
122-
split_value=selected_splitting_rule,
123-
)
124-
125-
left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points(
126-
new_split_node, current_node.idx_data_points
127-
)
128-
129-
left_node_value = self.draw_leaf_value(left_node_idx_data_points)
130-
right_node_value = self.draw_leaf_value(right_node_idx_data_points)
131-
132-
new_left_node = LeafNode(
133-
index=current_node.get_idx_left_child(),
134-
value=left_node_value,
135-
idx_data_points=left_node_idx_data_points,
136-
)
137-
new_right_node = LeafNode(
138-
index=current_node.get_idx_right_child(),
139-
value=right_node_value,
140-
idx_data_points=right_node_idx_data_points,
141-
)
142-
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)
143-
144-
return True, index_selected_predictor
145-
146-
def get_new_idx_data_points(self, current_split_node, idx_data_points):
147-
idx_split_variable = current_split_node.idx_split_variable
148-
split_value = current_split_node.split_value
149-
150-
left_idx = self.X[idx_data_points, idx_split_variable] <= split_value
151-
left_node_idx_data_points = idx_data_points[left_idx]
152-
right_node_idx_data_points = idx_data_points[~left_idx]
153-
154-
return left_node_idx_data_points, right_node_idx_data_points
155-
156-
def get_residuals(self):
157-
"""Compute the residuals."""
158-
R_j = self.Y - self.sum_trees_output
159-
return R_j
160-
161-
def get_residuals_loo(self, tree):
162-
"""Compute the residuals without leaving the passed tree out."""
163-
R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations))
164-
return R_j
165-
166-
def draw_leaf_value(self, idx_data_points):
167-
"""Draw the residual mean."""
168-
R_j = self.get_residuals()[idx_data_points]
169-
draw = self.mean(R_j)
170-
return draw
171-
172-
def predict(self, X_new):
173-
"""Compute out of sample predictions evaluated at X_new"""
174-
trees = self.all_trees
175-
num_observations = X_new.shape[0]
176-
pred = np.zeros((len(trees), num_observations))
177-
np.random.randint(len(trees))
178-
for draw, trees_to_sum in enumerate(trees):
179-
new_Y = np.zeros(num_observations)
180-
for tree in trees_to_sum:
181-
new_Y += [tree.predict_out_of_sample(x) for x in X_new]
182-
pred[draw] = new_Y
183-
return pred
184-
185-
186-
def compute_prior_probability(alpha):
24+
class BARTRV(RandomVariable):
18725
"""
188-
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
189-
Taken from equation 19 in [Rockova2018].
190-
191-
Parameters
192-
----------
193-
alpha : float
194-
195-
Returns
196-
-------
197-
list with probabilities for leaf nodes
198-
199-
References
200-
----------
201-
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
202-
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
26+
Base class for BART
20327
"""
204-
prior_leaf_prob = [0]
205-
depth = 1
206-
while prior_leaf_prob[-1] < 1:
207-
prior_leaf_prob.append(1 - alpha ** depth)
208-
depth += 1
209-
return prior_leaf_prob
210-
211-
212-
def fast_mean():
213-
"""If available use Numba to speed up the computation of the mean."""
214-
try:
215-
from numba import jit
216-
except ImportError:
217-
return np.mean
218-
219-
@jit
220-
def mean(a):
221-
count = a.shape[0]
222-
suma = 0
223-
for i in range(count):
224-
suma += a[i]
225-
return suma / count
226-
227-
return mean
22828

29+
name = "BART"
30+
ndim_supp = 1
31+
ndims_params = [2, 1, 0, 0, 0, 1]
32+
dtype = "floatX"
33+
_print_name = ("BART", "\\operatorname{BART}")
34+
all_trees = None
35+
36+
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
37+
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)
38+
39+
@classmethod
40+
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
41+
size = kwargs.pop("size", None)
42+
X_new = kwargs.pop("X_new", None)
43+
all_trees = cls.all_trees
44+
if all_trees:
45+
46+
if size is None:
47+
size = ()
48+
elif isinstance(size, int):
49+
size = [size]
50+
51+
flatten_size = 1
52+
for s in size:
53+
flatten_size *= s
54+
55+
idx = rng.randint(len(all_trees), size=flatten_size)
56+
57+
if X_new is None:
58+
pred = np.zeros((flatten_size, all_trees[0][0].num_observations))
59+
for ind, p in enumerate(pred):
60+
for tree in all_trees[idx[ind]]:
61+
p += tree.predict_output()
62+
else:
63+
pred = np.zeros((flatten_size, X_new.shape[0]))
64+
for ind, p in enumerate(pred):
65+
for tree in all_trees[idx[ind]]:
66+
p += np.array([tree.predict_out_of_sample(x) for x in X_new])
67+
return pred.reshape((*size, -1))
68+
else:
69+
return np.full_like(cls.Y, cls.Y.mean())
22970

230-
def discrete_uniform_sampler(upper_value):
231-
"""Draw from the uniform distribution with bounds [0, upper_value)."""
232-
return int(np.random.random() * upper_value)
233-
234-
235-
class SampleSplittingVariable:
236-
def __init__(self, prior, num_variates):
237-
self.prior = prior
238-
self.num_variates = num_variates
239-
240-
if self.prior is not None:
241-
self.prior = np.asarray(self.prior)
242-
self.prior = self.prior / self.prior.sum()
243-
if self.prior.size != self.num_variates:
244-
raise ValueError(
245-
f"The size of split_prior ({self.prior.size}) should be the "
246-
f"same as the number of covariates ({self.num_variates})"
247-
)
248-
self.enu = list(enumerate(np.cumsum(self.prior)))
24971

250-
def rvs(self):
251-
if self.prior is None:
252-
return int(np.random.random() * self.num_variates)
253-
else:
254-
r = np.random.random()
255-
for i, v in self.enu:
256-
if r <= v:
257-
return i
72+
bart = BARTRV()
25873

25974

260-
class BART(BaseBART):
75+
class BART(NoDistribution):
26176
"""
262-
BART distribution.
77+
Bayesian Additive Regression Tree distribution.
26378
26479
Distribution representing a sum over trees
26580
26681
Parameters
26782
----------
26883
X : array-like
269-
The design matrix.
84+
The covariate matrix.
27085
Y : array-like
27186
The response vector.
27287
m : int
27388
Number of trees
27489
alpha : float
275-
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
276-
altought it is recomenned to be in the interval (0, 0.5].
90+
Control the prior probability over the depth of the trees. Even when it can takes values in
91+
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
92+
k : float
93+
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
94+
and 3.
27795
split_prior : array-like
278-
Each element of split_prior should be in the [0, 1] interval and the elements should sum
279-
to 1. Otherwise they will be normalized.
280-
Defaults to None, all variable have the same a prior probability
96+
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
97+
1. Otherwise they will be normalized.
98+
Defaults to None, i.e. all covariates have the same prior probability to be selected.
28199
"""
282100

283-
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
284-
super().__init__(X, Y, m, alpha, split_prior)
101+
def __new__(
102+
cls,
103+
name,
104+
X,
105+
Y,
106+
m=50,
107+
alpha=0.25,
108+
k=2,
109+
split_prior=None,
110+
**kwargs,
111+
):
112+
113+
cls.all_trees = []
114+
115+
bart_op = type(
116+
f"BART_{name}",
117+
(BARTRV,),
118+
dict(
119+
name="BART",
120+
all_trees=cls.all_trees,
121+
inplace=False,
122+
initval=Y.mean(),
123+
X=X,
124+
Y=Y,
125+
m=m,
126+
alpha=alpha,
127+
k=k,
128+
split_prior=split_prior,
129+
),
130+
)()
131+
132+
NoDistribution.register(BARTRV)
133+
134+
cls.rv_op = bart_op
135+
params = [X, Y, m, alpha, k]
136+
return super().__new__(cls, name, *params, **kwargs)
137+
138+
@classmethod
139+
def dist(cls, *params, **kwargs):
140+
return super().dist(params, **kwargs)

0 commit comments

Comments
 (0)