5
5
from pytensor .graph import ancestors
6
6
from pytensor .graph .op import compute_test_value
7
7
from pytensor .graph .rewriting .basic import copy_stack_trace , in2out , node_rewriter
8
- from pytensor .scalar import integer_types
9
- from pytensor .tensor import NoneConst
8
+ from pytensor .tensor import NoneConst , TensorVariable
10
9
from pytensor .tensor .basic import constant
11
10
from pytensor .tensor .elemwise import DimShuffle
12
11
from pytensor .tensor .extra_ops import broadcast_to
13
12
from pytensor .tensor .random .op import RandomVariable
14
13
from pytensor .tensor .random .utils import broadcast_params
15
- from pytensor .tensor .shape import Shape , Shape_i , shape_padleft
14
+ from pytensor .tensor .shape import Shape , Shape_i
16
15
from pytensor .tensor .subtensor import (
17
16
AdvancedSubtensor ,
18
17
AdvancedSubtensor1 ,
19
18
Subtensor ,
20
- as_index_variable ,
21
19
get_idx_list ,
22
20
)
21
+ from pytensor .tensor .type import integer_dtypes
23
22
from pytensor .tensor .type_other import NoneTypeT , SliceType
24
23
25
24
@@ -127,22 +126,23 @@ def local_dimshuffle_rv_lift(fgraph, node):
127
126
128
127
ds_op = node .op
129
128
130
- if not isinstance (ds_op , DimShuffle ):
129
+ # Dimshuffle which drop dimensions not supported yet
130
+ if ds_op .drop :
131
131
return False
132
132
133
- base_rv = node .inputs [0 ]
134
- rv_node = base_rv .owner
133
+ rv_node = node .inputs [0 ].owner
135
134
136
135
if not (rv_node and isinstance (rv_node .op , RandomVariable )):
137
136
return False
138
137
139
- # Dimshuffle which drop dimensions not supported yet
140
- if ds_op .drop :
141
- return False
142
-
143
138
rv_op = rv_node .op
144
139
rng , size , * dist_params = rv_node .inputs
145
- rv = rv_node .default_output ()
140
+ next_rng , rv = rv_node .outputs
141
+
142
+ # If no one else is using the underlying `RandomVariable`, then we can
143
+ # do this; otherwise, the graph would be internally inconsistent.
144
+ if is_rv_used_in_graph (rv , node , fgraph ):
145
+ return False
146
146
147
147
# Check that Dimshuffle does not affect support dims
148
148
supp_dims = set (range (rv .ndim - rv_op .ndim_supp , rv .ndim ))
@@ -153,31 +153,24 @@ def local_dimshuffle_rv_lift(fgraph, node):
153
153
154
154
# If no one else is using the underlying RandomVariable, then we can
155
155
# do this; otherwise, the graph would be internally inconsistent.
156
- if is_rv_used_in_graph (base_rv , node , fgraph ):
156
+ if is_rv_used_in_graph (rv , node , fgraph ):
157
157
return False
158
158
159
159
batched_dims = rv .ndim - rv_op .ndim_supp
160
160
batched_dims_ds_order = tuple (o for o in ds_op .new_order if o not in supp_dims )
161
161
162
162
if isinstance (size .type , NoneTypeT ):
163
- # Make size explicit
164
- shape = tuple (broadcast_params (dist_params , rv_op .ndims_params )[0 ].shape )
165
- size = shape [:batched_dims ]
166
-
167
- # Update the size to reflect the DimShuffled dimensions
168
- new_size = [
169
- constant (1 , dtype = "int64" ) if o == "x" else size [o ]
170
- for o in batched_dims_ds_order
171
- ]
163
+ new_size = size
164
+ else :
165
+ # Update the size to reflect the DimShuffled dimensions
166
+ new_size = [
167
+ constant (1 , dtype = "int64" ) if o == "x" else size [o ]
168
+ for o in batched_dims_ds_order
169
+ ]
172
170
173
171
# Updates the params to reflect the Dimshuffled dimensions
174
172
new_dist_params = []
175
173
for param , param_ndim_supp in zip (dist_params , rv_op .ndims_params ):
176
- # Add broadcastable dimensions to the parameters that would have been expanded by the size
177
- padleft = batched_dims - (param .ndim - param_ndim_supp )
178
- if padleft > 0 :
179
- param = shape_padleft (param , padleft )
180
-
181
174
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
182
175
param_new_order = batched_dims_ds_order + tuple (
183
176
range (batched_dims , batched_dims + param_ndim_supp )
@@ -189,10 +182,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
189
182
if config .compute_test_value != "off" :
190
183
compute_test_value (new_node )
191
184
192
- out = new_node .outputs [ 1 ]
193
- if base_rv .name :
194
- out .name = f"{ base_rv .name } _lifted"
195
- return [out ]
185
+ new_rv = new_node .default_output ()
186
+ if rv .name :
187
+ new_rv .name = f"{ rv .name } _lifted"
188
+ return [new_rv ]
196
189
197
190
198
191
@node_rewriter ([Subtensor , AdvancedSubtensor1 , AdvancedSubtensor ])
@@ -206,47 +199,38 @@ def local_subtensor_rv_lift(fgraph, node):
206
199
``mvnormal(mu, cov, size=(2,))[0, 0]``.
207
200
"""
208
201
209
- def is_nd_advanced_idx (idx , dtype ):
202
+ def is_nd_advanced_idx (idx , dtype ) -> bool :
203
+ if not isinstance (idx , TensorVariable ):
204
+ return False
210
205
if isinstance (dtype , str ):
211
206
return (getattr (idx .type , "dtype" , None ) == dtype ) and (idx .type .ndim >= 1 )
212
207
else :
213
208
return (getattr (idx .type , "dtype" , None ) in dtype ) and (idx .type .ndim >= 1 )
214
209
215
210
subtensor_op = node .op
216
211
217
- old_subtensor = node .outputs [0 ]
218
- rv = node .inputs [0 ]
219
- rv_node = rv .owner
212
+ [indexed_rv ] = node .outputs
213
+ rv_node = node .inputs [0 ].owner
220
214
221
215
if not (rv_node and isinstance (rv_node .op , RandomVariable )):
222
216
return False
223
217
224
- shape_feature = getattr (fgraph , "shape_feature" , None )
225
- if not shape_feature :
226
- return None
227
-
228
- # Use shape_feature to facilitate inferring final shape.
229
- # Check that neither the RV nor the old Subtensor are in the shape graph.
230
- output_shape = fgraph .shape_feature .shape_of .get (old_subtensor , None )
231
- if output_shape is None or {old_subtensor , rv } & set (ancestors (output_shape )):
232
- return None
233
-
234
218
rv_op = rv_node .op
235
219
rng , size , * dist_params = rv_node .inputs
220
+ rv = rv_node .default_output ()
221
+
222
+ # If no one else is using the underlying `RandomVariable`, then we can
223
+ # do this; otherwise, the graph would be internally inconsistent.
224
+ if is_rv_used_in_graph (rv , node , fgraph ):
225
+ return False
236
226
237
227
# Parse indices
238
- idx_list = getattr (subtensor_op , "idx_list" , None )
239
- if idx_list :
240
- idx_vars = get_idx_list (node .inputs , idx_list )
241
- else :
242
- idx_vars = node .inputs [1 :]
243
- indices = tuple (as_index_variable (idx ) for idx in idx_vars )
228
+ indices = get_idx_list (node .inputs , getattr (subtensor_op , "idx_list" , None ))
244
229
245
230
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
246
231
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
247
232
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
248
233
# and make use of the dimshuffle lift rewrite
249
- integer_dtypes = {type .dtype for type in integer_types }
250
234
if any (
251
235
is_nd_advanced_idx (idx , integer_dtypes ) or NoneConst .equals (idx )
252
236
for idx in indices
@@ -277,34 +261,35 @@ def is_nd_advanced_idx(idx, dtype):
277
261
n_discarded_idxs = len (supp_indices )
278
262
indices = indices [:- n_discarded_idxs ]
279
263
280
- # If no one else is using the underlying `RandomVariable`, then we can
281
- # do this; otherwise, the graph would be internally inconsistent.
282
- if is_rv_used_in_graph (rv , node , fgraph ):
283
- return False
284
-
285
264
# Update the size to reflect the indexed dimensions
286
- new_size = output_shape [: len (output_shape ) - rv_op .ndim_supp ]
265
+ if isinstance (size .type , NoneTypeT ):
266
+ new_size = size
267
+ else :
268
+ shape_feature = getattr (fgraph , "shape_feature" , None )
269
+ if not shape_feature :
270
+ return None
271
+
272
+ # Use shape_feature to facilitate inferring final shape.
273
+ # Check that neither the RV nor the old Subtensor are in the shape graph.
274
+ output_shape = fgraph .shape_feature .shape_of .get (indexed_rv , None )
275
+ if output_shape is None or {indexed_rv , rv } & set (ancestors (output_shape )):
276
+ return None
277
+
278
+ new_size = output_shape [: len (output_shape ) - rv_op .ndim_supp ]
287
279
288
280
# Propagate indexing to the parameters' batch dims.
289
281
# We try to avoid broadcasting the parameters together (and with size), by only indexing
290
282
# non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
291
283
# should still correctly broadcast any degenerate parameter dims.
292
284
new_dist_params = []
293
285
for param , param_ndim_supp in zip (dist_params , rv_op .ndims_params ):
294
- # We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
295
- batch_param_dims_missing = batch_ndims - (param .ndim - param_ndim_supp )
296
- batch_param = (
297
- shape_padleft (param , batch_param_dims_missing )
298
- if batch_param_dims_missing
299
- else param
300
- )
301
- # Check which dims are actually broadcasted
302
- bcast_batch_param_dims = tuple (
286
+ # Check which dims are broadcasted by either size or other parameters
287
+ bcast_param_dims = tuple (
303
288
dim
304
- for dim , (param_dim , output_dim ) in enumerate (
305
- zip (batch_param .type .shape , rv .type .shape )
289
+ for dim , (param_dim_bcast , output_dim_bcast ) in enumerate (
290
+ zip (param .type .broadcastable , rv .type .broadcastable )
306
291
)
307
- if ( param_dim == 1 ) and ( output_dim != 1 )
292
+ if param_dim_bcast and not output_dim_bcast
308
293
)
309
294
batch_indices = []
310
295
curr_dim = 0
@@ -315,23 +300,23 @@ def is_nd_advanced_idx(idx, dtype):
315
300
# If not, we use that directly, instead of the more inefficient `nonzero` form
316
301
bool_dims = range (curr_dim , curr_dim + idx .type .ndim )
317
302
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
318
- if set (bool_dims ) & set (bcast_batch_param_dims ):
303
+ if set (bool_dims ) & set (bcast_param_dims ):
319
304
int_indices = list (idx .nonzero ())
320
305
# Indexing by 0 drops the degenerate dims
321
306
for bool_dim in bool_dims :
322
- if bool_dim in bcast_batch_param_dims :
307
+ if bool_dim in bcast_param_dims :
323
308
int_indices [bool_dim - curr_dim ] = 0
324
309
batch_indices .extend (int_indices )
325
- # No overlap, use index as is
310
+ # No overlap, use boolean index as is
326
311
else :
327
312
batch_indices .append (idx )
328
313
curr_dim += len (bool_dims )
329
314
# Basic-indexing (slice or integer)
330
315
else :
331
316
# Broadcasted dim
332
- if curr_dim in bcast_batch_param_dims :
317
+ if curr_dim in bcast_param_dims :
333
318
# Slice indexing, keep degenerate dim by none-slicing
334
- if isinstance (idx .type , SliceType ):
319
+ if isinstance (idx , slice ) or isinstance ( idx .type , SliceType ):
335
320
batch_indices .append (slice (None ))
336
321
# Integer indexing, drop degenerate dim by 0-indexing
337
322
else :
@@ -342,7 +327,7 @@ def is_nd_advanced_idx(idx, dtype):
342
327
batch_indices .append (idx )
343
328
curr_dim += 1
344
329
345
- new_dist_params .append (batch_param [tuple (batch_indices )])
330
+ new_dist_params .append (param [tuple (batch_indices )])
346
331
347
332
# Create new RV
348
333
new_node = rv_op .make_node (rng , new_size , * new_dist_params )
0 commit comments