Skip to content

Commit 2d81cca

Browse files
committed
Simplify RV rewrites
1 parent 94e9ef0 commit 2d81cca

File tree

1 file changed

+61
-76
lines changed
  • pytensor/tensor/random/rewriting

1 file changed

+61
-76
lines changed

pytensor/tensor/random/rewriting/basic.py

+61-76
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,20 @@
55
from pytensor.graph import ancestors
66
from pytensor.graph.op import compute_test_value
77
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
8-
from pytensor.scalar import integer_types
9-
from pytensor.tensor import NoneConst
8+
from pytensor.tensor import NoneConst, TensorVariable
109
from pytensor.tensor.basic import constant
1110
from pytensor.tensor.elemwise import DimShuffle
1211
from pytensor.tensor.extra_ops import broadcast_to
1312
from pytensor.tensor.random.op import RandomVariable
1413
from pytensor.tensor.random.utils import broadcast_params
15-
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
14+
from pytensor.tensor.shape import Shape, Shape_i
1615
from pytensor.tensor.subtensor import (
1716
AdvancedSubtensor,
1817
AdvancedSubtensor1,
1918
Subtensor,
20-
as_index_variable,
2119
get_idx_list,
2220
)
21+
from pytensor.tensor.type import integer_dtypes
2322
from pytensor.tensor.type_other import NoneTypeT, SliceType
2423

2524

@@ -127,22 +126,23 @@ def local_dimshuffle_rv_lift(fgraph, node):
127126

128127
ds_op = node.op
129128

130-
if not isinstance(ds_op, DimShuffle):
129+
# Dimshuffle which drop dimensions not supported yet
130+
if ds_op.drop:
131131
return False
132132

133-
base_rv = node.inputs[0]
134-
rv_node = base_rv.owner
133+
rv_node = node.inputs[0].owner
135134

136135
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
137136
return False
138137

139-
# Dimshuffle which drop dimensions not supported yet
140-
if ds_op.drop:
141-
return False
142-
143138
rv_op = rv_node.op
144139
rng, size, *dist_params = rv_node.inputs
145-
rv = rv_node.default_output()
140+
next_rng, rv = rv_node.outputs
141+
142+
# If no one else is using the underlying `RandomVariable`, then we can
143+
# do this; otherwise, the graph would be internally inconsistent.
144+
if is_rv_used_in_graph(rv, node, fgraph):
145+
return False
146146

147147
# Check that Dimshuffle does not affect support dims
148148
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
@@ -153,31 +153,24 @@ def local_dimshuffle_rv_lift(fgraph, node):
153153

154154
# If no one else is using the underlying RandomVariable, then we can
155155
# do this; otherwise, the graph would be internally inconsistent.
156-
if is_rv_used_in_graph(base_rv, node, fgraph):
156+
if is_rv_used_in_graph(rv, node, fgraph):
157157
return False
158158

159159
batched_dims = rv.ndim - rv_op.ndim_supp
160160
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
161161

162162
if isinstance(size.type, NoneTypeT):
163-
# Make size explicit
164-
shape = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
165-
size = shape[:batched_dims]
166-
167-
# Update the size to reflect the DimShuffled dimensions
168-
new_size = [
169-
constant(1, dtype="int64") if o == "x" else size[o]
170-
for o in batched_dims_ds_order
171-
]
163+
new_size = size
164+
else:
165+
# Update the size to reflect the DimShuffled dimensions
166+
new_size = [
167+
constant(1, dtype="int64") if o == "x" else size[o]
168+
for o in batched_dims_ds_order
169+
]
172170

