Skip to content

Commit 713b391

Browse files
AlexAndorrabrandonwillard
authored andcommitted
Fix call of logpt in test_dirichlet_multinomial_matches_beta_binomial
1 parent d25c5c4 commit 713b391

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

pymc3/tests/test_distributions.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -2273,11 +2273,17 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
22732273
a, b, n = 2, 1, 5
22742274
ns = np.arange(n + 1)
22752275
ns_dm = np.vstack((ns, n - ns)).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2276-
bb_logp = logpt(var=pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2), rv_values=ns).eval()
2277-
dm_logp = logpt(
2278-
var=pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2),
2279-
rv_values=ns_dm,
2280-
).eval().ravel()
2276+
2277+
bb = pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2)
2278+
bb_value = bb.type()
2279+
bb.tag.value_var = bb_value
2280+
bb_logp = logpt(var=bb, rv_values={bb: bb_value}).eval({bb_value: ns})
2281+
2282+
dm = pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2)
2283+
dm_value = dm.type()
2284+
dm.tag.value_var = dm_value
2285+
dm_logp = logpt(var=dm, rv_values={dm: dm_value}).eval({dm_value: ns_dm}).ravel()
2286+
22812287
assert_almost_equal(
22822288
dm_logp,
22832289
bb_logp,
@@ -2290,19 +2296,19 @@ def test_dirichlet_multinomial_vec(self):
22902296
n = 10
22912297

22922298
with Model() as model_single:
2293-
DirichletMultinomial("m", n=n, a=a)
2299+
pm.DirichletMultinomial("m", n=n, a=a)
22942300

22952301
with Model() as model_many:
2296-
DirichletMultinomial("m", n=n, a=a, size=2)
2302+
pm.DirichletMultinomial("m", n=n, a=a, size=2)
22972303

22982304
assert_almost_equal(
2299-
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
2305+
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
23002306
np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
23012307
decimal=4,
23022308
)
23032309

23042310
assert_almost_equal(
2305-
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
2311+
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
23062312
logpt(model_many.m, vals).eval().squeeze(),
23072313
decimal=4,
23082314
)
@@ -2319,7 +2325,7 @@ def test_dirichlet_multinomial_vec_1d_n(self):
23192325
ns = np.array([10, 11])
23202326

23212327
with Model() as model:
2322-
DirichletMultinomial("m", n=ns, a=a)
2328+
pm.DirichletMultinomial("m", n=ns, a=a)
23232329

23242330
assert_almost_equal(
23252331
sum(dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)),
@@ -2355,7 +2361,6 @@ def test_dirichlet_multinomial_vec_2d_a(self):
23552361
decimal=4,
23562362
)
23572363

2358-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23592364
def test_batch_dirichlet_multinomial(self):
23602365
# Test that DM can handle a 3d array for `a`
23612366

0 commit comments

Comments
 (0)