@@ -2260,11 +2260,17 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
2260
2260
a , b , n = 2 , 1 , 5
2261
2261
ns = np .arange (n + 1 )
2262
2262
ns_dm = np .vstack ((ns , n - ns )).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2263
- bb_logp = logpt (var = pm .BetaBinomial .dist (n = n , alpha = a , beta = b , size = 2 ), rv_values = ns ).eval ()
2264
- dm_logp = logpt (
2265
- var = pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = 2 ),
2266
- rv_values = ns_dm ,
2267
- ).eval ().ravel ()
2263
+
2264
+ bb = pm .BetaBinomial .dist (n = n , alpha = a , beta = b , size = 2 )
2265
+ bb_value = bb .type ()
2266
+ bb .tag .value_var = bb_value
2267
+ bb_logp = logpt (var = bb , rv_values = {bb : bb_value }).eval ({bb_value : ns })
2268
+
2269
+ dm = pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = 2 )
2270
+ dm_value = dm .type ()
2271
+ dm .tag .value_var = dm_value
2272
+ dm_logp = logpt (var = dm , rv_values = {dm : dm_value }).eval ({dm_value : ns_dm }).ravel ()
2273
+
2268
2274
assert_almost_equal (
2269
2275
dm_logp ,
2270
2276
bb_logp ,
@@ -2277,19 +2283,19 @@ def test_dirichlet_multinomial_vec(self):
2277
2283
n = 10
2278
2284
2279
2285
with Model () as model_single :
2280
- DirichletMultinomial ("m" , n = n , a = a )
2286
+ pm . DirichletMultinomial ("m" , n = n , a = a )
2281
2287
2282
2288
with Model () as model_many :
2283
- DirichletMultinomial ("m" , n = n , a = a , size = 2 )
2289
+ pm . DirichletMultinomial ("m" , n = n , a = a , size = 2 )
2284
2290
2285
2291
assert_almost_equal (
2286
- np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
2292
+ np .asarray ([dirichlet_multinomial_logpmf (val , n , a ) for val in vals ]),
2287
2293
np .asarray ([model_single .fastlogp ({"m" : val }) for val in vals ]),
2288
2294
decimal = 4 ,
2289
2295
)
2290
2296
2291
2297
assert_almost_equal (
2292
- np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
2298
+ np .asarray ([dirichlet_multinomial_logpmf (val , n , a ) for val in vals ]),
2293
2299
logpt (model_many .m , vals ).eval ().squeeze (),
2294
2300
decimal = 4 ,
2295
2301
)
@@ -2306,7 +2312,7 @@ def test_dirichlet_multinomial_vec_1d_n(self):
2306
2312
ns = np .array ([10 , 11 ])
2307
2313
2308
2314
with Model () as model :
2309
- DirichletMultinomial ("m" , n = ns , a = a )
2315
+ pm . DirichletMultinomial ("m" , n = ns , a = a )
2310
2316
2311
2317
assert_almost_equal (
2312
2318
sum ([dirichlet_multinomial_logpmf (val , n , a ) for val , n in zip (vals , ns )]),
@@ -2342,7 +2348,6 @@ def test_dirichlet_multinomial_vec_2d_a(self):
2342
2348
decimal = 4 ,
2343
2349
)
2344
2350
2345
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2346
2351
def test_batch_dirichlet_multinomial (self ):
2347
2352
# Test that DM can handle a 3d array for `a`
2348
2353
0 commit comments