Skip to content

Commit a50b386

Browse files
ricardoV94twiecki
authored andcommitted
Use model.logp_elemwise in InferenceDataConverter
1 parent 69815d9 commit a50b386

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

pymc/backends/arviz.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import pymc
2525

2626
from pymc.aesaraf import extract_obs_data
27-
from pymc.distributions import logpt
2827
from pymc.model import modelcontext
2928
from pymc.util import get_default_varnames
3029

@@ -264,11 +263,15 @@ def _extract_log_likelihood(self, trace):
264263
if self.model is None:
265264
return None
266265

266+
# TODO: We no longer need one function per observed variable
267267
if self.log_likelihood is True:
268-
cached = [(var, self.model.fn(logpt(var))) for var in self.model.observed_RVs]
268+
cached = [
269+
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
270+
for var in self.model.observed_RVs
271+
]
269272
else:
270273
cached = [
271-
(var, self.model.fn(logpt(var)))
274+
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
272275
for var in self.model.observed_RVs
273276
if var.name in self.log_likelihood
274277
]

pymc/sampling_jax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _get_log_likelihood(model, samples):
123123
"Compute log-likelihood for all observations"
124124
data = {}
125125
for v in model.observed_RVs:
126-
logp_v = replace_shared_variables([logpt(v)])
126+
logp_v = replace_shared_variables([model.logp_elemwiset(v)[0]])
127127
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
128128
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
129129
jax_fn = jax_funcify(fgraph)

pymc/tests/test_idata_conversion.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def test_to_idata(self, data, eight_schools_params, chains, draws):
143143
np.isclose(ivalues[chain], values[chain * draws : (chain + 1) * draws])
144144
)
145145

146+
chains = inference_data.posterior.dims["chain"]
147+
draws = inference_data.posterior.dims["draw"]
148+
obs = inference_data.observed_data["obs"]
149+
assert inference_data.log_likelihood["obs"].shape == (chains, draws) + obs.shape
150+
146151
def test_predictions_to_idata(self, data, eight_schools_params):
147152
"Test that we can add predictions to a previously-existing InferenceData."
148153
test_dict = {
@@ -329,6 +334,11 @@ def test_missing_data_model(self):
329334
fails = check_multiple_attrs(test_dict, inference_data)
330335
assert not fails
331336

337+
# The missing part of partial observed RVs is not included in log_likelihood
338+
# See https://github.com/pymc-devs/pymc/issues/5255
339+
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)
340+
341+
@pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
332342
@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
333343
def test_mv_missing_data_model(self):
334344
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)
@@ -375,8 +385,12 @@ def test_multiple_observed_rv(self, log_likelihood):
375385
if not log_likelihood:
376386
test_dict.pop("log_likelihood")
377387
test_dict["~log_likelihood"] = []
378-
if isinstance(log_likelihood, list):
388+
elif isinstance(log_likelihood, list):
379389
test_dict["log_likelihood"] = ["y1", "~y2"]
390+
assert inference_data.log_likelihood["y1"].shape == (2, 100, 10)
391+
else:
392+
assert inference_data.log_likelihood["y1"].shape == (2, 100, 10)
393+
assert inference_data.log_likelihood["y2"].shape == (2, 100, 100)
380394

381395
fails = check_multiple_attrs(test_dict, inference_data)
382396
assert not fails
@@ -445,12 +459,12 @@ def test_single_observation(self):
445459
inference_data = pm.sample(500, chains=2, return_inferencedata=True)
446460

447461
assert inference_data
462+
assert inference_data.log_likelihood["w"].shape == (2, 500, 1)
448463

449-
@pytest.mark.xfail(reason="Potential not refactored for v4")
450464
def test_potential(self):
451465
with pm.Model():
452466
x = pm.Normal("x", 0.0, 1.0)
453-
pm.Potential("z", logpt(pm.Normal.dist(x, 1.0), np.random.randn(10)))
467+
pm.Potential("z", pm.logp(pm.Normal.dist(x, 1.0), np.random.randn(10)))
454468
inference_data = pm.sample(100, chains=2, return_inferencedata=True)
455469

456470
assert inference_data
@@ -463,7 +477,7 @@ def test_constant_data(self, use_context):
463477
y = pm.Data("y", [1.0, 2.0, 3.0])
464478
beta = pm.Normal("beta", 0, 1)
465479
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
466-
trace = pm.sample(100, tune=100, return_inferencedata=False)
480+
trace = pm.sample(100, chains=2, tune=100, return_inferencedata=False)
467481
if use_context:
468482
inference_data = to_inference_data(trace=trace)
469483

@@ -472,6 +486,7 @@ def test_constant_data(self, use_context):
472486
test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]}
473487
fails = check_multiple_attrs(test_dict, inference_data)
474488
assert not fails
489+
assert inference_data.log_likelihood["obs"].shape == (2, 100, 3)
475490

476491
def test_predictions_constant_data(self):
477492
with pm.Model():
@@ -570,7 +585,7 @@ def test_multivariate_observations(self):
570585
with pm.Model(coords=coords):
571586
p = pm.Beta("p", 1, 1, size=(3,))
572587
pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data)
573-
idata = pm.sample(draws=50, tune=100, return_inferencedata=True)
588+
idata = pm.sample(draws=50, chains=2, tune=100, return_inferencedata=True)
574589
test_dict = {
575590
"posterior": ["p"],
576591
"sample_stats": ["lp"],
@@ -581,6 +596,7 @@ def test_multivariate_observations(self):
581596
assert not fails
582597
assert "direction" not in idata.log_likelihood.dims
583598
assert "direction" in idata.observed_data.dims
599+
assert idata.log_likelihood["y"].shape == (2, 50, 20)
584600

585601
def test_constant_data_coords_issue_5046(self):
586602
"""This is a regression test against a bug where a local coords variable was overwritten."""

0 commit comments

Comments
 (0)