@@ -31,16 +31,18 @@ class Node:
31
31
linear_params: Optional[List[float]] = None
32
32
"""
33
33
34
- __slots__ = "value" , "idx_split_variable" , "idx_data_points" , "linear_params"
34
+ __slots__ = "value" , "nvalue" , " idx_split_variable" , "idx_data_points" , "linear_params"
35
35
36
36
def __init__ (
37
37
self ,
38
38
value : npt .NDArray [np .float_ ] = np .array ([- 1.0 ]),
39
+ nvalue : int = 0 ,
39
40
idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
40
41
idx_split_variable : int = - 1 ,
41
42
linear_params : Optional [List [float ]] = None ,
42
43
) -> None :
43
44
self .value = value
45
+ self .nvalue = nvalue
44
46
self .idx_data_points = idx_data_points
45
47
self .idx_split_variable = idx_split_variable
46
48
self .linear_params = linear_params
@@ -49,12 +51,14 @@ def __init__(
49
51
def new_leaf_node (
50
52
cls ,
51
53
value : npt .NDArray [np .float_ ],
54
+ nvalue : int = 0 ,
52
55
idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
53
56
idx_split_variable : int = - 1 ,
54
57
linear_params : Optional [List [float ]] = None ,
55
58
) -> "Node" :
56
59
return cls (
57
60
value = value ,
61
+ nvalue = nvalue ,
58
62
idx_data_points = idx_data_points ,
59
63
idx_split_variable = idx_split_variable ,
60
64
linear_params = linear_params ,
@@ -152,6 +156,7 @@ def copy(self) -> "Tree":
152
156
tree : Dict [int , Node ] = {
153
157
k : Node (
154
158
value = v .value ,
159
+ nvalue = v .nvalue ,
155
160
idx_data_points = v .idx_data_points ,
156
161
idx_split_variable = v .idx_split_variable ,
157
162
linear_params = v .linear_params ,
@@ -186,6 +191,7 @@ def trim(self) -> "Tree":
186
191
tree : Dict [int , Node ] = {
187
192
k : Node (
188
193
value = v .value ,
194
+ nvalue = v .nvalue ,
189
195
idx_data_points = None ,
190
196
idx_split_variable = v .idx_split_variable ,
191
197
linear_params = v .linear_params ,
@@ -274,8 +280,12 @@ def _traverse_tree(
274
280
275
281
if excluded is not None and current_node .idx_split_variable in excluded :
276
282
leaf_values : List [npt .NDArray [np .float_ ]] = []
277
- self ._traverse_leaf_values (leaf_values , node_index )
278
- return np .mean (leaf_values , axis = 0 )
283
+ leaf_n_values : List [int ] = []
284
+ self ._traverse_leaf_values (leaf_values , leaf_n_values , node_index )
285
+ return (
286
+ leaf_values [0 ].mean (axis = 0 ) * leaf_n_values [0 ]
287
+ + leaf_values [1 ].mean (axis = 0 ) * leaf_n_values [1 ]
288
+ )
279
289
280
290
if x [current_node .idx_split_variable ] <= current_node .value :
281
291
next_node = get_idx_left_child (node_index )
@@ -286,7 +296,7 @@ def _traverse_tree(
286
296
)
287
297
288
298
def _traverse_leaf_values (
289
- self , leaf_values : List [npt .NDArray [np .float_ ]], node_index : int
299
+ self , leaf_values : List [npt .NDArray [np .float_ ]], leaf_n_values : List [ int ], node_index : int
290
300
) -> None :
291
301
"""
292
302
Traverse the tree appending leaf values starting from a particular node.
@@ -299,6 +309,7 @@ def _traverse_leaf_values(
299
309
node = self .get_node (node_index )
300
310
if node .is_leaf_node ():
301
311
leaf_values .append (node .value )
312
+ leaf_n_values .append (node .nvalue )
302
313
else :
303
- self ._traverse_leaf_values (leaf_values , get_idx_left_child (node_index ))
304
- self ._traverse_leaf_values (leaf_values , get_idx_right_child (node_index ))
314
+ self ._traverse_leaf_values (leaf_values , leaf_n_values , get_idx_left_child (node_index ))
315
+ self ._traverse_leaf_values (leaf_values , leaf_n_values , get_idx_right_child (node_index ))
0 commit comments