@@ -2273,11 +2273,17 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
2273
2273
a , b , n = 2 , 1 , 5
2274
2274
ns = np .arange (n + 1 )
2275
2275
ns_dm = np .vstack ((ns , n - ns )).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2276
- bb_logp = logpt (var = pm .BetaBinomial .dist (n = n , alpha = a , beta = b , size = 2 ), rv_values = ns ).eval ()
2277
- dm_logp = logpt (
2278
- var = pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = 2 ),
2279
- rv_values = ns_dm ,
2280
- ).eval ().ravel ()
2276
+
2277
+ bb = pm .BetaBinomial .dist (n = n , alpha = a , beta = b , size = 2 )
2278
+ bb_value = bb .type ()
2279
+ bb .tag .value_var = bb_value
2280
+ bb_logp = logpt (var = bb , rv_values = {bb : bb_value }).eval ({bb_value : ns })
2281
+
2282
+ dm = pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = 2 )
2283
+ dm_value = dm .type ()
2284
+ dm .tag .value_var = dm_value
2285
+ dm_logp = logpt (var = dm , rv_values = {dm : dm_value }).eval ({dm_value : ns_dm }).ravel ()
2286
+
2281
2287
assert_almost_equal (
2282
2288
dm_logp ,
2283
2289
bb_logp ,
@@ -2290,19 +2296,19 @@ def test_dirichlet_multinomial_vec(self):
2290
2296
n = 10
2291
2297
2292
2298
with Model () as model_single :
2293
- DirichletMultinomial ("m" , n = n , a = a )
2299
+ pm . DirichletMultinomial ("m" , n = n , a = a )
2294
2300
2295
2301
with Model () as model_many :
2296
- DirichletMultinomial ("m" , n = n , a = a , size = 2 )
2302
+ pm . DirichletMultinomial ("m" , n = n , a = a , size = 2 )
2297
2303
2298
2304
assert_almost_equal (
2299
- np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
2305
+ np .asarray ([dirichlet_multinomial_logpmf (val , n , a ) for val in vals ]),
2300
2306
np .asarray ([model_single .fastlogp ({"m" : val }) for val in vals ]),
2301
2307
decimal = 4 ,
2302
2308
)
2303
2309
2304
2310
assert_almost_equal (
2305
- np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
2311
+ np .asarray ([dirichlet_multinomial_logpmf (val , n , a ) for val in vals ]),
2306
2312
logpt (model_many .m , vals ).eval ().squeeze (),
2307
2313
decimal = 4 ,
2308
2314
)
@@ -2319,7 +2325,7 @@ def test_dirichlet_multinomial_vec_1d_n(self):
2319
2325
ns = np .array ([10 , 11 ])
2320
2326
2321
2327
with Model () as model :
2322
- DirichletMultinomial ("m" , n = ns , a = a )
2328
+ pm . DirichletMultinomial ("m" , n = ns , a = a )
2323
2329
2324
2330
assert_almost_equal (
2325
2331
sum (dirichlet_multinomial_logpmf (val , n , a ) for val , n in zip (vals , ns )),
@@ -2355,7 +2361,6 @@ def test_dirichlet_multinomial_vec_2d_a(self):
2355
2361
decimal = 4 ,
2356
2362
)
2357
2363
2358
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2359
2364
def test_batch_dirichlet_multinomial (self ):
2360
2365
# Test that DM can handle a 3d array for `a`
2361
2366
0 commit comments