50
50
from aesara .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
51
51
from aesara .tensor .var import TensorVariable
52
52
53
+ from pymc3 .exceptions import ShapeError
53
54
from pymc3 .vartypes import continuous_types , int_types , isgenerator , typefilter
54
55
55
56
PotentialShapeType = Union [
@@ -147,6 +148,10 @@ def change_rv_size(
147
148
Expand the existing size by `new_size`.
148
149
149
150
"""
151
+ new_size_ndim = new_size .ndim if isinstance (new_size , Variable ) else np .ndim (new_size )
152
+ if new_size_ndim > 1 :
153
+ raise ShapeError ("The `new_size` must be ≤1-dimensional." , actual = new_size_ndim )
154
+ new_size = at .as_tensor_variable (new_size , ndim = 1 )
150
155
if isinstance (rv_var .owner .op , SpecifyShape ):
151
156
rv_var = rv_var .owner .inputs [0 ]
152
157
rv_node = rv_var .owner
@@ -157,7 +162,7 @@ def change_rv_size(
157
162
if expand :
158
163
if rv_node .op .ndim_supp == 0 and at .get_vector_length (size ) == 0 :
159
164
size = rv_node .op ._infer_shape (size , dist_params )
160
- new_size = tuple (np . atleast_1d ( new_size ) ) + tuple (size )
165
+ new_size = tuple (new_size ) + tuple (size )
161
166
162
167
new_rv_node = rv_node .op .make_node (rng , new_size , dtype , * dist_params )
163
168
rv_var = new_rv_node .outputs [- 1 ]
0 commit comments