Skip to content

Commit 1446732

Browse files
michaelosthegetwiecki
authored andcommitted
Check ndim of new_size kwarg in change_rv_size
And don't use np.atleast_1d on things that can be tensors.
1 parent faed5f1 commit 1446732

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pymc3/aesaraf.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
5151
from aesara.tensor.var import TensorVariable
5252

53+
from pymc3.exceptions import ShapeError
5354
from pymc3.vartypes import continuous_types, int_types, isgenerator, typefilter
5455

5556
PotentialShapeType = Union[
@@ -147,6 +148,10 @@ def change_rv_size(
147148
Expand the existing size by `new_size`.
148149
149150
"""
151+
new_size_ndim = new_size.ndim if isinstance(new_size, Variable) else np.ndim(new_size)
152+
if new_size_ndim > 1:
153+
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
154+
new_size = at.as_tensor_variable(new_size, ndim=1)
150155
if isinstance(rv_var.owner.op, SpecifyShape):
151156
rv_var = rv_var.owner.inputs[0]
152157
rv_node = rv_var.owner
@@ -157,7 +162,7 @@ def change_rv_size(
157162
if expand:
158163
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
159164
size = rv_node.op._infer_shape(size, dist_params)
160-
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)
165+
new_size = tuple(new_size) + tuple(size)
161166

162167
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
163168
rv_var = new_rv_node.outputs[-1]

pymc3/tests/test_aesaraf.py

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
take_along_axis,
4242
walk_model,
4343
)
44+
from pymc3.exceptions import ShapeError
4445
from pymc3.vartypes import int_types
4546

4647
FLOATX = str(aesara.config.floatX)
@@ -53,6 +54,11 @@ def test_change_rv_size():
5354
assert rv.ndim == 1
5455
assert rv.eval().shape == (2,)
5556

57+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
58+
change_rv_size(rv, new_size=[[2, 3]])
59+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
60+
change_rv_size(rv, new_size=at.as_tensor_variable([[2, 3], [4, 5]]))
61+
5662
rv_new = change_rv_size(rv, new_size=(3,), expand=True)
5763
assert rv_new.ndim == 2
5864
assert rv_new.eval().shape == (3, 2)

0 commit comments

Comments
 (0)