4646from aesara .sandbox .rng_mrg import MRG_RandomStream as RandomStream
4747from aesara .tensor .elemwise import Elemwise
4848from aesara .tensor .random .op import RandomVariable
49+ from aesara .tensor .shape import SpecifyShape
4950from aesara .tensor .sharedvar import SharedVariable
5051from aesara .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
5152from aesara .tensor .var import TensorVariable
5253
54+ from pymc3 .exceptions import ShapeError
5355from pymc3 .vartypes import continuous_types , int_types , isgenerator , typefilter
5456
5557PotentialShapeType = Union [
@@ -153,6 +155,16 @@ def change_rv_size(
153155 Expand the existing size by `new_size`.
154156
155157 """
158+ # Check the dimensionality of the `new_size` kwarg
159+ new_size_ndim = np .ndim (new_size )
160+ if new_size_ndim > 1 :
161+ raise ShapeError ("The `new_size` must be ≤1-dimensional." , actual = new_size_ndim )
162+ elif new_size_ndim == 0 :
163+ new_size = (new_size ,)
164+
165+ # Extract the RV node that is to be resized, together with its inputs, name and tag
166+ if isinstance (rv_var .owner .op , SpecifyShape ):
167+ rv_var = rv_var .owner .inputs [0 ]
156168 rv_node = rv_var .owner
157169 rng , size , dtype , * dist_params = rv_node .inputs
158170 name = rv_var .name
@@ -161,10 +173,10 @@ def change_rv_size(
161173 if expand :
162174 if rv_node .op .ndim_supp == 0 and at .get_vector_length (size ) == 0 :
163175 size = rv_node .op ._infer_shape (size , dist_params )
164- new_size = tuple (at . atleast_1d ( new_size ) ) + tuple (size )
176+ new_size = tuple (new_size ) + tuple (size )
165177
166- # Make sure the new size is a tensor. This helps to not unnecessarily pick
167- # up a `Cast` in some cases
178+ # Make sure the new size is a tensor. This dtype-aware conversion helps
179+ # to not unnecessarily pick up a `Cast` in some cases (see #4652).
168180 new_size = at .as_tensor (new_size , ndim = 1 , dtype = "int64" )
169181
170182 new_rv_node = rv_node .op .make_node (rng , new_size , dtype , * dist_params )
0 commit comments