Skip to content

Commit 7cf8595

Browse files
authored
refactor so predict does not need m (#88)
1 parent 77658b2 commit 7cf8595

File tree

6 files changed

+16
-32
lines changed

6 files changed

+16
-32
lines changed

pymc_bart/bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None,
5656
shape = size[0]
5757
else:
5858
shape = 1
59-
return _sample_posterior(cls.all_trees, cls.X, cls.m, rng=rng, shape=shape).squeeze().T
59+
return _sample_posterior(cls.all_trees, cls.X, rng=rng, shape=shape).squeeze().T
6060

6161

6262
bart = BARTRV()

pymc_bart/pgbart.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,8 +491,7 @@ def draw_leaf_value(
491491
if response == "constant":
492492
mu_mean = fast_mean(y_mu_pred) / m
493493
if response == "linear":
494-
y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred)
495-
mu_mean = y_fit / m
494+
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m)
496495

497496
draw = norm + mu_mean
498497
return draw, linear_params
@@ -518,9 +517,12 @@ def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_
518517

519518
@njit
520519
def fast_linear_fit(
521-
x: npt.NDArray[np.float_], y: npt.NDArray[np.float_]
520+
x: npt.NDArray[np.float_],
521+
y: npt.NDArray[np.float_],
522+
m: int,
522523
) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]:
523524
n = len(x)
525+
y = y / m
524526

525527
xbar = np.sum(x) / n
526528
ybar = np.sum(y, axis=1) / n

