Skip to content

Commit d6cd156

Browse files
authored
Do weighted mean when pruning (#83)
* do weighted mean when prunning * generalize to constant response * black * remove test value
1 parent 144e048 commit d6cd156

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

pymc_bart/pgbart.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def grow_tree(
439439

440440
new_node = Node.new_leaf_node(
441441
value=node_value,
442+
nvalue=len(idx_data_point),
442443
idx_data_points=idx_data_point,
443444
linear_params=linear_params,
444445
)

pymc_bart/tree.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,18 @@ class Node:
3131
linear_params: Optional[List[float]] = None
3232
"""
3333

34-
__slots__ = "value", "idx_split_variable", "idx_data_points", "linear_params"
34+
__slots__ = "value", "nvalue", "idx_split_variable", "idx_data_points", "linear_params"
3535

3636
def __init__(
3737
self,
3838
value: npt.NDArray[np.float_] = np.array([-1.0]),
39+
nvalue: int = 0,
3940
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
4041
idx_split_variable: int = -1,
4142
linear_params: Optional[List[float]] = None,
4243
) -> None:
4344
self.value = value
45+
self.nvalue = nvalue
4446
self.idx_data_points = idx_data_points
4547
self.idx_split_variable = idx_split_variable
4648
self.linear_params = linear_params
@@ -49,12 +51,14 @@ def __init__(
4951
def new_leaf_node(
5052
cls,
5153
value: npt.NDArray[np.float_],
54+
nvalue: int = 0,
5255
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
5356
idx_split_variable: int = -1,
5457
linear_params: Optional[List[float]] = None,
5558
) -> "Node":
5659
return cls(
5760
value=value,
61+
nvalue=nvalue,
5862
idx_data_points=idx_data_points,
5963
idx_split_variable=idx_split_variable,
6064
linear_params=linear_params,
@@ -152,6 +156,7 @@ def copy(self) -> "Tree":
152156
tree: Dict[int, Node] = {
153157
k: Node(
154158
value=v.value,
159+
nvalue=v.nvalue,
155160
idx_data_points=v.idx_data_points,
156161
idx_split_variable=v.idx_split_variable,
157162
linear_params=v.linear_params,
@@ -186,6 +191,7 @@ def trim(self) -> "Tree":
186191
tree: Dict[int, Node] = {
187192
k: Node(
188193
value=v.value,
194+
nvalue=v.nvalue,
189195
idx_data_points=None,
190196
idx_split_variable=v.idx_split_variable,
191197
linear_params=v.linear_params,
@@ -274,8 +280,12 @@ def _traverse_tree(
274280

275281
if excluded is not None and current_node.idx_split_variable in excluded:
276282
leaf_values: List[npt.NDArray[np.float_]] = []
277-
self._traverse_leaf_values(leaf_values, node_index)
278-
return np.mean(leaf_values, axis=0)
283+
leaf_n_values: List[int] = []
284+
self._traverse_leaf_values(leaf_values, leaf_n_values, node_index)
285+
return (
286+
leaf_values[0].mean(axis=0) * leaf_n_values[0]
287+
+ leaf_values[1].mean(axis=0) * leaf_n_values[1]
288+
)
279289

280290
if x[current_node.idx_split_variable] <= current_node.value:
281291
next_node = get_idx_left_child(node_index)
@@ -286,7 +296,7 @@ def _traverse_tree(
286296
)
287297

288298
def _traverse_leaf_values(
289-
self, leaf_values: List[npt.NDArray[np.float_]], node_index: int
299+
self, leaf_values: List[npt.NDArray[np.float_]], leaf_n_values: List[int], node_index: int
290300
) -> None:
291301
"""
292302
Traverse the tree appending leaf values starting from a particular node.
@@ -299,6 +309,7 @@ def _traverse_leaf_values(
299309
node = self.get_node(node_index)
300310
if node.is_leaf_node():
301311
leaf_values.append(node.value)
312+
leaf_n_values.append(node.nvalue)
302313
else:
303-
self._traverse_leaf_values(leaf_values, get_idx_left_child(node_index))
304-
self._traverse_leaf_values(leaf_values, get_idx_right_child(node_index))
314+
self._traverse_leaf_values(leaf_values, leaf_n_values, get_idx_left_child(node_index))
315+
self._traverse_leaf_values(leaf_values, leaf_n_values, get_idx_right_child(node_index))

0 commit comments

Comments
 (0)