Skip to content

Commit d25c5c4

Browse files
AlexAndorrabrandonwillard
authored andcommitted
Fix test_dirichlet_multinomial_matches_beta_binomial
1 parent f0e4fdb commit d25c5c4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,7 @@ def test_batch_multinomial(self):
22592259
sample = dist.eval()
22602260
assert_allclose(sample, np.stack([vals, vals], axis=0))
22612261

2262+
# https://github.com/pymc-devs/pymc3/pull/4508/files
22622263
@pytest.mark.parametrize("n", [2, 3])
22632264
def test_dirichlet_multinomial(self, n):
22642265
self.check_logp(
@@ -2272,11 +2273,11 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
22722273
a, b, n = 2, 1, 5
22732274
ns = np.arange(n + 1)
22742275
ns_dm = np.vstack((ns, n - ns)).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2275-
bb_logp = logpt(pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2), ns).tag.test_value
2276+
bb_logp = logpt(var=pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2), rv_values=ns).eval()
22762277
dm_logp = logpt(
2277-
pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2), ns_dm
2278-
).tag.test_value
2279-
dm_logp = dm_logp.ravel()
2278+
var=pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2),
2279+
rv_values=ns_dm,
2280+
).eval().ravel()
22802281
assert_almost_equal(
22812282
dm_logp,
22822283
bb_logp,

0 commit comments

Comments
 (0)