|
| 1 | +# Copyright 2020 The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import aesara.tensor as at |
| 16 | +import numpy as np |
| 17 | + |
| 18 | +from aeppl.logprob import _logprob |
| 19 | +from aesara.tensor.random.op import RandomVariable, default_shape_from_params |
| 20 | +from pandas import DataFrame, Series |
| 21 | + |
| 22 | +from pymc.distributions.distribution import NoDistribution, _get_moment |
| 23 | + |
| 24 | +__all__ = ["BART"] |
| 25 | + |
| 26 | + |
| 27 | +class BARTRV(RandomVariable): |
| 28 | + """ |
| 29 | + Base class for BART |
| 30 | + """ |
| 31 | + |
| 32 | + name = "BART" |
| 33 | + ndim_supp = 1 |
| 34 | + ndims_params = [2, 1, 0, 0, 0, 1] |
| 35 | + dtype = "floatX" |
| 36 | + _print_name = ("BART", "\\operatorname{BART}") |
| 37 | + all_trees = None |
| 38 | + |
| 39 | + def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
| 40 | + return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs): |
| 44 | + return np.full_like(cls.Y, cls.Y.mean()) |
| 45 | + |
| 46 | + |
| 47 | +bart = BARTRV() |
| 48 | + |
| 49 | + |
| 50 | +class BART(NoDistribution): |
| 51 | + """ |
| 52 | + Bayesian Additive Regression Tree distribution. |
| 53 | +
|
| 54 | + Distribution representing a sum over trees |
| 55 | +
|
| 56 | + Parameters |
| 57 | + ---------- |
| 58 | + X : array-like |
| 59 | + The covariate matrix. |
| 60 | + Y : array-like |
| 61 | + The response vector. |
| 62 | + m : int |
| 63 | + Number of trees |
| 64 | + alpha : float |
| 65 | + Control the prior probability over the depth of the trees. Even when it can takes values in |
| 66 | + the interval (0, 1), it is recommended to be in the interval (0, 0.5]. |
| 67 | + k : float |
| 68 | + Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1 |
| 69 | + and 3. |
| 70 | + split_prior : array-like |
| 71 | + Each element of split_prior should be in the [0, 1] interval and the elements should sum to |
| 72 | + 1. Otherwise they will be normalized. |
| 73 | + Defaults to None, i.e. all covariates have the same prior probability to be selected. |
| 74 | + """ |
| 75 | + |
| 76 | + def __new__( |
| 77 | + cls, |
| 78 | + name, |
| 79 | + X, |
| 80 | + Y, |
| 81 | + m=50, |
| 82 | + alpha=0.25, |
| 83 | + k=2, |
| 84 | + split_prior=None, |
| 85 | + **kwargs, |
| 86 | + ): |
| 87 | + |
| 88 | + X, Y = preprocess_XY(X, Y) |
| 89 | + |
| 90 | + bart_op = type( |
| 91 | + f"BART_{name}", |
| 92 | + (BARTRV,), |
| 93 | + dict( |
| 94 | + name="BART", |
| 95 | + inplace=False, |
| 96 | + initval=Y.mean(), |
| 97 | + X=X, |
| 98 | + Y=Y, |
| 99 | + m=m, |
| 100 | + alpha=alpha, |
| 101 | + k=k, |
| 102 | + split_prior=split_prior, |
| 103 | + ), |
| 104 | + )() |
| 105 | + |
| 106 | + NoDistribution.register(BARTRV) |
| 107 | + |
| 108 | + @_get_moment.register(BARTRV) |
| 109 | + def get_moment(rv, size, *rv_inputs): |
| 110 | + return cls.get_moment(rv, size, *rv_inputs) |
| 111 | + |
| 112 | + cls.rv_op = bart_op |
| 113 | + params = [X, Y, m, alpha, k] |
| 114 | + return super().__new__(cls, name, *params, **kwargs) |
| 115 | + |
| 116 | + @classmethod |
| 117 | + def dist(cls, *params, **kwargs): |
| 118 | + return super().dist(params, **kwargs) |
| 119 | + |
| 120 | + def logp(x, *inputs): |
| 121 | + """Calculate log probability. |
| 122 | +
|
| 123 | + Parameters |
| 124 | + ---------- |
| 125 | + x: numeric, TensorVariable |
| 126 | + Value for which log-probability is calculated. |
| 127 | +
|
| 128 | + Returns |
| 129 | + ------- |
| 130 | + TensorVariable |
| 131 | + """ |
| 132 | + return at.zeros_like(x) |
| 133 | + |
| 134 | + @classmethod |
| 135 | + def get_moment(cls, rv, size, *rv_inputs): |
| 136 | + mean = at.fill(size, rv.Y.mean()) |
| 137 | + return mean |
| 138 | + |
| 139 | + |
| 140 | +def preprocess_XY(X, Y): |
| 141 | + if isinstance(Y, (Series, DataFrame)): |
| 142 | + Y = Y.to_numpy() |
| 143 | + if isinstance(X, (Series, DataFrame)): |
| 144 | + X = X.to_numpy() |
| 145 | + # X = np.random.normal(X, X.std(0)/100) |
| 146 | + Y = Y.astype(float) |
| 147 | + X = X.astype(float) |
| 148 | + return X, Y |
| 149 | + |
| 150 | + |
| 151 | +@_logprob.register(BARTRV) |
| 152 | +def logp(op, value_var, *dist_params, **kwargs): |
| 153 | + _dist_params = dist_params[3:] |
| 154 | + value_var = value_var[0] |
| 155 | + return BART.logp(value_var, *_dist_params) |
0 commit comments