|
| 1 | +# Copyright 2023 The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from copy import copy |
| 16 | +from typing import Dict, Optional, Sequence, Union |
| 17 | + |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +from pytensor.gradient import DisconnectedType |
| 21 | +from pytensor.graph import Apply, Op |
| 22 | +from pytensor.graph.features import AlreadyThere, Feature |
| 23 | +from pytensor.graph.fg import FunctionGraph |
| 24 | +from pytensor.graph.replace import clone_replace |
| 25 | +from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter |
| 26 | +from pytensor.scan.op import Scan |
| 27 | +from pytensor.tensor.variable import TensorVariable |
| 28 | + |
| 29 | +from pymc.logprob.abstract import MeasurableVariable, _logprob |
| 30 | +from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db |
| 31 | +from pymc.logprob.transforms import RVTransform |
| 32 | + |
| 33 | + |
| 34 | +class TransformedVariable(Op): |
| 35 | + """A no-op that identifies a transform and its un-transformed input.""" |
| 36 | + |
| 37 | + view_map = {0: [0]} |
| 38 | + |
| 39 | + def make_node(self, tran_value: TensorVariable, value: TensorVariable): |
| 40 | + return Apply(self, [tran_value, value], [tran_value.type()]) |
| 41 | + |
| 42 | + def perform(self, node, inputs, outputs): |
| 43 | + raise NotImplementedError("These `Op`s should be removed from graphs used for computation.") |
| 44 | + |
| 45 | + def connection_pattern(self, node): |
| 46 | + return [[True], [False]] |
| 47 | + |
| 48 | + def infer_shape(self, fgraph, node, input_shapes): |
| 49 | + return [input_shapes[0]] |
| 50 | + |
| 51 | + def grad(self, args, g_outs): |
| 52 | + return g_outs[0], DisconnectedType()() |
| 53 | + |
| 54 | + |
| 55 | +transformed_variable = TransformedVariable() |
| 56 | + |
| 57 | + |
| 58 | +@node_rewriter(tracks=None) |
| 59 | +def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]]: |
| 60 | + """Apply transforms to value variables. |
| 61 | +
|
| 62 | + It is assumed that the input value variables correspond to forward |
| 63 | + transformations, usually chosen in such a way that the values are |
| 64 | + unconstrained on the real line. |
| 65 | +
|
| 66 | + For example, if ``Y = halfnormal(...)``, we assume the respective value |
| 67 | + variable is specified on the log scale and back-transform it to obtain |
| 68 | + ``Y`` on the natural scale. |
| 69 | + """ |
| 70 | + |
| 71 | + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) |
| 72 | + values_to_transforms: Optional[TransformValuesMapping] = getattr( |
| 73 | + fgraph, "values_to_transforms", None |
| 74 | + ) |
| 75 | + |
| 76 | + if rv_map_feature is None or values_to_transforms is None: |
| 77 | + return None # pragma: no cover |
| 78 | + |
| 79 | + rv_vars = [] |
| 80 | + value_vars = [] |
| 81 | + |
| 82 | + for out in node.outputs: |
| 83 | + value = rv_map_feature.rv_values.get(out, None) |
| 84 | + if value is None: |
| 85 | + continue |
| 86 | + rv_vars.append(out) |
| 87 | + value_vars.append(value) |
| 88 | + |
| 89 | + if not value_vars: |
| 90 | + return None |
| 91 | + |
| 92 | + transforms = [values_to_transforms.get(value_var, None) for value_var in value_vars] |
| 93 | + |
| 94 | + if all(transform is None for transform in transforms): |
| 95 | + return None |
| 96 | + |
| 97 | + new_op = _create_transformed_rv_op(node.op, transforms) |
| 98 | + # Create a new `Apply` node and outputs |
| 99 | + trans_node = node.clone() |
| 100 | + trans_node.op = new_op |
| 101 | + |
| 102 | + # We now assume that the old value variable represents the *transformed space*. |
| 103 | + # This means that we need to replace all instance of the old value variable |
| 104 | + # with "inversely/un-" transformed versions of itself. |
| 105 | + for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms): |
| 106 | + rv_var_out_idx = node.outputs.index(rv_var) |
| 107 | + |
| 108 | + if transform is None: |
| 109 | + continue |
| 110 | + |
| 111 | + new_value_var = transformed_variable( |
| 112 | + transform.backward(value_var, *trans_node.inputs), value_var |
| 113 | + ) |
| 114 | + |
| 115 | + if value_var.name and getattr(transform, "name", None): |
| 116 | + new_value_var.name = f"{value_var.name}_{transform.name}" |
| 117 | + |
| 118 | + rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx]) |
| 119 | + |
| 120 | + return trans_node.outputs |
| 121 | + |
| 122 | + |
| 123 | +@node_rewriter(tracks=[Scan]) |
| 124 | +def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]]: |
| 125 | + """Apply transforms to Scan value variables. |
| 126 | +
|
| 127 | + This specialized rewrite is needed because Scan replaces the original value variables |
| 128 | + by a more complex graph. We want to apply the transform to the original value variable |
| 129 | + in this subgraph, leaving the rest intact |
| 130 | + """ |
| 131 | + |
| 132 | + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) |
| 133 | + values_to_transforms: Optional[TransformValuesMapping] = getattr( |
| 134 | + fgraph, "values_to_transforms", None |
| 135 | + ) |
| 136 | + |
| 137 | + if rv_map_feature is None or values_to_transforms is None: |
| 138 | + return None # pragma: no cover |
| 139 | + |
| 140 | + rv_vars = [] |
| 141 | + value_vars = [] |
| 142 | + |
| 143 | + for out in node.outputs: |
| 144 | + value = rv_map_feature.rv_values.get(out, None) |
| 145 | + if value is None: |
| 146 | + continue |
| 147 | + rv_vars.append(out) |
| 148 | + value_vars.append(value) |
| 149 | + |
| 150 | + if not value_vars: |
| 151 | + return None |
| 152 | + |
| 153 | + transforms = [ |
| 154 | + values_to_transforms.get(rv_map_feature.original_values[value_var], None) |
| 155 | + for value_var in value_vars |
| 156 | + ] |
| 157 | + |
| 158 | + if all(transform is None for transform in transforms): |
| 159 | + return None |
| 160 | + |
| 161 | + new_op = _create_transformed_rv_op(node.op, transforms) |
| 162 | + trans_node = node.clone() |
| 163 | + trans_node.op = new_op |
| 164 | + |
| 165 | + # We now assume that the old value variable represents the *transformed space*. |
| 166 | + # This means that we need to replace all instance of the old value variable |
| 167 | + # with "inversely/un-" transformed versions of itself. |
| 168 | + for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms): |
| 169 | + rv_var_out_idx = node.outputs.index(rv_var) |
| 170 | + |
| 171 | + if transform is None: |
| 172 | + continue |
| 173 | + |
| 174 | + # We access the original value variable and apply the transform to that |
| 175 | + original_value_var = rv_map_feature.original_values[value_var] |
| 176 | + trans_original_value_var = transform.backward(original_value_var, *trans_node.inputs) |
| 177 | + |
| 178 | + # We then replace the reference to the original value variable in the scan value |
| 179 | + # variable by the back-transform projection computed above |
| 180 | + |
| 181 | + # The first input corresponds to the original value variable. We are careful to |
| 182 | + # only clone_replace that part of the graph, as we don't want to break the |
| 183 | + # mappings between other rvs that are likely to be present in the rest of the |
| 184 | + # scan value variable graph |
| 185 | + # TODO: Is it true that the original value only appears in the first input |
| 186 | + # and that no other RV can appear there? |
| 187 | + (trans_original_value_var,) = clone_replace( |
| 188 | + (value_var.owner.inputs[0],), |
| 189 | + replace={original_value_var: trans_original_value_var}, |
| 190 | + ) |
| 191 | + trans_value_var = value_var.owner.clone_with_new_inputs( |
| 192 | + inputs=[trans_original_value_var] + value_var.owner.inputs[1:] |
| 193 | + ).default_output() |
| 194 | + |
| 195 | + new_value_var = transformed_variable(trans_value_var, original_value_var) |
| 196 | + |
| 197 | + if value_var.name and getattr(transform, "name", None): |
| 198 | + new_value_var.name = f"{value_var.name}_{transform.name}" |
| 199 | + |
| 200 | + rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx]) |
| 201 | + |
| 202 | + return trans_node.outputs |
| 203 | + |
| 204 | + |
| 205 | +class TransformValuesMapping(Feature): |
| 206 | + r"""A `Feature` that maintains a map between value variables and their transforms.""" |
| 207 | + |
| 208 | + def __init__(self, values_to_transforms): |
| 209 | + self.values_to_transforms = values_to_transforms.copy() |
| 210 | + |
| 211 | + def on_attach(self, fgraph): |
| 212 | + if hasattr(fgraph, "values_to_transforms"): |
| 213 | + raise AlreadyThere() |
| 214 | + |
| 215 | + fgraph.values_to_transforms = self.values_to_transforms |
| 216 | + |
| 217 | + |
| 218 | +class TransformValuesRewrite(GraphRewriter): |
| 219 | + r"""Transforms value variables according to a map.""" |
| 220 | + |
| 221 | + transform_rewrite = in2out(transform_values, ignore_newtrees=True) |
| 222 | + scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True) |
| 223 | + |
| 224 | + def __init__( |
| 225 | + self, |
| 226 | + values_to_transforms: Dict[TensorVariable, Union[RVTransform, None]], |
| 227 | + ): |
| 228 | + """ |
| 229 | + Parameters |
| 230 | + ---------- |
| 231 | + values_to_transforms |
| 232 | + Mapping between value variables and their transformations. Each |
| 233 | + value variable can be assigned one of `RVTransform`, or ``None``. |
| 234 | + If a transform is not specified for a specific value variable it will |
| 235 | + not be transformed. |
| 236 | +
|
| 237 | + """ |
| 238 | + |
| 239 | + self.values_to_transforms = values_to_transforms |
| 240 | + |
| 241 | + def add_requirements(self, fgraph): |
| 242 | + values_transforms_feature = TransformValuesMapping(self.values_to_transforms) |
| 243 | + fgraph.attach_feature(values_transforms_feature) |
| 244 | + |
| 245 | + def apply(self, fgraph: FunctionGraph): |
| 246 | + self.transform_rewrite.rewrite(fgraph) |
| 247 | + self.scan_transform_rewrite.rewrite(fgraph) |
| 248 | + |
| 249 | + |
| 250 | +def _create_transformed_rv_op( |
| 251 | + rv_op: Op, |
| 252 | + transforms: Union[RVTransform, Sequence[Union[None, RVTransform]]], |
| 253 | + *, |
| 254 | + cls_dict_extra: Optional[Dict] = None, |
| 255 | +) -> Op: |
| 256 | + """Create a new transformed variable instance given a base `RandomVariable` `Op`. |
| 257 | +
|
| 258 | + This will essentially copy the `type` of the given `Op` instance, create a |
| 259 | + copy of said `Op` instance and change it's `type` to the new one. |
| 260 | +
|
| 261 | + In the end, we have an `Op` instance that will map to a `RVTransform` while |
| 262 | + also behaving exactly as it did before. |
| 263 | +
|
| 264 | + Parameters |
| 265 | + ---------- |
| 266 | + rv_op |
| 267 | + The `RandomVariable` for which we want to construct a `TransformedRV`. |
| 268 | + transform |
| 269 | + The `RVTransform` for `rv_op`. |
| 270 | + cls_dict_extra |
| 271 | + Additional class members to add to the constructed `TransformedRV`. |
| 272 | +
|
| 273 | + """ |
| 274 | + |
| 275 | + if not isinstance(transforms, Sequence): |
| 276 | + transforms = (transforms,) |
| 277 | + |
| 278 | + trans_names = [ |
| 279 | + getattr(transform, "name", "transformed") if transform is not None else "None" |
| 280 | + for transform in transforms |
| 281 | + ] |
| 282 | + rv_op_type = type(rv_op) |
| 283 | + rv_type_name = rv_op_type.__name__ |
| 284 | + cls_dict = rv_op_type.__dict__.copy() |
| 285 | + rv_name = cls_dict.get("name", "") |
| 286 | + if rv_name: |
| 287 | + cls_dict["name"] = f"{rv_name}_{'_'.join(trans_names)}" |
| 288 | + cls_dict["transforms"] = transforms |
| 289 | + |
| 290 | + if cls_dict_extra is not None: |
| 291 | + cls_dict.update(cls_dict_extra) |
| 292 | + |
| 293 | + new_op_type = type(f"Transformed{rv_type_name}", (rv_op_type,), cls_dict) |
| 294 | + |
| 295 | + MeasurableVariable.register(new_op_type) |
| 296 | + |
| 297 | + @_logprob.register(new_op_type) |
| 298 | + def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs): |
| 299 | + """Compute the log-likelihood graph for a `TransformedRV`. |
| 300 | +
|
| 301 | + We assume that the value variable was back-transformed to be on the natural |
| 302 | + support of the respective random variable. |
| 303 | + """ |
| 304 | + logprobs = _logprob(rv_op, values, *inputs, **kwargs) |
| 305 | + |
| 306 | + if not isinstance(logprobs, Sequence): |
| 307 | + logprobs = [logprobs] |
| 308 | + |
| 309 | + # Handle jacobian |
| 310 | + assert len(values) == len(logprobs) == len(op.transforms) |
| 311 | + logprobs_jac = [] |
| 312 | + for value, transform, logp in zip(values, op.transforms, logprobs): |
| 313 | + if transform is None: |
| 314 | + logprobs_jac.append(logp) |
| 315 | + continue |
| 316 | + |
| 317 | + assert isinstance(value.owner.op, TransformedVariable) |
| 318 | + original_forward_value = value.owner.inputs[1] |
| 319 | + log_jac_det = transform.log_jac_det(original_forward_value, *inputs).copy() |
| 320 | + # The jacobian determinant has less dims than the logp |
| 321 | + # when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions. |
| 322 | + # In this case we have to reduce the last logp dimensions, as they are no longer independent |
| 323 | + if log_jac_det.ndim < logp.ndim: |
| 324 | + diff_ndims = logp.ndim - log_jac_det.ndim |
| 325 | + logp = logp.sum(axis=np.arange(-diff_ndims, 0)) |
| 326 | + # This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the |
| 327 | + # multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html |
| 328 | + elif log_jac_det.ndim > logp.ndim: |
| 329 | + raise NotImplementedError( |
| 330 | + f"Univariate transform {transform} cannot be applied to multivariate {rv_op}" |
| 331 | + ) |
| 332 | + else: |
| 333 | + # Check there is no broadcasting between logp and jacobian |
| 334 | + if logp.type.broadcastable != log_jac_det.type.broadcastable: |
| 335 | + raise ValueError( |
| 336 | + f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " |
| 337 | + "There is a bug in the implementation of either one." |
| 338 | + ) |
| 339 | + |
| 340 | + if use_jacobian: |
| 341 | + if value.name: |
| 342 | + log_jac_det.name = f"{value.name}_jacobian" |
| 343 | + logprobs_jac.append(logp + log_jac_det) |
| 344 | + else: |
| 345 | + # We still want to use the reduced logp, even though the jacobian isn't included |
| 346 | + logprobs_jac.append(logp) |
| 347 | + |
| 348 | + return logprobs_jac |
| 349 | + |
| 350 | + new_op = copy(rv_op) |
| 351 | + new_op.__class__ = new_op_type |
| 352 | + |
| 353 | + return new_op |
| 354 | + |
| 355 | + |
| 356 | +@node_rewriter([TransformedVariable]) |
| 357 | +def remove_TransformedVariables(fgraph, node): |
| 358 | + if isinstance(node.op, TransformedVariable): |
| 359 | + return [node.inputs[0]] |
| 360 | + |
| 361 | + |
| 362 | +cleanup_ir_rewrites_db.register( |
| 363 | + "remove_TransformedVariables", |
| 364 | + remove_TransformedVariables, |
| 365 | + "cleanup", |
| 366 | + "transform", |
| 367 | +) |
0 commit comments