|
32 | 32 | import pymc3 as pm
|
33 | 33 |
|
34 | 34 | from pymc3.aesaraf import change_rv_size, floatX, intX
|
35 |
| -from pymc3.distributions.continuous import get_tau_sigma |
| 35 | +from pymc3.distributions.continuous import get_tau_sigma, interpolated |
36 | 36 | from pymc3.distributions.dist_math import clipped_beta_rvs
|
37 | 37 | from pymc3.distributions.multivariate import quaddist_matrix
|
38 | 38 | from pymc3.distributions.shape_utils import to_tuple
|
@@ -270,17 +270,6 @@ class TestWald(BaseTestCases.BaseTestCase):
|
270 | 270 | params = {"mu": 1.0, "lam": 1.0, "alpha": 0.0}
|
271 | 271 |
|
272 | 272 |
|
273 |
| -class TestAsymmetricLaplace(BaseTestCases.BaseTestCase): |
274 |
| - distribution = pm.AsymmetricLaplace |
275 |
| - params = {"kappa": 1.0, "b": 1.0, "mu": 0.0} |
276 |
| - |
277 |
| - |
278 |
| -@pytest.mark.xfail(reason="This distribution has not been refactored for v4") |
279 |
| -class TestExGaussian(BaseTestCases.BaseTestCase): |
280 |
| - distribution = pm.ExGaussian |
281 |
| - params = {"mu": 0.0, "sigma": 1.0, "nu": 1.0} |
282 |
| - |
283 |
| - |
284 | 273 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
285 | 274 | class TestZeroInflatedNegativeBinomial(BaseTestCases.BaseTestCase):
|
286 | 275 | distribution = pm.ZeroInflatedNegativeBinomial
|
@@ -464,6 +453,64 @@ class TestLaplace(BaseTestDistribution):
|
464 | 453 | ]
|
465 | 454 |
|
466 | 455 |
|
| 456 | +class TestAsymmetricLaplace(BaseTestDistribution): |
| 457 | + def asymmetriclaplace_rng_fn(self, b, kappa, mu, size, uniform_rng_fct): |
| 458 | + u = uniform_rng_fct(size=size) |
| 459 | + switch = kappa ** 2 / (1 + kappa ** 2) |
| 460 | + non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b |
| 461 | + positive_x = mu - np.log((1 - u) * (1 + kappa ** 2)) / (kappa * b) |
| 462 | + draws = non_positive_x * (u <= switch) + positive_x * (u > switch) |
| 463 | + return draws |
| 464 | + |
| 465 | + def seeded_asymmetriclaplace_rng_fn(self): |
| 466 | + uniform_rng_fct = functools.partial( |
| 467 | + getattr(np.random.RandomState, "uniform"), self.get_random_state() |
| 468 | + ) |
| 469 | + return functools.partial(self.asymmetriclaplace_rng_fn, uniform_rng_fct=uniform_rng_fct) |
| 470 | + |
| 471 | + pymc_dist = pm.AsymmetricLaplace |
| 472 | + |
| 473 | + pymc_dist_params = {"b": 1.0, "kappa": 1.0, "mu": 0.0} |
| 474 | + expected_rv_op_params = {"b": 1.0, "kappa": 1.0, "mu": 0.0} |
| 475 | + reference_dist_params = {"b": 1.0, "kappa": 1.0, "mu": 0.0} |
| 476 | + reference_dist = seeded_asymmetriclaplace_rng_fn |
| 477 | + tests_to_run = [ |
| 478 | + "check_pymc_params_match_rv_op", |
| 479 | + "check_pymc_draws_match_reference", |
| 480 | + "check_rv_size", |
| 481 | + ] |
| 482 | + |
| 483 | + |
| 484 | +class TestExGaussian(BaseTestDistribution): |
| 485 | + def exgaussian_rng_fn(self, mu, sigma, nu, size, normal_rng_fct, exponential_rng_fct): |
| 486 | + return normal_rng_fct(mu, sigma, size=size) + exponential_rng_fct(scale=nu, size=size) |
| 487 | + |
| 488 | + def seeded_exgaussian_rng_fn(self): |
| 489 | + normal_rng_fct = functools.partial( |
| 490 | + getattr(np.random.RandomState, "normal"), self.get_random_state() |
| 491 | + ) |
| 492 | + exponential_rng_fct = functools.partial( |
| 493 | + getattr(np.random.RandomState, "exponential"), self.get_random_state() |
| 494 | + ) |
| 495 | + return functools.partial( |
| 496 | + self.exgaussian_rng_fn, |
| 497 | + normal_rng_fct=normal_rng_fct, |
| 498 | + exponential_rng_fct=exponential_rng_fct, |
| 499 | + ) |
| 500 | + |
| 501 | + pymc_dist = pm.ExGaussian |
| 502 | + |
| 503 | + pymc_dist_params = {"mu": 1.0, "sigma": 1.0, "nu": 1.0} |
| 504 | + expected_rv_op_params = {"mu": 1.0, "sigma": 1.0, "nu": 1.0} |
| 505 | + reference_dist_params = {"mu": 1.0, "sigma": 1.0, "nu": 1.0} |
| 506 | + reference_dist = seeded_exgaussian_rng_fn |
| 507 | + tests_to_run = [ |
| 508 | + "check_pymc_params_match_rv_op", |
| 509 | + "check_pymc_draws_match_reference", |
| 510 | + "check_rv_size", |
| 511 | + ] |
| 512 | + |
| 513 | + |
467 | 514 | class TestGumbel(BaseTestDistribution):
|
468 | 515 | pymc_dist = pm.Gumbel
|
469 | 516 | pymc_dist_params = {"mu": 1.5, "beta": 3.0}
|
@@ -1195,6 +1242,27 @@ class TestOrderedProbit(BaseTestDistribution):
|
1195 | 1242 | ]
|
1196 | 1243 |
|
1197 | 1244 |
|
| 1245 | +class TestInterpolated(SeededTest): |
| 1246 | + @pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32") |
| 1247 | + def test_interpolated(self): |
| 1248 | + for mu in R.vals: |
| 1249 | + for sigma in Rplus.vals: |
| 1250 | + # pylint: disable=cell-var-from-loop |
| 1251 | + def ref_rand(size): |
| 1252 | + return st.norm.rvs(loc=mu, scale=sigma, size=size) |
| 1253 | + |
| 1254 | + class TestedInterpolated(pm.Interpolated): |
| 1255 | + rv_op = interpolated |
| 1256 | + |
| 1257 | + @classmethod |
| 1258 | + def dist(cls, **kwargs): |
| 1259 | + x_points = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100) |
| 1260 | + pdf_points = st.norm.pdf(x_points, loc=mu, scale=sigma) |
| 1261 | + return super().dist(x_points=x_points, pdf_points=pdf_points, **kwargs) |
| 1262 | + |
| 1263 | + pymc3_random(TestedInterpolated, {}, ref_rand=ref_rand) |
| 1264 | + |
| 1265 | + |
1198 | 1266 | class TestScalarParameterSamples(SeededTest):
|
1199 | 1267 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1200 | 1268 | def test_bounded(self):
|
@@ -1256,23 +1324,6 @@ def ref_rand(size, mu, lam, alpha):
|
1256 | 1324 | ref_rand=ref_rand,
|
1257 | 1325 | )
|
1258 | 1326 |
|
1259 |
| - def test_laplace_asymmetric(self): |
1260 |
| - def ref_rand(size, kappa, b, mu): |
1261 |
| - u = np.random.uniform(size=size) |
1262 |
| - switch = kappa ** 2 / (1 + kappa ** 2) |
1263 |
| - non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b |
1264 |
| - positive_x = mu - np.log((1 - u) * (1 + kappa ** 2)) / (kappa * b) |
1265 |
| - draws = non_positive_x * (u <= switch) + positive_x * (u > switch) |
1266 |
| - return draws |
1267 |
| - |
1268 |
| - pymc3_random(pm.AsymmetricLaplace, {"b": Rplus, "kappa": Rplus, "mu": R}, ref_rand=ref_rand) |
1269 |
| - |
1270 |
| - def test_ex_gaussian(self): |
1271 |
| - def ref_rand(size, mu, sigma, nu): |
1272 |
| - return nr.normal(mu, sigma, size=size) + nr.exponential(scale=nu, size=size) |
1273 |
| - |
1274 |
| - pymc3_random(pm.ExGaussian, {"mu": R, "sigma": Rplus, "nu": Rplus}, ref_rand=ref_rand) |
1275 |
| - |
1276 | 1327 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1277 | 1328 | def test_matrix_normal(self):
|
1278 | 1329 | def ref_rand(size, mu, rowcov, colcov):
|
@@ -1494,27 +1545,6 @@ def ref_rand(size, mu, sigma):
|
1494 | 1545 |
|
1495 | 1546 | pymc3_random(pm.Moyal, {"mu": R, "sigma": Rplus}, ref_rand=ref_rand)
|
1496 | 1547 |
|
1497 |
| - @pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32") |
1498 |
| - def test_interpolated(self): |
1499 |
| - for mu in R.vals: |
1500 |
| - for sigma in Rplus.vals: |
1501 |
| - # pylint: disable=cell-var-from-loop |
1502 |
| - def ref_rand(size): |
1503 |
| - return st.norm.rvs(loc=mu, scale=sigma, size=size) |
1504 |
| - |
1505 |
| - from pymc3.distributions.continuous import interpolated |
1506 |
| - |
1507 |
| - class TestedInterpolated(pm.Interpolated): |
1508 |
| - rv_op = interpolated |
1509 |
| - |
1510 |
| - @classmethod |
1511 |
| - def dist(cls, **kwargs): |
1512 |
| - x_points = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100) |
1513 |
| - pdf_points = st.norm.pdf(x_points, loc=mu, scale=sigma) |
1514 |
| - return super().dist(x_points=x_points, pdf_points=pdf_points, **kwargs) |
1515 |
| - |
1516 |
| - pymc3_random(TestedInterpolated, {}, ref_rand=ref_rand) |
1517 |
| - |
1518 | 1548 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1519 | 1549 | @pytest.mark.skip(
|
1520 | 1550 | "Wishart random sampling not implemented.\n"
|
|
0 commit comments