Skip to content

Commit d4894a7

Browse files
Revert functionality change of size kwarg
Corresponding tests were reverted, or edited to use other parametrization flavors. The Ellipsis feature now works with all three dimensionality kwargs. The MultinomialRV implementation was removed, because the broadcasting behavior was implemented in Aesara. Closes pymc-devs#4662
1 parent d604881 commit d4894a7

10 files changed

+241
-205
lines changed

RELEASE-NOTES.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
### New Features
1010
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
1111
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)):
12-
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. A `SpecifyShape` `Op` is added automatically unless `Ellipsis` is used. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
12+
- With `shape` the length of all dimensions must be given numerically or as scalar Aesara `Variables`. A `SpecifyShape` `Op` is added automatically unless `Ellipsis` is used. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
1313
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
14-
- The `size` kwarg creates new dimensions in addition to what is implied by RV parameters.
15-
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
14+
- The `size` kwarg resembles the behavior found in Aesara and NumPy: It does not include _support_ dimensions.
15+
- An `Ellipsis` (`...`) in the last position of either kwarg can be used as short-hand notation for implied dimensions.
1616
- ...
1717

1818
### Maintenance

pymc3/distributions/distribution.py

+63-46
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949

5050
PLATFORM = sys.platform
5151

52-
Shape = Union[int, Sequence[Union[str, type(Ellipsis)]], Variable]
53-
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
54-
Size = Union[int, Tuple[int, ...]]
52+
Shape = Union[int, Sequence[Union[int, type(Ellipsis), Variable]], Variable]
53+
Dims = Union[str, Sequence[Union[str, type(Ellipsis), None]]]
54+
Size = Union[int, Sequence[Union[int, type(Ellipsis), Variable]], Variable]
5555

5656

