@@ -2271,10 +2271,10 @@ def test_dirichlet_multinomial(self, n):
2271
2271
def test_dirichlet_multinomial_matches_beta_binomial (self ):
2272
2272
a , b , n = 2 , 1 , 5
2273
2273
ns = np .arange (n + 1 )
2274
- ns_dm = np .vstack ((ns , n - ns )).T # covert ns=1 to ns_dm=[1, 4], for all ns...
2275
- bb_logp = logpt (pm .BetaBinomial .dist (n = n , alpha = a , beta = b ), ns ).tag .test_value
2274
+ ns_dm = np .vstack ((ns , n - ns )).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2275
+ bb_logp = logpt (pm .BetaBinomial .dist (n = n , alpha = a , beta = b , size = 2 ), ns ).tag .test_value
2276
2276
dm_logp = logpt (
2277
- pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = ( 1 , 2 ) ), ns_dm
2277
+ pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = 2 ), ns_dm
2278
2278
).tag .test_value
2279
2279
dm_logp = dm_logp .ravel ()
2280
2280
assert_almost_equal (
@@ -2289,10 +2289,10 @@ def test_dirichlet_multinomial_vec(self):
2289
2289
n = 10
2290
2290
2291
2291
with Model () as model_single :
2292
- DirichletMultinomial ("m" , n = n , a = a , size = len ( a ) )
2292
+ DirichletMultinomial ("m" , n = n , a = a )
2293
2293
2294
2294
with Model () as model_many :
2295
- DirichletMultinomial ("m" , n = n , a = a , size = vals . shape )
2295
+ DirichletMultinomial ("m" , n = n , a = a , size = 2 )
2296
2296
2297
2297
assert_almost_equal (
2298
2298
np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
@@ -2302,7 +2302,7 @@ def test_dirichlet_multinomial_vec(self):
2302
2302
2303
2303
assert_almost_equal (
2304
2304
np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
2305
- model_many .free_RVs [ 0 ]. logp_elemwise ({ "m" : vals } ).squeeze (),
2305
+ logpt ( model_many .m , vals ). eval ( ).squeeze (),
2306
2306
decimal = 4 ,
2307
2307
)
2308
2308
@@ -2318,7 +2318,7 @@ def test_dirichlet_multinomial_vec_1d_n(self):
2318
2318
ns = np .array ([10 , 11 ])
2319
2319
2320
2320
with Model () as model :
2321
- DirichletMultinomial ("m" , n = ns , a = a , size = vals . shape )
2321
+ DirichletMultinomial ("m" , n = ns , a = a )
2322
2322
2323
2323
assert_almost_equal (
2324
2324
sum (dirichlet_multinomial_logpmf (val , n , a ) for val , n in zip (vals , ns )),
@@ -2332,7 +2332,7 @@ def test_dirichlet_multinomial_vec_1d_n_2d_a(self):
2332
2332
ns = np .array ([10 , 11 ])
2333
2333
2334
2334
with Model () as model :
2335
- DirichletMultinomial ("m" , n = ns , a = as_ , size = vals . shape )
2335
+ DirichletMultinomial ("m" , n = ns , a = as_ )
2336
2336
2337
2337
assert_almost_equal (
2338
2338
sum (dirichlet_multinomial_logpmf (val , n , a ) for val , n , a in zip (vals , ns , as_ )),
@@ -2346,7 +2346,7 @@ def test_dirichlet_multinomial_vec_2d_a(self):
2346
2346
n = 10
2347
2347
2348
2348
with Model () as model :
2349
- DirichletMultinomial ("m" , n = n , a = as_ , size = vals . shape )
2349
+ DirichletMultinomial ("m" , n = n , a = as_ )
2350
2350
2351
2351
assert_almost_equal (
2352
2352
sum (dirichlet_multinomial_logpmf (val , n , a ) for val , a in zip (vals , as_ )),
@@ -2358,7 +2358,7 @@ def test_dirichlet_multinomial_vec_2d_a(self):
2358
2358
def test_batch_dirichlet_multinomial (self ):
2359
2359
# Test that DM can handle a 3d array for `a`
2360
2360
2361
- # Create an almost deterministic DM by setting a to 0.001, everywehere
2361
+ # Create an almost deterministic DM by setting a to 0.001, everywhere
2362
2362
# except for one category / dimension which is given the value of 1000
2363
2363
n = 5
2364
2364
vals = np .zeros ((4 , 5 , 3 ), dtype = "int32" )
@@ -2367,19 +2367,20 @@ def test_batch_dirichlet_multinomial(self):
2367
2367
np .put_along_axis (vals , inds , n , axis = - 1 )
2368
2368
np .put_along_axis (a , inds , 1000 , axis = - 1 )
2369
2369
2370
- dist = DirichletMultinomial .dist (n = n , a = a , size = vals . shape )
2370
+ dist = DirichletMultinomial .dist (n = n , a = a )
2371
2371
2372
2372
# Logp should be approx -9.924431e-06
2373
2373
dist_logp = logpt (dist , vals ).tag .test_value
2374
- expected_logp = np .full (shape = vals .shape [:- 1 ] + ( 1 ,) , fill_value = - 9.924431e-06 )
2374
+ expected_logp = np .full (shape = vals .shape [:- 1 ], fill_value = - 9.924431e-06 )
2375
2375
assert_almost_equal (
2376
2376
dist_logp ,
2377
2377
expected_logp ,
2378
2378
decimal = select_by_precision (float64 = 6 , float32 = 3 ),
2379
2379
)
2380
2380
2381
2381
# Samples should be equal given the almost deterministic DM
2382
- sample = dist .random (size = 2 )
2382
+ dist = DirichletMultinomial .dist (n = n , a = a , size = 2 )
2383
+ sample = dist .eval ()
2383
2384
assert_allclose (sample , np .stack ([vals , vals ], axis = 0 ))
2384
2385
2385
2386
@aesara .config .change_flags (compute_test_value = "raise" )
0 commit comments