Skip to content

Commit c34ae3f

Browse files
authored
Check that concentration parameters of Dirichlet distribution are all > 0 (#3853)
* Added check that a>0 in Dirichlet * Cast a as array for tests * Test a>0 only when a not an RV and convert to array when list * Added test for init of Dirichlet with negative values * Added release note * Resolved conflict in release notes * Escaped parenthesis in match regexp
1 parent 0456f39 commit c34ae3f

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)).
2121
- Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)).
2222
- Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509).
23+
- The Dirichlet distribution now raises a ValueError when it's initialized with <= 0 values (see [#3853](https://github.com/pymc-devs/pymc3/pull/3853)).
2324

2425
## PyMC3 3.8 (November 29 2019)
2526

Diff for: pymc3/distributions/multivariate.py

+10
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,16 @@ class Dirichlet(Continuous):
488488

489489
def __init__(self, a, transform=transforms.stick_breaking,
490490
*args, **kwargs):
491+
492+
if not isinstance(a, pm.model.TensorVariable):
493+
if not isinstance(a, list) and not isinstance(a, np.ndarray):
494+
raise TypeError(
495+
'The vector of concentration parameters (a) must be a python list '
496+
'or numpy array.')
497+
a = np.array(a)
498+
if (a <= 0).any():
499+
raise ValueError("All concentration parameters (a) must be > 0.")
500+
491501
shape = np.atleast_1d(a.shape)[-1]
492502

493503
kwargs.setdefault("shape", shape)

Diff for: pymc3/tests/test_distributions.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -944,17 +944,43 @@ def test_lkj(self, x, eta, n, lp):
944944

945945
@pytest.mark.parametrize('n', [2, 3])
946946
def test_dirichlet(self, n):
947-
self.pymc3_matches_scipy(Dirichlet, Simplex(
948-
n), {'a': Vector(Rplus, n)}, dirichlet_logpdf)
947+
self.pymc3_matches_scipy(
948+
Dirichlet,
949+
Simplex(n),
950+
{'a': Vector(Rplus, n)},
951+
dirichlet_logpdf
952+
)
953+
954+
@pytest.mark.parametrize('n', [3, 4])
955+
def test_dirichlet_init_fail(self, n):
956+
with Model():
957+
with pytest.raises(
958+
ValueError,
959+
match=r"All concentration parameters \(a\) must be > 0."
960+
):
961+
_ = Dirichlet('x', a=np.zeros(n), shape=n)
962+
with pytest.raises(
963+
ValueError,
964+
match=r"All concentration parameters \(a\) must be > 0."
965+
):
966+
_ = Dirichlet('x', a=np.array([-1.] * n), shape=n)
949967

950968
def test_dirichlet_2D(self):
951-
self.pymc3_matches_scipy(Dirichlet, MultiSimplex(2, 2),
952-
{'a': Vector(Vector(Rplus, 2), 2)}, dirichlet_logpdf)
969+
self.pymc3_matches_scipy(
970+
Dirichlet,
971+
MultiSimplex(2, 2),
972+
{'a': Vector(Vector(Rplus, 2), 2)},
973+
dirichlet_logpdf
974+
)
953975

954976
@pytest.mark.parametrize('n', [2, 3])
955977
def test_multinomial(self, n):
956-
self.pymc3_matches_scipy(Multinomial, Vector(Nat, n), {'p': Simplex(n), 'n': Nat},
957-
multinomial_logpdf)
978+
self.pymc3_matches_scipy(
979+
Multinomial,
980+
Vector(Nat, n),
981+
{'p': Simplex(n), 'n': Nat},
982+
multinomial_logpdf
983+
)
958984

959985
@pytest.mark.parametrize('p,n', [
960986
[[.25, .25, .25, .25], 1],

0 commit comments

Comments
 (0)