Skip to content

Commit 77658b2

Browse files
authored
refactor plot_dependence and implement fast version of pdp (#85)
1 parent d6cd156 commit 77658b2

File tree

6 files changed

+536
-197
lines changed

6 files changed

+536
-197
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ good-names=i,
254254
p0,
255255
p1,
256256
rv,
257+
fake_X,
257258
new_X,
258259
new_y,
259260
a,

pymc_bart/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515

1616
from pymc_bart.bart import BART
1717
from pymc_bart.pgbart import PGBART
18-
from pymc_bart.utils import plot_convergence, plot_dependence, plot_variable_importance
18+
from pymc_bart.utils import (
19+
plot_convergence,
20+
plot_pdp,
21+
plot_ice,
22+
plot_dependence,
23+
plot_variable_importance,
24+
)
1925

2026
__all__ = ["BART", "PGBART"]
2127
__version__ = "0.4.0"

pymc_bart/bart.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None,
5252
else:
5353
return np.full(cls.Y.shape[0], cls.Y.mean())
5454
else:
55-
return _sample_posterior(cls.all_trees, cls.X, cls.m, rng=rng).squeeze().T
55+
if size is not None:
56+
shape = size[0]
57+
else:
58+
shape = 1
59+
return _sample_posterior(cls.all_trees, cls.X, cls.m, rng=rng, shape=shape).squeeze().T
5660

5761

5862
bart = BARTRV()

pymc_bart/tree.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ def new_tree(
140140
) -> "Tree":
141141
return cls(
142142
tree_structure={
143-
0: Node.new_leaf_node(value=leaf_node_value, idx_data_points=idx_data_points)
143+
0: Node.new_leaf_node(
144+
value=leaf_node_value,
145+
nvalue=len(idx_data_points) if idx_data_points is not None else 0,
146+
idx_data_points=idx_data_points,
147+
)
144148
},
145149
idx_leaf_nodes=[0],
146150
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
@@ -215,7 +219,11 @@ def _predict(self) -> npt.NDArray[np.float_]:
215219
return output.T
216220

217221
def predict(
218-
self, x: npt.NDArray[np.float_], m: int, excluded: Optional[List[int]] = None
222+
self,
223+
x: npt.NDArray[np.float_],
224+
m: int,
225+
excluded: Optional[List[int]] = None,
226+
shape: int = 1,
219227
) -> npt.NDArray[np.float_]:
220228
"""
221229
Predict output of tree for an (un)observed point x.
@@ -236,23 +244,22 @@ def predict(
236244
"""
237245
if excluded is None:
238246
excluded = []
239-
return self._traverse_tree(x=x, m=m, node_index=0, split_variable=-1, excluded=excluded)
247+
return self._traverse_tree(x=x, m=m, excluded=excluded, shape=shape)
240248

241249
def _traverse_tree(
242250
self,
243251
x: npt.NDArray[np.float_],
244252
m: int,
245-
node_index: int,
246-
split_variable: int = -1,
247253
excluded: Optional[List[int]] = None,
254+
shape: int = 1,
248255
) -> npt.NDArray[np.float_]:
249256
"""
250-
Traverse the tree starting from a particular node given an unobserved point.
257+
Traverse the tree starting from the root node given an (un)observed point.
251258
252259
Parameters
253260
----------
254261
x : npt.NDArray[np.float_]
255-
Unobserved point
262+
(Un)observed point
256263
m : int
257264
Number of trees
258265
node_index : int
@@ -267,33 +274,37 @@ def _traverse_tree(
267274
npt.NDArray[np.float_]
268275
Leaf node value or mean of leaf node values
269276
"""
270-
current_node = self.get_node(node_index)
271-
if current_node.is_leaf_node():
272-
if current_node.linear_params is None:
273-
return np.array(current_node.value)
274-
275-
x = x[split_variable].item()
276-
y_x = current_node.linear_params[0] + current_node.linear_params[1] * x
277-
return np.array(y_x / m)
278-
279-
split_variable = current_node.idx_split_variable
280-
281-
if excluded is not None and current_node.idx_split_variable in excluded:
282-
leaf_values: List[npt.NDArray[np.float_]] = []
283-
leaf_n_values: List[int] = []
284-
self._traverse_leaf_values(leaf_values, leaf_n_values, node_index)
285-
return (
286-
leaf_values[0].mean(axis=0) * leaf_n_values[0]
287-
+ leaf_values[1].mean(axis=0) * leaf_n_values[1]
288-
)
289-
290-
if x[current_node.idx_split_variable] <= current_node.value:
291-
next_node = get_idx_left_child(node_index)
292-
else:
293-
next_node = get_idx_right_child(node_index)
294-
return self._traverse_tree(
295-
x=x, m=m, node_index=next_node, split_variable=split_variable, excluded=excluded
296-
)
277+
stack = [(0, 1.0)] # (node_index, prop) initial state
278+
p_d = np.zeros(shape)
279+
while stack:
280+
node_index, weight = stack.pop()
281+
node = self.get_node(node_index)
282+
if node.is_leaf_node():
283+
params = node.linear_params
284+
if params is None:
285+
p_d += weight * node.value
286+
else:
287+
# this produce nonsensical results
288+
p_d += weight * ((params[0] + params[1] * x[node.idx_split_variable]) / m)
289+
# this produce reasonable result
290+
# p_d += weight * node.value.mean()
291+
else:
292+
if excluded is not None and node.idx_split_variable in excluded:
293+
left_node_index, right_node_index = get_idx_left_child(
294+
node_index
295+
), get_idx_right_child(node_index)
296+
prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue
297+
stack.append((left_node_index, weight * prop_nvalue_left))
298+
stack.append((right_node_index, weight * (1 - prop_nvalue_left)))
299+
else:
300+
next_node = (
301+
get_idx_left_child(node_index)
302+
if x[node.idx_split_variable] <= node.value
303+
else get_idx_right_child(node_index)
304+
)
305+
stack.append((next_node, weight))
306+
307+
return p_d
297308

298309
def _traverse_leaf_values(
299310
self, leaf_values: List[npt.NDArray[np.float_]], leaf_n_values: List[int], node_index: int

0 commit comments

Comments
 (0)