Skip to content

Commit f326166

Browse files
michaelosthegetwiecki
authored andcommitted
Revert functionality change of size kwarg
Corresponding tests were reverted, or edited to use other parametrization flavors. The Ellipsis feature now works with all three dimensionality kwargs. The MultinomialRV implementation was removed, because the broadcasting behavior was implemented in Aesara. Closes #4662
1 parent 47fd9b8 commit f326166

10 files changed

+236
-198
lines changed

Diff for: RELEASE-NOTES.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
### New Features
1111
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
1212
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)):
13-
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. A `SpecifyShape` `Op` is added automatically unless `Ellipsis` is used. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
13+
- With `shape` the length of all dimensions must be given numerically or as scalar Aesara `Variables`. A `SpecifyShape` `Op` is added automatically unless `Ellipsis` is used. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
1414
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
15-
- The `size` kwarg creates new dimensions in addition to what is implied by RV parameters.
16-
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
15+
- The `size` kwarg resembles the behavior found in Aesara and NumPy: It does not include _support_ dimensions.
16+
- An `Ellipsis` (`...`) in the last position of either kwarg can be used as short-hand notation for implied dimensions.
1717
- ...
1818

1919
### Maintenance

Diff for: pymc3/distributions/distribution.py

+63-46
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949

5050
PLATFORM = sys.platform
5151

52-
Shape = Union[int, Sequence[Union[str, type(Ellipsis)]], Variable]
53-
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
54-
Size = Union[int, Tuple[int, ...]]
52+
Shape = Union[int, Sequence[Union[int, type(Ellipsis), Variable]], Variable]
53+
Dims = Union[str, Sequence[Union[str, type(Ellipsis), None]]]
54+
Size = Union[int, Sequence[Union[int, type(Ellipsis), Variable]], Variable]
5555

5656

