Skip to content

Commit 4d7c192

Browse files
committed
Added a test to check Dirichlet.logp with different batch shapes.
1 parent a6a08bb commit 4d7c192

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

pymc3/tests/test_distributions.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1696,10 +1696,14 @@ def test_lkj(self, x, eta, n, lp):
16961696
def test_dirichlet(self, n):
16971697
self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)
16981698

1699-
def test_dirichlet_with_unit_batch_shape(self):
1699+
@pytest.mark.parametrize("dist_shape", [1, (2, 1), (1, 2), (2, 4, 3)])
1700+
def test_dirichlet_with_batch_shapes(self, dist_shape):
1701+
a = np.ones(dist_shape)
17001702
with pm.Model() as model:
1701-
a = pm.Dirichlet("a", a=np.ones((1, 2)))
1702-
assert np.isfinite(model.check_test_point()[0])
1703+
d = pm.Dirichlet("a", a=a)
1704+
1705+
value = d.tag.test_value
1706+
assert_almost_equal(dirichlet_logpdf(value, a), d.distribution.logp(value).eval().sum())
17031707

17041708
def test_dirichlet_shape(self):
17051709
a = tt.as_tensor_variable(np.r_[1, 2])

0 commit comments

Comments
 (0)