Skip to content

Commit 131c829

Browse files
michaelosthegericardoV94twiecki
committed
Make change_rv_size more robust
Co-authored-by: Ricardo <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 738c9de commit 131c829

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

pymc3/aesaraf.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@
4646
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
4747
from aesara.tensor.elemwise import Elemwise
4848
from aesara.tensor.random.op import RandomVariable
49+
from aesara.tensor.shape import SpecifyShape
4950
from aesara.tensor.sharedvar import SharedVariable
5051
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
5152
from aesara.tensor.var import TensorVariable
5253

54+
from pymc3.exceptions import ShapeError
5355
from pymc3.vartypes import continuous_types, int_types, isgenerator, typefilter
5456

5557
PotentialShapeType = Union[
@@ -153,6 +155,16 @@ def change_rv_size(
153155
Expand the existing size by `new_size`.
154156
155157
"""
158+
# Check the dimensionality of the `new_size` kwarg
159+
new_size_ndim = np.ndim(new_size)
160+
if new_size_ndim > 1:
161+
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
162+
elif new_size_ndim == 0:
163+
new_size = (new_size,)
164+
165+
# Extract the RV node that is to be resized, together with its inputs, name and tag
166+
if isinstance(rv_var.owner.op, SpecifyShape):
167+
rv_var = rv_var.owner.inputs[0]
156168
rv_node = rv_var.owner
157169
rng, size, dtype, *dist_params = rv_node.inputs
158170
name = rv_var.name
@@ -161,10 +173,10 @@ def change_rv_size(
161173
if expand:
162174
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
163175
size = rv_node.op._infer_shape(size, dist_params)
164-
new_size = tuple(at.atleast_1d(new_size)) + tuple(size)
176+
new_size = tuple(new_size) + tuple(size)
165177

166-
# Make sure the new size is a tensor. This helps to not unnecessarily pick
167-
# up a `Cast` in some cases
178+
# Make sure the new size is a tensor. This dtype-aware conversion helps
179+
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
168180
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
169181

170182
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)

pymc3/tests/test_aesaraf.py

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
take_along_axis,
4242
walk_model,
4343
)
44+
from pymc3.exceptions import ShapeError
4445
from pymc3.vartypes import int_types
4546

4647
FLOATX = str(aesara.config.floatX)
@@ -53,6 +54,11 @@ def test_change_rv_size():
5354
assert rv.ndim == 1
5455
assert tuple(rv.shape.eval()) == (2,)
5556

57+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
58+
change_rv_size(rv, new_size=[[2, 3]])
59+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
60+
change_rv_size(rv, new_size=at.as_tensor_variable([[2, 3], [4, 5]]))
61+
5662
rv_new = change_rv_size(rv, new_size=(3,), expand=True)
5763
assert rv_new.ndim == 2
5864
assert tuple(rv_new.shape.eval()) == (3, 2)

pymc3/tests/test_distributions_random.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,10 @@ def get_random_variable(self, shape, with_vector_params=False, name=None):
187187

188188
@staticmethod
189189
def sample_random_variable(random_variable, size):
190-
"""Draws samples from a RandomVariable using its .random() method."""
191-
if size is None:
192-
return random_variable.eval()
193-
else:
194-
return change_rv_size(random_variable, size, expand=True).eval()
190+
""" Draws samples from a RandomVariable. """
191+
if size:
192+
random_variable = change_rv_size(random_variable, size, expand=True)
193+
return random_variable.eval()
195194

196195
@pytest.mark.parametrize("size", [None, (), 1, (1,), 5, (4, 5)], ids=str)
197196
@pytest.mark.parametrize("shape", [None, ()], ids=str)

0 commit comments

Comments
 (0)