|
25 | 25 |
|
26 | 26 | __all__ = ("compute_log_likelihood", "compute_log_prior")
|
27 | 27 |
|
| 28 | +from pymc.model.transform.conditioning import remove_value_transforms |
| 29 | + |
28 | 30 |
|
29 | 31 | def compute_log_likelihood(
|
30 | 32 | idata: InferenceData,
|
@@ -126,46 +128,35 @@ def compute_log_density(
|
126 | 128 | if kind not in ("likelihood", "prior"):
|
127 | 129 | raise ValueError("kind must be either 'likelihood' or 'prior'")
|
128 | 130 |
|
| 131 | + # We need to disable transforms, because the InferenceData only keeps the untransformed values |
| 132 | + umodel = remove_value_transforms(model) |
| 133 | + |
129 | 134 | if kind == "likelihood":
|
130 |
| - target_rvs = model.observed_RVs |
| 135 | + target_rvs = list(umodel.observed_RVs) |
131 | 136 | target_str = "observed_RVs"
|
132 | 137 | else:
|
133 |
| - target_rvs = model.free_RVs |
| 138 | + target_rvs = list(umodel.free_RVs) |
134 | 139 | target_str = "free_RVs"
|
135 | 140 |
|
136 | 141 | if var_names is None:
|
137 | 142 | vars = target_rvs
|
138 | 143 | var_names = tuple(rv.name for rv in vars)
|
139 | 144 | else:
|
140 |
| - vars = [model.named_vars[name] for name in var_names] |
| 145 | + vars = [umodel.named_vars[name] for name in var_names] |
141 | 146 | if not set(vars).issubset(target_rvs):
|
142 | 147 | raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")
|
143 | 148 |
|
144 |
| - # We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values |
145 |
| - try: |
146 |
| - original_rvs_to_values = model.rvs_to_values |
147 |
| - original_rvs_to_transforms = model.rvs_to_transforms |
148 |
| - |
149 |
| - model.rvs_to_values = { |
150 |
| - rv: rv.clone() if rv not in model.observed_RVs else value |
151 |
| - for rv, value in model.rvs_to_values.items() |
152 |
| - } |
153 |
| - model.rvs_to_transforms = {rv: None for rv in model.basic_RVs} |
154 |
| - |
155 |
| - elemwise_logdens_fn = model.compile_fn( |
156 |
| - inputs=model.value_vars, |
157 |
| - outs=model.logp(vars=vars, sum=False), |
158 |
| - on_unused_input="ignore", |
159 |
| - ) |
160 |
| - finally: |
161 |
| - model.rvs_to_values = original_rvs_to_values |
162 |
| - model.rvs_to_transforms = original_rvs_to_transforms |
| 149 | + elemwise_logdens_fn = umodel.compile_fn( |
| 150 | + inputs=umodel.value_vars, |
| 151 | + outs=umodel.logp(vars=vars, sum=False), |
| 152 | + on_unused_input="ignore", |
| 153 | + ) |
163 | 154 |
|
164 |
| - coords, dims = coords_and_dims_for_inferencedata(model) |
| 155 | + coords, dims = coords_and_dims_for_inferencedata(umodel) |
165 | 156 |
|
166 | 157 | logdens_dataset = apply_function_over_dataset(
|
167 | 158 | elemwise_logdens_fn,
|
168 |
| - posterior[[rv.name for rv in model.free_RVs]], |
| 159 | + posterior[[rv.name for rv in umodel.free_RVs]], |
169 | 160 | output_var_names=var_names,
|
170 | 161 | sample_dims=sample_dims,
|
171 | 162 | dims=dims,
|
|
0 commit comments