Skip to content

Commit 33a6af1

Browse files
committed
Move value transform logic to its own logprob file
1 parent ec4407d commit 33a6af1

File tree

9 files changed

+948
-895
lines changed

9 files changed

+948
-895
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ jobs:
117117
tests/logprob/test_rewriting.py
118118
tests/logprob/test_scan.py
119119
tests/logprob/test_tensor.py
120+
tests/logprob/test_transform_value.py
120121
tests/logprob/test_transforms.py
121122
tests/logprob/test_utils.py
122123

pymc/logprob/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@
6363
_logprob_helper,
6464
)
6565
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
66-
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
66+
from pymc.logprob.transform_value import TransformValuesRewrite
67+
from pymc.logprob.transforms import RVTransform
6768
from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars
6869

6970
TensorLike: TypeAlias = Union[Variable, float, np.ndarray]

pymc/logprob/transform_value.py

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from copy import copy
16+
from typing import Dict, Optional, Sequence, Union
17+
18+
import numpy as np
19+
20+
from pytensor.gradient import DisconnectedType
21+
from pytensor.graph import Apply, Op
22+
from pytensor.graph.features import AlreadyThere, Feature
23+
from pytensor.graph.fg import FunctionGraph
24+
from pytensor.graph.replace import clone_replace
25+
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
26+
from pytensor.scan.op import Scan
27+
from pytensor.tensor.variable import TensorVariable
28+
29+
from pymc.logprob.abstract import MeasurableVariable, _logprob
30+
from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db
31+
from pymc.logprob.transforms import RVTransform
32+
33+
34+
class TransformedVariable(Op):
35+
"""A no-op that identifies a transform and its un-transformed input."""
36+
37+
view_map = {0: [0]}
38+
39+
def make_node(self, tran_value: TensorVariable, value: TensorVariable):
40+
return Apply(self, [tran_value, value], [tran_value.type()])
41+
42+
def perform(self, node, inputs, outputs):
43+
raise NotImplementedError("These `Op`s should be removed from graphs used for computation.")
44+
45+
def connection_pattern(self, node):
46+
return [[True], [False]]
47+
48+
def infer_shape(self, fgraph, node, input_shapes):
49+
return [input_shapes[0]]
50+
51+
def grad(self, args, g_outs):
52+
return g_outs[0], DisconnectedType()()
53+
54+
55+
transformed_variable = TransformedVariable()
56+
57+
58+
@node_rewriter(tracks=None)
59+
def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]]:
60+
"""Apply transforms to value variables.
61+
62+
It is assumed that the input value variables correspond to forward
63+
transformations, usually chosen in such a way that the values are
64+
unconstrained on the real line.
65+
66+
For example, if ``Y = halfnormal(...)``, we assume the respective value
67+
variable is specified on the log scale and back-transform it to obtain
68+
``Y`` on the natural scale.
69+
"""
70+
71+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
72+
values_to_transforms: Optional[TransformValuesMapping] = getattr(
73+
fgraph, "values_to_transforms", None
74+
)
75+
76+
if rv_map_feature is None or values_to_transforms is None:
77+
return None # pragma: no cover
78+
79+
rv_vars = []
80+
value_vars = []
81+
82+
for out in node.outputs:
83+
value = rv_map_feature.rv_values.get(out, None)
84+
if value is None:
85+
continue
86+
rv_vars.append(out)
87+
value_vars.append(value)
88+
89+
if not value_vars:
90+
return None
91+
92+
transforms = [values_to_transforms.get(value_var, None) for value_var in value_vars]
93+
94+
if all(transform is None for transform in transforms):
95+
return None
96+
97+
new_op = _create_transformed_rv_op(node.op, transforms)
98+
# Create a new `Apply` node and outputs
99+
trans_node = node.clone()
100+
trans_node.op = new_op
101+
102+
# We now assume that the old value variable represents the *transformed space*.
103+
# This means that we need to replace all instance of the old value variable
104+
# with "inversely/un-" transformed versions of itself.
105+
for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms):
106+
rv_var_out_idx = node.outputs.index(rv_var)
107+
108+
if transform is None:
109+
continue
110+
111+
new_value_var = transformed_variable(
112+
transform.backward(value_var, *trans_node.inputs), value_var
113+
)
114+
115+
if value_var.name and getattr(transform, "name", None):
116+
new_value_var.name = f"{value_var.name}_{transform.name}"
117+
118+
rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
119+
120+
return trans_node.outputs
121+
122+
123+
@node_rewriter(tracks=[Scan])
124+
def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]]:
125+
"""Apply transforms to Scan value variables.
126+
127+
This specialized rewrite is needed because Scan replaces the original value variables
128+
by a more complex graph. We want to apply the transform to the original value variable
129+
in this subgraph, leaving the rest intact
130+
"""
131+
132+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
133+
values_to_transforms: Optional[TransformValuesMapping] = getattr(
134+
fgraph, "values_to_transforms", None
135+
)
136+
137+
if rv_map_feature is None or values_to_transforms is None:
138+
return None # pragma: no cover
139+
140+
rv_vars = []
141+
value_vars = []
142+
143+
for out in node.outputs:
144+
value = rv_map_feature.rv_values.get(out, None)
145+
if value is None:
146+
continue
147+
rv_vars.append(out)
148+
value_vars.append(value)
149+
150+
if not value_vars:
151+
return None
152+
153+
transforms = [
154+
values_to_transforms.get(rv_map_feature.original_values[value_var], None)
155+
for value_var in value_vars
156+
]
157+
158+
if all(transform is None for transform in transforms):
159+
return None
160+
161+
new_op = _create_transformed_rv_op(node.op, transforms)
162+
trans_node = node.clone()
163+
trans_node.op = new_op
164+
165+
# We now assume that the old value variable represents the *transformed space*.
166+
# This means that we need to replace all instance of the old value variable
167+
# with "inversely/un-" transformed versions of itself.
168+
for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms):
169+
rv_var_out_idx = node.outputs.index(rv_var)
170+
171+
if transform is None:
172+
continue
173+
174+
# We access the original value variable and apply the transform to that
175+
original_value_var = rv_map_feature.original_values[value_var]
176+
trans_original_value_var = transform.backward(original_value_var, *trans_node.inputs)
177+
178+
# We then replace the reference to the original value variable in the scan value
179+
# variable by the back-transform projection computed above
180+
181+
# The first input corresponds to the original value variable. We are careful to
182+
# only clone_replace that part of the graph, as we don't want to break the
183+
# mappings between other rvs that are likely to be present in the rest of the
184+
# scan value variable graph
185+
# TODO: Is it true that the original value only appears in the first input
186+
# and that no other RV can appear there?
187+
(trans_original_value_var,) = clone_replace(
188+
(value_var.owner.inputs[0],),
189+
replace={original_value_var: trans_original_value_var},
190+
)
191+
trans_value_var = value_var.owner.clone_with_new_inputs(
192+
inputs=[trans_original_value_var] + value_var.owner.inputs[1:]
193+
).default_output()
194+
195+
new_value_var = transformed_variable(trans_value_var, original_value_var)
196+
197+
if value_var.name and getattr(transform, "name", None):
198+
new_value_var.name = f"{value_var.name}_{transform.name}"
199+
200+
rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
201+
202+
return trans_node.outputs
203+
204+
205+
class TransformValuesMapping(Feature):
206+
r"""A `Feature` that maintains a map between value variables and their transforms."""
207+
208+
def __init__(self, values_to_transforms):
209+
self.values_to_transforms = values_to_transforms.copy()
210+
211+
def on_attach(self, fgraph):
212+
if hasattr(fgraph, "values_to_transforms"):
213+
raise AlreadyThere()
214+
215+
fgraph.values_to_transforms = self.values_to_transforms
216+
217+
218+
class TransformValuesRewrite(GraphRewriter):
219+
r"""Transforms value variables according to a map."""
220+
221+
transform_rewrite = in2out(transform_values, ignore_newtrees=True)
222+
scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True)
223+
224+
def __init__(
225+
self,
226+
values_to_transforms: Dict[TensorVariable, Union[RVTransform, None]],
227+
):
228+
"""
229+
Parameters
230+
----------
231+
values_to_transforms
232+
Mapping between value variables and their transformations. Each
233+
value variable can be assigned one of `RVTransform`, or ``None``.
234+
If a transform is not specified for a specific value variable it will
235+
not be transformed.
236+
237+
"""
238+
239+
self.values_to_transforms = values_to_transforms
240+
241+
def add_requirements(self, fgraph):
242+
values_transforms_feature = TransformValuesMapping(self.values_to_transforms)
243+
fgraph.attach_feature(values_transforms_feature)
244+
245+
def apply(self, fgraph: FunctionGraph):
246+
self.transform_rewrite.rewrite(fgraph)
247+
self.scan_transform_rewrite.rewrite(fgraph)
248+
249+
250+
def _create_transformed_rv_op(
251+
rv_op: Op,
252+
transforms: Union[RVTransform, Sequence[Union[None, RVTransform]]],
253+
*,
254+
cls_dict_extra: Optional[Dict] = None,
255+
) -> Op:
256+
"""Create a new transformed variable instance given a base `RandomVariable` `Op`.
257+
258+
This will essentially copy the `type` of the given `Op` instance, create a
259+
copy of said `Op` instance and change it's `type` to the new one.
260+
261+
In the end, we have an `Op` instance that will map to a `RVTransform` while
262+
also behaving exactly as it did before.
263+
264+
Parameters
265+
----------
266+
rv_op
267+
The `RandomVariable` for which we want to construct a `TransformedRV`.
268+
transform
269+
The `RVTransform` for `rv_op`.
270+
cls_dict_extra
271+
Additional class members to add to the constructed `TransformedRV`.
272+
273+
"""
274+
275+
if not isinstance(transforms, Sequence):
276+
transforms = (transforms,)
277+
278+
trans_names = [
279+
getattr(transform, "name", "transformed") if transform is not None else "None"
280+
for transform in transforms
281+
]
282+
rv_op_type = type(rv_op)
283+
rv_type_name = rv_op_type.__name__
284+
cls_dict = rv_op_type.__dict__.copy()
285+
rv_name = cls_dict.get("name", "")
286+
if rv_name:
287+
cls_dict["name"] = f"{rv_name}_{'_'.join(trans_names)}"
288+
cls_dict["transforms"] = transforms
289+
290+
if cls_dict_extra is not None:
291+
cls_dict.update(cls_dict_extra)
292+
293+
new_op_type = type(f"Transformed{rv_type_name}", (rv_op_type,), cls_dict)
294+
295+
MeasurableVariable.register(new_op_type)
296+
297+
@_logprob.register(new_op_type)
298+
def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
299+
"""Compute the log-likelihood graph for a `TransformedRV`.
300+
301+
We assume that the value variable was back-transformed to be on the natural
302+
support of the respective random variable.
303+
"""
304+
logprobs = _logprob(rv_op, values, *inputs, **kwargs)
305+
306+
if not isinstance(logprobs, Sequence):
307+
logprobs = [logprobs]
308+
309+
# Handle jacobian
310+
assert len(values) == len(logprobs) == len(op.transforms)
311+
logprobs_jac = []
312+
for value, transform, logp in zip(values, op.transforms, logprobs):
313+
if transform is None:
314+
logprobs_jac.append(logp)
315+
continue
316+
317+
assert isinstance(value.owner.op, TransformedVariable)
318+
original_forward_value = value.owner.inputs[1]
319+
log_jac_det = transform.log_jac_det(original_forward_value, *inputs).copy()
320+
# The jacobian determinant has less dims than the logp
321+
# when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
322+
# In this case we have to reduce the last logp dimensions, as they are no longer independent
323+
if log_jac_det.ndim < logp.ndim:
324+
diff_ndims = logp.ndim - log_jac_det.ndim
325+
logp = logp.sum(axis=np.arange(-diff_ndims, 0))
326+
# This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the
327+
# multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html
328+
elif log_jac_det.ndim > logp.ndim:
329+
raise NotImplementedError(
330+
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
331+
)
332+
else:
333+
# Check there is no broadcasting between logp and jacobian
334+
if logp.type.broadcastable != log_jac_det.type.broadcastable:
335+
raise ValueError(
336+
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
337+
"There is a bug in the implementation of either one."
338+
)
339+
340+
if use_jacobian:
341+
if value.name:
342+
log_jac_det.name = f"{value.name}_jacobian"
343+
logprobs_jac.append(logp + log_jac_det)
344+
else:
345+
# We still want to use the reduced logp, even though the jacobian isn't included
346+
logprobs_jac.append(logp)
347+
348+
return logprobs_jac
349+
350+
new_op = copy(rv_op)
351+
new_op.__class__ = new_op_type
352+
353+
return new_op
354+
355+
356+
@node_rewriter([TransformedVariable])
357+
def remove_TransformedVariables(fgraph, node):
358+
if isinstance(node.op, TransformedVariable):
359+
return [node.inputs[0]]
360+
361+
362+
cleanup_ir_rewrites_db.register(
363+
"remove_TransformedVariables",
364+
remove_TransformedVariables,
365+
"cleanup",
366+
"transform",
367+
)

0 commit comments

Comments
 (0)