Skip to content

Commit 5360939

Browse files
committed
Do not create new Ops in TransformValuesRewrite
1 parent 33a6af1 commit 5360939

File tree

3 files changed

+180
-132
lines changed

3 files changed

+180
-132
lines changed

pymc/logprob/transform_value.py

Lines changed: 131 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from copy import copy
1615
from typing import Dict, Optional, Sequence, Union
1716

1817
import numpy as np
@@ -31,8 +30,11 @@
3130
from pymc.logprob.transforms import RVTransform
3231

3332

34-
class TransformedVariable(Op):
35-
"""A no-op that identifies a transform and its un-transformed input."""
33+
class TransformedValue(Op):
34+
"""A no-op that pairs the original value with its transformed version.
35+
36+
This is introduced by the `TransformValuesRewrite`
37+
"""
3638

3739
view_map = {0: [0]}
3840

@@ -52,7 +54,94 @@ def grad(self, args, g_outs):
5254
return g_outs[0], DisconnectedType()()
5355

5456

55-
transformed_variable = TransformedVariable()
57+
transformed_value = TransformedValue()
58+
59+
60+
class TransformedValueRV(Op):
61+
"""A no-op that identifies RVs whose values were transformed.
62+
63+
This is introduced by the `TransformValuesRewrite`
64+
"""
65+
66+
view_map = {0: [0]}
67+
68+
__props__ = ("transforms",)
69+
70+
def __init__(self, transforms: Sequence[RVTransform]):
71+
self.transforms = tuple(transforms)
72+
super().__init__()
73+
74+
def make_node(self, *rv_outputs):
75+
return Apply(self, rv_outputs, [out.type() for out in rv_outputs])
76+
77+
def perform(self, node, inputs, outputs):
78+
raise NotImplementedError(
79+
"`TransformedRV` `Op`s should be removed from graphs used for computation."
80+
)
81+
82+
def connection_pattern(self, node):
83+
return [[True] for _ in node.outputs]
84+
85+
def infer_shape(self, fgraph, node, input_shapes):
86+
return input_shapes
87+
88+
89+
MeasurableVariable.register(TransformedValueRV)
90+
91+
92+
@_logprob.register(TransformedValueRV)
93+
def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs):
94+
"""Compute the log-probability graph for a `TransformedRV`.
95+
96+
This is introduced by the `TransformValuesRewrite`
97+
"""
98+
rv_op = rv_outs[0].owner.op
99+
rv_inputs = rv_outs[0].owner.inputs
100+
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
101+
102+
if not isinstance(logprobs, Sequence):
103+
logprobs = [logprobs]
104+
105+
# Handle jacobian
106+
assert len(values) == len(logprobs) == len(op.transforms)
107+
logprobs_jac = []
108+
for value, transform, logp in zip(values, op.transforms, logprobs):
109+
if transform is None:
110+
logprobs_jac.append(logp)
111+
continue
112+
113+
assert isinstance(value.owner.op, TransformedValue)
114+
original_forward_value = value.owner.inputs[1]
115+
log_jac_det = transform.log_jac_det(original_forward_value, *rv_inputs).copy()
116+
# The jacobian determinant has less dims than the logp
117+
# when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
118+
# In this case we have to reduce the last logp dimensions, as they are no longer independent
119+
if log_jac_det.ndim < logp.ndim:
120+
diff_ndims = logp.ndim - log_jac_det.ndim
121+
logp = logp.sum(axis=np.arange(-diff_ndims, 0))
122+
# This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the
123+
# multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html
124+
elif log_jac_det.ndim > logp.ndim:
125+
raise NotImplementedError(
126+
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
127+
)
128+
else:
129+
# Check there is no broadcasting between logp and jacobian
130+
if logp.type.broadcastable != log_jac_det.type.broadcastable:
131+
raise ValueError(
132+
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
133+
"There is a bug in the implementation of either one."
134+
)
135+
136+
if use_jacobian:
137+
if value.name:
138+
log_jac_det.name = f"{value.name}_jacobian"
139+
logprobs_jac.append(logp + log_jac_det)
140+
else:
141+
# We still want to use the reduced logp, even though the jacobian isn't included
142+
logprobs_jac.append(logp)
143+
144+
return logprobs_jac
56145

57146

58147
@node_rewriter(tracks=None)
@@ -94,10 +183,10 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]
94183
if all(transform is None for transform in transforms):
95184
return None
96185

97-
new_op = _create_transformed_rv_op(node.op, transforms)
98-
# Create a new `Apply` node and outputs
99-
trans_node = node.clone()
100-
trans_node.op = new_op
186+
transformed_rv_op = TransformedValueRV(transforms)
187+
# Clone outputs so that rewrite doesn't reference original variables circularly
188+
cloned_outputs = node.clone().outputs
189+
transformed_rv_node = transformed_rv_op.make_node(*cloned_outputs)
101190

