46
46
from aesara .sandbox .rng_mrg import MRG_RandomStream as RandomStream
47
47
from aesara .tensor .elemwise import Elemwise
48
48
from aesara .tensor .random .op import RandomVariable
49
+ from aesara .tensor .shape import SpecifyShape
49
50
from aesara .tensor .sharedvar import SharedVariable
50
51
from aesara .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
51
52
from aesara .tensor .var import TensorVariable
52
53
54
+ from pymc3 .exceptions import ShapeError
53
55
from pymc3 .vartypes import continuous_types , int_types , isgenerator , typefilter
54
56
55
57
PotentialShapeType = Union [
@@ -153,6 +155,16 @@ def change_rv_size(
153
155
Expand the existing size by `new_size`.
154
156
155
157
"""
158
+ # Check the dimensionality of the `new_size` kwarg
159
+ new_size_ndim = np .ndim (new_size )
160
+ if new_size_ndim > 1 :
161
+ raise ShapeError ("The `new_size` must be ≤1-dimensional." , actual = new_size_ndim )
162
+ elif new_size_ndim == 0 :
163
+ new_size = (new_size ,)
164
+
165
+ # Extract the RV node that is to be resized, together with its inputs, name and tag
166
+ if isinstance (rv_var .owner .op , SpecifyShape ):
167
+ rv_var = rv_var .owner .inputs [0 ]
156
168
rv_node = rv_var .owner
157
169
rng , size , dtype , * dist_params = rv_node .inputs
158
170
name = rv_var .name
@@ -161,10 +173,10 @@ def change_rv_size(
161
173
if expand :
162
174
if rv_node .op .ndim_supp == 0 and at .get_vector_length (size ) == 0 :
163
175
size = rv_node .op ._infer_shape (size , dist_params )
164
- new_size = tuple (at . atleast_1d ( new_size ) ) + tuple (size )
176
+ new_size = tuple (new_size ) + tuple (size )
165
177
166
- # Make sure the new size is a tensor. This helps to not unnecessarily pick
167
- # up a `Cast` in some cases
178
+ # Make sure the new size is a tensor. This dtype-aware conversion helps
179
+ # to not unnecessarily pick up a `Cast` in some cases (see #4652).
168
180
new_size = at .as_tensor (new_size , ndim = 1 , dtype = "int64" )
169
181
170
182
new_rv_node = rv_node .op .make_node (rng , new_size , dtype , * dist_params )
0 commit comments