12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from copy import copy
16
15
from typing import Dict , Optional , Sequence , Union
17
16
18
17
import numpy as np
31
30
from pymc .logprob .transforms import RVTransform
32
31
33
32
34
- class TransformedVariable (Op ):
35
- """A no-op that identifies a transform and its un-transformed input."""
33
+ class TransformedValue (Op ):
34
+ """A no-op that pairs the original value with its transformed version.
35
+
36
+ This is introduced by the `TransformValuesRewrite`
37
+ """
36
38
37
39
view_map = {0 : [0 ]}
38
40
@@ -52,7 +54,94 @@ def grad(self, args, g_outs):
52
54
return g_outs [0 ], DisconnectedType ()()
53
55
54
56
55
- transformed_variable = TransformedVariable ()
57
+ transformed_value = TransformedValue ()
58
+
59
+
60
+ class TransformedValueRV (Op ):
61
+ """A no-op that identifies RVs whose values were transformed.
62
+
63
+ This is introduced by the `TransformValuesRewrite`
64
+ """
65
+
66
+ view_map = {0 : [0 ]}
67
+
68
+ __props__ = ("transforms" ,)
69
+
70
+ def __init__ (self , transforms : Sequence [RVTransform ]):
71
+ self .transforms = tuple (transforms )
72
+ super ().__init__ ()
73
+
74
+ def make_node (self , * rv_outputs ):
75
+ return Apply (self , rv_outputs , [out .type () for out in rv_outputs ])
76
+
77
+ def perform (self , node , inputs , outputs ):
78
+ raise NotImplementedError (
79
+ "`TransformedRV` `Op`s should be removed from graphs used for computation."
80
+ )
81
+
82
+ def connection_pattern (self , node ):
83
+ return [[True ] for _ in node .outputs ]
84
+
85
+ def infer_shape (self , fgraph , node , input_shapes ):
86
+ return input_shapes
87
+
88
+
89
+ MeasurableVariable .register (TransformedValueRV )
90
+
91
+
92
+ @_logprob .register (TransformedValueRV )
93
+ def transformed_value_logprob (op , values , * rv_outs , use_jacobian = True , ** kwargs ):
94
+ """Compute the log-probability graph for a `TransformedRV`.
95
+
96
+ This is introduced by the `TransformValuesRewrite`
97
+ """
98
+ rv_op = rv_outs [0 ].owner .op
99
+ rv_inputs = rv_outs [0 ].owner .inputs
100
+ logprobs = _logprob (rv_op , values , * rv_inputs , ** kwargs )
101
+
102
+ if not isinstance (logprobs , Sequence ):
103
+ logprobs = [logprobs ]
104
+
105
+ # Handle jacobian
106
+ assert len (values ) == len (logprobs ) == len (op .transforms )
107
+ logprobs_jac = []
108
+ for value , transform , logp in zip (values , op .transforms , logprobs ):
109
+ if transform is None :
110
+ logprobs_jac .append (logp )
111
+ continue
112
+
113
+ assert isinstance (value .owner .op , TransformedValue )
114
+ original_forward_value = value .owner .inputs [1 ]
115
+ log_jac_det = transform .log_jac_det (original_forward_value , * rv_inputs ).copy ()
116
+ # The jacobian determinant has less dims than the logp
117
+ # when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
118
+ # In this case we have to reduce the last logp dimensions, as they are no longer independent
119
+ if log_jac_det .ndim < logp .ndim :
120
+ diff_ndims = logp .ndim - log_jac_det .ndim
121
+ logp = logp .sum (axis = np .arange (- diff_ndims , 0 ))
122
+ # This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the
123
+ # multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html
124
+ elif log_jac_det .ndim > logp .ndim :
125
+ raise NotImplementedError (
126
+ f"Univariate transform { transform } cannot be applied to multivariate { rv_op } "
127
+ )
128
+ else :
129
+ # Check there is no broadcasting between logp and jacobian
130
+ if logp .type .broadcastable != log_jac_det .type .broadcastable :
131
+ raise ValueError (
132
+ f"The logp of { rv_op } and log_jac_det of { transform } are not allowed to broadcast together. "
133
+ "There is a bug in the implementation of either one."
134
+ )
135
+
136
+ if use_jacobian :
137
+ if value .name :
138
+ log_jac_det .name = f"{ value .name } _jacobian"
139
+ logprobs_jac .append (logp + log_jac_det )
140
+ else :
141
+ # We still want to use the reduced logp, even though the jacobian isn't included
142
+ logprobs_jac .append (logp )
143
+
144
+ return logprobs_jac
56
145
57
146
58
147
@node_rewriter (tracks = None )
@@ -94,10 +183,10 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]
94
183
if all (transform is None for transform in transforms ):
95
184
return None
96
185
97
- new_op = _create_transformed_rv_op ( node . op , transforms )
98
- # Create a new `Apply` node and outputs
99
- trans_node = node .clone ()
100
- trans_node . op = new_op
186
+ transformed_rv_op = TransformedValueRV ( transforms )
187
+ # Clone outputs so that rewrite doesn't reference original variables circularly
188
+ cloned_outputs = node .clone (). outputs
189
+ transformed_rv_node = transformed_rv_op . make_node ( * cloned_outputs )
101
190
102
191
# We now assume that the old value variable represents the *transformed space*.
103
192
# This means that we need to replace all instance of the old value variable
@@ -108,16 +197,19 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]
108
197
if transform is None :
109
198
continue
110
199
111
- new_value_var = transformed_variable (
112
- transform .backward (value_var , * trans_node .inputs ), value_var
200
+ new_value_var = transformed_value (
201
+ transform .backward (value_var , * node .inputs ),
202
+ value_var ,
113
203
)
114
204
115
205
if value_var .name and getattr (transform , "name" , None ):
116
206
new_value_var .name = f"{ value_var .name } _{ transform .name } "
117
207
118
- rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
208
+ rv_map_feature .update_rv_maps (
209
+ rv_var , new_value_var , transformed_rv_node .outputs [rv_var_out_idx ]
210
+ )
119
211
120
- return trans_node .outputs
212
+ return transformed_rv_node .outputs
121
213
122
214
123
215
@node_rewriter (tracks = [Scan ])
@@ -158,9 +250,10 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A
158
250
if all (transform is None for transform in transforms ):
159
251
return None
160
252
161
- new_op = _create_transformed_rv_op (node .op , transforms )
162
- trans_node = node .clone ()
163
- trans_node .op = new_op
253
+ transformed_rv_op = TransformedValueRV (transforms )
254
+ # Clone outputs so that rewrite doesn't reference original variables circularly
255
+ cloned_outputs = node .clone ().outputs
256
+ transformed_rv_node = transformed_rv_op .make_node (* cloned_outputs )
164
257
165
258
# We now assume that the old value variable represents the *transformed space*.
166
259
# This means that we need to replace all instance of the old value variable
@@ -173,7 +266,9 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A
173
266
174
267
# We access the original value variable and apply the transform to that
175
268
original_value_var = rv_map_feature .original_values [value_var ]
176
- trans_original_value_var = transform .backward (original_value_var , * trans_node .inputs )
269
+ trans_original_value_var = transform .backward (
270
+ original_value_var , * transformed_rv_node .inputs
271
+ )
177
272
178
273
# We then replace the reference to the original value variable in the scan value
179
274
# variable by the back-transform projection computed above
@@ -188,18 +283,20 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A
188
283
(value_var .owner .inputs [0 ],),
189
284
replace = {original_value_var : trans_original_value_var },
190
285
)
191
- trans_value_var = value_var .owner .clone_with_new_inputs (
286
+ transformed_value_var = value_var .owner .clone_with_new_inputs (
192
287
inputs = [trans_original_value_var ] + value_var .owner .inputs [1 :]
193
288
).default_output ()
194
289
195
- new_value_var = transformed_variable ( trans_value_var , original_value_var )
290
+ new_value_var = transformed_value ( transformed_value_var , original_value_var )
196
291
197
292
if value_var .name and getattr (transform , "name" , None ):
198
293
new_value_var .name = f"{ value_var .name } _{ transform .name } "
199
294
200
- rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
295
+ rv_map_feature .update_rv_maps (
296
+ rv_var , new_value_var , transformed_rv_node .outputs [rv_var_out_idx ]
297
+ )
201
298
202
- return trans_node .outputs
299
+ return transformed_rv_node .outputs
203
300
204
301
205
302
class TransformValuesMapping (Feature ):
@@ -247,121 +344,27 @@ def apply(self, fgraph: FunctionGraph):
247
344
self .scan_transform_rewrite .rewrite (fgraph )
248
345
249
346
250
- def _create_transformed_rv_op (
251
- rv_op : Op ,
252
- transforms : Union [RVTransform , Sequence [Union [None , RVTransform ]]],
253
- * ,
254
- cls_dict_extra : Optional [Dict ] = None ,
255
- ) -> Op :
256
- """Create a new transformed variable instance given a base `RandomVariable` `Op`.
257
-
258
- This will essentially copy the `type` of the given `Op` instance, create a
259
- copy of said `Op` instance and change it's `type` to the new one.
260
-
261
- In the end, we have an `Op` instance that will map to a `RVTransform` while
262
- also behaving exactly as it did before.
263
-
264
- Parameters
265
- ----------
266
- rv_op
267
- The `RandomVariable` for which we want to construct a `TransformedRV`.
268
- transform
269
- The `RVTransform` for `rv_op`.
270
- cls_dict_extra
271
- Additional class members to add to the constructed `TransformedRV`.
272
-
273
- """
274
-
275
- if not isinstance (transforms , Sequence ):
276
- transforms = (transforms ,)
277
-
278
- trans_names = [
279
- getattr (transform , "name" , "transformed" ) if transform is not None else "None"
280
- for transform in transforms
281
- ]
282
- rv_op_type = type (rv_op )
283
- rv_type_name = rv_op_type .__name__
284
- cls_dict = rv_op_type .__dict__ .copy ()
285
- rv_name = cls_dict .get ("name" , "" )
286
- if rv_name :
287
- cls_dict ["name" ] = f"{ rv_name } _{ '_' .join (trans_names )} "
288
- cls_dict ["transforms" ] = transforms
289
-
290
- if cls_dict_extra is not None :
291
- cls_dict .update (cls_dict_extra )
292
-
293
- new_op_type = type (f"Transformed{ rv_type_name } " , (rv_op_type ,), cls_dict )
294
-
295
- MeasurableVariable .register (new_op_type )
296
-
297
- @_logprob .register (new_op_type )
298
- def transformed_logprob (op , values , * inputs , use_jacobian = True , ** kwargs ):
299
- """Compute the log-likelihood graph for a `TransformedRV`.
300
-
301
- We assume that the value variable was back-transformed to be on the natural
302
- support of the respective random variable.
303
- """
304
- logprobs = _logprob (rv_op , values , * inputs , ** kwargs )
305
-
306
- if not isinstance (logprobs , Sequence ):
307
- logprobs = [logprobs ]
308
-
309
- # Handle jacobian
310
- assert len (values ) == len (logprobs ) == len (op .transforms )
311
- logprobs_jac = []
312
- for value , transform , logp in zip (values , op .transforms , logprobs ):
313
- if transform is None :
314
- logprobs_jac .append (logp )
315
- continue
316
-
317
- assert isinstance (value .owner .op , TransformedVariable )
318
- original_forward_value = value .owner .inputs [1 ]
319
- log_jac_det = transform .log_jac_det (original_forward_value , * inputs ).copy ()
320
- # The jacobian determinant has less dims than the logp
321
- # when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
322
- # In this case we have to reduce the last logp dimensions, as they are no longer independent
323
- if log_jac_det .ndim < logp .ndim :
324
- diff_ndims = logp .ndim - log_jac_det .ndim
325
- logp = logp .sum (axis = np .arange (- diff_ndims , 0 ))
326
- # This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the
327
- # multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html
328
- elif log_jac_det .ndim > logp .ndim :
329
- raise NotImplementedError (
330
- f"Univariate transform { transform } cannot be applied to multivariate { rv_op } "
331
- )
332
- else :
333
- # Check there is no broadcasting between logp and jacobian
334
- if logp .type .broadcastable != log_jac_det .type .broadcastable :
335
- raise ValueError (
336
- f"The logp of { rv_op } and log_jac_det of { transform } are not allowed to broadcast together. "
337
- "There is a bug in the implementation of either one."
338
- )
339
-
340
- if use_jacobian :
341
- if value .name :
342
- log_jac_det .name = f"{ value .name } _jacobian"
343
- logprobs_jac .append (logp + log_jac_det )
344
- else :
345
- # We still want to use the reduced logp, even though the jacobian isn't included
346
- logprobs_jac .append (logp )
347
+ @node_rewriter ([TransformedValue ])
348
+ def remove_TransformedValues (fgraph , node ):
349
+ return [node .inputs [0 ]]
347
350
348
- return logprobs_jac
349
351
350
- new_op = copy (rv_op )
351
- new_op .__class__ = new_op_type
352
+ @node_rewriter ([TransformedValueRV ])
353
+ def remove_TransformedValueRVs (fgraph , node ):
354
+ return node .inputs
352
355
353
- return new_op
354
356
355
-
356
- @node_rewriter ([TransformedVariable ])
357
- def remove_TransformedVariables (fgraph , node ):
358
- if isinstance (node .op , TransformedVariable ):
359
- return [node .inputs [0 ]]
357
+ cleanup_ir_rewrites_db .register (
358
+ "remove_TransformedValues" ,
359
+ remove_TransformedValues ,
360
+ "cleanup" ,
361
+ "transform" ,
362
+ )
360
363
361
364
362
365
cleanup_ir_rewrites_db .register (
363
- "remove_TransformedVariables " ,
364
- remove_TransformedVariables ,
366
+ "remove_TransformedValueRVs " ,
367
+ remove_TransformedValueRVs ,
365
368
"cleanup" ,
366
369
"transform" ,
367
370
)
0 commit comments