Skip to content

Commit 65f4f2b

Browse files
Separate shape logic into a separate file (#4708)
Co-authored-by: Michael Osthege <[email protected]>
1 parent 784dec3 commit 65f4f2b

File tree

3 files changed

+326
-224
lines changed

3 files changed

+326
-224
lines changed

pymc3/distributions/distribution.py

+34-207
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,29 @@
1919
import warnings
2020

2121
from abc import ABCMeta
22-
from typing import Optional, Sequence, Tuple, Union
22+
from typing import Optional
2323

2424
import aesara
2525
import aesara.tensor as at
2626
import dill
27-
import numpy as np
2827

29-
from aesara.graph.basic import Variable
3028
from aesara.tensor.random.op import RandomVariable
3129
from aesara.tensor.random.var import RandomStateSharedVariable
32-
from aesara.tensor.var import TensorVariable
3330

34-
from pymc3.aesaraf import change_rv_size, pandas_to_array
31+
from pymc3.aesaraf import change_rv_size
3532
from pymc3.distributions import _logcdf, _logp
36-
from pymc3.exceptions import ShapeError, ShapeWarning
33+
from pymc3.distributions.shape_utils import (
34+
Dims,
35+
Shape,
36+
Size,
37+
convert_dims,
38+
convert_shape,
39+
convert_size,
40+
find_size,
41+
maybe_resize,
42+
resize_from_dims,
43+
resize_from_observed,
44+
)
3745
from pymc3.util import UNSET, get_repr_for_variable
3846
from pymc3.vartypes import string_types
3947

@@ -51,20 +59,6 @@
5159

5260
PLATFORM = sys.platform
5361

54-
# User-provided can be lazily specified as scalars
55-
Shape = Union[int, TensorVariable, Sequence[Union[int, TensorVariable, type(Ellipsis)]]]
56-
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
57-
Size = Union[int, TensorVariable, Sequence[Union[int, TensorVariable]]]
58-
59-
# After conversion to vectors
60-
WeakShape = Union[TensorVariable, Tuple[Union[int, TensorVariable, type(Ellipsis)], ...]]
61-
WeakDims = Tuple[Union[str, None, type(Ellipsis)], ...]
62-
63-
# After Ellipsis were substituted
64-
StrongShape = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
65-
StrongDims = Sequence[Union[str, None]]
66-
StrongSize = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
67-
6862

6963
class _Unpickling:
7064
pass
@@ -120,135 +114,6 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
120114
return new_cls
121115

122116

123-
def _convert_dims(dims: Dims) -> Optional[WeakDims]:
124-
""" Process a user-provided dims variable into None or a valid dims tuple. """
125-
if dims is None:
126-
return None
127-
128-
if isinstance(dims, str):
129-
dims = (dims,)
130-
elif isinstance(dims, (list, tuple)):
131-
dims = tuple(dims)
132-
else:
133-
raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}")
134-
135-
if any(d == Ellipsis for d in dims[:-1]):
136-
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
137-
138-
return dims
139-
140-
141-
def _convert_shape(shape: Shape) -> Optional[WeakShape]:
142-
""" Process a user-provided shape variable into None or a valid shape object. """
143-
if shape is None:
144-
return None
145-
146-
if isinstance(shape, int) or (isinstance(shape, TensorVariable) and shape.ndim == 0):
147-
shape = (shape,)
148-
elif isinstance(shape, (list, tuple)):
149-
shape = tuple(shape)
150-
else:
151-
raise ValueError(
152-
f"The `shape` parameter must be a tuple, TensorVariable, int or list. Actual: {type(shape)}"
153-
)
154-
155-
if isinstance(shape, tuple) and any(s == Ellipsis for s in shape[:-1]):
156-
raise ValueError(
157-
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
158-
)
159-
160-
return shape
161-
162-
163-
def _convert_size(size: Size) -> Optional[StrongSize]:
164-
""" Process a user-provided size variable into None or a valid size object. """
165-
if size is None:
166-
return None
167-
168-
if isinstance(size, int) or (isinstance(size, TensorVariable) and size.ndim == 0):
169-
size = (size,)
170-
elif isinstance(size, (list, tuple)):
171-
size = tuple(size)
172-
else:
173-
raise ValueError(
174-
f"The `size` parameter must be a tuple, TensorVariable, int or list. Actual: {type(size)}"
175-
)
176-
177-
if isinstance(size, tuple) and Ellipsis in size:
178-
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")
179-
180-
return size
181-
182-
183-
def _resize_from_dims(
184-
dims: WeakDims, ndim_implied: int, model
185-
) -> Tuple[int, StrongSize, StrongDims]:
186-
"""Determines a potential resize shape from a `dims` tuple.
187-
188-
Parameters
189-
----------
190-
dims : array-like
191-
A vector of dimension names, None or Ellipsis.
192-
ndim_implied : int
193-
Number of RV dimensions that were implied from its inputs alone.
194-
model : pm.Model
195-
The current model on stack.
196-
197-
Returns
198-
-------
199-
ndim_resize : int
200-
Number of dimensions that should be added through resizing.
201-
resize_shape : array-like
202-
The shape of the new dimensions.
203-
"""
204-
if Ellipsis in dims:
205-
# Auto-complete the dims tuple to the full length.
206-
# We don't have a way to know the names of implied
207-
# dimensions, so they will be `None`.
208-
dims = (*dims[:-1], *[None] * ndim_implied)
209-
210-
ndim_resize = len(dims) - ndim_implied
211-
212-
# All resize dims must be known already (numerically or symbolically).
213-
unknowndim_resize_dims = set(dims[:ndim_resize]) - set(model.dim_lengths)
214-
if unknowndim_resize_dims:
215-
raise KeyError(
216-
f"Dimensions {unknowndim_resize_dims} are unknown to the model and cannot be used to specify a `size`."
217-
)
218-
219-
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
220-
resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize])
221-
return ndim_resize, resize_shape, dims
222-
223-
224-
def _resize_from_observed(
225-
observed, ndim_implied: int
226-
) -> Tuple[int, StrongSize, Union[np.ndarray, Variable]]:
227-
"""Determines a potential resize shape from observations.
228-
229-
Parameters
230-
----------
231-
observed : scalar, array-like
232-
The value of the `observed` kwarg to the RV creation.
233-
ndim_implied : int
234-
Number of RV dimensions that were implied from its inputs alone.
235-
236-
Returns
237-
-------
238-
ndim_resize : int
239-
Number of dimensions that should be added through resizing.
240-
resize_shape : array-like
241-
The shape of the new dimensions.
242-
observed : scalar, array-like
243-
Observations as numpy array or `Variable`.
244-
"""
245-
if not hasattr(observed, "shape"):
246-
observed = pandas_to_array(observed)
247-
ndim_resize = observed.ndim - ndim_implied
248-
resize_shape = tuple(observed.shape[d] for d in range(ndim_resize))
249-
return ndim_resize, resize_shape, observed
250-
251-
252117
class Distribution(metaclass=DistributionMeta):
253118
"""Statistical distribution"""
254119

