Skip to content

Commit 7e34e86

Browse files
committed
Refactored tests for new distributions
1 parent bfc384b commit 7e34e86

File tree

1 file changed

+80
-50
lines changed

1 file changed

+80
-50
lines changed

pymc3/tests/test_distributions_random.py

+80-50
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import pymc3 as pm
3333

3434
from pymc3.aesaraf import change_rv_size, floatX, intX
35-
from pymc3.distributions.continuous import get_tau_sigma
35+
from pymc3.distributions.continuous import get_tau_sigma, interpolated
3636
from pymc3.distributions.dist_math import clipped_beta_rvs
3737
from pymc3.distributions.multivariate import quaddist_matrix
3838
from pymc3.distributions.shape_utils import to_tuple
@@ -270,17 +270,6 @@ class TestWald(BaseTestCases.BaseTestCase):
270270
params = {"mu": 1.0, "lam": 1.0, "alpha": 0.0}
271271

272272

273-
class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
274-
distribution = pm.AsymmetricLaplace
275-
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}
276-
277-
278-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
279-
class TestExGaussian(BaseTestCases.BaseTestCase):
280-
distribution = pm.ExGaussian
281-
params = {"mu": 0.0, "sigma": 1.0, "nu": 1.0}
282-
283-
284273
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
285274
class TestZeroInflatedNegativeBinomial(BaseTestCases.BaseTestCase):
286275
distribution = pm.ZeroInflatedNegativeBinomial
@@ -464,6 +453,64 @@ class TestLaplace(BaseTestDistribution):
464453
]
465454

466455

456+
class TestAsymmetricLaplace(BaseTestDistribution):
457+
def asymmetriclaplace_rng_fn(self, b, kappa, mu, size, uniform_rng_fct):
458+
u = uniform_rng_fct(size=size)
459+
switch = kappa ** 2 / (1 + kappa ** 2)
460+
non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b
461+
positive_x = mu - np.log((1 - u) * (1 + kappa ** 2)) / (kappa * b)
462+
draws = non_positive_x * (u <= switch) + positive_x * (u > switch)
463+
return draws
464+
465+
def seeded_asymmetriclaplace_rng_fn(self):
466+
uniform_rng_fct = functools.partial(
467+
getattr(np.random.RandomState, "uniform"), self.get_random_state()
468+
)
469+
return functools.partial(self.asymmetriclaplace_rng_fn, uniform_rng_fct=uniform_rng_fct)
470+
471+
pymc_dist = pm.AsymmetricLaplace
472+
473+
pymc_dist_params = {"b": 1.0, "kappa": 1.0, "mu": 0.0}
474+
expected_rv_op_params = {"b": 1.0, "kappa": 1.0, "mu": 0.0}
475+
reference_dist_params = {"b": 1.0, "kappa": 1.0, "mu": 0.0}
476+
reference_dist = seeded_asymmetriclaplace_rng_fn
477+
tests_to_run = [
478+
"check_pymc_params_match_rv_op",
479+
"check_pymc_draws_match_reference",
480+
"check_rv_size",
481+
]
482+
483+
484+
class TestExGaussian(BaseTestDistribution):
485+
def exgaussian_rng_fn(self, mu, sigma, nu, size, normal_rng_fct, exponential_rng_fct):
486+
return normal_rng_fct(mu, sigma, size=size) + exponential_rng_fct(scale=nu, size=size)
487+
488+
def seeded_exgaussian_rng_fn(self):
489+
normal_rng_fct = functools.partial(
490+
getattr(np.random.RandomState, "normal"), self.get_random_state()
491+
)
492+
exponential_rng_fct = functools.partial(
493+
getattr(np.random.RandomState, "exponential"), self.get_random_state()
494+
)
495+
return functools.partial(
496+
self.exgaussian_rng_fn,
497+
normal_rng_fct=normal_rng_fct,
498+
exponential_rng_fct=exponential_rng_fct,
499+
)
500+
501+
pymc_dist = pm.ExGaussian
502+
503+
pymc_dist_params = {"mu": 1.0, "sigma": 1.0, "nu": 1.0}
504+
expected_rv_op_params = {"mu": 1.0, "sigma": 1.0, "nu": 1.0}
505+
reference_dist_params = {"mu": 1.0, "sigma": 1.0, "nu": 1.0}
506+
reference_dist = seeded_exgaussian_rng_fn
507+
tests_to_run = [
508+
"check_pymc_params_match_rv_op",
509+
"check_pymc_draws_match_reference",
510+
"check_rv_size",
511+
]
512+
513+
467514
class TestGumbel(BaseTestDistribution):
468515
pymc_dist = pm.Gumbel
469516
pymc_dist_params = {"mu": 1.5, "beta": 3.0}
@@ -1195,6 +1242,27 @@ class TestOrderedProbit(BaseTestDistribution):
11951242
]
11961243

