Skip to content

Commit 1d37d31

Browse files
authored
Fix Triangular bounds (#4470)
* Small fix Triangular logp and logcdf methods * Add tests for invalid parameters Uniform, Triangular, DiscreteUniform
1 parent e46f490 commit 1d37d31

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
1616
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
1717
- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)).
18+
- Bugfix in logp and logcdf methods of `Triangular` distribution (see[#4470](https://github.com/pymc-devs/pymc3/pull/4470)).
1819

1920
## PyMC3 3.11.0 (21 January 2021)
2021

pymc3/distributions/continuous.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from pymc3.distributions import transforms
3030
from pymc3.distributions.dist_math import (
3131
SplineWrapper,
32-
alltrue_elemwise,
3332
betaln,
3433
bound,
3534
clipped_beta_rvs,
@@ -3649,18 +3648,14 @@ def logp(self, value):
36493648
c = self.c
36503649
lower = self.lower
36513650
upper = self.upper
3652-
return tt.switch(
3653-
alltrue_elemwise([lower <= value, value < c]),
3654-
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
3651+
return bound(
36553652
tt.switch(
3656-
tt.eq(value, c),
3657-
tt.log(2 / (upper - lower)),
3658-
tt.switch(
3659-
alltrue_elemwise([c < value, value <= upper]),
3660-
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
3661-
np.inf,
3662-
),
3653+
tt.lt(value, c),
3654+
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
3655+
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
36633656
),
3657+
lower <= value,
3658+
value <= upper,
36643659
)
36653660

36663661
def logcdf(self, value):
@@ -3678,17 +3673,24 @@ def logcdf(self, value):
36783673
-------
36793674
TensorVariable
36803675
"""
3681-
l = self.lower
3682-
u = self.upper
36833676
c = self.c
3684-
return tt.switch(
3685-
tt.le(value, l),
3686-
-np.inf,
3677+
lower = self.lower
3678+
upper = self.upper
3679+
return bound(
36873680
tt.switch(
3688-
tt.le(value, c),
3689-
tt.log(((value - l) ** 2) / ((u - l) * (c - l))),
3690-
tt.switch(tt.lt(value, u), tt.log1p(-((u - value) ** 2) / ((u - l) * (u - c))), 0),
3681+
tt.le(value, lower),
3682+
-np.inf,
3683+
tt.switch(
3684+
tt.le(value, c),
3685+
tt.log(((value - lower) ** 2) / ((upper - lower) * (c - lower))),
3686+
tt.switch(
3687+
tt.lt(value, upper),
3688+
tt.log1p(-((upper - value) ** 2) / ((upper - lower) * (upper - c))),
3689+
0,
3690+
),
3691+
),
36913692
),
3693+
lower <= upper,
36923694
)
36933695

36943696

pymc3/tests/test_distributions.py

+16
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,10 @@ def test_uniform(self):
802802
lambda value, lower, upper: sp.uniform.logcdf(value, lower, upper - lower),
803803
skip_paramdomain_outside_edge_test=True,
804804
)
805+
# Custom logp / logcdf check for invalid parameters
806+
invalid_dist = Uniform.dist(lower=1, upper=0)
807+
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
808+
assert invalid_dist.logcdf(2).tag.test_value == -np.inf
805809

806810
def test_triangular(self):
807811
self.check_logp(
@@ -817,6 +821,14 @@ def test_triangular(self):
817821
lambda value, c, lower, upper: sp.triang.logcdf(value, c - lower, lower, upper - lower),
818822
skip_paramdomain_outside_edge_test=True,
819823
)
824+
# Custom logp check for invalid value
825+
valid_dist = Triangular.dist(lower=0, upper=1, c=2.0)
826+
assert np.all(valid_dist.logp(np.array([1.9, 2.0, 2.1])).tag.test_value == -np.inf)
827+
828+
# Custom logp / logcdf check for invalid parameters
829+
invalid_dist = Triangular.dist(lower=1, upper=0, c=2.0)
830+
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
831+
assert invalid_dist.logcdf(2).tag.test_value == -np.inf
820832

821833
def test_bound_normal(self):
822834
PositiveNormal = Bound(Normal, lower=0.0)
@@ -850,6 +862,10 @@ def test_discrete_unif(self):
850862
Rdunif,
851863
{"lower": -Rplusdunif, "upper": Rplusdunif},
852864
)
865+
# Custom logp / logcdf check for invalid parameters
866+
invalid_dist = DiscreteUniform.dist(lower=1, upper=0)
867+
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
868+
assert invalid_dist.logcdf(2).tag.test_value == -np.inf
853869

854870
def test_flat(self):
855871
self.check_logp(Flat, Runif, {}, lambda value: 0)

0 commit comments

Comments
 (0)