Skip to content

Commit 512491a

Browse files
committed
Merge branch 'fix-4662' into reintro_shape
2 parents 08754cd + 3f8ea24 commit 512491a

File tree

4 files changed

+216
-1
lines changed

4 files changed

+216
-1
lines changed

Diff for: pymc3/aesaraf.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
5050
from aesara.tensor.var import TensorVariable
5151

52+
from pymc3.exceptions import ShapeError
5253
from pymc3.vartypes import continuous_types, int_types, isgenerator, typefilter
5354

5455
PotentialShapeType = Union[
@@ -146,6 +147,12 @@ def change_rv_size(
146147
Expand the existing size by `new_size`.
147148
148149
"""
150+
new_size_ndim = new_size.ndim if isinstance(new_size, Variable) else np.ndim(new_size)
151+
if new_size_ndim > 1:
152+
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
153+
new_size = at.as_tensor_variable(new_size, ndim=1)
154+
if isinstance(rv_var.owner.op, SpecifyShape):
155+
rv_var = rv_var.owner.inputs[0]
149156
rv_node = rv_var.owner
150157
rng, size, dtype, *dist_params = rv_node.inputs
151158
name = rv_var.name
@@ -154,7 +161,7 @@ def change_rv_size(
154161
if expand:
155162
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
156163
size = rv_node.op._infer_shape(size, dist_params)
157-
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)
164+
new_size = tuple(new_size) + tuple(size)
158165

159166
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
160167
rv_var = new_rv_node.outputs[-1]

Diff for: pymc3/tests/test_aesaraf.py

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
take_along_axis,
4242
walk_model,
4343
)
44+
from pymc3.exceptions import ShapeError
4445
from pymc3.vartypes import int_types
4546

4647
FLOATX = str(aesara.config.floatX)
@@ -53,6 +54,11 @@ def test_change_rv_size():
5354
assert rv.ndim == 1
5455
assert rv.eval().shape == (2,)
5556

57+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
58+
change_rv_size(rv, new_size=[[2, 3]])
59+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
60+
change_rv_size(rv, new_size=at.as_tensor_variable([[2, 3], [4, 5]]))
61+
5662
rv_new = change_rv_size(rv, new_size=(3,), expand=True)
5763
assert rv_new.ndim == 2
5864
assert rv_new.eval().shape == (3, 2)

Diff for: pymc3/tests/test_ode.py

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
16+
1517
import aesara
1618
import numpy as np
1719
import pytest
@@ -168,6 +170,9 @@ def ode_func_5(y, t, p):
168170
np.testing.assert_array_equal(np.ravel(model5_sens_ic), model5._sens_ic)
169171

170172

173+
@pytest.mark.xfail(
174+
condition=sys.platform == "win32", reason="See https://github.com/pymc-devs/pymc3/issues/4652."
175+
)
171176
def test_logp_scalar_ode():
172177
"""Test the computation of the log probability for these models"""
173178

Diff for: pymc3/tests/test_shape_handling.py

+197
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,200 @@ def test_sample_generate_values(fixture_model, fixture_sizes):
219219
prior = pm.sample_prior_predictive(samples=fixture_sizes)
220220
for rv in RVs:
221221
assert prior[rv.name].shape == size + tuple(rv.distribution.shape)
222+
223+
224+
class TestShapeDimsSize:
225+
@pytest.mark.parametrize("param_shape", [(), (3,)])
226+
@pytest.mark.parametrize("batch_shape", [(), (3,)])
227+
@pytest.mark.parametrize(
228+
"parametrization",
229+
[
230+
"implicit",
231+
"shape",
232+
"shape...",
233+
"dims",
234+
"dims...",
235+
"size",
236+
],
237+
)
238+
def test_param_and_batch_shape_combos(
239+
self, param_shape: tuple, batch_shape: tuple, parametrization: str
240+
):
241+
coords = {}
242+
param_dims = []
243+
batch_dims = []
244+
245+
# Create coordinates corresponding to the parameter shape
246+
for d in param_shape:
247+
dname = f"param_dim_{d}"
248+
coords[dname] = [f"c_{i}" for i in range(d)]
249+
param_dims.append(dname)
250+
assert len(param_dims) == len(param_shape)
251+
# Create coordinates corresponding to the batch shape
252+
for d in batch_shape:
253+
dname = f"batch_dim_{d}"
254+
coords[dname] = [f"c_{i}" for i in range(d)]
255+
batch_dims.append(dname)
256+
assert len(batch_dims) == len(batch_shape)
257+
258+
with pm.Model(coords=coords) as pmodel:
259+
mu = aesara.shared(np.random.normal(size=param_shape))
260+
261+
with pytest.warns(None):
262+
if parametrization == "implicit":
263+
rv = pm.Normal("rv", mu=mu).shape == param_shape
264+
else:
265+
if parametrization == "shape":
266+
rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape)
267+
assert rv.eval().shape == batch_shape + param_shape
268+
elif parametrization == "shape...":
269+
rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
270+
assert rv.eval().shape == batch_shape + param_shape
271+
elif parametrization == "dims":
272+
rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims)
273+
assert rv.eval().shape == batch_shape + param_shape
274+
elif parametrization == "dims...":
275+
rv = pm.Normal("rv", mu=mu, dims=(*batch_dims, ...))
276+
n_size = len(batch_shape)
277+
n_implied = len(param_shape)
278+
ndim = n_size + n_implied
279+
assert len(pmodel.RV_dims["rv"]) == ndim, pmodel.RV_dims
280+
assert len(pmodel.RV_dims["rv"][:n_size]) == len(batch_dims)
281+
assert len(pmodel.RV_dims["rv"][n_size:]) == len(param_dims)
282+
if n_implied > 0:
283+
assert pmodel.RV_dims["rv"][-1] is None
284+
elif parametrization == "size":
285+
rv = pm.Normal("rv", mu=mu, size=batch_shape)
286+
assert rv.eval().shape == batch_shape + param_shape
287+
else:
288+
raise NotImplementedError("Invalid test case parametrization.")
289+
290+
def test_define_dims_on_the_fly(self):
291+
with pm.Model() as pmodel:
292+
agedata = aesara.shared(np.array([10, 20, 30]))
293+
294+
# Associate the "patient" dim with an implied dimension
295+
age = pm.Normal("age", agedata, dims=("patient",))
296+
assert "patient" in pmodel.dim_lengths
297+
assert pmodel.dim_lengths["patient"].eval() == 3
298+
299+
# Use the dim to replicate a new RV
300+
effect = pm.Normal("effect", 0, dims=("patient",))
301+
assert effect.ndim == 1
302+
assert effect.eval().shape == (3,)
303+
304+
# Now change the length of the implied dimension
305+
agedata.set_value([1, 2, 3, 4])
306+
# The change should propagate all the way through
307+
assert effect.eval().shape == (4,)
308+
309+
@pytest.mark.xfail(reason="Simultaneous use of size and dims is not implemented")
310+
def test_data_defined_size_dimension_can_register_dimname(self):
311+
with pm.Model() as pmodel:
312+
x = pm.Data("x", [[1, 2, 3, 4]], dims=("first", "second"))
313+
assert "first" in pmodel.dim_lengths
314+
assert "second" in pmodel.dim_lengths
315+
# two dimensions are implied; a "third" dimension is created
316+
y = pm.Normal("y", mu=x, size=2, dims=("third", "first", "second"))
317+
assert "third" in pmodel.dim_lengths
318+
assert y.eval().shape() == (2, 1, 4)
319+
320+
def test_can_resize_data_defined_size(self):
321+
with pm.Model() as pmodel:
322+
x = pm.Data("x", [[1, 2, 3, 4]], dims=("first", "second"))
323+
y = pm.Normal("y", mu=0, dims=("first", "second"))
324+
z = pm.Normal("z", mu=y, observed=np.ones((1, 4)))
325+
assert x.eval().shape == (1, 4)
326+
assert y.eval().shape == (1, 4)
327+
assert z.eval().shape == (1, 4)
328+
assert "first" in pmodel.dim_lengths
329+
assert "second" in pmodel.dim_lengths
330+
pmodel.set_data("x", [[1, 2], [3, 4], [5, 6]])
331+
assert x.eval().shape == (3, 2)
332+
assert y.eval().shape == (3, 2)
333+
assert z.eval().shape == (3, 2)
334+
335+
@pytest.mark.xfail(
336+
condition=sys.platform == "win32",
337+
reason="See https://github.com/pymc-devs/pymc3/issues/4652.",
338+
)
339+
def test_observed_with_column_vector(self):
340+
with pm.Model() as model:
341+
pm.Normal("x1", mu=0, sd=1, observed=np.random.normal(size=(3, 4)))
342+
model.logp()
343+
pm.Normal("x2", mu=0, sd=1, observed=np.random.normal(size=(3, 1)))
344+
model.logp()
345+
346+
def test_dist_api_works(self):
347+
mu = aesara.shared(np.array([1, 2, 3]))
348+
with pytest.raises(NotImplementedError, match="API is not yet supported"):
349+
pm.Normal.dist(mu=mu, dims=("town",))
350+
assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,)
351+
assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3)
352+
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
353+
assert pm.Normal.dist(mu=mu, size=(4,)).eval().shape == (4, 3)
354+
355+
def test_auto_assert_shape(self):
356+
with pytest.raises(AssertionError, match="will never match"):
357+
pm.Normal.dist(mu=[1, 2], shape=[])
358+
359+
mu = at.vector(name="mu_input")
360+
rv = pm.Normal.dist(mu=mu, shape=[3, 4])
361+
f = aesara.function([mu], rv, mode=aesara.Mode("py"))
362+
assert f([1, 2, 3, 4]).shape == (3, 4)
363+
364+
with pytest.raises(AssertionError, match=r"Got shape \(3, 2\), expected \(3, 4\)."):
365+
f([1, 2])
366+
367+
# The `shape` can be symbolic!
368+
s = at.vector(dtype="int32")
369+
rv = pm.Uniform.dist(2, [4, 5], shape=s)
370+
f = aesara.function([s], rv, mode=aesara.Mode("py"))
371+
f(
372+
[
373+
2,
374+
]
375+
)
376+
with pytest.raises(
377+
AssertionError,
378+
match=r"Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(3, 4\).",
379+
):
380+
f([3, 4])
381+
with pytest.raises(
382+
AssertionError,
383+
match=r"Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\).",
384+
):
385+
f([])
386+
pass
387+
388+
def test_lazy_flavors(self):
389+
390+
_validate_shape_dims_size(shape=5)
391+
_validate_shape_dims_size(dims="town")
392+
_validate_shape_dims_size(size=7)
393+
394+
assert pm.Uniform.dist(2, [4, 5], size=[3, 4]).eval().shape == (3, 4, 2)
395+
assert pm.Uniform.dist(2, [4, 5], shape=[3, 2]).eval().shape == (3, 2)
396+
with pm.Model(coords=dict(town=["Greifswald", "Madrid"])):
397+
assert pm.Normal("n2", mu=[1, 2], dims=("town",)).eval().shape == (2,)
398+
399+
def test_invalid_flavors(self):
400+
# redundant parametrizations
401+
with pytest.raises(ValueError, match="Passing both"):
402+
_validate_shape_dims_size(shape=(2,), dims=("town",))
403+
with pytest.raises(ValueError, match="Passing both"):
404+
_validate_shape_dims_size(dims=("town",), size=(2,))
405+
with pytest.raises(ValueError, match="Passing both"):
406+
_validate_shape_dims_size(shape=(3,), size=(3,))
407+
408+
# invalid, but not necessarly rare
409+
with pytest.raises(ValueError, match="must be an int, list or tuple"):
410+
_validate_shape_dims_size(size="notasize")
411+
412+
# invalid ellipsis positions
413+
with pytest.raises(ValueError, match="may only appear in the last position"):
414+
_validate_shape_dims_size(shape=(3, ..., 2))
415+
with pytest.raises(ValueError, match="may only appear in the last position"):
416+
_validate_shape_dims_size(dims=(..., "town"))
417+
with pytest.raises(ValueError, match="cannot contain"):
418+
_validate_shape_dims_size(size=(3, ...))

0 commit comments

Comments
 (0)