Skip to content

Commit 1333c23

Browse files
committed
Tested exact Dirichlet.logp values againt scipy implementation
Given a mention in RELEASE-NOTES.md
1 parent 4d7c192 commit 1333c23

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
1414
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
1515
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
16+
- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)).
1617

1718
## PyMC3 3.11.0 (21 January 2021)
1819

pymc3/tests/test_distributions.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1702,8 +1702,10 @@ def test_dirichlet_with_batch_shapes(self, dist_shape):
17021702
with pm.Model() as model:
17031703
d = pm.Dirichlet("a", a=a)
17041704

1705-
value = d.tag.test_value
1706-
assert_almost_equal(dirichlet_logpdf(value, a), d.distribution.logp(value).eval().sum())
1705+
pymc3_res = d.distribution.logp(d.tag.test_value).eval()
1706+
for idx in np.ndindex(a.shape[:-1]):
1707+
scipy_res = scipy.stats.dirichlet(a[idx]).logpdf(d.tag.test_value[idx])
1708+
assert_almost_equal(pymc3_res[idx], scipy_res)
17071709

17081710
def test_dirichlet_shape(self):
17091711
a = tt.as_tensor_variable(np.r_[1, 2])

0 commit comments

Comments
 (0)