Skip to content

Commit 697f8dc

Browse files
authored
Merge pull request #14 from pymc-devs/bart
Move bart from PyMC
2 parents 8679539 + fa390cf commit 697f8dc

File tree

8 files changed

+1427
-0
lines changed

8 files changed

+1427
-0
lines changed

pymc_experimental/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@
99
if len(_log.handlers) == 0:
1010
handler = logging.StreamHandler()
1111
_log.addHandler(handler)
12+
13+
14+
from pymc_experimental.bart import *

pymc_experimental/bart/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from pymc_experimental.bart.bart import BART
17+
from pymc_experimental.bart.pgbart import PGBART
18+
from pymc_experimental.bart.utils import plot_dependence, plot_variable_importance, predict
19+
20+
__all__ = ["BART", "PGBART"]
21+
22+
23+
24+
import pymc as pm
25+
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_experimental/bart/bart.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import aesara.tensor as at
16+
import numpy as np
17+
18+
from aeppl.logprob import _logprob
19+
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
20+
from pandas import DataFrame, Series
21+
22+
from pymc.distributions.distribution import NoDistribution, _get_moment
23+
24+
__all__ = ["BART"]
25+
26+
27+
class BARTRV(RandomVariable):
28+
"""
29+
Base class for BART
30+
"""
31+
32+
name = "BART"
33+
ndim_supp = 1
34+
ndims_params = [2, 1, 0, 0, 0, 1]
35+
dtype = "floatX"
36+
_print_name = ("BART", "\\operatorname{BART}")
37+
all_trees = None
38+
39+
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
40+
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)
41+
42+
@classmethod
43+
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
44+
return np.full_like(cls.Y, cls.Y.mean())
45+
46+
47+
bart = BARTRV()
48+
49+
50+
class BART(NoDistribution):
51+
"""
52+
Bayesian Additive Regression Tree distribution.
53+
54+
Distribution representing a sum over trees
55+
56+
Parameters
57+
----------
58+
X : array-like
59+
The covariate matrix.
60+
Y : array-like
61+
The response vector.
62+
m : int
63+
Number of trees
64+
alpha : float
65+
Control the prior probability over the depth of the trees. Even when it can takes values in
66+
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
67+
k : float
68+
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
69+
and 3.
70+
split_prior : array-like
71+
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
72+
1. Otherwise they will be normalized.
73+
Defaults to None, i.e. all covariates have the same prior probability to be selected.
74+
"""
75+
76+
def __new__(
77+
cls,
78+
name,
79+
X,
80+
Y,
81+
m=50,
82+
alpha=0.25,
83+
k=2,
84+
split_prior=None,
85+
**kwargs,
86+
):
87+
88+
X, Y = preprocess_XY(X, Y)
89+
90+
bart_op = type(
91+
f"BART_{name}",
92+
(BARTRV,),
93+
dict(
94+
name="BART",
95+
inplace=False,
96+
initval=Y.mean(),
97+
X=X,
98+
Y=Y,
99+
m=m,
100+
alpha=alpha,
101+
k=k,
102+
split_prior=split_prior,
103+
),
104+
)()
105+
106+
NoDistribution.register(BARTRV)
107+
108+
@_get_moment.register(BARTRV)
109+
def get_moment(rv, size, *rv_inputs):
110+
return cls.get_moment(rv, size, *rv_inputs)
111+
112+
cls.rv_op = bart_op
113+
params = [X, Y, m, alpha, k]
114+
return super().__new__(cls, name, *params, **kwargs)
115+
116+
@classmethod
117+
def dist(cls, *params, **kwargs):
118+
return super().dist(params, **kwargs)
119+
120+
def logp(x, *inputs):
121+
"""Calculate log probability.
122+
123+
Parameters
124+
----------
125+
x: numeric, TensorVariable
126+
Value for which log-probability is calculated.
127+
128+
Returns
129+
-------
130+
TensorVariable
131+
"""
132+
return at.zeros_like(x)
133+
134+
@classmethod
135+
def get_moment(cls, rv, size, *rv_inputs):
136+
mean = at.fill(size, rv.Y.mean())
137+
return mean
138+
139+
140+
def preprocess_XY(X, Y):
141+
if isinstance(Y, (Series, DataFrame)):
142+
Y = Y.to_numpy()
143+
if isinstance(X, (Series, DataFrame)):
144+
X = X.to_numpy()
145+
# X = np.random.normal(X, X.std(0)/100)
146+
Y = Y.astype(float)
147+
X = X.astype(float)
148+
return X, Y
149+
150+
151+
@_logprob.register(BARTRV)
152+
def logp(op, value_var, *dist_params, **kwargs):
153+
_dist_params = dist_params[3:]
154+
value_var = value_var[0]
155+
return BART.logp(value_var, *_dist_params)

0 commit comments

Comments
 (0)