Skip to content

Commit aef5af2

Browse files
committed
Fix test_dirichlet_multinomial_matches_beta_binomial
1 parent 66104af commit aef5af2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pymc3/tests/test_distributions.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,7 @@ def test_batch_multinomial(self):
22462246
sample = dist.eval()
22472247
assert_allclose(sample, np.stack([vals, vals], axis=0))
22482248

2249+
# https://github.com/pymc-devs/pymc3/pull/4508/files
22492250
@pytest.mark.parametrize("n", [2, 3])
22502251
def test_dirichlet_multinomial(self, n):
22512252
self.check_logp(
@@ -2259,11 +2260,11 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
22592260
a, b, n = 2, 1, 5
22602261
ns = np.arange(n + 1)
22612262
ns_dm = np.vstack((ns, n - ns)).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2262-
bb_logp = logpt(pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2), ns).tag.test_value
2263+
bb_logp = logpt(var=pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2), rv_values=ns).eval()
22632264
dm_logp = logpt(
2264-
pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2), ns_dm
2265-
).tag.test_value
2266-
dm_logp = dm_logp.ravel()
2265+
var=pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2),
2266+
rv_values=ns_dm,
2267+
).eval().ravel()
22672268
assert_almost_equal(
22682269
dm_logp,
22692270
bb_logp,

0 commit comments

Comments
 (0)