Skip to content

Commit d1bb5b7

Browse files
authored
compute running variance for leaf nodes (#91)
* use running variance for leaf nodes * use running variance for leaf nodes
1 parent de582f7 commit d1bb5b7

File tree

2 files changed

+74
-38
lines changed

2 files changed

+74
-38
lines changed

pymc_bart/bart.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class BARTRV(RandomVariable):
3636

3737
name: str = "BART"
3838
ndim_supp = 1
39-
ndims_params: List[int] = [2, 1, 0, 0, 1]
39+
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
4040
dtype: str = "floatX"
4141
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
4242
all_trees = List[List[Tree]]
@@ -45,7 +45,9 @@ def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=Non
4545
return dist_params[0].shape[:1]
4646

4747
@classmethod
48-
def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, size=None):
48+
def rng_fn(
49+
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
50+
):
4951
if not cls.all_trees:
5052
if size is not None:
5153
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
@@ -94,7 +96,8 @@ def __new__(
9496
X: TensorLike,
9597
Y: TensorLike,
9698
m: int = 50,
97-
alpha: float = 0.25,
99+
alpha: float = 0.95,
100+
beta: float = 2,
98101
response: str = "constant",
99102
split_prior: Optional[List[float]] = None,
100103
**kwargs,
@@ -120,6 +123,7 @@ def __new__(
120123
m=m,
121124
response=response,
122125
alpha=alpha,
126+
beta=beta,
123127
split_prior=split_prior,
124128
),
125129
)()
@@ -131,7 +135,7 @@ def get_moment(rv, size, *rv_inputs):
131135
return cls.get_moment(rv, size, *rv_inputs)
132136

133137
cls.rv_op = bart_op
134-
params = [X, Y, m, alpha, split_prior]
138+
params = [X, Y, m, alpha, beta, split_prior]
135139
return super().__new__(cls, name, *params, **kwargs)
136140

137141
@classmethod

pymc_bart/pgbart.py

+66-34
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,16 @@
3131
class ParticleTree:
3232
"""Particle tree."""
3333

34-
__slots__ = "tree", "expansion_nodes", "log_weight", "kfactor"
34+
__slots__ = "tree", "expansion_nodes", "log_weight"
3535

36-
def __init__(self, tree: Tree, kfactor: float = 0.75):
36+
def __init__(self, tree: Tree):
3737
self.tree: Tree = tree.copy()
3838
self.expansion_nodes: List[int] = [0]
3939
self.log_weight: float = 0
40-
self.kfactor: float = kfactor
4140

4241
def copy(self) -> "ParticleTree":
4342
p = ParticleTree(self.tree)
4443
p.expansion_nodes = self.expansion_nodes.copy()
45-
p.kfactor = self.kfactor
4644
return p
4745

4846
def sample_tree(
@@ -53,6 +51,7 @@ def sample_tree(
5351
X,
5452
missing_data,
5553
sum_trees,
54+
leaf_sd,
5655
m,
5756
response,
5857
normal,
@@ -73,10 +72,10 @@ def sample_tree(
7372
X,
7473
missing_data,
7574
sum_trees,
75+
leaf_sd,
7676
m,
7777
response,
7878
normal,
79-
self.kfactor,
8079
shape,
8180
)
8281
if idx_new_nodes is not None:
@@ -95,7 +94,7 @@ class PGBART(ArrayStepShared):
9594
vars: list
9695
List of value variables for sampler
9796
num_particles : tuple
98-
Number of particles. Defaults to 20
97+
Number of particles. Defaults to 10
9998
batch : int or tuple
10099
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
101100
during tuning and after tuning. If a tuple is passed the first element is the batch size
@@ -112,7 +111,7 @@ class PGBART(ArrayStepShared):
112111
def __init__(
113112
self,
114113
vars=None, # pylint: disable=redefined-builtin
115-
num_particles: int = 20,
114+
num_particles: int = 10,
116115
batch: Tuple[float, float] = (0.1, 0.1),
117116
model: Optional[Model] = None,
118117
):
@@ -141,17 +140,20 @@ def __init__(
141140
self.alpha_vec = self.bart.split_prior
142141
else:
143142
self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32)
143+
144144
init_mean = self.bart.Y.mean()
145+
self.num_observations = self.X.shape[0]
146+
self.num_variates = self.X.shape[1]
147+
self.available_predictors = list(range(self.num_variates))
148+
145149
# if data is binary
146150
y_unique = np.unique(self.bart.Y)
147151
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
148-
mu_std = 3 / self.m**0.5
152+
self.leaf_sd = 3 / self.m**0.5
149153
else:
150-
mu_std = self.bart.Y.std() / self.m**0.5
154+
self.leaf_sd = self.bart.Y.std() / self.m**0.5
151155

152-
self.num_observations = self.X.shape[0]
153-
self.num_variates = self.X.shape[1]
154-
self.available_predictors = list(range(self.num_variates))
156+
self.running_sd = RunningSd(shape)
155157

156158
self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
157159
config.floatX
@@ -164,10 +166,9 @@ def __init__(
164166
shape=self.shape,
165167
)
166168

167-
self.normal = NormalSampler(mu_std, self.shape)
169+
self.normal = NormalSampler(1, self.shape)
168170
self.uniform = UniformSampler(0, 1)
169-
self.uniform_kf = UniformSampler(0.33, 0.75, self.shape)
170-
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha)
171+
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha, self.bart.beta)
171172
self.ssv = SampleSplittingVariable(self.alpha_vec)
172173

