Skip to content

Commit f0e4fdb

Browse files
AlexAndorrabrandonwillard
authored andcommitted
Start adapting tests to new v4 implementation
1 parent 8a8503a commit f0e4fdb

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

pymc3/tests/test_distributions.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,10 +2271,10 @@ def test_dirichlet_multinomial(self, n):
22712271
def test_dirichlet_multinomial_matches_beta_binomial(self):
22722272
a, b, n = 2, 1, 5
22732273
ns = np.arange(n + 1)
2274-
ns_dm = np.vstack((ns, n - ns)).T # covert ns=1 to ns_dm=[1, 4], for all ns...
2275-
bb_logp = logpt(pm.BetaBinomial.dist(n=n, alpha=a, beta=b), ns).tag.test_value
2274+
ns_dm = np.vstack((ns, n - ns)).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2275+
bb_logp = logpt(pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2), ns).tag.test_value
22762276
dm_logp = logpt(
2277-
pm.DirichletMultinomial.dist(n=n, a=[a, b], size=(1, 2)), ns_dm
2277+
pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2), ns_dm
22782278
).tag.test_value
22792279
dm_logp = dm_logp.ravel()
22802280
assert_almost_equal(
@@ -2289,10 +2289,10 @@ def test_dirichlet_multinomial_vec(self):
22892289
n = 10
22902290

22912291
with Model() as model_single:
2292-
DirichletMultinomial("m", n=n, a=a, size=len(a))
2292+
DirichletMultinomial("m", n=n, a=a)
22932293

22942294
with Model() as model_many:
2295-
DirichletMultinomial("m", n=n, a=a, size=vals.shape)
2295+
DirichletMultinomial("m", n=n, a=a, size=2)
22962296

22972297
assert_almost_equal(
22982298
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
@@ -2302,7 +2302,7 @@ def test_dirichlet_multinomial_vec(self):
23022302

23032303
assert_almost_equal(
23042304
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
2305-
model_many.free_RVs[0].logp_elemwise({"m": vals}).squeeze(),
2305+
logpt(model_many.m, vals).eval().squeeze(),
23062306
decimal=4,
23072307
)
23082308

@@ -2318,7 +2318,7 @@ def test_dirichlet_multinomial_vec_1d_n(self):
23182318
ns = np.array([10, 11])
23192319

23202320
with Model() as model:
2321-
DirichletMultinomial("m", n=ns, a=a, size=vals.shape)
2321+
DirichletMultinomial("m", n=ns, a=a)
23222322

23232323
assert_almost_equal(
23242324
sum(dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)),
@@ -2332,7 +2332,7 @@ def test_dirichlet_multinomial_vec_1d_n_2d_a(self):
23322332
ns = np.array([10, 11])
23332333

23342334
with Model() as model:
2335-
DirichletMultinomial("m", n=ns, a=as_, size=vals.shape)
2335+
DirichletMultinomial("m", n=ns, a=as_)
23362336

23372337
assert_almost_equal(
23382338
sum(dirichlet_multinomial_logpmf(val, n, a) for val, n, a in zip(vals, ns, as_)),
@@ -2346,7 +2346,7 @@ def test_dirichlet_multinomial_vec_2d_a(self):
23462346
n = 10
23472347

23482348
with Model() as model:
2349-
DirichletMultinomial("m", n=n, a=as_, size=vals.shape)
2349+
DirichletMultinomial("m", n=n, a=as_)
23502350

23512351
assert_almost_equal(
23522352
sum(dirichlet_multinomial_logpmf(val, n, a) for val, a in zip(vals, as_)),
@@ -2358,7 +2358,7 @@ def test_dirichlet_multinomial_vec_2d_a(self):
23582358
def test_batch_dirichlet_multinomial(self):
23592359
# Test that DM can handle a 3d array for `a`
23602360

2361-
# Create an almost deterministic DM by setting a to 0.001, everywehere
2361+
# Create an almost deterministic DM by setting a to 0.001, everywhere
23622362
# except for one category / dimension which is given the value of 1000
23632363
n = 5
23642364
vals = np.zeros((4, 5, 3), dtype="int32")
@@ -2367,19 +2367,20 @@ def test_batch_dirichlet_multinomial(self):
23672367
np.put_along_axis(vals, inds, n, axis=-1)
23682368
np.put_along_axis(a, inds, 1000, axis=-1)
23692369

2370-
dist = DirichletMultinomial.dist(n=n, a=a, size=vals.shape)
2370+
dist = DirichletMultinomial.dist(n=n, a=a)
23712371

23722372
# Logp should be approx -9.924431e-06
23732373
dist_logp = logpt(dist, vals).tag.test_value
2374-
expected_logp = np.full(shape=vals.shape[:-1] + (1,), fill_value=-9.924431e-06)
2374+
expected_logp = np.full(shape=vals.shape[:-1], fill_value=-9.924431e-06)
23752375
assert_almost_equal(
23762376
dist_logp,
23772377
expected_logp,
23782378
decimal=select_by_precision(float64=6, float32=3),
23792379
)
23802380

23812381
# Samples should be equal given the almost deterministic DM
2382-
sample = dist.random(size=2)
2382+
dist = DirichletMultinomial.dist(n=n, a=a, size=2)
2383+
sample = dist.eval()
23832384
assert_allclose(sample, np.stack([vals, vals], axis=0))
23842385

23852386
@aesara.config.change_flags(compute_test_value="raise")

0 commit comments

Comments
 (0)