Skip to content

Commit 3f4d516

Browse files
michaelosthegetwiecki
authored andcommitted
Allow parametrization through either shape, dims or size
Always treats `size` as being in addition to dimensions implied by RV parameters. All resizing beyond parameter-implied dimensionality is done from: - `shape` or `size` in `Distribution.dist()` - `dims` or `observed` in `Distribution.__new__` and only in those two places. Closes #4552.
1 parent 64e63d8 commit 3f4d516

13 files changed

+541
-148
lines changed

RELEASE-NOTES.md

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
### New Features
1010
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
11+
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)):
12+
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
13+
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
14+
- The `size` kwarg creates new dimensions in addition to what is implied by RV parameters.
15+
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
1116
- ...
1217

1318
### Maintenance

pymc3/distributions/distribution.py

+209-26
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,17 @@
2020

2121
from abc import ABCMeta
2222
from copy import copy
23-
from typing import TYPE_CHECKING
23+
from typing import Any, Optional, Sequence, Tuple, Union
2424

25+
import aesara
26+
import aesara.tensor as at
2527
import dill
2628

29+
from aesara.graph.basic import Variable
2730
from aesara.tensor.random.op import RandomVariable
2831

32+
from pymc3.aesaraf import change_rv_size, pandas_to_array
2933
from pymc3.distributions import _logcdf, _logp
30-
31-
if TYPE_CHECKING:
32-
from typing import Optional, Callable
33-
34-
import aesara
35-
import aesara.graph.basic
36-
import aesara.tensor as at
37-
3834
from pymc3.util import UNSET, get_repr_for_variable
3935
from pymc3.vartypes import string_types
4036

@@ -52,6 +48,10 @@
5248

5349
PLATFORM = sys.platform
5450

51+
Shape = Union[int, Sequence[Union[str, type(Ellipsis)]], Variable]
52+
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
53+
Size = Union[int, Tuple[int, ...]]
54+
5555

5656
class _Unpickling:
5757
pass
@@ -122,13 +122,111 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
122122
return new_cls
123123

124124

125+
def _valid_ellipsis_position(items: Union[None, Shape, Dims, Size]) -> bool:
126+
if items is not None and not isinstance(items, Variable) and Ellipsis in items:
127+
if any(i == Ellipsis for i in items[:-1]):
128+
return False
129+
return True
130+
131+
132+
def _validate_shape_dims_size(
133+
shape: Any = None, dims: Any = None, size: Any = None
134+
) -> Tuple[Optional[Shape], Optional[Dims], Optional[Size]]:
135+
# Raise on unsupported parametrization
136+
if shape is not None and dims is not None:
137+
raise ValueError("Passing both `shape` ({shape}) and `dims` ({dims}) is not supported!")
138+
if dims is not None and size is not None:
139+
raise ValueError("Passing both `dims` ({dims}) and `size` ({size}) is not supported!")
140+
if shape is not None and size is not None:
141+
raise ValueError("Passing both `shape` ({shape}) and `size` ({size}) is not supported!")
142+
143+
# Raise on invalid types
144+
if not isinstance(shape, (type(None), int, list, tuple, Variable)):
145+
raise ValueError("The `shape` parameter must be an int, list or tuple.")
146+
if not isinstance(dims, (type(None), str, list, tuple)):
147+
raise ValueError("The `dims` parameter must be a str, list or tuple.")
148+
if not isinstance(size, (type(None), int, list, tuple)):
149+
raise ValueError("The `size` parameter must be an int, list or tuple.")
150+
151+
# Auto-convert non-tupled parameters
152+
if isinstance(shape, int):
153+
shape = (shape,)
154+
if isinstance(dims, str):
155+
dims = (dims,)
156+
if isinstance(size, int):
157+
size = (size,)
158+
159+
# Convert to actual tuples
160+
if not isinstance(shape, (type(None), tuple, Variable)):
161+
shape = tuple(shape)
162+
if not isinstance(dims, (type(None), tuple)):
163+
dims = tuple(dims)
164+
if not isinstance(size, (type(None), tuple)):
165+
size = tuple(size)
166+
167+
if not _valid_ellipsis_position(shape):
168+
raise ValueError(
169+
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
170+
)
171+
if not _valid_ellipsis_position(dims):
172+
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
173+
if size is not None and Ellipsis in size:
174+
raise ValueError("The `size` parameter cannot contain an Ellipsis. Actual: {size}")
175+
return shape, dims, size
176+
177+
125178
class Distribution(metaclass=DistributionMeta):
126179
"""Statistical distribution"""
127180

128181
rv_class = None
129182
rv_op = None
130183

131-
def __new__(cls, name, *args, **kwargs):
184+
def __new__(
185+
cls,
186+
name: str,
187+
*args,
188+
rng=None,
189+
dims: Optional[Dims] = None,
190+
testval=None,
191+
observed=None,
192+
total_size=None,
193+
transform=UNSET,
194+
**kwargs,
195+
) -> RandomVariable:
196+
"""Adds a RandomVariable corresponding to a PyMC3 distribution to the current model.
197+
198+
Note that all remaining kwargs must be compatible with ``.dist()``
199+
200+
Parameters
201+
----------
202+
cls : type
203+
A PyMC3 distribution.
204+
name : str
205+
Name for the new model variable.
206+
rng : optional
207+
Random number generator to use with the RandomVariable.
208+
dims : tuple, optional
209+
A tuple of dimension names known to the model.
210+
testval : optional
211+
Test value to be attached to the output RV.
212+
Must match its shape exactly.
213+
observed : optional
214+
Observed data to be passed when registering the random variable in the model.
215+
See ``Model.register_rv``.
216+
total_size : float, optional
217+
See ``Model.register_rv``.
218+
transform : optional
219+
See ``Model.register_rv``.
220+
**kwargs
221+
Keyword arguments that will be forwarded to ``.dist()``.
222+
Most prominently: ``shape`` and ``size``
223+
224+
Returns
225+
-------
226+
rv : RandomVariable
227+
The created RV, registered in the Model.
228+
"""
229+
132230
try:
133231
from pymc3.model import Model
134232

