@@ -143,6 +143,11 @@ def test_to_idata(self, data, eight_schools_params, chains, draws):
143
143
np .isclose (ivalues [chain ], values [chain * draws : (chain + 1 ) * draws ])
144
144
)
145
145
146
+ chains = inference_data .posterior .dims ["chain" ]
147
+ draws = inference_data .posterior .dims ["draw" ]
148
+ obs = inference_data .observed_data ["obs" ]
149
+ assert inference_data .log_likelihood ["obs" ].shape == (chains , draws ) + obs .shape
150
+
146
151
def test_predictions_to_idata (self , data , eight_schools_params ):
147
152
"Test that we can add predictions to a previously-existing InferenceData."
148
153
test_dict = {
@@ -329,6 +334,11 @@ def test_missing_data_model(self):
329
334
fails = check_multiple_attrs (test_dict , inference_data )
330
335
assert not fails
331
336
337
+ # The missing part of partial observed RVs is not included in log_likelihood
338
+ # See https://github.com/pymc-devs/pymc/issues/5255
339
+ assert inference_data .log_likelihood ["y_observed" ].shape == (2 , 100 , 3 )
340
+
341
+ @pytest .mark .xfal (reason = "Multivariate partial observed RVs not implemented for V4" )
332
342
@pytest .mark .xfail (reason = "LKJCholeskyCov not refactored for v4" )
333
343
def test_mv_missing_data_model (self ):
334
344
data = ma .masked_values ([[1 , 2 ], [2 , 2 ], [- 1 , 4 ], [2 , - 1 ], [- 1 , - 1 ]], value = - 1 )
@@ -375,8 +385,12 @@ def test_multiple_observed_rv(self, log_likelihood):
375
385
if not log_likelihood :
376
386
test_dict .pop ("log_likelihood" )
377
387
test_dict ["~log_likelihood" ] = []
378
- if isinstance (log_likelihood , list ):
388
+ elif isinstance (log_likelihood , list ):
379
389
test_dict ["log_likelihood" ] = ["y1" , "~y2" ]
390
+ assert inference_data .log_likelihood ["y1" ].shape == (2 , 100 , 10 )
391
+ else :
392
+ assert inference_data .log_likelihood ["y1" ].shape == (2 , 100 , 10 )
393
+ assert inference_data .log_likelihood ["y2" ].shape == (2 , 100 , 100 )
380
394
381
395
fails = check_multiple_attrs (test_dict , inference_data )
382
396
assert not fails
@@ -445,12 +459,12 @@ def test_single_observation(self):
445
459
inference_data = pm .sample (500 , chains = 2 , return_inferencedata = True )
446
460
447
461
assert inference_data
462
+ assert inference_data .log_likelihood ["w" ].shape == (2 , 500 , 1 )
448
463
449
- @pytest .mark .xfail (reason = "Potential not refactored for v4" )
450
464
def test_potential (self ):
451
465
with pm .Model ():
452
466
x = pm .Normal ("x" , 0.0 , 1.0 )
453
- pm .Potential ("z" , logpt (pm .Normal .dist (x , 1.0 ), np .random .randn (10 )))
467
+ pm .Potential ("z" , pm . logp (pm .Normal .dist (x , 1.0 ), np .random .randn (10 )))
454
468
inference_data = pm .sample (100 , chains = 2 , return_inferencedata = True )
455
469
456
470
assert inference_data
@@ -463,7 +477,7 @@ def test_constant_data(self, use_context):
463
477
y = pm .Data ("y" , [1.0 , 2.0 , 3.0 ])
464
478
beta = pm .Normal ("beta" , 0 , 1 )
465
479
obs = pm .Normal ("obs" , x * beta , 1 , observed = y ) # pylint: disable=unused-variable
466
- trace = pm .sample (100 , tune = 100 , return_inferencedata = False )
480
+ trace = pm .sample (100 , chains = 2 , tune = 100 , return_inferencedata = False )
467
481
if use_context :
468
482
inference_data = to_inference_data (trace = trace )
469
483
@@ -472,6 +486,7 @@ def test_constant_data(self, use_context):
472
486
test_dict = {"posterior" : ["beta" ], "observed_data" : ["obs" ], "constant_data" : ["x" ]}
473
487
fails = check_multiple_attrs (test_dict , inference_data )
474
488
assert not fails
489
+ assert inference_data .log_likelihood ["obs" ].shape == (2 , 100 , 3 )
475
490
476
491
def test_predictions_constant_data (self ):
477
492
with pm .Model ():
@@ -570,7 +585,7 @@ def test_multivariate_observations(self):
570
585
with pm .Model (coords = coords ):
571
586
p = pm .Beta ("p" , 1 , 1 , size = (3 ,))
572
587
pm .Multinomial ("y" , 20 , p , dims = ("experiment" , "direction" ), observed = data )
573
- idata = pm .sample (draws = 50 , tune = 100 , return_inferencedata = True )
588
+ idata = pm .sample (draws = 50 , chains = 2 , tune = 100 , return_inferencedata = True )
574
589
test_dict = {
575
590
"posterior" : ["p" ],
576
591
"sample_stats" : ["lp" ],
@@ -581,6 +596,7 @@ def test_multivariate_observations(self):
581
596
assert not fails
582
597
assert "direction" not in idata .log_likelihood .dims
583
598
assert "direction" in idata .observed_data .dims
599
+ assert idata .log_likelihood ["y" ].shape == (2 , 50 , 20 )
584
600
585
601
def test_constant_data_coords_issue_5046 (self ):
586
602
"""This is a regression test against a bug where a local coords variable was overwritten."""
0 commit comments