Skip to content

Commit 784dec3

Browse files
michaelosthegetwiecki
authored andcommitted
Add Ellipsis-support for the shape kwarg
1 parent 3adca2a commit 784dec3

File tree

3 files changed

+42
-27
lines changed

3 files changed

+42
-27
lines changed

RELEASE-NOTES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
1313
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4696](https://github.com/pymc-devs/pymc3/pull/4696)):
1414
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Numeric entries in `shape` restrict the model variable to the exact length and re-sizing is no longer possible.
15-
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects. An `Ellipsis` (`...`) in the last position of `dims` can be used as short-hand notation for implied dimensions.
15+
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
1616
- The `size` kwarg behaves like it does in Aesara/NumPy. For univariate RVs it is the same as `shape`, but for multivariate RVs it depends on how the RV implements broadcasting to dimensionality greater than `RVOp.ndim_supp`.
17+
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
1718
- Add `logcdf` method to Kumaraswamy distribution (see [#4706](https://github.com/pymc-devs/pymc3/pull/4706)).
1819
- ...
1920

pymc3/distributions/distribution.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ def dist(
378378
The inputs to the `RandomVariable` `Op`.
379379
shape : int, tuple, Variable, optional
380380
A tuple of sizes for each dimension of the new RV.
381+
382+
An Ellipsis (...) may be inserted in the last position to short-hand refer to
383+
all the dimensions that the RV would get if no shape/size/dims were passed at all.
381384
size : int, tuple, Variable, optional
382385
For creating the RV like in Aesara/NumPy.
383386
initival : optional
@@ -414,9 +417,16 @@ def dist(
414417
create_size = None
415418

416419
if shape is not None:
417-
ndim_expected = len(tuple(shape))
418-
ndim_batch = ndim_expected - ndim_supp
419-
create_size = tuple(shape)[:ndim_batch]
420+
if Ellipsis in shape:
421+
# Ellipsis short-hands all implied dimensions. Therefore
422+
# we don't know how many dimensions to expect.
423+
ndim_expected = ndim_batch = None
424+
# Create the RV with its implied shape and resize later
425+
create_size = None
426+
else:
427+
ndim_expected = len(tuple(shape))
428+
ndim_batch = ndim_expected - ndim_supp
429+
create_size = tuple(shape)[:ndim_batch]
420430
elif size is not None:
421431
ndim_expected = ndim_supp + len(tuple(size))
422432
ndim_batch = ndim_expected - ndim_supp
@@ -429,21 +439,25 @@ def dist(
429439
ndims_unexpected = ndim_actual != ndim_expected
430440

431441
if shape is not None and ndims_unexpected:
432-
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
433-
# Recreate the RV without passing `size` to created it with just the implied dimensions.
434-
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)
435-
436-
# Now resize by the "extra" dimensions that were not implied from support and parameters
437-
if rv_out.ndim < ndim_expected:
438-
expand_shape = shape[: ndim_expected - rv_out.ndim]
439-
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
440-
if not rv_out.ndim == ndim_expected:
441-
raise ShapeError(
442-
f"Failed to create the RV with the expected dimensionality. "
443-
f"This indicates a severe problem. Please open an issue.",
444-
actual=ndim_actual,
445-
expected=ndim_batch + ndim_supp,
446-
)
442+
if Ellipsis in shape:
443+
# Resize and we're done!
444+
rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True)
445+
else:
446+
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
447+
# Recreate the RV without passing `size` to created it with just the implied dimensions.
448+
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)
449+
450+
# Now resize by any remaining "extra" dimensions that were not implied from support and parameters
451+
if rv_out.ndim < ndim_expected:
452+
expand_shape = shape[: ndim_expected - rv_out.ndim]
453+
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
454+
if not rv_out.ndim == ndim_expected:
455+
raise ShapeError(
456+
f"Failed to create the RV with the expected dimensionality. "
457+
f"This indicates a severe problem. Please open an issue.",
458+
actual=ndim_actual,
459+
expected=ndim_batch + ndim_supp,
460+
)
447461

448462
# Warn about the edge cases where the RV Op creates more dimensions than
449463
# it should based on `size` and `RVOp.ndim_supp`.

pymc3/tests/test_shape_handling.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class TestShapeDimsSize:
236236
[
237237
"implicit",
238238
"shape",
239-
# "shape...",
239+
"shape...",
240240
"dims",
241241
"dims...",
242242
"size",
@@ -273,9 +273,9 @@ def test_param_and_batch_shape_combos(
273273
if parametrization == "shape":
274274
rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape)
275275
assert rv.eval().shape == expected_shape
276-
# elif parametrization == "shape...":
277-
# rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
278-
# assert rv.eval().shape == batch_shape + param_shape
276+
elif parametrization == "shape...":
277+
rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
278+
assert rv.eval().shape == batch_shape + param_shape
279279
elif parametrization == "dims":
280280
rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims)
281281
assert rv.eval().shape == expected_shape
@@ -376,7 +376,7 @@ def test_dist_api_works(self):
376376
pm.Normal.dist(mu=mu, dims=("town",))
377377
assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,)
378378
assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3)
379-
# assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
379+
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
380380
assert pm.Normal.dist(mu=mu, size=(3,)).eval().shape == (3,)
381381
assert pm.Normal.dist(mu=mu, size=(4, 3)).eval().shape == (4, 3)
382382

@@ -402,9 +402,9 @@ def test_mvnormal_shape_size_difference(self):
402402
assert rv.ndim == 3
403403
assert tuple(rv.shape.eval()) == (5, 4, 3)
404404

405-
# rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
406-
# assert rv.ndim == 5
407-
# assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)
405+
rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
406+
assert rv.ndim == 5
407+
assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)
408408

409409
with pytest.warns(None):
410410
rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4))

0 commit comments

Comments
 (0)