5757
class _Unpickling:
@@ -139,7 +139,7 @@ def _validate_shape_dims_size(
139139
raise ValueError("The `shape` parameter must be an int, list or tuple.")
140140
if not isinstance(dims, (type(None), str, list, tuple)):
141141
raise ValueError("The `dims` parameter must be a str, list or tuple.")
142-
if not isinstance(size, (type(None), int, list, tuple)):
142+
if not isinstance(size, (type(None), int, list, tuple, Variable)):
143143
raise ValueError("The `size` parameter must be an int, list or tuple.")
144144

145145
# Auto-convert non-tupled parameters
@@ -155,17 +155,14 @@ def _validate_shape_dims_size(
155155
shape = tuple(shape)
156156
if not isinstance(dims, (type(None), tuple)):
157157
dims = tuple(dims)
158-
if not isinstance(size, (type(None), tuple)):
158+
if not isinstance(size, (type(None), tuple, Variable)):
159159
size = tuple(size)
160160

161-
if not _valid_ellipsis_position(shape):
162-
raise ValueError(
163-
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
164-
)
165-
if not _valid_ellipsis_position(dims):
166-
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
167-
if size is not None and Ellipsis in size:
168-
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")
161+
for kwarg, val in dict(shape=shape, dims=dims, size=size).items():
162+
if not _valid_ellipsis_position(val):
163+
raise ValueError(
164+
f"Ellipsis in `{kwarg}` may only appear in the last position. Actual: {val}"
165+
)
169166
return shape, dims, size
170167

171168

@@ -240,12 +237,12 @@ def __new__(
240237
rng = model.default_rng
241238

242239
_, dims, _ = _validate_shape_dims_size(dims=dims)
243-
resize = None
240+
batch_shape = None
244241

245242
# Create the RV without specifying testval, because the testval may have a shape
246243
# that only matches after replicating with a size implied by dims (see below).
247244
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)
248-
n_implied = rv_out.ndim
245+
ndim_implied = rv_out.ndim
249246

250247
# The `.dist()` can wrap automatically with a SpecifyShape Op which brings informative
251248
# error messages earlier in model construction.
@@ -261,33 +258,33 @@ def __new__(
261258
# Auto-complete the dims tuple to the full length
262259
dims = (*dims[:-1], *[None] * rv_out.ndim)
263260

264-
n_resize = len(dims) - n_implied
261+
ndim_batch = len(dims) - ndim_implied
265262

266-
# All resize dims must be known already (numerically or symbolically).
267-
unknown_resize_dims = set(dims[:n_resize]) - set(model.dim_lengths)
268-
if unknown_resize_dims:
263+
# All batch dims must be known already (numerically or symbolically).
264+
unknown_batch_dims = set(dims[:ndim_batch]) - set(model.dim_lengths)
265+
if unknown_batch_dims:
269266
raise KeyError(
270-
f"Dimensions {unknown_resize_dims} are unknown to the model and cannot be used to specify a `size`."
267+
f"Dimensions {unknown_batch_dims} are unknown to the model and cannot be used to specify a `size`."
271268
)
272269

273-
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
274-
resize = tuple(model.dim_lengths[dname] for dname in dims[:n_resize])
270+
# The numeric/symbolic batch shape can be created using model.RV_dim_lengths
271+
batch_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_batch])
275272
elif observed is not None:
276273
if not hasattr(observed, "shape"):
277274
observed = pandas_to_array(observed)
278-
n_resize = observed.ndim - n_implied
279-
resize = tuple(observed.shape[d] for d in range(n_resize))
275+
ndim_batch = observed.ndim - ndim_implied
276+
batch_shape = tuple(observed.shape[d] for d in range(ndim_batch))
280277

281-
if resize:
278+
if batch_shape:
282279
# A batch size was specified through `dims`, or implied by `observed`.
283-
rv_out = change_rv_size(rv_var=rv_out, new_size=resize, expand=True)
280+
rv_out = change_rv_size(rv_var=rv_out, new_size=batch_shape, expand=True)
284281

285282
if dims is not None:
286283
# Now that we have a handle on the output RV, we can register named implied dimensions that
287284
# were not yet known to the model, such that they can be used for size further downstream.
288-
for di, dname in enumerate(dims[n_resize:]):
285+
for di, dname in enumerate(dims[ndim_batch:]):
289286
if not dname in model.dim_lengths:
290-
model.add_coord(dname, values=None, length=rv_out.shape[n_resize + di])
287+
model.add_coord(dname, values=None, length=rv_out.shape[ndim_batch + di])
291288

292289
if testval is not None:
293290
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
@@ -329,6 +326,9 @@ def dist(
329326
size : int, tuple, Variable, optional
330327
A scalar or tuple for replicating the RV in addition
331328
to its implied shape/dimensionality.
329+
330+
Ellipsis (...) may be used in the last position of the tuple,
331+
such that only batch dimensions must be specified.
332332
testval : optional
333333
Test value to be attached to the output RV.
334334
Must match its shape exactly.
@@ -344,29 +344,46 @@ def dist(
344344
shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
345345
assert_shape = None
346346

347-
# Create the RV without specifying size or testval.
348-
# The size will be expanded later (if necessary) and only then the testval fits.
347+
# Create the RV without specifying size or testval, because we don't know
348+
# a-priori if `size` contains a batch shape.
349+
# In the end `testval` must match the final shape, but it is
350+
# not taken into consideration when creating the RV.
351+
# Batch-shape expansion (if necessary) of the RV happens later.
349352
rv_native = cls.rv_op(*dist_params, size=None, **kwargs)
350353

351-
if shape is None and size is None:
352-
size = ()
354+
# Now we know the implied dimensionality and can figure out the batch shape
355+
batch_shape = ()
356+
if size is not None:
357+
if not isinstance(size, Variable) and Ellipsis in size:
358+
batch_shape = size[:-1]
359+
else:
360+
# Parametrization through size does not include support dimensions,
361+
# but may include a batch shape.
362+
ndim_support = rv_native.owner.op.ndim_supp
363+
ndim_inputs = rv_native.ndim - ndim_support
364+
# Be careful to avoid len(size) because it may be symbolic
365+
if ndim_inputs == 0:
366+
batch_shape = size
367+
else:
368+
batch_shape = size[:-ndim_inputs]
353369
elif shape is not None:
354-
# SpecifyShape is automatically applied for symbolic and non-Ellipsis shapes
355-
if isinstance(shape, Variable):
356-
assert_shape = shape
357-
size = ()
370+
if not isinstance(shape, Variable) and Ellipsis in shape:
371+
# Can't assert a shape without knowing all dimension lengths.
372+
# The batch shape are all entries before ...
373+
batch_shape = tuple(shape[:-1])
358374
else:
359-
if Ellipsis in shape:
360-
size = tuple(shape[:-1])
375+
# Fully symbolic, or without Ellipsis shapes can be asserted.
376+
assert_shape = shape
377+
# The batch shape are the entries that preceed the implied dimensions.
378+
# Be careful to avoid len(shape) because it may be symbolic
379+
if rv_native.ndim == 0:
380+
batch_shape = shape
361381
else:
362-
size = tuple(shape[: len(shape) - rv_native.ndim])
363-
assert_shape = shape
364-
# no-op conditions:
365-
# `elif size is not None` (User already specified how to expand the RV)
366-
# `else` (Unreachable)
367-
368-
if size:
369-
rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True)
382+
batch_shape = shape[: -rv_native.ndim]
383+
# else: both dimensionality kwargs are None
384+
385+
if batch_shape:
386+
rv_out = change_rv_size(rv_var=rv_native, new_size=batch_shape, expand=True)
370387
else:
371388
rv_out = rv_native
372389

Diff for: pymc3/distributions/multivariate.py

-22
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from aesara.graph.op import Op
2727
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
2828
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
29-
from aesara.tensor.random.utils import broadcast_params
3029
from aesara.tensor.slinalg import (
3130
Cholesky,
3231
Solve,
@@ -427,27 +426,6 @@ def _distr_parameters_for_repr(self):
427426
return ["a"]
428427

429428

430-
class MultinomialRV(MultinomialRV):
431-
"""Aesara's `MultinomialRV` doesn't broadcast; this one does."""
432-
433-
@classmethod
434-
def rng_fn(cls, rng, n, p, size):
435-
if n.ndim > 0 or p.ndim > 1:
436-
n, p = broadcast_params([n, p], cls.ndims_params)
437-
size = tuple(size or ())
438-
439-
if size:
440-
n = np.broadcast_to(n, size + n.shape)
441-
p = np.broadcast_to(p, size + p.shape)
442-
443-
res = np.empty(p.shape)
444-
for idx in np.ndindex(p.shape[:-1]):
445-
res[idx] = rng.multinomial(n[idx], p[idx])
446-
return res
447-
else:
448-
return rng.multinomial(n, p, size=size)
449-
450-
451429
multinomial = MultinomialRV()
452430

453431

Diff for: pymc3/tests/test_data_container.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_shared_data_as_rv_input(self):
163163
with pm.Model() as m:
164164
x = pm.Data("x", [1.0, 2.0, 3.0])
165165
assert x.eval().shape == (3,)
166-
y = pm.Normal("y", mu=x, size=2)
166+
y = pm.Normal("y", mu=x, shape=(2, ...))
167167
assert y.eval().shape == (2, 3)
168168
idata = pm.sample(
169169
chains=1,

Diff for: pymc3/tests/test_distributions.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -2074,29 +2074,33 @@ def test_multinomial_mode(self, p, n):
20742074
@pytest.mark.parametrize(
20752075
"p, size, n",
20762076
[
2077-
[[0.25, 0.25, 0.25, 0.25], (4,), 2],
2078-
[[0.25, 0.25, 0.25, 0.25], (1, 4), 3],
2077+
[[0.25, 0.25, 0.25, 0.25], (7,), 2],
2078+
[[0.25, 0.25, 0.25, 0.25], (1, 7), 3],
20792079
# 3: expect to fail
2080-
# [[.25, .25, .25, .25], (10, 4)],
2081-
[[0.25, 0.25, 0.25, 0.25], (10, 1, 4), 5],
2080+
# [[.25, .25, .25, .25], (10, 7)],
2081+
[[0.25, 0.25, 0.25, 0.25], (10, 1, 7), 5],
20822082
# 5: expect to fail
2083-
# [[[.25, .25, .25, .25]], (2, 4), [7, 11]],
2084-
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), 13],
2085-
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (1, 2, 4), [23, 29]],
2083+
# [[[.25, .25, .25, .25]], (2, 5), [7, 11]],
2084+
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (5, 2), 13],
2085+
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (5, 7, 2), [23, 29]],
20862086
[
20872087
[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]],
2088-
(10, 2, 4),
2088+
(10, 8, 2),
20892089
[31, 37],
20902090
],
2091-
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), [17, 19]],
2091+
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (3, 2), [17, 19]],
20922092
],
20932093
)
20942094
def test_multinomial_random(self, p, size, n):
20952095
p = np.asarray(p)
20962096
with Model() as model:
20972097
m = Multinomial("m", n=n, p=p, size=size)
2098-
2099-
assert m.eval().shape == size + p.shape
2098+
# The support has length 4 in all test parametrizations!
2099+
# Broadcasting of the `p` parameter does not affect the ndim_supp
2100+
# of the Op, hence the broadcasted p must be included in `size`.
2101+
support_shape = (p.shape[-1],)
2102+
assert support_shape == (4,)
2103+
assert m.eval().shape == size + support_shape
21002104

