@@ -2273,7 +2273,6 @@ def test_batch_multinomial(self):
2273
2273
assert_allclose (sample , np .stack ([vals , vals ], axis = 0 ))
2274
2274
2275
2275
@pytest .mark .parametrize ("n" , [2 , 3 ])
2276
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2277
2276
def test_dirichlet_multinomial (self , n ):
2278
2277
self .check_logp (
2279
2278
DirichletMultinomial ,
@@ -2282,43 +2281,47 @@ def test_dirichlet_multinomial(self, n):
2282
2281
dirichlet_multinomial_logpmf ,
2283
2282
)
2284
2283
2285
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2286
2284
def test_dirichlet_multinomial_matches_beta_binomial (self ):
2287
2285
a , b , n = 2 , 1 , 5
2288
2286
ns = np .arange (n + 1 )
2289
- ns_dm = np .vstack ((ns , n - ns )).T # covert ns=1 to ns_dm=[1, 4], for all ns...
2290
- bb_logp = logpt (pm .BetaBinomial .dist (n = n , alpha = a , beta = b ), ns ).tag .test_value
2291
- dm_logp = logpt (
2292
- pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = (1 , 2 )), ns_dm
2293
- ).tag .test_value
2294
- dm_logp = dm_logp .ravel ()
2287
+ ns_dm = np .vstack ((ns , n - ns )).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2288
+
2289
+ bb = pm .BetaBinomial .dist (n = n , alpha = a , beta = b , size = 2 )
2290
+ bb_value = bb .type ()
2291
+ bb .tag .value_var = bb_value
2292
+ bb_logp = logpt (var = bb , rv_values = {bb : bb_value }).eval ({bb_value : ns })
2293
+
2294
+ dm = pm .DirichletMultinomial .dist (n = n , a = [a , b ], size = 2 )
2295
+ dm_value = dm .type ()
2296
+ dm .tag .value_var = dm_value
2297
+ dm_logp = logpt (var = dm , rv_values = {dm : dm_value }).eval ({dm_value : ns_dm }).ravel ()
2298
+
2295
2299
assert_almost_equal (
2296
2300
dm_logp ,
2297
2301
bb_logp ,
2298
2302
decimal = select_by_precision (float64 = 6 , float32 = 3 ),
2299
2303
)
2300
2304
2301
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2302
2305
def test_dirichlet_multinomial_vec (self ):
2303
2306
vals = np .array ([[2 , 4 , 4 ], [3 , 3 , 4 ]])
2304
2307
a = np .array ([0.2 , 0.3 , 0.5 ])
2305
2308
n = 10
2306
2309
2307
2310
with Model () as model_single :
2308
- DirichletMultinomial ("m" , n = n , a = a , size = len ( a ) )
2311
+ DirichletMultinomial ("m" , n = n , a = a )
2309
2312
2310
2313
with Model () as model_many :
2311
- DirichletMultinomial ("m" , n = n , a = a , size = vals . shape )
2314
+ DirichletMultinomial ("m" , n = n , a = a , size = 2 )
2312
2315
2313
2316
assert_almost_equal (
2314
- np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
2317
+ np .asarray ([dirichlet_multinomial_logpmf (val , n , a ) for val in vals ]),
2315
2318
np .asarray ([model_single .fastlogp ({"m" : val }) for val in vals ]),
2316
2319
decimal = 4 ,
2317
2320
)
2318
2321
2319
2322
assert_almost_equal (
2320
- np .asarray ([dirichlet_multinomial_logpmf (v , n , a ) for v in vals ]),
2321
- model_many .free_RVs [ 0 ]. logp_elemwise ({ "m" : vals } ).squeeze (),
2323
+ np .asarray ([dirichlet_multinomial_logpmf (val , n , a ) for val in vals ]),
2324
+ logpt ( model_many .m , vals ). eval ( ).squeeze (),
2322
2325
decimal = 4 ,
2323
2326
)
2324
2327
@@ -2328,56 +2331,52 @@ def test_dirichlet_multinomial_vec(self):
2328
2331
decimal = 4 ,
2329
2332
)
2330
2333
2331
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2332
2334
def test_dirichlet_multinomial_vec_1d_n (self ):
2333
2335
vals = np .array ([[2 , 4 , 4 ], [4 , 3 , 4 ]])
2334
2336
a = np .array ([0.2 , 0.3 , 0.5 ])
2335
2337
ns = np .array ([10 , 11 ])
2336
2338
2337
2339
with Model () as model :
2338
- DirichletMultinomial ("m" , n = ns , a = a , size = vals . shape )
2340
+ DirichletMultinomial ("m" , n = ns , a = a )
2339
2341
2340
2342
assert_almost_equal (
2341
2343
sum (dirichlet_multinomial_logpmf (val , n , a ) for val , n in zip (vals , ns )),
2342
2344
model .fastlogp ({"m" : vals }),
2343
2345
decimal = 4 ,
2344
2346
)
2345
2347
2346
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2347
2348
def test_dirichlet_multinomial_vec_1d_n_2d_a (self ):
2348
2349
vals = np .array ([[2 , 4 , 4 ], [4 , 3 , 4 ]])
2349
2350
as_ = np .array ([[0.2 , 0.3 , 0.5 ], [0.9 , 0.09 , 0.01 ]])
2350
2351
ns = np .array ([10 , 11 ])
2351
2352
2352
2353
with Model () as model :
2353
- DirichletMultinomial ("m" , n = ns , a = as_ , size = vals . shape )
2354
+ DirichletMultinomial ("m" , n = ns , a = as_ )
2354
2355
2355
2356
assert_almost_equal (
2356
2357
sum (dirichlet_multinomial_logpmf (val , n , a ) for val , n , a in zip (vals , ns , as_ )),
2357
2358
model .fastlogp ({"m" : vals }),
2358
2359
decimal = 4 ,
2359
2360
)
2360
2361
2361
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2362
2362
def test_dirichlet_multinomial_vec_2d_a (self ):
2363
2363
vals = np .array ([[2 , 4 , 4 ], [3 , 3 , 4 ]])
2364
2364
as_ = np .array ([[0.2 , 0.3 , 0.5 ], [0.3 , 0.3 , 0.4 ]])
2365
2365
n = 10
2366
2366
2367
2367
with Model () as model :
2368
- DirichletMultinomial ("m" , n = n , a = as_ , size = vals . shape )
2368
+ DirichletMultinomial ("m" , n = n , a = as_ )
2369
2369
2370
2370
assert_almost_equal (
2371
2371
sum (dirichlet_multinomial_logpmf (val , n , a ) for val , a in zip (vals , as_ )),
2372
2372
model .fastlogp ({"m" : vals }),
2373
2373
decimal = 4 ,
2374
2374
)
2375
2375
2376
- @pytest .mark .xfail (reason = "Distribution not refactored yet" )
2377
2376
def test_batch_dirichlet_multinomial (self ):
2378
2377
# Test that DM can handle a 3d array for `a`
2379
2378
2380
- # Create an almost deterministic DM by setting a to 0.001, everywehere
2379
+ # Create an almost deterministic DM by setting a to 0.001, everywhere
2381
2380
# except for one category / dimension which is given the value of 1000
2382
2381
n = 5
2383
2382
vals = np .zeros ((4 , 5 , 3 ), dtype = "int32" )
@@ -2386,19 +2385,23 @@ def test_batch_dirichlet_multinomial(self):
2386
2385
np .put_along_axis (vals , inds , n , axis = - 1 )
2387
2386
np .put_along_axis (a , inds , 1000 , axis = - 1 )
2388
2387
2389
- dist = DirichletMultinomial .dist (n = n , a = a , size = vals . shape )
2388
+ dist = DirichletMultinomial .dist (n = n , a = a )
2390
2389
2391
- # Logp should be approx -9.924431e-06
2392
- dist_logp = logpt (dist , vals ).tag .test_value
2393
- expected_logp = np .full (shape = vals .shape [:- 1 ] + (1 ,), fill_value = - 9.924431e-06 )
2390
+ # Logp should be approx -9.98004998e-06
2391
+ value = at .tensor3 (dtype = "int32" )
2392
+ value .tag .test_value = np .zeros_like (vals , dtype = "int32" )
2393
+ logp = logpt (dist , value )
2394
+ f = aesara .function (inputs = [value ], outputs = logp )
2395
+ expected_logp = np .full (shape = f (vals ).shape , fill_value = - 9.98004998e-06 )
2394
2396
assert_almost_equal (
2395
- dist_logp ,
2397
+ f ( vals ) ,
2396
2398
expected_logp ,
2397
2399
decimal = select_by_precision (float64 = 6 , float32 = 3 ),
2398
2400
)
2399
2401
2400
2402
# Samples should be equal given the almost deterministic DM
2401
- sample = dist .random (size = 2 )
2403
+ dist = DirichletMultinomial .dist (n = n , a = a , size = 2 )
2404
+ sample = dist .eval ()
2402
2405
assert_allclose (sample , np .stack ([vals , vals ], axis = 0 ))
2403
2406
2404
2407
@aesara .config .change_flags (compute_test_value = "raise" )
0 commit comments