102191
# We now assume that the old value variable represents the *transformed space*.
103192
# This means that we need to replace all instance of the old value variable
@@ -108,16 +197,19 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]
108197
if transform is None:
109198
continue
110199

111-
new_value_var = transformed_variable(
112-
transform.backward(value_var, *trans_node.inputs), value_var
200+
new_value_var = transformed_value(
201+
transform.backward(value_var, *node.inputs),
202+
value_var,
113203
)
114204

115205
if value_var.name and getattr(transform, "name", None):
116206
new_value_var.name = f"{value_var.name}_{transform.name}"
117207

118-
rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
208+
rv_map_feature.update_rv_maps(
209+
rv_var, new_value_var, transformed_rv_node.outputs[rv_var_out_idx]
210+
)
119211

120-
return trans_node.outputs
212+
return transformed_rv_node.outputs
121213

122214

123215
@node_rewriter(tracks=[Scan])
@@ -158,9 +250,10 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A
158250
if all(transform is None for transform in transforms):
159251
return None
160252

161-
new_op = _create_transformed_rv_op(node.op, transforms)
162-
trans_node = node.clone()
163-
trans_node.op = new_op
253+
transformed_rv_op = TransformedValueRV(transforms)
254+
# Clone outputs so that rewrite doesn't reference original variables circularly
255+
cloned_outputs = node.clone().outputs
256+
transformed_rv_node = transformed_rv_op.make_node(*cloned_outputs)
164257

165258
# We now assume that the old value variable represents the *transformed space*.
166259
# This means that we need to replace all instance of the old value variable
@@ -173,7 +266,9 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A
173266

174267
# We access the original value variable and apply the transform to that
175268
original_value_var = rv_map_feature.original_values[value_var]
176-
trans_original_value_var = transform.backward(original_value_var, *trans_node.inputs)
269+
trans_original_value_var = transform.backward(
270+
original_value_var, *transformed_rv_node.inputs
271+
)
177272

178273
# We then replace the reference to the original value variable in the scan value
179274
# variable by the back-transform projection computed above
@@ -188,18 +283,20 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A
188283
(value_var.owner.inputs[0],),
189284
replace={original_value_var: trans_original_value_var},
190285
)
191-
trans_value_var = value_var.owner.clone_with_new_inputs(
286+
transformed_value_var = value_var.owner.clone_with_new_inputs(
192287
inputs=[trans_original_value_var] + value_var.owner.inputs[1:]
193288
).default_output()
194289

195-
new_value_var = transformed_variable(trans_value_var, original_value_var)
290+
new_value_var = transformed_value(transformed_value_var, original_value_var)
196291

197292
if value_var.name and getattr(transform, "name", None):
198293
new_value_var.name = f"{value_var.name}_{transform.name}"
199294

200-
rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
295+
rv_map_feature.update_rv_maps(
296+
rv_var, new_value_var, transformed_rv_node.outputs[rv_var_out_idx]
297+
)
201298

202-
return trans_node.outputs
299+
return transformed_rv_node.outputs
203300

204301

205302
class TransformValuesMapping(Feature):
@@ -247,121 +344,27 @@ def apply(self, fgraph: FunctionGraph):
247344
self.scan_transform_rewrite.rewrite(fgraph)
248345

249346

250-
def _create_transformed_rv_op(
251-
rv_op: Op,
252-
transforms: Union[RVTransform, Sequence[Union[None, RVTransform]]],
253-
*,
254-
cls_dict_extra: Optional[Dict] = None,
255-
) -> Op:
256-
"""Create a new transformed variable instance given a base `RandomVariable` `Op`.
257-
258-
This will essentially copy the `type` of the given `Op` instance, create a
259-
copy of said `Op` instance and change it's `type` to the new one.
260-
261-
In the end, we have an `Op` instance that will map to a `RVTransform` while
262-
also behaving exactly as it did before.
263-
264-
Parameters
265-
----------
266-
rv_op
267-
The `RandomVariable` for which we want to construct a `TransformedRV`.
268-
transform
269-
The `RVTransform` for `rv_op`.
270-
cls_dict_extra
271-
Additional class members to add to the constructed `TransformedRV`.
272-
273-
"""
274-
275-
if not isinstance(transforms, Sequence):
276-
transforms = (transforms,)
277-
278-
trans_names = [
279-
getattr(transform, "name", "transformed") if transform is not None else "None"
280-
for transform in transforms
281-
]
282-
rv_op_type = type(rv_op)
283-
rv_type_name = rv_op_type.__name__
284-
cls_dict = rv_op_type.__dict__.copy()
285-
rv_name = cls_dict.get("name", "")
286-
if rv_name:
287-
cls_dict["name"] = f"{rv_name}_{'_'.join(trans_names)}"
288-
cls_dict["transforms"] = transforms
289-
290-
if cls_dict_extra is not None:
291-
cls_dict.update(cls_dict_extra)
292-
293-
new_op_type = type(f"Transformed{rv_type_name}", (rv_op_type,), cls_dict)
294-
295-
MeasurableVariable.register(new_op_type)
296-
297-
@_logprob.register(new_op_type)
298-
def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
299-
"""Compute the log-likelihood graph for a `TransformedRV`.
300-
301-
We assume that the value variable was back-transformed to be on the natural
302-
support of the respective random variable.
303-
"""
304-
logprobs = _logprob(rv_op, values, *inputs, **kwargs)
305-
306-
if not isinstance(logprobs, Sequence):
307-
logprobs = [logprobs]
308-
309-
# Handle jacobian
310-
assert len(values) == len(logprobs) == len(op.transforms)
311-
logprobs_jac = []
312-
for value, transform, logp in zip(values, op.transforms, logprobs):
313-
if transform is None:
314-
logprobs_jac.append(logp)
315-
continue
316-
317-
assert isinstance(value.owner.op, TransformedVariable)
318-
original_forward_value = value.owner.inputs[1]
319-
log_jac_det = transform.log_jac_det(original_forward_value, *inputs).copy()
320-
# The jacobian determinant has less dims than the logp
321-
# when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
322-
# In this case we have to reduce the last logp dimensions, as they are no longer independent
323-
if log_jac_det.ndim < logp.ndim:
324-
diff_ndims = logp.ndim - log_jac_det.ndim
325-
logp = logp.sum(axis=np.arange(-diff_ndims, 0))
326-
# This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the
327-
# multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html
328-
elif log_jac_det.ndim > logp.ndim:
329-
raise NotImplementedError(
330-
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
331-
)
332-
else:
333-
# Check there is no broadcasting between logp and jacobian
334-
if logp.type.broadcastable != log_jac_det.type.broadcastable:
335-
raise ValueError(
336-
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
337-
"There is a bug in the implementation of either one."
338-
)
339-
340-
if use_jacobian:
341-
if value.name:
342-
log_jac_det.name = f"{value.name}_jacobian"
343-
logprobs_jac.append(logp + log_jac_det)
344-
else:
345-
# We still want to use the reduced logp, even though the jacobian isn't included
346-
logprobs_jac.append(logp)
347+
@node_rewriter([TransformedValue])
348+
def remove_TransformedValues(fgraph, node):
349+
return [node.inputs[0]]
347350

348-
return logprobs_jac
349351

350-
new_op = copy(rv_op)
351-
new_op.__class__ = new_op_type
352+
@node_rewriter([TransformedValueRV])
353+
def remove_TransformedValueRVs(fgraph, node):
354+
return node.inputs
352355

353-
return new_op
354356

355-
356-
@node_rewriter([TransformedVariable])
357-
def remove_TransformedVariables(fgraph, node):
358-
if isinstance(node.op, TransformedVariable):
359-
return [node.inputs[0]]
357+
cleanup_ir_rewrites_db.register(
358+
"remove_TransformedValues",
359+
remove_TransformedValues,
360+
"cleanup",
361+
"transform",
362+
)
360363

361364

362365
cleanup_ir_rewrites_db.register(
363-
"remove_TransformedVariables",
364-
remove_TransformedVariables,
366+
"remove_TransformedValueRVs",
367+
remove_TransformedValueRVs,
365368
"cleanup",
366369
"transform",
367370
)

tests/logprob/test_rewriting.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from pymc.distributions.transforms import logodds
5555
from pymc.logprob.basic import conditional_logp
5656
from pymc.logprob.rewriting import cleanup_ir, local_lift_DiracDelta
57-
from pymc.logprob.transform_value import TransformedVariable, TransformValuesRewrite
57+
from pymc.logprob.transform_value import TransformedValue, TransformValuesRewrite
5858
from pymc.logprob.utils import DiracDelta, dirac_delta
5959

6060

@@ -105,9 +105,7 @@ def test_local_remove_TransformedVariable():
105105
tr = TransformValuesRewrite({p_vv: logodds})
106106
[p_logp] = conditional_logp({p_rv: p_vv}, extra_rewrites=tr).values()
107107

108-
assert not any(
109-
isinstance(v.owner.op, TransformedVariable) for v in ancestors([p_logp]) if v.owner
110-
)
108+
assert not any(isinstance(v.owner.op, TransformedValue) for v in ancestors([p_logp]) if v.owner)
111109

112110

113111
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)