21012105
@pytest.mark.skip(reason="Moment calculations have not been refactored yet")
21022106
def test_multinomial_mode_with_shape(self):
@@ -2197,7 +2201,7 @@ def test_batch_multinomial(self):
21972201
decimal=select_by_precision(float64=6, float32=3),
21982202
)
21992203

2200-
dist = Multinomial.dist(n=n, p=p, size=2)
2204+
dist = Multinomial.dist(n=n, p=p, size=(2, ...))
22012205
sample = dist.eval()
22022206
assert_allclose(sample, np.stack([vals, vals], axis=0))
22032207

Diff for: pymc3/tests/test_model.py

+5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pymc3 import Deterministic, Potential
3636
from pymc3.blocking import DictToArrayBijection, RaveledVars
3737
from pymc3.distributions import Normal, logpt_sum, transforms
38+
from pymc3.exceptions import ShapeError
3839
from pymc3.model import Point, ValueGradFunction
3940
from pymc3.tests.helpers import SeededTest
4041

@@ -465,6 +466,10 @@ def test_make_obs_var():
465466
# Create the testval attribute simply for the sake of model testing
466467
fake_distribution.name = input_name
467468

469+
# The function requires data and RV dimensionality to be compatible
470+
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."):
471+
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None)
472+
468473
# Check function behavior using the various inputs
469474
# dense, sparse: Ensure that the missing values are appropriately set to None
470475
# masked: a deterministic variable is returned

Diff for: pymc3/tests/test_sampling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -967,14 +967,14 @@ def test_multivariate2(self):
967967

968968
def test_layers(self):
969969
with pm.Model() as model:
970-
a = pm.Uniform("a", lower=0, upper=1, size=5)
971-
b = pm.Binomial("b", n=1, p=a, size=7)
970+
a = pm.Uniform("a", lower=0, upper=1, size=10)
971+
b = pm.Binomial("b", n=1, p=a)
972972

973973
model.default_rng.get_value(borrow=True).seed(232093)
974974

975975
b_sampler = aesara.function([], b)
976976
avg = np.stack([b_sampler() for i in range(10000)]).mean(0)
977-
npt.assert_array_almost_equal(avg, 0.5 * np.ones((7, 5)), decimal=2)
977+
npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2)
978978

979979
def test_transformed(self):
980980
n = 18

0 commit comments

Comments
 (0)