31
31
class ParticleTree :
32
32
"""Particle tree."""
33
33
34
- __slots__ = "tree" , "expansion_nodes" , "log_weight" , "kfactor"
34
+ __slots__ = "tree" , "expansion_nodes" , "log_weight"
35
35
36
- def __init__ (self , tree : Tree , kfactor : float = 0.75 ):
36
+ def __init__ (self , tree : Tree ):
37
37
self .tree : Tree = tree .copy ()
38
38
self .expansion_nodes : List [int ] = [0 ]
39
39
self .log_weight : float = 0
40
- self .kfactor : float = kfactor
41
40
42
41
def copy (self ) -> "ParticleTree" :
43
42
p = ParticleTree (self .tree )
44
43
p .expansion_nodes = self .expansion_nodes .copy ()
45
- p .kfactor = self .kfactor
46
44
return p
47
45
48
46
def sample_tree (
@@ -53,6 +51,7 @@ def sample_tree(
53
51
X ,
54
52
missing_data ,
55
53
sum_trees ,
54
+ leaf_sd ,
56
55
m ,
57
56
response ,
58
57
normal ,
@@ -73,10 +72,10 @@ def sample_tree(
73
72
X ,
74
73
missing_data ,
75
74
sum_trees ,
75
+ leaf_sd ,
76
76
m ,
77
77
response ,
78
78
normal ,
79
- self .kfactor ,
80
79
shape ,
81
80
)
82
81
if idx_new_nodes is not None :
@@ -95,7 +94,7 @@ class PGBART(ArrayStepShared):
95
94
vars: list
96
95
List of value variables for sampler
97
96
num_particles : tuple
98
- Number of particles. Defaults to 20
97
+ Number of particles. Defaults to 10
99
98
batch : int or tuple
100
99
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
101
100
during tuning and after tuning. If a tuple is passed the first element is the batch size
@@ -112,7 +111,7 @@ class PGBART(ArrayStepShared):
112
111
def __init__ (
113
112
self ,
114
113
vars = None , # pylint: disable=redefined-builtin
115
- num_particles : int = 20 ,
114
+ num_particles : int = 10 ,
116
115
batch : Tuple [float , float ] = (0.1 , 0.1 ),
117
116
model : Optional [Model ] = None ,
118
117
):
@@ -141,17 +140,20 @@ def __init__(
141
140
self .alpha_vec = self .bart .split_prior
142
141
else :
143
142
self .alpha_vec = np .ones (self .X .shape [1 ], dtype = np .int32 )
143
+
144
144
init_mean = self .bart .Y .mean ()
145
+ self .num_observations = self .X .shape [0 ]
146
+ self .num_variates = self .X .shape [1 ]
147
+ self .available_predictors = list (range (self .num_variates ))
148
+
145
149
# if data is binary
146
150
y_unique = np .unique (self .bart .Y )
147
151
if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
148
- mu_std = 3 / self .m ** 0.5
152
+ self . leaf_sd = 3 / self .m ** 0.5
149
153
else :
150
- mu_std = self .bart .Y .std () / self .m ** 0.5
154
+ self . leaf_sd = self .bart .Y .std () / self .m ** 0.5
151
155
152
- self .num_observations = self .X .shape [0 ]
153
- self .num_variates = self .X .shape [1 ]
154
- self .available_predictors = list (range (self .num_variates ))
156
+ self .running_sd = RunningSd (shape )
155
157
156
158
self .sum_trees = np .full ((self .shape , self .bart .Y .shape [0 ]), init_mean ).astype (
157
159
config .floatX
@@ -164,10 +166,9 @@ def __init__(
164
166
shape = self .shape ,
165
167
)
166
168
167
- self .normal = NormalSampler (mu_std , self .shape )
169
+ self .normal = NormalSampler (1 , self .shape )
168
170
self .uniform = UniformSampler (0 , 1 )
169
- self .uniform_kf = UniformSampler (0.33 , 0.75 , self .shape )
170
- self .prior_prob_leaf_node = compute_prior_probability (self .bart .alpha )
171
+ self .prior_prob_leaf_node = compute_prior_probability (self .bart .alpha , self .bart .beta )
171
172
self .ssv = SampleSplittingVariable (self .alpha_vec )
172
173
173
174
self .tune = True
@@ -212,6 +213,7 @@ def astep(self, _):
212
213
self .X ,
213
214
self .missing_data ,
214
215
self .sum_trees ,
216
+ self .leaf_sd ,
215
217
self .m ,
216
218
self .response ,
217
219
self .normal ,
@@ -235,16 +237,25 @@ def astep(self, _):
235
237
particles , normalized_weights
236
238
)
237
239
# Update the sum of trees
238
- self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
240
+ new = new_tree ._predict ()
241
+ self .sum_trees = self .sum_trees_noi + new
239
242
# To reduce memory usage, we trim the tree
240
243
self .all_trees [tree_id ] = new_tree .trim ()
241
244
242
245
if self .tune :
243
246
# Update the splitting variable and the splitting variable sampler
244
247
if self .iter > self .m :
245
248
self .ssv = SampleSplittingVariable (self .alpha_vec )
249
+
246
250
for index in new_tree .get_split_variables ():
247
251
self .alpha_vec [index ] += 1
252
+
253
+ # update standard deviation at leaf nodes
254
+ if self .iter > 2 :
255
+ self .leaf_sd = self .running_sd .update (new )
256
+ else :
257
+ self .running_sd .update (new )
258
+
248
259
else :
249
260
# update the variable inclusion
250
261
for index in new_tree .get_split_variables ():
@@ -320,10 +331,7 @@ def init_particles(self, tree_id: int) -> List[ParticleTree]:
320
331
self .update_weight (p0 )
321
332
particles : List [ParticleTree ] = [p0 ]
322
333
323
- particles .extend (
324
- ParticleTree (self .a_tree , self .uniform_kf .rvs () if self .tune else p0 .kfactor )
325
- for _ in self .indices
326
- )
334
+ particles .extend (ParticleTree (self .a_tree ) for _ in self .indices )
327
335
return particles
328
336
329
337
def update_weight (self , particle : ParticleTree ) -> None :
@@ -344,6 +352,34 @@ def competence(var, has_grad):
344
352
return Competence .INCOMPATIBLE
345
353
346
354
355
+ class RunningSd :
356
+ def __init__ (self , shape : tuple ) -> None :
357
+ self .count = 0 # number of data points
358
+ self .mean = np .zeros (shape ) # running mean
359
+ self .m_2 = np .zeros (shape ) # running second moment
360
+
361
+ def update (self , new_value : npt .NDArray [np .float_ ]) -> Union [float , npt .NDArray [np .float_ ]]:
362
+ self .count = self .count + 1
363
+ self .mean , self .m_2 , std = _update (self .count , self .mean , self .m_2 , new_value )
364
+ return fast_mean (std )
365
+
366
+
367
+ @njit
368
+ def _update (
369
+ count : int ,
370
+ mean : npt .NDArray [np .float_ ],
371
+ m_2 : npt .NDArray [np .float_ ],
372
+ new_value : npt .NDArray [np .float_ ],
373
+ ) -> Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ], Union [float , npt .NDArray [np .float_ ]]]:
374
+ delta = new_value - mean
375
+ mean += delta / count
376
+ delta2 = new_value - mean
377
+ m_2 += delta * delta2
378
+
379
+ std = (m_2 / count ) ** 0.5
380
+ return mean , m_2 , std
381
+
382
+
347
383
class SampleSplittingVariable :
348
384
def __init__ (self , alpha_vec : npt .NDArray [np .float_ ]) -> None :
349
385
"""
@@ -362,30 +398,26 @@ def rvs(self) -> Union[int, Tuple[int, float]]:
362
398
return self .enu [- 1 ]
363
399
364
400
365
- def compute_prior_probability (alpha ) -> List [float ]:
401
+ def compute_prior_probability (alpha : int , beta : int ) -> List [float ]:
366
402
"""
367
403
Calculate the probability of the node being a leaf node (1 - p(being split node)).
368
404
369
- Taken from equation 19 in [Rockova2018].
370
-
371
405
Parameters
372
406
----------
373
407
alpha : float
408
+ beta: float
374
409
375
410
Returns
376
411
-------
377
412
list with probabilities for leaf nodes
378
-
379
- References
380
- ----------
381
- .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
382
- arXiv, `link <https://arxiv.org/abs/1810.00787>`__
383
413
"""
384
414
prior_leaf_prob : List [float ] = [0 ]
385
- depth = 1
386
- while prior_leaf_prob [- 1 ] < 1 :
387
- prior_leaf_prob .append (1 - alpha ** depth )
415
+ depth = 0
416
+ while prior_leaf_prob [- 1 ] < 0.9999 :
417
+ prior_leaf_prob .append (1 - ( alpha * (( 1 + depth ) ** ( - beta ))) )
388
418
depth += 1
419
+ prior_leaf_prob .append (1 )
420
+
389
421
return prior_leaf_prob
390
422
391
423
@@ -397,10 +429,10 @@ def grow_tree(
397
429
X ,
398
430
missing_data ,
399
431
sum_trees ,
432
+ leaf_sd ,
400
433
m ,
401
434
response ,
402
435
normal ,
403
- kfactor ,
404
436
shape ,
405
437
):
406
438
current_node = tree .get_node (index_leaf_node )
@@ -432,7 +464,7 @@ def grow_tree(
432
464
y_mu_pred = sum_trees [:, idx_data_point ],
433
465
x_mu = X [idx_data_point , selected_predictor ],
434
466
m = m ,
435
- norm = normal .rvs () * kfactor ,
467
+ norm = normal .rvs () * leaf_sd ,
436
468
shape = shape ,
437
469
response = response ,
438
470
)
@@ -493,7 +525,7 @@ def draw_leaf_value(
493
525
if response == "linear" :
494
526
mu_mean , linear_params = fast_linear_fit (x = x_mu , y = y_mu_pred , m = m )
495
527
496
- draw = norm + mu_mean
528
+ draw = mu_mean + norm
497
529
return draw , linear_params
498
530
499
531
0 commit comments