@@ -143,6 +143,11 @@ def test_to_idata(self, data, eight_schools_params, chains, draws):
143143 np .isclose (ivalues [chain ], values [chain * draws : (chain + 1 ) * draws ])
144144 )
145145
146+ chains = inference_data .posterior .dims ["chain" ]
147+ draws = inference_data .posterior .dims ["draw" ]
148+ obs = inference_data .observed_data ["obs" ]
149+ assert inference_data .log_likelihood ["obs" ].shape == (chains , draws ) + obs .shape
150+
146151 def test_predictions_to_idata (self , data , eight_schools_params ):
147152 "Test that we can add predictions to a previously-existing InferenceData."
148153 test_dict = {
@@ -329,6 +334,11 @@ def test_missing_data_model(self):
329334 fails = check_multiple_attrs (test_dict , inference_data )
330335 assert not fails
331336
337+ # The missing part of partial observed RVs is not included in log_likelihood
338+ # See https://github.com/pymc-devs/pymc/issues/5255
339+ assert inference_data .log_likelihood ["y_observed" ].shape == (2 , 100 , 3 )
340+
341+ @pytest .mark .xfal (reason = "Multivariate partial observed RVs not implemented for V4" )
332342 @pytest .mark .xfail (reason = "LKJCholeskyCov not refactored for v4" )
333343 def test_mv_missing_data_model (self ):
334344 data = ma .masked_values ([[1 , 2 ], [2 , 2 ], [- 1 , 4 ], [2 , - 1 ], [- 1 , - 1 ]], value = - 1 )
@@ -375,8 +385,12 @@ def test_multiple_observed_rv(self, log_likelihood):
375385 if not log_likelihood :
376386 test_dict .pop ("log_likelihood" )
377387 test_dict ["~log_likelihood" ] = []
378- if isinstance (log_likelihood , list ):
388+ elif isinstance (log_likelihood , list ):
379389 test_dict ["log_likelihood" ] = ["y1" , "~y2" ]
390+ assert inference_data .log_likelihood ["y1" ].shape == (2 , 100 , 10 )
391+ else :
392+ assert inference_data .log_likelihood ["y1" ].shape == (2 , 100 , 10 )
393+ assert inference_data .log_likelihood ["y2" ].shape == (2 , 100 , 100 )
380394
381395 fails = check_multiple_attrs (test_dict , inference_data )
382396 assert not fails
@@ -445,12 +459,12 @@ def test_single_observation(self):
445459 inference_data = pm .sample (500 , chains = 2 , return_inferencedata = True )
446460
447461 assert inference_data
462+ assert inference_data .log_likelihood ["w" ].shape == (2 , 500 , 1 )
448463
449- @pytest .mark .xfail (reason = "Potential not refactored for v4" )
450464 def test_potential (self ):
451465 with pm .Model ():
452466 x = pm .Normal ("x" , 0.0 , 1.0 )
453- pm .Potential ("z" , logpt (pm .Normal .dist (x , 1.0 ), np .random .randn (10 )))
467+ pm .Potential ("z" , pm . logp (pm .Normal .dist (x , 1.0 ), np .random .randn (10 )))
454468 inference_data = pm .sample (100 , chains = 2 , return_inferencedata = True )
455469
456470 assert inference_data
@@ -463,7 +477,7 @@ def test_constant_data(self, use_context):
463477 y = pm .Data ("y" , [1.0 , 2.0 , 3.0 ])
464478 beta = pm .Normal ("beta" , 0 , 1 )
465479 obs = pm .Normal ("obs" , x * beta , 1 , observed = y ) # pylint: disable=unused-variable
466- trace = pm .sample (100 , tune = 100 , return_inferencedata = False )
480+ trace = pm .sample (100 , chains = 2 , tune = 100 , return_inferencedata = False )
467481 if use_context :
468482 inference_data = to_inference_data (trace = trace )
469483
@@ -472,6 +486,7 @@ def test_constant_data(self, use_context):
472486 test_dict = {"posterior" : ["beta" ], "observed_data" : ["obs" ], "constant_data" : ["x" ]}
473487 fails = check_multiple_attrs (test_dict , inference_data )
474488 assert not fails
489+ assert inference_data .log_likelihood ["obs" ].shape == (2 , 100 , 3 )
475490
476491 def test_predictions_constant_data (self ):
477492 with pm .Model ():
@@ -570,7 +585,7 @@ def test_multivariate_observations(self):
570585 with pm .Model (coords = coords ):
571586 p = pm .Beta ("p" , 1 , 1 , size = (3 ,))
572587 pm .Multinomial ("y" , 20 , p , dims = ("experiment" , "direction" ), observed = data )
573- idata = pm .sample (draws = 50 , tune = 100 , return_inferencedata = True )
588+ idata = pm .sample (draws = 50 , chains = 2 , tune = 100 , return_inferencedata = True )
574589 test_dict = {
575590 "posterior" : ["p" ],
576591 "sample_stats" : ["lp" ],
@@ -581,6 +596,7 @@ def test_multivariate_observations(self):
581596 assert not fails
582597 assert "direction" not in idata .log_likelihood .dims
583598 assert "direction" in idata .observed_data .dims
599+ assert idata .log_likelihood ["y" ].shape == (2 , 50 , 20 )
584600
585601 def test_constant_data_coords_issue_5046 (self ):
586602 """This is a regression test against a bug where a local coords variable was overwritten."""
0 commit comments