173174
self.tune = True
@@ -212,6 +213,7 @@ def astep(self, _):
212213
self.X,
213214
self.missing_data,
214215
self.sum_trees,
216+
self.leaf_sd,
215217
self.m,
216218
self.response,
217219
self.normal,
@@ -235,16 +237,25 @@ def astep(self, _):
235237
particles, normalized_weights
236238
)
237239
# Update the sum of trees
238-
self.sum_trees = self.sum_trees_noi + new_tree._predict()
240+
new = new_tree._predict()
241+
self.sum_trees = self.sum_trees_noi + new
239242
# To reduce memory usage, we trim the tree
240243
self.all_trees[tree_id] = new_tree.trim()
241244

242245
if self.tune:
243246
# Update the splitting variable and the splitting variable sampler
244247
if self.iter > self.m:
245248
self.ssv = SampleSplittingVariable(self.alpha_vec)
249+
246250
for index in new_tree.get_split_variables():
247251
self.alpha_vec[index] += 1
252+
253+
# update standard deviation at leaf nodes
254+
if self.iter > 2:
255+
self.leaf_sd = self.running_sd.update(new)
256+
else:
257+
self.running_sd.update(new)
258+
248259
else:
249260
# update the variable inclusion
250261
for index in new_tree.get_split_variables():
@@ -320,10 +331,7 @@ def init_particles(self, tree_id: int) -> List[ParticleTree]:
320331
self.update_weight(p0)
321332
particles: List[ParticleTree] = [p0]
322333

323-
particles.extend(
324-
ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor)
325-
for _ in self.indices
326-
)
334+
particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
327335
return particles
328336

329337
def update_weight(self, particle: ParticleTree) -> None:
@@ -344,6 +352,34 @@ def competence(var, has_grad):
344352
return Competence.INCOMPATIBLE
345353

346354

355+
class RunningSd:
356+
def __init__(self, shape: tuple) -> None:
357+
self.count = 0 # number of data points
358+
self.mean = np.zeros(shape) # running mean
359+
self.m_2 = np.zeros(shape) # running second moment
360+
361+
def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
362+
self.count = self.count + 1
363+
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
364+
return fast_mean(std)
365+
366+
367+
@njit
368+
def _update(
369+
count: int,
370+
mean: npt.NDArray[np.float_],
371+
m_2: npt.NDArray[np.float_],
372+
new_value: npt.NDArray[np.float_],
373+
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], Union[float, npt.NDArray[np.float_]]]:
374+
delta = new_value - mean
375+
mean += delta / count
376+
delta2 = new_value - mean
377+
m_2 += delta * delta2
378+
379+
std = (m_2 / count) ** 0.5
380+
return mean, m_2, std
381+
382+
347383
class SampleSplittingVariable:
348384
def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None:
349385
"""
@@ -362,30 +398,26 @@ def rvs(self) -> Union[int, Tuple[int, float]]:
362398
return self.enu[-1]
363399

364400

365-
def compute_prior_probability(alpha) -> List[float]:
401+
def compute_prior_probability(alpha: int, beta: int) -> List[float]:
366402
"""
367403
Calculate the probability of the node being a leaf node (1 - p(being split node)).
368404
369-
Taken from equation 19 in [Rockova2018].
370-
371405
Parameters
372406
----------
373407
alpha : float
408+
beta: float
374409
375410
Returns
376411
-------
377412
list with probabilities for leaf nodes
378-
379-
References
380-
----------
381-
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
382-
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
383413
"""
384414
prior_leaf_prob: List[float] = [0]
385-
depth = 1
386-
while prior_leaf_prob[-1] < 1:
387-
prior_leaf_prob.append(1 - alpha**depth)
415+
depth = 0
416+
while prior_leaf_prob[-1] < 0.9999:
417+
prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta))))
388418
depth += 1
419+
prior_leaf_prob.append(1)
420+
389421
return prior_leaf_prob
390422

391423

@@ -397,10 +429,10 @@ def grow_tree(
397429
X,
398430
missing_data,
399431
sum_trees,
432+
leaf_sd,
400433
m,
401434
response,
402435
normal,
403-
kfactor,
404436
shape,
405437
):
406438
current_node = tree.get_node(index_leaf_node)
@@ -432,7 +464,7 @@ def grow_tree(
432464
y_mu_pred=sum_trees[:, idx_data_point],
433465
x_mu=X[idx_data_point, selected_predictor],
434466
m=m,
435-
norm=normal.rvs() * kfactor,
467+
norm=normal.rvs() * leaf_sd,
436468
shape=shape,
437469
response=response,
438470
)
@@ -493,7 +525,7 @@ def draw_leaf_value(
493525
if response == "linear":
494526
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m)
495527

496-
draw = norm + mu_mean
528+
draw = mu_mean + norm
497529
return draw, linear_params
498530

499531

0 commit comments

Comments
 (0)