Skip to content

Commit 74d09f6

Browse files
authored
Fix bart to work with latest version of aesara (#25)
* update tests, update aesara import * infer shape
1 parent a30ea8c commit 74d09f6

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

Diff for: pymc_experimental/bart/bart.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy as np
1717

1818
from aeppl.logprob import _logprob
19-
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
19+
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
2020
from pandas import DataFrame, Series
2121

2222
from pymc.distributions.distribution import NoDistribution, _get_moment
@@ -37,7 +37,12 @@ class BARTRV(RandomVariable):
3737
all_trees = None
3838

3939
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
40-
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)
40+
return default_supp_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)
41+
42+
43+
def _infer_shape(cls, size, dist_params, param_shapes=None):
44+
dist_shape = (cls.X.shape[0],)
45+
return dist_shape
4146

4247
@classmethod
4348
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):

Diff for: pymc_experimental/tests/test_bart.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from numpy.testing import assert_almost_equal, assert_array_equal
66

77
import pymc as pm
8-
import pymcx as pmx
8+
import pymc_experimental as pmx
99

1010
from pymc.tests.test_distributions_moments import assert_moment_is_expected
1111

@@ -37,7 +37,7 @@ def test_bart_vi():
3737
X[:, 0] = np.random.normal(Y, 0.1)
3838

3939
with pm.Model() as model:
40-
mu = pmx.bart("mu", X, Y, m=10)
40+
mu = pmx.BART("mu", X, Y, m=10)
4141
sigma = pm.HalfNormal("sigma", 1)
4242
y = pm.Normal("y", mu, sigma, observed=Y)
4343
idata = pm.sample(random_seed=3415)
@@ -57,7 +57,7 @@ def test_missing_data():
5757
X[10:20, 0] = np.nan
5858

5959
with pm.Model() as model:
60-
mu = pmx.bart("mu", X, Y, m=10)
60+
mu = pmx.BART("mu", X, Y, m=10)
6161
sigma = pm.HalfNormal("sigma", 1)
6262
y = pm.Normal("y", mu, sigma, observed=Y)
6363
idata = pm.sample(tune=10, draws=10, chains=1, random_seed=3415)
@@ -70,7 +70,7 @@ class TestUtils:
7070
Y = np.random.normal(0, 1, size=50)
7171

7272
with pm.Model() as model:
73-
mu = pmx.bart("mu", X, Y, m=10)
73+
mu = pmx.BART("mu", X, Y, m=10)
7474
sigma = pm.HalfNormal("sigma", 1)
7575
y = pm.Normal("y", mu, sigma, observed=Y)
7676
idata = pm.sample(random_seed=3415)
@@ -137,5 +137,5 @@ def test_bart_moment(size, expected):
137137
X = np.zeros((50, 2))
138138
Y = np.zeros(50)
139139
with pm.Model() as model:
140-
pmx.bart("x", X=X, Y=Y, size=size)
140+
pmx.BART("x", X=X, Y=Y, size=size)
141141
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)