173171
# Updates the params to reflect the Dimshuffled dimensions
174172
new_dist_params = []
175173
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
176-
# Add broadcastable dimensions to the parameters that would have been expanded by the size
177-
padleft = batched_dims - (param.ndim - param_ndim_supp)
178-
if padleft > 0:
179-
param = shape_padleft(param, padleft)
180-
181174
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
182175
param_new_order = batched_dims_ds_order + tuple(
183176
range(batched_dims, batched_dims + param_ndim_supp)
@@ -189,10 +182,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
189182
if config.compute_test_value != "off":
190183
compute_test_value(new_node)
191184

192-
out = new_node.outputs[1]
193-
if base_rv.name:
194-
out.name = f"{base_rv.name}_lifted"
195-
return [out]
185+
new_rv = new_node.default_output()
186+
if rv.name:
187+
new_rv.name = f"{rv.name}_lifted"
188+
return [new_rv]
196189

197190

198191
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
@@ -206,47 +199,38 @@ def local_subtensor_rv_lift(fgraph, node):
206199
``mvnormal(mu, cov, size=(2,))[0, 0]``.
207200
"""
208201

209-
def is_nd_advanced_idx(idx, dtype):
202+
def is_nd_advanced_idx(idx, dtype) -> bool:
203+
if not isinstance(idx, TensorVariable):
204+
return False
210205
if isinstance(dtype, str):
211206
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
212207
else:
213208
return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1)
214209

215210
subtensor_op = node.op
216211

217-
old_subtensor = node.outputs[0]
218-
rv = node.inputs[0]
219-
rv_node = rv.owner
212+
[indexed_rv] = node.outputs
213+
rv_node = node.inputs[0].owner
220214

221215
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
222216
return False
223217

224-
shape_feature = getattr(fgraph, "shape_feature", None)
225-
if not shape_feature:
226-
return None
227-
228-
# Use shape_feature to facilitate inferring final shape.
229-
# Check that neither the RV nor the old Subtensor are in the shape graph.
230-
output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None)
231-
if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)):
232-
return None
233-
234218
rv_op = rv_node.op
235219
rng, size, *dist_params = rv_node.inputs
220+
rv = rv_node.default_output()
221+
222+
# If no one else is using the underlying `RandomVariable`, then we can
223+
# do this; otherwise, the graph would be internally inconsistent.
224+
if is_rv_used_in_graph(rv, node, fgraph):
225+
return False
236226

237227
# Parse indices
238-
idx_list = getattr(subtensor_op, "idx_list", None)
239-
if idx_list:
240-
idx_vars = get_idx_list(node.inputs, idx_list)
241-
else:
242-
idx_vars = node.inputs[1:]
243-
indices = tuple(as_index_variable(idx) for idx in idx_vars)
228+
indices = get_idx_list(node.inputs, getattr(subtensor_op, "idx_list", None))
244229

245230
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
246231
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
247232
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
248233
# and make use of the dimshuffle lift rewrite
249-
integer_dtypes = {type.dtype for type in integer_types}
250234
if any(
251235
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
252236
for idx in indices
@@ -277,34 +261,35 @@ def is_nd_advanced_idx(idx, dtype):
277261
n_discarded_idxs = len(supp_indices)
278262
indices = indices[:-n_discarded_idxs]
279263

280-
# If no one else is using the underlying `RandomVariable`, then we can
281-
# do this; otherwise, the graph would be internally inconsistent.
282-
if is_rv_used_in_graph(rv, node, fgraph):
283-
return False
284-
285264
# Update the size to reflect the indexed dimensions
286-
new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]
265+
if isinstance(size.type, NoneTypeT):
266+
new_size = size
267+
else:
268+
shape_feature = getattr(fgraph, "shape_feature", None)
269+
if not shape_feature:
270+
return None
271+
272+
# Use shape_feature to facilitate inferring final shape.
273+
# Check that neither the RV nor the old Subtensor are in the shape graph.
274+
output_shape = fgraph.shape_feature.shape_of.get(indexed_rv, None)
275+
if output_shape is None or {indexed_rv, rv} & set(ancestors(output_shape)):
276+
return None
277+
278+
new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]
287279

288280
# Propagate indexing to the parameters' batch dims.
289281
# We try to avoid broadcasting the parameters together (and with size), by only indexing
290282
# non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
291283
# should still correctly broadcast any degenerate parameter dims.
292284
new_dist_params = []
293285
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
294-
# We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
295-
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
296-
batch_param = (
297-
shape_padleft(param, batch_param_dims_missing)
298-
if batch_param_dims_missing
299-
else param
300-
)
301-
# Check which dims are actually broadcasted
302-
bcast_batch_param_dims = tuple(
286+
# Check which dims are broadcasted by either size or other parameters
287+
bcast_param_dims = tuple(
303288
dim
304-
for dim, (param_dim, output_dim) in enumerate(
305-
zip(batch_param.type.shape, rv.type.shape)
289+
for dim, (param_dim_bcast, output_dim_bcast) in enumerate(
290+
zip(param.type.broadcastable, rv.type.broadcastable)
306291
)
307-
if (param_dim == 1) and (output_dim != 1)
292+
if param_dim_bcast and not output_dim_bcast
308293
)
309294
batch_indices = []
310295
curr_dim = 0
@@ -315,23 +300,23 @@ def is_nd_advanced_idx(idx, dtype):
315300
# If not, we use that directly, instead of the more inefficient `nonzero` form
316301
bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
317302
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
318-
if set(bool_dims) & set(bcast_batch_param_dims):
303+
if set(bool_dims) & set(bcast_param_dims):
319304
int_indices = list(idx.nonzero())
320305
# Indexing by 0 drops the degenerate dims
321306
for bool_dim in bool_dims:
322-
if bool_dim in bcast_batch_param_dims:
307+
if bool_dim in bcast_param_dims:
323308
int_indices[bool_dim - curr_dim] = 0
324309
batch_indices.extend(int_indices)
325-
# No overlap, use index as is
310+
# No overlap, use boolean index as is
326311
else:
327312
batch_indices.append(idx)
328313
curr_dim += len(bool_dims)
329314
# Basic-indexing (slice or integer)
330315
else:
331316
# Broadcasted dim
332-
if curr_dim in bcast_batch_param_dims:
317+
if curr_dim in bcast_param_dims:
333318
# Slice indexing, keep degenerate dim by none-slicing
334-
if isinstance(idx.type, SliceType):
319+
if isinstance(idx, slice) or isinstance(idx.type, SliceType):
335320
batch_indices.append(slice(None))
336321
# Integer indexing, drop degenerate dim by 0-indexing
337322
else:
@@ -342,7 +327,7 @@ def is_nd_advanced_idx(idx, dtype):
342327
batch_indices.append(idx)
343328
curr_dim += 1
344329

345-
new_dist_params.append(batch_param[tuple(batch_indices)])
330+
new_dist_params.append(param[tuple(batch_indices)])
346331

347332
# Create new RV
348333
new_node = rv_op.make_node(rng, new_size, *new_dist_params)

0 commit comments

Comments
 (0)