pymc_bart/tree.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def _predict(self) -> npt.NDArray[np.float_]:
221221
def predict(
222222
self,
223223
x: npt.NDArray[np.float_],
224-
m: int,
225224
excluded: Optional[List[int]] = None,
226225
shape: int = 1,
227226
) -> npt.NDArray[np.float_]:
@@ -232,8 +231,6 @@ def predict(
232231
----------
233232
x : npt.NDArray[np.float_]
234233
Unobserved point
235-
m : int
236-
Number of trees
237234
excluded: Optional[List[int]]
238235
Indexes of the variables to exclude when computing predictions
239236
@@ -244,12 +241,11 @@ def predict(
244241
"""
245242
if excluded is None:
246243
excluded = []
247-
return self._traverse_tree(x=x, m=m, excluded=excluded, shape=shape)
244+
return self._traverse_tree(x=x, excluded=excluded, shape=shape)
248245

249246
def _traverse_tree(
250247
self,
251248
x: npt.NDArray[np.float_],
252-
m: int,
253249
excluded: Optional[List[int]] = None,
254250
shape: int = 1,
255251
) -> npt.NDArray[np.float_]:
@@ -260,8 +256,6 @@ def _traverse_tree(
260256
----------
261257
x : npt.NDArray[np.float_]
262258
(Un)observed point
263-
m : int
264-
Number of trees
265259
node_index : int
266260
Index of the node to start the traversal from
267261
split_variable : int
@@ -274,7 +268,7 @@ def _traverse_tree(
274268
npt.NDArray[np.float_]
275269
Leaf node value or mean of leaf node values
276270
"""
277-
stack = [(0, 1.0)] # (node_index, prop) initial state
271+
stack = [(0, 1.0)] # (node_index, weight) initial state
278272
p_d = np.zeros(shape)
279273
while stack:
280274
node_index, weight = stack.pop()
@@ -285,7 +279,7 @@ def _traverse_tree(
285279
p_d += weight * node.value
286280
else:
287281
# this produce nonsensical results
288-
p_d += weight * ((params[0] + params[1] * x[node.idx_split_variable]) / m)
282+
p_d += weight * (params[0] + params[1] * x[node.idx_split_variable])
289283
# this produce reasonable result
290284
# p_d += weight * node.value.mean()
291285
else:

pymc_bart/utils.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
def _sample_posterior(
2222
all_trees: List[List[Tree]],
2323
X: TensorLike,
24-
m: int,
2524
rng: np.random.Generator,
2625
size: Optional[Union[int, Tuple[int, ...]]] = None,
2726
excluded: Optional[List[int]] = None,
@@ -37,8 +36,6 @@ def _sample_posterior(
3736
X : tensor-like
3837
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
3938
out-of-sample predictions.
40-
m : int
41-
Number of trees
4239
rng : NumPy RandomGenerator
4340
size : int or tuple
4441
Number of samples.
@@ -66,7 +63,7 @@ def _sample_posterior(
6663

6764
for ind, p in enumerate(pred):
6865
for tree in stacked_trees[idx[ind]]:
69-
p += np.vstack([tree.predict(x=x, m=m, excluded=excluded, shape=shape) for x in X])
66+
p += np.vstack([tree.predict(x=x, excluded=excluded, shape=shape) for x in X])
7067
pred.reshape((*size_iter, shape, -1))
7168
return pred
7269

@@ -239,7 +236,6 @@ def plot_ice(
239236
axes: matplotlib axes
240237
"""
241238
all_trees = bartrv.owner.op.all_trees
242-
m: int = bartrv.owner.op.m
243239
rng = np.random.default_rng(random_seed)
244240

245241
if func is None:
@@ -271,7 +267,7 @@ def plot_ice(
271267
fake_X[:, indices_mi] = X[:, indices_mi][instance]
272268
y_pred.append(
273269
np.mean(
274-
_sample_posterior(all_trees, X=fake_X, m=m, rng=rng, size=samples, shape=shape),
270+
_sample_posterior(all_trees, X=fake_X, rng=rng, size=samples, shape=shape),
275271
0,
276272
)
277273
)
@@ -386,7 +382,6 @@ def plot_pdp(
386382
axes: matplotlib axes
387383
"""
388384
all_trees: list = bartrv.owner.op.all_trees
389-
m: int = bartrv.owner.op.m
390385
rng = np.random.default_rng(random_seed)
391386

392387
if func is None:
@@ -411,7 +406,7 @@ def plot_pdp(
411406
excluded.remove(var)
412407
fake_X, new_x = _create_pdp_data(X, xs_interval, var, xs_values, var_discrete)
413408
p_d = _sample_posterior(
414-
all_trees, X=fake_X, m=m, rng=rng, size=samples, excluded=excluded, shape=shape
409+
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
415410
)
416411

417412
for s_i in range(shape):
@@ -738,8 +733,6 @@ def plot_variable_importance(
738733
"""
739734
_, axes = plt.subplots(2, 1, figsize=figsize)
740735

741-
m: int = bartrv.owner.op.m
742-
743736
if bartrv.ndim == 1: # type: ignore
744737
shape = 1
745738
else:
@@ -775,7 +768,7 @@ def plot_variable_importance(
775768
all_trees = bartrv.owner.op.all_trees
776769

777770
predicted_all = _sample_posterior(
778-
all_trees, X=X, m=m, rng=rng, size=samples, excluded=None, shape=shape
771+
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
779772
)
780773

781774
ev_mean = np.zeros(len(var_imp))
@@ -784,7 +777,6 @@ def plot_variable_importance(
784777
predicted_subset = _sample_posterior(
785778
all_trees=all_trees,
786779
X=X,
787-
m=m,
788780
rng=rng,
789781
size=samples,
790782
excluded=subset,

tests/test_bart.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,9 @@ class TestUtils:
139139
def test_sample_posterior(self):
140140
all_trees = self.mu.owner.op.all_trees
141141
rng = np.random.default_rng(3)
142-
pred_all = pmb.utils._sample_posterior(
143-
all_trees, X=self.X, m=self.mu.owner.op.m, rng=rng, size=2
144-
)
142+
pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2)
145143
rng = np.random.default_rng(3)
146-
pred_first = pmb.utils._sample_posterior(
147-
all_trees, X=self.X[:10], m=self.mu.owner.op.m, rng=rng
148-
)
144+
pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng)
149145

150146
assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4)
151147
assert pred_all.shape == (2, 50, 1)

tests/test_pgbart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_fast_mean():
5454
ids=["1d-id", "1d-const"],
5555
)
5656
def test_fast_linear_fit(x, y, a_expected, b_expected):
57-
y_fit, linear_params = fast_linear_fit(x, y)
57+
y_fit, linear_params = fast_linear_fit(x, y, m=1)
5858
assert linear_params[0] == a_expected
5959
assert linear_params[1] == b_expected
6060
np.testing.assert_almost_equal(

0 commit comments

Comments
 (0)