Skip to content

Commit 0c21de4

Browse files
authored
Fix Dirichlet.logp (#4454)
* Fix Dirichlet.logp by checking number of categories > 1 only at event dims * Update test_distributions.py * Removed the shape validation check to even work for last dimensional shape as 1. Modified the `test_dirichlet` function to check for the same. * Added a test to check Dirichlet.logp with different batch shapes. * Tested exact Dirichlet.logp values againt scipy implementation Given a mention in RELEASE-NOTES.md
1 parent b6660f9 commit 0c21de4

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
1414
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
1515
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
16+
- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)).
1617

1718
## PyMC3 3.11.0 (21 January 2021)
1819

pymc3/distributions/multivariate.py

-1
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ def logp(self, value):
522522
tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)),
523523
tt.all(value >= 0),
524524
tt.all(value <= 1),
525-
np.logical_not(a.broadcastable),
526525
tt.all(a > 0),
527526
broadcast_conditions=False,
528527
)

pymc3/tests/test_distributions.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1692,10 +1692,21 @@ def test_lkj(self, x, eta, n, lp):
16921692
decimals = select_by_precision(float64=6, float32=4)
16931693
assert_almost_equal(model.fastlogp(pt), lp, decimal=decimals, err_msg=str(pt))
16941694

1695-
@pytest.mark.parametrize("n", [2, 3])
1695+
@pytest.mark.parametrize("n", [1, 2, 3])
16961696
def test_dirichlet(self, n):
16971697
self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)
16981698

1699+
@pytest.mark.parametrize("dist_shape", [1, (2, 1), (1, 2), (2, 4, 3)])
1700+
def test_dirichlet_with_batch_shapes(self, dist_shape):
1701+
a = np.ones(dist_shape)
1702+
with pm.Model() as model:
1703+
d = pm.Dirichlet("a", a=a)
1704+
1705+
pymc3_res = d.distribution.logp(d.tag.test_value).eval()
1706+
for idx in np.ndindex(a.shape[:-1]):
1707+
scipy_res = scipy.stats.dirichlet(a[idx]).logpdf(d.tag.test_value[idx])
1708+
assert_almost_equal(pymc3_res[idx], scipy_res)
1709+
16991710
def test_dirichlet_shape(self):
17001711
a = tt.as_tensor_variable(np.r_[1, 2])
17011712
with pytest.warns(DeprecationWarning):

0 commit comments

Comments
 (0)