5757
class _Unpickling:
@@ -146,7 +146,7 @@ def _validate_shape_dims_size(
146146
raise ValueError("The `shape` parameter must be an int, list or tuple.")
147147
if not isinstance(dims, (type(None), str, list, tuple)):
148148
raise ValueError("The `dims` parameter must be a str, list or tuple.")
149-
if not isinstance(size, (type(None), int, list, tuple)):
149+
if not isinstance(size, (type(None), int, list, tuple, Variable)):
150150
raise ValueError("The `size` parameter must be an int, list or tuple.")
151151

152152
# Auto-convert non-tupled parameters
@@ -162,17 +162,14 @@ def _validate_shape_dims_size(
162162
shape = tuple(shape)
163163
if not isinstance(dims, (type(None), tuple)):
164164
dims = tuple(dims)
165-
if not isinstance(size, (type(None), tuple)):
165+
if not isinstance(size, (type(None), tuple, Variable)):
166166
size = tuple(size)
167167

168-
if not _valid_ellipsis_position(shape):
169-
raise ValueError(
170-
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
171-
)
172-
if not _valid_ellipsis_position(dims):
173-
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
174-
if size is not None and Ellipsis in size:
175-
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")
168+
for kwarg, val in dict(shape=shape, dims=dims, size=size).items():
169+
if not _valid_ellipsis_position(val):
170+
raise ValueError(
171+
f"Ellipsis in `{kwarg}` may only appear in the last position. Actual: {val}"
172+
)
176173
return shape, dims, size
177174

178175

@@ -247,12 +244,12 @@ def __new__(
247244
rng = model.default_rng
248245

249246
_, dims, _ = _validate_shape_dims_size(dims=dims)
250-
resize = None
247+
batch_shape = None
251248

252249
# Create the RV without specifying testval, because the testval may have a shape
253250
# that only matches after replicating with a size implied by dims (see below).
254251
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)
255-
n_implied = rv_out.ndim
252+
ndim_implied = rv_out.ndim
256253

257254
# The `.dist()` can wrap automatically with a SpecifyShape Op which brings informative
258255
# error messages earlier in model construction.
@@ -268,33 +265,33 @@ def __new__(
268265
# Auto-complete the dims tuple to the full length
269266
dims = (*dims[:-1], *[None] * rv_out.ndim)
270267

271-
n_resize = len(dims) - n_implied
268+
ndim_batch = len(dims) - ndim_implied
272269

273-
# All resize dims must be known already (numerically or symbolically).
274-
unknown_resize_dims = set(dims[:n_resize]) - set(model.dim_lengths)
275-
if unknown_resize_dims:
270+
# All batch dims must be known already (numerically or symbolically).
271+
unknown_batch_dims = set(dims[:ndim_batch]) - set(model.dim_lengths)
272+
if unknown_batch_dims:
276273
raise KeyError(
277-
f"Dimensions {unknown_resize_dims} are unknown to the model and cannot be used to specify a `size`."
274+
f"Dimensions {unknown_batch_dims} are unknown to the model and cannot be used to specify a `size`."
278275
)
279276

280-
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
281-
resize = tuple(model.dim_lengths[dname] for dname in dims[:n_resize])
277+
# The numeric/symbolic batch shape can be created using model.RV_dim_lengths
278+
batch_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_batch])
282279
elif observed is not None:
283280
if not hasattr(observed, "shape"):
284281
observed = pandas_to_array(observed)
285-
n_resize = observed.ndim - n_implied
286-
resize = tuple(observed.shape[d] for d in range(n_resize))
282+
ndim_batch = observed.ndim - ndim_implied
283+
batch_shape = tuple(observed.shape[d] for d in range(ndim_batch))
287284

288-
if resize:
285+
if batch_shape:
289286
# A batch size was specified through `dims`, or implied by `observed`.
290-
rv_out = change_rv_size(rv_var=rv_out, new_size=resize, expand=True)
287+
rv_out = change_rv_size(rv_var=rv_out, new_size=batch_shape, expand=True)
291288

292289
if dims is not None:
293290
# Now that we have a handle on the output RV, we can register named implied dimensions that
294291
# were not yet known to the model, such that they can be used for size further downstream.
295-
for di, dname in enumerate(dims[n_resize:]):
292+
for di, dname in enumerate(dims[ndim_batch:]):
296293
if not dname in model.dim_lengths:
297-
model.add_coord(dname, values=None, length=rv_out.shape[n_resize + di])
294+
model.add_coord(dname, values=None, length=rv_out.shape[ndim_batch + di])
298295

299296
if testval is not None:
300297
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
@@ -336,6 +333,9 @@ def dist(
336333
size : int, tuple, Variable, optional
337334
A scalar or tuple for replicating the RV in addition
338335
to its implied shape/dimensionality.
336+
337+
Ellipsis (...) may be used in the last position of the tuple,
338+
such that only batch dimensions must be specified.
339339
testval : optional
340340
Test value to be attached to the output RV.
341341
Must match its shape exactly.
@@ -351,29 +351,46 @@ def dist(
351351
shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
352352
assert_shape = None
353353

354-
# Create the RV without specifying size or testval.
355-
# The size will be expanded later (if necessary) and only then the testval fits.
354+
# Create the RV without specifying size or testval, because we don't know
355+
# a-priori if `size` contains a batch shape.
356+
# In the end `testval` must match the final shape, but it is
357+
# not taken into consideration when creating the RV.
358+
# Batch-shape expansion (if necessary) of the RV happens later.
356359
rv_native = cls.rv_op(*dist_params, size=None, **kwargs)
357360

358-
if shape is None and size is None:
359-
size = ()
361+
# Now we know the implied dimensionality and can figure out the batch shape
362+
batch_shape = ()
363+
if size is not None:
364+
if not isinstance(size, Variable) and Ellipsis in size:
365+
batch_shape = size[:-1]
366+
else:
367+
# Parametrization through size does not include support dimensions,
368+
# but may include a batch shape.
369+
ndim_support = rv_native.owner.op.ndim_supp
370+
ndim_inputs = rv_native.ndim - ndim_support
371+
# Be careful to avoid len(size) because it may be symbolic
372+
if ndim_inputs == 0:
373+
batch_shape = size
374+
else:
375+
batch_shape = size[:-ndim_inputs]
360376
elif shape is not None:
361-
# SpecifyShape is automatically applied for symbolic and non-Ellipsis shapes
362-
if isinstance(shape, Variable):
363-
assert_shape = shape
364-
size = ()
377+
if not isinstance(shape, Variable) and Ellipsis in shape:
378+
# Can't assert a shape without knowing all dimension lengths.
379+
# The batch shape are all entries before ...
380+
batch_shape = tuple(shape[:-1])
365381
else:
366-
if Ellipsis in shape:
367-
size = tuple(shape[:-1])
382+
# Fully symbolic, or without Ellipsis shapes can be asserted.
383+
assert_shape = shape
384+
# The batch shape are the entries that preceed the implied dimensions.
385+
# Be careful to avoid len(shape) because it may be symbolic
386+
if rv_native.ndim == 0:
387+
batch_shape = shape
368388
else:
369-
size = tuple(shape[: len(shape) - rv_native.ndim])
370-
assert_shape = shape
371-
# no-op conditions:
372-
# `elif size is not None` (User already specified how to expand the RV)
373-
# `else` (Unreachable)
374-
375-
if size:
376-
rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True)
389+
batch_shape = shape[: -rv_native.ndim]
390+
# else: both dimensionality kwargs are None
391+
392+
if batch_shape:
393+
rv_out = change_rv_size(rv_var=rv_native, new_size=batch_shape, expand=True)
377394
else:
378395
rv_out = rv_native
379396

pymc3/distributions/multivariate.py

-22
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from aesara.graph.op import Op
2727
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
2828
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
29-
from aesara.tensor.random.utils import broadcast_params
3029
from aesara.tensor.slinalg import (
3130
Cholesky,
3231
Solve,
@@ -427,27 +426,6 @@ def _distr_parameters_for_repr(self):
427426
return ["a"]
428427

429428

430-
class MultinomialRV(MultinomialRV):
431-
"""Aesara's `MultinomialRV` doesn't broadcast; this one does."""
432-
433-
@classmethod
434-
def rng_fn(cls, rng, n, p, size):
435-
if n.ndim > 0 or p.ndim > 1:
436-
n, p = broadcast_params([n, p], cls.ndims_params)
437-
size = tuple(size or ())
438-
439-
if size:
440-
n = np.broadcast_to(n, size + n.shape)
441-
p = np.broadcast_to(p, size + p.shape)
442-
443-
res = np.empty(p.shape)
444-
for idx in np.ndindex(p.shape[:-1]):
445-
res[idx] = rng.multinomial(n[idx], p[idx])
446-
return res
447-
else:
448-
return rng.multinomial(n, p, size=size)
449-
450-
451429
multinomial = MultinomialRV()
452430

453431

pymc3/tests/test_data_container.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_shared_data_as_rv_input(self):
163163
with pm.Model() as m:
164164
x = pm.Data("x", [1.0, 2.0, 3.0])
165165
assert x.eval().shape == (3,)
166-
y = pm.Normal("y", mu=x, size=2)
166+
y = pm.Normal("y", mu=x, shape=(2, ...))
167167
assert y.eval().shape == (2, 3)
168168
idata = pm.sample(
169169
chains=1,

pymc3/tests/test_distributions.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -1986,29 +1986,33 @@ def test_multinomial_mode(self, p, n):
19861986
@pytest.mark.parametrize(
19871987
"p, size, n",
19881988
[
1989-
[[0.25, 0.25, 0.25, 0.25], (4,), 2],
1990-
[[0.25, 0.25, 0.25, 0.25], (1, 4), 3],
1989+
[[0.25, 0.25, 0.25, 0.25], (7,), 2],
1990+
[[0.25, 0.25, 0.25, 0.25], (1, 7), 3],
19911991
# 3: expect to fail
1992-
# [[.25, .25, .25, .25], (10, 4)],
1993-
[[0.25, 0.25, 0.25, 0.25], (10, 1, 4), 5],
1992+
# [[.25, .25, .25, .25], (10, 7)],
1993+
[[0.25, 0.25, 0.25, 0.25], (10, 1, 7), 5],
19941994
# 5: expect to fail
1995-
# [[[.25, .25, .25, .25]], (2, 4), [7, 11]],
1996-
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), 13],
1997-
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (1, 2, 4), [23, 29]],
1995+
# [[[.25, .25, .25, .25]], (2, 5), [7, 11]],
1996+
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (5, 2), 13],
1997+
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (5, 7, 2), [23, 29]],
19981998
[
19991999
[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]],
2000-
(10, 2, 4),
2000+
(10, 8, 2),
20012001
[31, 37],
20022002
],
2003-
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), [17, 19]],
2003+
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (3, 2), [17, 19]],
20042004
],
20052005
)
20062006
def test_multinomial_random(self, p, size, n):
20072007
p = np.asarray(p)
20082008
with Model() as model:
20092009
m = Multinomial("m", n=n, p=p, size=size)
2010-
2011-
assert m.eval().shape == size + p.shape
2010+
# The support has length 4 in all test parametrizations!
2011+
# Broadcasting of the `p` parameter does not affect the ndim_supp
2012+
# of the Op, hence the broadcasted p must be included in `size`.
2013+
support_shape = (p.shape[-1],)
2014+
assert support_shape == (4,)
2015+
assert m.eval().shape == size + support_shape
20122016

