Skip to content

Commit 53e4e0e

Browse files
committed
Added RV size testing for Interpolated
1 parent 7e34e86 commit 53e4e0e

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

pymc3/tests/test_distributions_random.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
import pymc3 as pm
3333

3434
from pymc3.aesaraf import change_rv_size, floatX, intX
35-
from pymc3.distributions.continuous import get_tau_sigma, interpolated
35+
from pymc3.distributions.continuous import (
36+
_interpolated_argcdf,
37+
get_tau_sigma,
38+
interpolated,
39+
)
3640
from pymc3.distributions.dist_math import clipped_beta_rvs
3741
from pymc3.distributions.multivariate import quaddist_matrix
3842
from pymc3.distributions.shape_utils import to_tuple
@@ -1242,7 +1246,28 @@ class TestOrderedProbit(BaseTestDistribution):
12421246
]
12431247

12441248

1245-
class TestInterpolated(SeededTest):
1249+
class TestInterpolated(BaseTestDistribution):
1250+
def interpolated_rng_fn(self, size, mu, sigma, rng):
1251+
return st.norm.rvs(loc=mu, scale=sigma, size=size)
1252+
1253+
pymc_dist = pm.Interpolated
1254+
1255+
# Dummy values for RV size testing
1256+
mu = sigma = 1
1257+
x_points = pdf_points = np.linspace(1, 100, 100)
1258+
1259+
pymc_dist_params = {"x_points": x_points, "pdf_points": pdf_points}
1260+
reference_dist_params = {"mu": mu, "sigma": sigma}
1261+
1262+
reference_dist = lambda self: functools.partial(
1263+
self.interpolated_rng_fn, rng=self.get_random_state()
1264+
)
1265+
tests_to_run = [
1266+
"check_rv_size",
1267+
]
1268+
1269+
1270+
class TestInterpolatedSeeded(SeededTest):
12461271
@pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
12471272
def test_interpolated(self):
12481273
for mu in R.vals:

0 commit comments

Comments
 (0)