Skip to content

Commit 791a1c4

Browse files
refactor pareto and laplace (#4691)
* refactor pareto * refactor laplace * Reintroduce `Pareto` default transform Co-authored-by: Farhan Reynaldo <[email protected]> Co-authored-by: Ricardo <[email protected]>
1 parent faed5f1 commit 791a1c4

File tree

3 files changed

+47
-76
lines changed

3 files changed

+47
-76
lines changed

pymc3/distributions/continuous.py

+21-62
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
halfcauchy,
3535
halfnormal,
3636
invgamma,
37+
laplace,
3738
logistic,
3839
lognormal,
3940
normal,
@@ -152,10 +153,16 @@ def default_transform(cls):
152153

153154
def transform_params(rv_var):
154155
_, _, _, *args = rv_var.owner.inputs
155-
lower = args[cls.bound_args_indices[0]]
156-
upper = args[cls.bound_args_indices[1]]
156+
157+
lower, upper = None, None
158+
if cls.bound_args_indices[0] is not None:
159+
lower = args[cls.bound_args_indices[0]]
160+
if cls.bound_args_indices[1] is not None:
161+
upper = args[cls.bound_args_indices[1]]
162+
157163
lower = at.as_tensor_variable(lower) if lower is not None else None
158164
upper = at.as_tensor_variable(upper) if upper is not None else None
165+
159166
return lower, upper
160167

161168
return transforms.interval(transform_params)
@@ -1505,37 +1512,17 @@ class Laplace(Continuous):
15051512
b: float
15061513
Scale parameter (b > 0).
15071514
"""
1515+
rv_op = laplace
15081516

1509-
def __init__(self, mu, b, *args, **kwargs):
1510-
super().__init__(*args, **kwargs)
1511-
self.b = b = at.as_tensor_variable(floatX(b))
1512-
self.mean = self.median = self.mode = self.mu = mu = at.as_tensor_variable(floatX(mu))
1513-
1514-
self.variance = 2 * self.b ** 2
1517+
@classmethod
1518+
def dist(cls, mu, b, *args, **kwargs):
1519+
b = at.as_tensor_variable(floatX(b))
1520+
mu = at.as_tensor_variable(floatX(mu))
15151521

15161522
assert_negative_support(b, "b", "Laplace")
1523+
return super().dist([mu, b], *args, **kwargs)
15171524

1518-
def random(self, point=None, size=None):
1519-
"""
1520-
Draw random values from Laplace distribution.
1521-
1522-
Parameters
1523-
----------
1524-
point: dict, optional
1525-
Dict of variable values on which random values are to be
1526-
conditioned (uses default point if not specified).
1527-
size: int, optional
1528-
Desired size of random sample (returns one sample if not
1529-
specified).
1530-
1531-
Returns
1532-
-------
1533-
array
1534-
"""
1535-
# mu, b = draw_values([self.mu, self.b], point=point, size=size)
1536-
# return generate_samples(np.random.laplace, mu, b, dist_shape=self.shape, size=size)
1537-
1538-
def logp(self, value):
1525+
def logp(value, mu, b):
15391526
"""
15401527
Calculate log-probability of Laplace distribution at specified value.
15411528
@@ -1549,12 +1536,9 @@ def logp(self, value):
15491536
-------
15501537
TensorVariable
15511538
"""
1552-
mu = self.mu
1553-
b = self.b
1554-
15551539
return -at.log(2 * b) - abs(value - mu) / b
15561540

1557-
def logcdf(self, value):
1541+
def logcdf(value, mu, b):
15581542
"""
15591543
Compute the log of the cumulative distribution function for Laplace distribution
15601544
at the specified value.
@@ -1569,12 +1553,10 @@ def logcdf(self, value):
15691553
-------
15701554
TensorVariable
15711555
"""
1572-
a = self.mu
1573-
b = self.b
1574-
y = (value - a) / b
1556+
y = (value - mu) / b
15751557
return bound(
15761558
at.switch(
1577-
at.le(value, a),
1559+
at.le(value, mu),
15781560
at.log(0.5) + y,
15791561
at.switch(
15801562
at.gt(y, 1),
@@ -1980,7 +1962,7 @@ def logcdf(self, value):
19801962
)
19811963

19821964

1983-
class Pareto(Continuous):
1965+
class Pareto(BoundedContinuous):
19841966
r"""
19851967
Pareto log-likelihood.
19861968
@@ -2026,6 +2008,7 @@ class Pareto(Continuous):
20262008
Scale parameter (m > 0).
20272009
"""
20282010
rv_op = pareto
2011+
bound_args_indices = (1, None) # lower-bounded by `m`
20292012

20302013
@classmethod
20312014
def dist(
@@ -2039,30 +2022,6 @@ def dist(
20392022

20402023
return super().dist([alpha, m], **kwargs)
20412024

2042-
def _random(self, alpha, m, size=None):
2043-
u = np.random.uniform(size=size)
2044-
return m * (1.0 - u) ** (-1.0 / alpha)
2045-
2046-
def random(self, point=None, size=None):
2047-
"""
2048-
Draw random values from Pareto distribution.
2049-
2050-
Parameters
2051-
----------
2052-
point: dict, optional
2053-
Dict of variable values on which random values are to be
2054-
conditioned (uses default point if not specified).
2055-
size: int, optional
2056-
Desired size of random sample (returns one sample if not
2057-
specified).
2058-
2059-
Returns
2060-
-------
2061-
array
2062-
"""
2063-
# alpha, m = draw_values([self.alpha, self.m], point=point, size=size)
2064-
# return generate_samples(self._random, alpha, m, dist_shape=self.shape, size=size)
2065-
20662025
def logp(
20672026
value: Union[float, np.ndarray, TensorVariable],
20682027
alpha: Union[float, np.ndarray, TensorVariable],

pymc3/tests/test_distributions.py

-1
Original file line numberDiff line numberDiff line change
@@ -1250,7 +1250,6 @@ def test_negative_binomial_init_fail(self, mu, p, alpha, n, expected):
12501250
with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"):
12511251
NegativeBinomial("x", mu=mu, p=p, alpha=alpha, n=n)
12521252

1253-
@pytest.mark.xfail(reason="Distribution not refactored yet")
12541253
def test_laplace(self):
12551254
self.check_logp(
12561255
Laplace,

pymc3/tests/test_distributions_random.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,6 @@ class TestKumaraswamy(BaseTestCases.BaseTestCase):
277277
params = {"a": 1.0, "b": 1.0}
278278

279279

280-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
281-
class TestLaplace(BaseTestCases.BaseTestCase):
282-
distribution = pm.Laplace
283-
params = {"mu": 1.0, "b": 1.0}
284-
285-
286280
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
287281
class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
288282
distribution = pm.AsymmetricLaplace
@@ -449,6 +443,32 @@ def seeded_discrete_weibul_rng_fn(self):
449443
]
450444

451445

446+
class TestPareto(BaseTestDistribution):
447+
pymc_dist = pm.Pareto
448+
pymc_dist_params = {"alpha": 3.0, "m": 2.0}
449+
expected_rv_op_params = {"alpha": 3.0, "m": 2.0}
450+
reference_dist_params = {"b": 3.0, "scale": 2.0}
451+
reference_dist = seeded_scipy_distribution_builder("pareto")
452+
tests_to_run = [
453+
"check_pymc_params_match_rv_op",
454+
"check_pymc_draws_match_reference",
455+
"check_rv_size",
456+
]
457+
458+
459+
class TestLaplace(BaseTestDistribution):
460+
pymc_dist = pm.Laplace
461+
pymc_dist_params = {"mu": 0.0, "b": 1.0}
462+
expected_rv_op_params = {"mu": 0.0, "b": 1.0}
463+
reference_dist_params = {"loc": 0.0, "scale": 1.0}
464+
reference_dist = seeded_scipy_distribution_builder("laplace")
465+
tests_to_run = [
466+
"check_pymc_params_match_rv_op",
467+
"check_pymc_draws_match_reference",
468+
"check_rv_size",
469+
]
470+
471+
452472
class TestGumbel(BaseTestDistribution):
453473
pymc_dist = pm.Gumbel
454474
pymc_dist_params = {"mu": 1.5, "beta": 3.0}
@@ -1102,13 +1122,6 @@ def ref_rand(size, mu, lam, alpha):
11021122
ref_rand=ref_rand,
11031123
)
11041124

1105-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1106-
def test_laplace(self):
1107-
def ref_rand(size, mu, b):
1108-
return st.laplace.rvs(mu, b, size=size)
1109-
1110-
pymc3_random(pm.Laplace, {"mu": R, "b": Rplus}, ref_rand=ref_rand)
1111-
11121125
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
11131126
def test_laplace_asymmetric(self):
11141127
def ref_rand(size, kappa, b, mu):

0 commit comments

Comments
 (0)