Skip to content

Commit 3114ec2

Browse files
committed
Fix call of logpt in test_dirichlet_multinomial_matches_beta_binomial
1 parent aef5af2 commit 3114ec2

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

pymc3/tests/test_distributions.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -2260,11 +2260,17 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
22602260
a, b, n = 2, 1, 5
22612261
ns = np.arange(n + 1)
22622262
ns_dm = np.vstack((ns, n - ns)).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2263-
bb_logp = logpt(var=pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2), rv_values=ns).eval()
2264-
dm_logp = logpt(
2265-
var=pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2),
2266-
rv_values=ns_dm,
2267-
).eval().ravel()
2263+
2264+
bb = pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2)
2265+
bb_value = bb.type()
2266+
bb.tag.value_var = bb_value
2267+
bb_logp = logpt(var=bb, rv_values={bb: bb_value}).eval({bb_value: ns})
2268+
2269+
dm = pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2)
2270+
dm_value = dm.type()
2271+
dm.tag.value_var = dm_value
2272+
dm_logp = logpt(var=dm, rv_values={dm: dm_value}).eval({dm_value: ns_dm}).ravel()
2273+
22682274
assert_almost_equal(
22692275
dm_logp,
22702276
bb_logp,
@@ -2277,19 +2283,19 @@ def test_dirichlet_multinomial_vec(self):
22772283
n = 10
22782284

22792285
with Model() as model_single:
2280-
DirichletMultinomial("m", n=n, a=a)
2286+
pm.DirichletMultinomial("m", n=n, a=a)
22812287

22822288
with Model() as model_many:
2283-
DirichletMultinomial("m", n=n, a=a, size=2)
2289+
pm.DirichletMultinomial("m", n=n, a=a, size=2)
22842290

22852291
assert_almost_equal(
2286-
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
2292+
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
22872293
np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
22882294
decimal=4,
22892295
)
22902296

22912297
assert_almost_equal(
2292-
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
2298+
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
22932299
logpt(model_many.m, vals).eval().squeeze(),
22942300
decimal=4,
22952301
)
@@ -2306,7 +2312,7 @@ def test_dirichlet_multinomial_vec_1d_n(self):
23062312
ns = np.array([10, 11])
23072313

23082314
with Model() as model:
2309-
DirichletMultinomial("m", n=ns, a=a)
2315+
pm.DirichletMultinomial("m", n=ns, a=a)
23102316

23112317
assert_almost_equal(
23122318
sum([dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)]),
@@ -2342,7 +2348,6 @@ def test_dirichlet_multinomial_vec_2d_a(self):
23422348
decimal=4,
23432349
)
23442350

2345-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23462351
def test_batch_dirichlet_multinomial(self):
23472352
# Test that DM can handle a 3d array for `a`
23482353

0 commit comments

Comments
 (0)