3131class ParticleTree :
3232 """Particle tree."""
3333
34- __slots__ = "tree" , "expansion_nodes" , "log_weight" , "kfactor"
34+ __slots__ = "tree" , "expansion_nodes" , "log_weight"
3535
36- def __init__ (self , tree : Tree , kfactor : float = 0.75 ):
36+ def __init__ (self , tree : Tree ):
3737 self .tree : Tree = tree .copy ()
3838 self .expansion_nodes : List [int ] = [0 ]
3939 self .log_weight : float = 0
40- self .kfactor : float = kfactor
4140
4241 def copy (self ) -> "ParticleTree" :
4342 p = ParticleTree (self .tree )
4443 p .expansion_nodes = self .expansion_nodes .copy ()
45- p .kfactor = self .kfactor
4644 return p
4745
4846 def sample_tree (
@@ -53,6 +51,7 @@ def sample_tree(
5351 X ,
5452 missing_data ,
5553 sum_trees ,
54+ leaf_sd ,
5655 m ,
5756 response ,
5857 normal ,
@@ -73,10 +72,10 @@ def sample_tree(
7372 X ,
7473 missing_data ,
7574 sum_trees ,
75+ leaf_sd ,
7676 m ,
7777 response ,
7878 normal ,
79- self .kfactor ,
8079 shape ,
8180 )
8281 if idx_new_nodes is not None :
@@ -95,7 +94,7 @@ class PGBART(ArrayStepShared):
9594 vars: list
9695 List of value variables for sampler
9796 num_particles : tuple
98- Number of particles. Defaults to 20
97+ Number of particles. Defaults to 10
9998 batch : int or tuple
10099 Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
101100 during tuning and after tuning. If a tuple is passed the first element is the batch size
@@ -112,7 +111,7 @@ class PGBART(ArrayStepShared):
112111 def __init__ (
113112 self ,
114113 vars = None , # pylint: disable=redefined-builtin
115- num_particles : int = 20 ,
114+ num_particles : int = 10 ,
116115 batch : Tuple [float , float ] = (0.1 , 0.1 ),
117116 model : Optional [Model ] = None ,
118117 ):
@@ -141,17 +140,20 @@ def __init__(
141140 self .alpha_vec = self .bart .split_prior
142141 else :
143142 self .alpha_vec = np .ones (self .X .shape [1 ], dtype = np .int32 )
143+
144144 init_mean = self .bart .Y .mean ()
145+ self .num_observations = self .X .shape [0 ]
146+ self .num_variates = self .X .shape [1 ]
147+ self .available_predictors = list (range (self .num_variates ))
148+
145149 # if data is binary
146150 y_unique = np .unique (self .bart .Y )
147151 if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
148- mu_std = 3 / self .m ** 0.5
152+ self . leaf_sd = 3 / self .m ** 0.5
149153 else :
150- mu_std = self .bart .Y .std () / self .m ** 0.5
154+ self . leaf_sd = self .bart .Y .std () / self .m ** 0.5
151155
152- self .num_observations = self .X .shape [0 ]
153- self .num_variates = self .X .shape [1 ]
154- self .available_predictors = list (range (self .num_variates ))
156+ self .running_sd = RunningSd (shape )
155157
156158 self .sum_trees = np .full ((self .shape , self .bart .Y .shape [0 ]), init_mean ).astype (
157159 config .floatX
@@ -164,10 +166,9 @@ def __init__(
164166 shape = self .shape ,
165167 )
166168
167- self .normal = NormalSampler (mu_std , self .shape )
169+ self .normal = NormalSampler (1 , self .shape )
168170 self .uniform = UniformSampler (0 , 1 )
169- self .uniform_kf = UniformSampler (0.33 , 0.75 , self .shape )
170- self .prior_prob_leaf_node = compute_prior_probability (self .bart .alpha )
171+ self .prior_prob_leaf_node = compute_prior_probability (self .bart .alpha , self .bart .beta )
171172 self .ssv = SampleSplittingVariable (self .alpha_vec )
172173
173174 self .tune = True
@@ -212,6 +213,7 @@ def astep(self, _):
212213 self .X ,
213214 self .missing_data ,
214215 self .sum_trees ,
216+ self .leaf_sd ,
215217 self .m ,
216218 self .response ,
217219 self .normal ,
@@ -235,16 +237,25 @@ def astep(self, _):
235237 particles , normalized_weights
236238 )
237239 # Update the sum of trees
238- self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
240+ new = new_tree ._predict ()
241+ self .sum_trees = self .sum_trees_noi + new
239242 # To reduce memory usage, we trim the tree
240243 self .all_trees [tree_id ] = new_tree .trim ()
241244
242245 if self .tune :
243246 # Update the splitting variable and the splitting variable sampler
244247 if self .iter > self .m :
245248 self .ssv = SampleSplittingVariable (self .alpha_vec )
249+
246250 for index in new_tree .get_split_variables ():
247251 self .alpha_vec [index ] += 1
252+
253+ # update standard deviation at leaf nodes
254+ if self .iter > 2 :
255+ self .leaf_sd = self .running_sd .update (new )
256+ else :
257+ self .running_sd .update (new )
258+
248259 else :
249260 # update the variable inclusion
250261 for index in new_tree .get_split_variables ():
@@ -320,10 +331,7 @@ def init_particles(self, tree_id: int) -> List[ParticleTree]:
320331 self .update_weight (p0 )
321332 particles : List [ParticleTree ] = [p0 ]
322333
323- particles .extend (
324- ParticleTree (self .a_tree , self .uniform_kf .rvs () if self .tune else p0 .kfactor )
325- for _ in self .indices
326- )
334+ particles .extend (ParticleTree (self .a_tree ) for _ in self .indices )
327335 return particles
328336
329337 def update_weight (self , particle : ParticleTree ) -> None :
@@ -344,6 +352,34 @@ def competence(var, has_grad):
344352 return Competence .INCOMPATIBLE
345353
346354
355+ class RunningSd :
356+ def __init__ (self , shape : tuple ) -> None :
357+ self .count = 0 # number of data points
358+ self .mean = np .zeros (shape ) # running mean
359+ self .m_2 = np .zeros (shape ) # running second moment
360+
361+ def update (self , new_value : npt .NDArray [np .float_ ]) -> Union [float , npt .NDArray [np .float_ ]]:
362+ self .count = self .count + 1
363+ self .mean , self .m_2 , std = _update (self .count , self .mean , self .m_2 , new_value )
364+ return fast_mean (std )
365+
366+
367+ @njit
368+ def _update (
369+ count : int ,
370+ mean : npt .NDArray [np .float_ ],
371+ m_2 : npt .NDArray [np .float_ ],
372+ new_value : npt .NDArray [np .float_ ],
373+ ) -> Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ], Union [float , npt .NDArray [np .float_ ]]]:
374+ delta = new_value - mean
375+ mean += delta / count
376+ delta2 = new_value - mean
377+ m_2 += delta * delta2
378+
379+ std = (m_2 / count ) ** 0.5
380+ return mean , m_2 , std
381+
382+
347383class SampleSplittingVariable :
348384 def __init__ (self , alpha_vec : npt .NDArray [np .float_ ]) -> None :
349385 """
@@ -362,30 +398,26 @@ def rvs(self) -> Union[int, Tuple[int, float]]:
362398 return self .enu [- 1 ]
363399
364400
365- def compute_prior_probability (alpha ) -> List [float ]:
401+ def compute_prior_probability (alpha : int , beta : int ) -> List [float ]:
366402 """
367403 Calculate the probability of the node being a leaf node (1 - p(being split node)).
368404
369- Taken from equation 19 in [Rockova2018].
370-
371405 Parameters
372406 ----------
373407 alpha : float
408+ beta: float
374409
375410 Returns
376411 -------
377412 list with probabilities for leaf nodes
378-
379- References
380- ----------
381- .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
382- arXiv, `link <https://arxiv.org/abs/1810.00787>`__
383413 """
384414 prior_leaf_prob : List [float ] = [0 ]
385- depth = 1
386- while prior_leaf_prob [- 1 ] < 1 :
387- prior_leaf_prob .append (1 - alpha ** depth )
415+ depth = 0
416+ while prior_leaf_prob [- 1 ] < 0.9999 :
417+ prior_leaf_prob .append (1 - ( alpha * (( 1 + depth ) ** ( - beta ))) )
388418 depth += 1
419+ prior_leaf_prob .append (1 )
420+
389421 return prior_leaf_prob
390422
391423
@@ -397,10 +429,10 @@ def grow_tree(
397429 X ,
398430 missing_data ,
399431 sum_trees ,
432+ leaf_sd ,
400433 m ,
401434 response ,
402435 normal ,
403- kfactor ,
404436 shape ,
405437):
406438 current_node = tree .get_node (index_leaf_node )
@@ -432,7 +464,7 @@ def grow_tree(
432464 y_mu_pred = sum_trees [:, idx_data_point ],
433465 x_mu = X [idx_data_point , selected_predictor ],
434466 m = m ,
435- norm = normal .rvs () * kfactor ,
467+ norm = normal .rvs () * leaf_sd ,
436468 shape = shape ,
437469 response = response ,
438470 )
@@ -493,7 +525,7 @@ def draw_leaf_value(
493525 if response == "linear" :
494526 mu_mean , linear_params = fast_linear_fit (x = x_mu , y = y_mu_pred , m = m )
495527
496- draw = norm + mu_mean
528+ draw = mu_mean + norm
497529 return draw , linear_params
498530
499531
0 commit comments