11971244

1245+
class TestInterpolated(SeededTest):
1246+
@pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
1247+
def test_interpolated(self):
1248+
for mu in R.vals:
1249+
for sigma in Rplus.vals:
1250+
# pylint: disable=cell-var-from-loop
1251+
def ref_rand(size):
1252+
return st.norm.rvs(loc=mu, scale=sigma, size=size)
1253+
1254+
class TestedInterpolated(pm.Interpolated):
1255+
rv_op = interpolated
1256+
1257+
@classmethod
1258+
def dist(cls, **kwargs):
1259+
x_points = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
1260+
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sigma)
1261+
return super().dist(x_points=x_points, pdf_points=pdf_points, **kwargs)
1262+
1263+
pymc3_random(TestedInterpolated, {}, ref_rand=ref_rand)
1264+
1265+
11981266
class TestScalarParameterSamples(SeededTest):
11991267
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
12001268
def test_bounded(self):
@@ -1256,23 +1324,6 @@ def ref_rand(size, mu, lam, alpha):
12561324
ref_rand=ref_rand,
12571325
)
12581326

1259-
def test_laplace_asymmetric(self):
1260-
def ref_rand(size, kappa, b, mu):
1261-
u = np.random.uniform(size=size)
1262-
switch = kappa ** 2 / (1 + kappa ** 2)
1263-
non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b
1264-
positive_x = mu - np.log((1 - u) * (1 + kappa ** 2)) / (kappa * b)
1265-
draws = non_positive_x * (u <= switch) + positive_x * (u > switch)
1266-
return draws
1267-
1268-
pymc3_random(pm.AsymmetricLaplace, {"b": Rplus, "kappa": Rplus, "mu": R}, ref_rand=ref_rand)
1269-
1270-
def test_ex_gaussian(self):
1271-
def ref_rand(size, mu, sigma, nu):
1272-
return nr.normal(mu, sigma, size=size) + nr.exponential(scale=nu, size=size)
1273-
1274-
pymc3_random(pm.ExGaussian, {"mu": R, "sigma": Rplus, "nu": Rplus}, ref_rand=ref_rand)
1275-
12761327
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
12771328
def test_matrix_normal(self):
12781329
def ref_rand(size, mu, rowcov, colcov):
@@ -1494,27 +1545,6 @@ def ref_rand(size, mu, sigma):
14941545

14951546
pymc3_random(pm.Moyal, {"mu": R, "sigma": Rplus}, ref_rand=ref_rand)
14961547

1497-
@pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
1498-
def test_interpolated(self):
1499-
for mu in R.vals:
1500-
for sigma in Rplus.vals:
1501-
# pylint: disable=cell-var-from-loop
1502-
def ref_rand(size):
1503-
return st.norm.rvs(loc=mu, scale=sigma, size=size)
1504-
1505-
from pymc3.distributions.continuous import interpolated
1506-
1507-
class TestedInterpolated(pm.Interpolated):
1508-
rv_op = interpolated
1509-
1510-
@classmethod
1511-
def dist(cls, **kwargs):
1512-
x_points = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
1513-
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sigma)
1514-
return super().dist(x_points=x_points, pdf_points=pdf_points, **kwargs)
1515-
1516-
pymc3_random(TestedInterpolated, {}, ref_rand=ref_rand)
1517-
15181548
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
15191549
@pytest.mark.skip(
15201550
"Wishart random sampling not implemented.\n"

0 commit comments

Comments
 (0)