@@ -140,7 +140,11 @@ def new_tree(
140
140
) -> "Tree" :
141
141
return cls (
142
142
tree_structure = {
143
- 0 : Node .new_leaf_node (value = leaf_node_value , idx_data_points = idx_data_points )
143
+ 0 : Node .new_leaf_node (
144
+ value = leaf_node_value ,
145
+ nvalue = len (idx_data_points ) if idx_data_points is not None else 0 ,
146
+ idx_data_points = idx_data_points ,
147
+ )
144
148
},
145
149
idx_leaf_nodes = [0 ],
146
150
output = np .zeros ((num_observations , shape )).astype (config .floatX ).squeeze (),
@@ -215,7 +219,11 @@ def _predict(self) -> npt.NDArray[np.float_]:
215
219
return output .T
216
220
217
221
def predict (
218
- self , x : npt .NDArray [np .float_ ], m : int , excluded : Optional [List [int ]] = None
222
+ self ,
223
+ x : npt .NDArray [np .float_ ],
224
+ m : int ,
225
+ excluded : Optional [List [int ]] = None ,
226
+ shape : int = 1 ,
219
227
) -> npt .NDArray [np .float_ ]:
220
228
"""
221
229
Predict output of tree for an (un)observed point x.
@@ -236,23 +244,22 @@ def predict(
236
244
"""
237
245
if excluded is None :
238
246
excluded = []
239
- return self ._traverse_tree (x = x , m = m , node_index = 0 , split_variable = - 1 , excluded = excluded )
247
+ return self ._traverse_tree (x = x , m = m , excluded = excluded , shape = shape )
240
248
241
249
def _traverse_tree (
242
250
self ,
243
251
x : npt .NDArray [np .float_ ],
244
252
m : int ,
245
- node_index : int ,
246
- split_variable : int = - 1 ,
247
253
excluded : Optional [List [int ]] = None ,
254
+ shape : int = 1 ,
248
255
) -> npt .NDArray [np .float_ ]:
249
256
"""
250
- Traverse the tree starting from a particular node given an unobserved point.
257
+ Traverse the tree starting from the root node given an (un)observed point.
251
258
252
259
Parameters
253
260
----------
254
261
x : npt.NDArray[np.float_]
255
- Unobserved point
262
+ (Un)observed point
256
263
m : int
257
264
Number of trees
258
265
node_index : int
@@ -267,33 +274,37 @@ def _traverse_tree(
267
274
npt.NDArray[np.float_]
268
275
Leaf node value or mean of leaf node values
269
276
"""
270
- current_node = self .get_node (node_index )
271
- if current_node .is_leaf_node ():
272
- if current_node .linear_params is None :
273
- return np .array (current_node .value )
274
-
275
- x = x [split_variable ].item ()
276
- y_x = current_node .linear_params [0 ] + current_node .linear_params [1 ] * x
277
- return np .array (y_x / m )
278
-
279
- split_variable = current_node .idx_split_variable
280
-
281
- if excluded is not None and current_node .idx_split_variable in excluded :
282
- leaf_values : List [npt .NDArray [np .float_ ]] = []
283
- leaf_n_values : List [int ] = []
284
- self ._traverse_leaf_values (leaf_values , leaf_n_values , node_index )
285
- return (
286
- leaf_values [0 ].mean (axis = 0 ) * leaf_n_values [0 ]
287
- + leaf_values [1 ].mean (axis = 0 ) * leaf_n_values [1 ]
288
- )
289
-
290
- if x [current_node .idx_split_variable ] <= current_node .value :
291
- next_node = get_idx_left_child (node_index )
292
- else :
293
- next_node = get_idx_right_child (node_index )
294
- return self ._traverse_tree (
295
- x = x , m = m , node_index = next_node , split_variable = split_variable , excluded = excluded
296
- )
277
+ stack = [(0 , 1.0 )] # (node_index, prop) initial state
278
+ p_d = np .zeros (shape )
279
+ while stack :
280
+ node_index , weight = stack .pop ()
281
+ node = self .get_node (node_index )
282
+ if node .is_leaf_node ():
283
+ params = node .linear_params
284
+ if params is None :
285
+ p_d += weight * node .value
286
+ else :
287
+ # this produce nonsensical results
288
+ p_d += weight * ((params [0 ] + params [1 ] * x [node .idx_split_variable ]) / m )
289
+ # this produce reasonable result
290
+ # p_d += weight * node.value.mean()
291
+ else :
292
+ if excluded is not None and node .idx_split_variable in excluded :
293
+ left_node_index , right_node_index = get_idx_left_child (
294
+ node_index
295
+ ), get_idx_right_child (node_index )
296
+ prop_nvalue_left = self .get_node (left_node_index ).nvalue / node .nvalue
297
+ stack .append ((left_node_index , weight * prop_nvalue_left ))
298
+ stack .append ((right_node_index , weight * (1 - prop_nvalue_left )))
299
+ else :
300
+ next_node = (
301
+ get_idx_left_child (node_index )
302
+ if x [node .idx_split_variable ] <= node .value
303
+ else get_idx_right_child (node_index )
304
+ )
305
+ stack .append ((next_node , weight ))
306
+
307
+ return p_d
297
308
298
309
def _traverse_leaf_values (
299
310
self , leaf_values : List [npt .NDArray [np .float_ ]], leaf_n_values : List [int ], node_index : int
0 commit comments