@@ -335,7 +200,7 @@ def __new__(
335200
raise ValueError(
336201
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
337202
)
338-
dims = _convert_dims(dims)
203+
dims = convert_dims(dims)
339204

340205
# Create the RV without specifying testval, because the testval may have a shape
341206
# that only matches after replicating with a size implied by dims (see below).
@@ -346,9 +211,9 @@ def __new__(
346211
# `dims` are only available with this API, because `.dist()` can be used
347212
# without a modelcontext and dims are not tracked at the Aesara level.
348213
if dims is not None:
349-
ndim_resize, resize_shape, dims = _resize_from_dims(dims, ndim_actual, model)
214+
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
350215
elif observed is not None:
351-
ndim_resize, resize_shape, observed = _resize_from_observed(observed, ndim_actual)
216+
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)
352217

353218
if resize_shape:
354219
# A batch size was specified through `dims`, or implied by `observed`.
@@ -408,65 +273,27 @@ def dist(
408273
raise ValueError(
409274
f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
410275
)
411-
shape = _convert_shape(shape)
412-
size = _convert_size(size)
413-
414-
ndim_supp = cls.rv_op.ndim_supp
415-
ndim_expected = None
416-
ndim_batch = None
417-
create_size = None
418-
419-
if shape is not None:
420-
if Ellipsis in shape:
421-
# Ellipsis short-hands all implied dimensions. Therefore
422-
# we don't know how many dimensions to expect.
423-
ndim_expected = ndim_batch = None
424-
# Create the RV with its implied shape and resize later
425-
create_size = None
426-
else:
427-
ndim_expected = len(tuple(shape))
428-
ndim_batch = ndim_expected - ndim_supp
429-
create_size = tuple(shape)[:ndim_batch]
430-
elif size is not None:
431-
ndim_expected = ndim_supp + len(tuple(size))
432-
ndim_batch = ndim_expected - ndim_supp
433-
create_size = size
434276

277+
shape = convert_shape(shape)
278+
size = convert_size(size)
279+
280+
create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
281+
shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp
282+
)
435283
# Create the RV with a `size` right away.
436284
# This is not necessarily the final result.
437285
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
438-
ndim_actual = rv_out.ndim
439-
ndims_unexpected = ndim_actual != ndim_expected
440-
441-
if shape is not None and ndims_unexpected:
442-
if Ellipsis in shape:
443-
# Resize and we're done!
444-
rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True)
445-
else:
446-
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
447-
# Recreate the RV without passing `size` to created it with just the implied dimensions.
448-
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)
449-
450-
# Now resize by any remaining "extra" dimensions that were not implied from support and parameters
451-
if rv_out.ndim < ndim_expected:
452-
expand_shape = shape[: ndim_expected - rv_out.ndim]
453-
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
454-
if not rv_out.ndim == ndim_expected:
455-
raise ShapeError(
456-
f"Failed to create the RV with the expected dimensionality. "
457-
f"This indicates a severe problem. Please open an issue.",
458-
actual=ndim_actual,
459-
expected=ndim_batch + ndim_supp,
460-
)
461-
462-
# Warn about the edge cases where the RV Op creates more dimensions than
463-
# it should based on `size` and `RVOp.ndim_supp`.
464-
if size is not None and ndims_unexpected:
465-
warnings.warn(
466-
f"You may have expected a ({len(tuple(size))}+{ndim_supp})-dimensional RV, but the resulting RV will be {ndim_actual}-dimensional."
467-
' To silence this warning use `warnings.simplefilter("ignore", pm.ShapeWarning)`.',
468-
ShapeWarning,
469-
)
286+
rv_out = maybe_resize(
287+
rv_out,
288+
cls.rv_op,
289+
dist_params,
290+
ndim_expected,
291+
ndim_batch,
292+
ndim_supp,
293+
shape,
294+
size,
295+
**kwargs,
296+
)
470297

471298
rng = kwargs.pop("rng", None)
472299
if (

0 commit comments

Comments
 (0)