@@ -141,40 +239,125 @@ def __new__(cls, name, *args, **kwargs):
141239
"for a standalone distribution."
142240
)
143241

144-
rng = kwargs.pop("rng", None)
242+
if not isinstance(name, string_types):
243+
raise TypeError(f"Name needs to be a string but got: {name}")
145244

146245
if rng is None:
147246
rng = model.default_rng
148247

149-
if not isinstance(name, string_types):
150-
raise TypeError(f"Name needs to be a string but got: {name}")
248+
_, dims, _ = _validate_shape_dims_size(dims=dims)
249+
resize = None
151250

152-
data = kwargs.pop("observed", None)
251+
# Create the RV without specifying testval, because the testval may have a shape
252+
# that only matches after replicating with a size implied by dims (see below).
253+
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)
254+
n_implied = rv_out.ndim
153255

154-
total_size = kwargs.pop("total_size", None)
256+
# `dims` are only available with this API, because `.dist()` can be used
257+
# without a modelcontext and dims are not tracked at the Aesara level.
258+
if dims is not None:
259+
if Ellipsis in dims:
260+
# Auto-complete the dims tuple to the full length
261+
dims = (*dims[:-1], *[None] * rv_out.ndim)
155262

156-
dims = kwargs.pop("dims", None)
263+
n_resize = len(dims) - n_implied
157264

158-
if "shape" in kwargs:
159-
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")
265+
# All resize dims must be known already (numerically or symbolically).
266+
unknown_resize_dims = set(dims[:n_resize]) - set(model.dim_lengths)
267+
if unknown_resize_dims:
268+
raise KeyError(
269+
f"Dimensions {unknown_resize_dims} are unknown to the model and cannot be used to specify a `size`."
270+
)
160271

161-
transform = kwargs.pop("transform", UNSET)
272+
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
273+
resize = tuple(model.dim_lengths[dname] for dname in dims[:n_resize])
274+
elif observed is not None:
275+
if not hasattr(observed, "shape"):
276+
observed = pandas_to_array(observed)
277+
n_resize = observed.ndim - n_implied
278+
resize = tuple(observed.shape[d] for d in range(n_resize))
279+
280+
if resize:
281+
# A batch size was specified through `dims`, or implied by `observed`.
282+
rv_out = change_rv_size(rv_var=rv_out, new_size=resize, expand=True)
283+
284+
if dims is not None:
285+
# Now that we have a handle on the output RV, we can register named implied dimensions that
286+
# were not yet known to the model, such that they can be used for size further downstream.
287+
for di, dname in enumerate(dims[n_resize:]):
288+
if not dname in model.dim_lengths:
289+
model.add_coord(dname, values=None, length=rv_out.shape[n_resize + di])
162290

163-
rv_out = cls.dist(*args, rng=rng, **kwargs)
291+
if testval is not None:
292+
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
293+
rv_out.tag.test_value = testval
164294

165-
return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
295+
return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform)
166296

167297
@classmethod
168-
def dist(cls, dist_params, **kwargs):
298+
def dist(
299+
cls,
300+
dist_params,
301+
*,
302+
shape: Optional[Shape] = None,
303+
size: Optional[Size] = None,
304+
testval=None,
305+
**kwargs,
306+
) -> RandomVariable:
307+
"""Creates a RandomVariable corresponding to the `cls` distribution.
169308
170-
testval = kwargs.pop("testval", None)
309+
Parameters
310+
----------
311+
dist_params
312+
shape : tuple, optional
313+
A tuple of sizes for each dimension of the new RV.
314+
315+
Ellipsis (...) may be used in the last position of the tuple,
316+
and automatically expand to the shape implied by RV inputs.
317+
size : int, tuple, Variable, optional
318+
A scalar or tuple for replicating the RV in addition
319+
to its implied shape/dimensionality.
320+
testval : optional
321+
Test value to be attached to the output RV.
322+
Must match its shape exactly.
323+
324+
Returns
325+
-------
326+
rv : RandomVariable
327+
The created RV.
328+
"""
329+
if "dims" in kwargs:
330+
raise NotImplementedError("The use of a `.dist(dims=...)` API is not yet supported.")
331+
332+
shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
333+
334+
# Create the RV without specifying size or testval.
335+
# The size will be expanded later (if necessary) and only then the testval fits.
336+
rv_native = cls.rv_op(*dist_params, size=None, **kwargs)
171337

172-
rv_var = cls.rv_op(*dist_params, **kwargs)
338+
if shape is None and size is None:
339+
size = ()
340+
elif shape is not None:
341+
if isinstance(shape, Variable):
342+
size = ()
343+
else:
344+
if Ellipsis in shape:
345+
size = tuple(shape[:-1])
346+
else:
347+
size = tuple(shape[: len(shape) - rv_native.ndim])
348+
# no-op conditions:
349+
# `elif size is not None` (User already specified how to expand the RV)
350+
# `else` (Unreachable)
351+
352+
if size:
353+
rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True)
354+
else:
355+
rv_out = rv_native
173356

174357
if testval is not None:
175-
rv_var.tag.test_value = testval
358+
rv_out.tag.test_value = testval
176359

177-
return rv_var
360+
return rv_out
178361

179362
def _distr_parameters_for_repr(self):
180363
"""Return the names of the parameters for this distribution (e.g. "mu"

0 commit comments

Comments
 (0)