20132017
@pytest.mark.skip(reason="Moment calculations have not been refactored yet")
20142018
def test_multinomial_mode_with_shape(self):
@@ -2109,7 +2113,7 @@ def test_batch_multinomial(self):
21092113
decimal=select_by_precision(float64=6, float32=3),
21102114
)
21112115

2112-
dist = Multinomial.dist(n=n, p=p, size=2)
2116+
dist = Multinomial.dist(n=n, p=p, size=(2, ...))
21132117
sample = dist.eval()
21142118
assert_allclose(sample, np.stack([vals, vals], axis=0))
21152119

pymc3/tests/test_model.py

+5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pymc3 import Deterministic, Potential
3636
from pymc3.blocking import DictToArrayBijection, RaveledVars
3737
from pymc3.distributions import Normal, logpt_sum, transforms
38+
from pymc3.exceptions import ShapeError
3839
from pymc3.model import Point, ValueGradFunction
3940
from pymc3.tests.helpers import SeededTest
4041

@@ -465,6 +466,10 @@ def test_make_obs_var():
465466
# Create the testval attribute simply for the sake of model testing
466467
fake_distribution.name = input_name
467468

469+
# The function requires data and RV dimensionality to be compatible
470+
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."):
471+
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None)
472+
468473
# Check function behavior using the various inputs
469474
# dense, sparse: Ensure that the missing values are appropriately set to None
470475
# masked: a deterministic variable is returned

pymc3/tests/test_sampling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -967,14 +967,14 @@ def test_multivariate2(self):
967967

968968
def test_layers(self):
969969
with pm.Model() as model:
970-
a = pm.Uniform("a", lower=0, upper=1, size=5)
971-
b = pm.Binomial("b", n=1, p=a, size=7)
970+
a = pm.Uniform("a", lower=0, upper=1, size=10)
971+
b = pm.Binomial("b", n=1, p=a)
972972

973973
model.default_rng.get_value(borrow=True).seed(232093)
974974

975975
b_sampler = aesara.function([], b)
976976
avg = np.stack([b_sampler() for i in range(10000)]).mean(0)
977-
npt.assert_array_almost_equal(avg, 0.5 * np.ones((7, 5)), decimal=2)
977+
npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2)
978978

979979
def test_transformed(self):
980980
n = 18

0 commit comments

Comments
 (0)