19
19
import warnings
20
20
21
21
from abc import ABCMeta
22
- from typing import Optional , Sequence , Tuple , Union
22
+ from typing import Optional
23
23
24
24
import aesara
25
25
import aesara .tensor as at
26
26
import dill
27
- import numpy as np
28
27
29
- from aesara .graph .basic import Variable
30
28
from aesara .tensor .random .op import RandomVariable
31
29
from aesara .tensor .random .var import RandomStateSharedVariable
32
- from aesara .tensor .var import TensorVariable
33
30
34
- from pymc3 .aesaraf import change_rv_size , pandas_to_array
31
+ from pymc3 .aesaraf import change_rv_size
35
32
from pymc3 .distributions import _logcdf , _logp
36
- from pymc3 .exceptions import ShapeError , ShapeWarning
33
+ from pymc3 .distributions .shape_utils import (
34
+ Dims ,
35
+ Shape ,
36
+ Size ,
37
+ convert_dims ,
38
+ convert_shape ,
39
+ convert_size ,
40
+ find_size ,
41
+ maybe_resize ,
42
+ resize_from_dims ,
43
+ resize_from_observed ,
44
+ )
37
45
from pymc3 .util import UNSET , get_repr_for_variable
38
46
from pymc3 .vartypes import string_types
39
47
51
59
52
60
PLATFORM = sys .platform
53
61
54
- # User-provided can be lazily specified as scalars
55
- Shape = Union [int , TensorVariable , Sequence [Union [int , TensorVariable , type (Ellipsis )]]]
56
- Dims = Union [str , Sequence [Union [str , None , type (Ellipsis )]]]
57
- Size = Union [int , TensorVariable , Sequence [Union [int , TensorVariable ]]]
58
-
59
- # After conversion to vectors
60
- WeakShape = Union [TensorVariable , Tuple [Union [int , TensorVariable , type (Ellipsis )], ...]]
61
- WeakDims = Tuple [Union [str , None , type (Ellipsis )], ...]
62
-
63
- # After Ellipsis were substituted
64
- StrongShape = Union [TensorVariable , Tuple [Union [int , TensorVariable ], ...]]
65
- StrongDims = Sequence [Union [str , None ]]
66
- StrongSize = Union [TensorVariable , Tuple [Union [int , TensorVariable ], ...]]
67
-
68
62
69
63
class _Unpickling :
70
64
pass
@@ -120,135 +114,6 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
120
114
return new_cls
121
115
122
116
123
- def _convert_dims (dims : Dims ) -> Optional [WeakDims ]:
124
- """ Process a user-provided dims variable into None or a valid dims tuple. """
125
- if dims is None :
126
- return None
127
-
128
- if isinstance (dims , str ):
129
- dims = (dims ,)
130
- elif isinstance (dims , (list , tuple )):
131
- dims = tuple (dims )
132
- else :
133
- raise ValueError (f"The `dims` parameter must be a tuple, str or list. Actual: { type (dims )} " )
134
-
135
- if any (d == Ellipsis for d in dims [:- 1 ]):
136
- raise ValueError (f"Ellipsis in `dims` may only appear in the last position. Actual: { dims } " )
137
-
138
- return dims
139
-
140
-
141
- def _convert_shape (shape : Shape ) -> Optional [WeakShape ]:
142
- """ Process a user-provided shape variable into None or a valid shape object. """
143
- if shape is None :
144
- return None
145
-
146
- if isinstance (shape , int ) or (isinstance (shape , TensorVariable ) and shape .ndim == 0 ):
147
- shape = (shape ,)
148
- elif isinstance (shape , (list , tuple )):
149
- shape = tuple (shape )
150
- else :
151
- raise ValueError (
152
- f"The `shape` parameter must be a tuple, TensorVariable, int or list. Actual: { type (shape )} "
153
- )
154
-
155
- if isinstance (shape , tuple ) and any (s == Ellipsis for s in shape [:- 1 ]):
156
- raise ValueError (
157
- f"Ellipsis in `shape` may only appear in the last position. Actual: { shape } "
158
- )
159
-
160
- return shape
161
-
162
-
163
- def _convert_size (size : Size ) -> Optional [StrongSize ]:
164
- """ Process a user-provided size variable into None or a valid size object. """
165
- if size is None :
166
- return None
167
-
168
- if isinstance (size , int ) or (isinstance (size , TensorVariable ) and size .ndim == 0 ):
169
- size = (size ,)
170
- elif isinstance (size , (list , tuple )):
171
- size = tuple (size )
172
- else :
173
- raise ValueError (
174
- f"The `size` parameter must be a tuple, TensorVariable, int or list. Actual: { type (size )} "
175
- )
176
-
177
- if isinstance (size , tuple ) and Ellipsis in size :
178
- raise ValueError (f"The `size` parameter cannot contain an Ellipsis. Actual: { size } " )
179
-
180
- return size
181
-
182
-
183
- def _resize_from_dims (
184
- dims : WeakDims , ndim_implied : int , model
185
- ) -> Tuple [int , StrongSize , StrongDims ]:
186
- """Determines a potential resize shape from a `dims` tuple.
187
-
188
- Parameters
189
- ----------
190
- dims : array-like
191
- A vector of dimension names, None or Ellipsis.
192
- ndim_implied : int
193
- Number of RV dimensions that were implied from its inputs alone.
194
- model : pm.Model
195
- The current model on stack.
196
-
197
- Returns
198
- -------
199
- ndim_resize : int
200
- Number of dimensions that should be added through resizing.
201
- resize_shape : array-like
202
- The shape of the new dimensions.
203
- """
204
- if Ellipsis in dims :
205
- # Auto-complete the dims tuple to the full length.
206
- # We don't have a way to know the names of implied
207
- # dimensions, so they will be `None`.
208
- dims = (* dims [:- 1 ], * [None ] * ndim_implied )
209
-
210
- ndim_resize = len (dims ) - ndim_implied
211
-
212
- # All resize dims must be known already (numerically or symbolically).
213
- unknowndim_resize_dims = set (dims [:ndim_resize ]) - set (model .dim_lengths )
214
- if unknowndim_resize_dims :
215
- raise KeyError (
216
- f"Dimensions { unknowndim_resize_dims } are unknown to the model and cannot be used to specify a `size`."
217
- )
218
-
219
- # The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
220
- resize_shape = tuple (model .dim_lengths [dname ] for dname in dims [:ndim_resize ])
221
- return ndim_resize , resize_shape , dims
222
-
223
-
224
- def _resize_from_observed (
225
- observed , ndim_implied : int
226
- ) -> Tuple [int , StrongSize , Union [np .ndarray , Variable ]]:
227
- """Determines a potential resize shape from observations.
228
-
229
- Parameters
230
- ----------
231
- observed : scalar, array-like
232
- The value of the `observed` kwarg to the RV creation.
233
- ndim_implied : int
234
- Number of RV dimensions that were implied from its inputs alone.
235
-
236
- Returns
237
- -------
238
- ndim_resize : int
239
- Number of dimensions that should be added through resizing.
240
- resize_shape : array-like
241
- The shape of the new dimensions.
242
- observed : scalar, array-like
243
- Observations as numpy array or `Variable`.
244
- """
245
- if not hasattr (observed , "shape" ):
246
- observed = pandas_to_array (observed )
247
- ndim_resize = observed .ndim - ndim_implied
248
- resize_shape = tuple (observed .shape [d ] for d in range (ndim_resize ))
249
- return ndim_resize , resize_shape , observed
250
-
251
-
252
117
class Distribution (metaclass = DistributionMeta ):
253
118
"""Statistical distribution"""
254
119
@@ -335,7 +200,7 @@ def __new__(
335
200
raise ValueError (
336
201
f"Passing both `dims` ({ dims } ) and `size` ({ kwargs ['size' ]} ) is not supported!"
337
202
)
338
- dims = _convert_dims (dims )
203
+ dims = convert_dims (dims )
339
204
340
205
# Create the RV without specifying testval, because the testval may have a shape
341
206
# that only matches after replicating with a size implied by dims (see below).
@@ -346,9 +211,9 @@ def __new__(
346
211
# `dims` are only available with this API, because `.dist()` can be used
347
212
# without a modelcontext and dims are not tracked at the Aesara level.
348
213
if dims is not None :
349
- ndim_resize , resize_shape , dims = _resize_from_dims (dims , ndim_actual , model )
214
+ ndim_resize , resize_shape , dims = resize_from_dims (dims , ndim_actual , model )
350
215
elif observed is not None :
351
- ndim_resize , resize_shape , observed = _resize_from_observed (observed , ndim_actual )
216
+ ndim_resize , resize_shape , observed = resize_from_observed (observed , ndim_actual )
352
217
353
218
if resize_shape :
354
219
# A batch size was specified through `dims`, or implied by `observed`.
@@ -408,65 +273,27 @@ def dist(
408
273
raise ValueError (
409
274
f"Passing both `shape` ({ shape } ) and `size` ({ size } ) is not supported!"
410
275
)
411
- shape = _convert_shape (shape )
412
- size = _convert_size (size )
413
-
414
- ndim_supp = cls .rv_op .ndim_supp
415
- ndim_expected = None
416
- ndim_batch = None
417
- create_size = None
418
-
419
- if shape is not None :
420
- if Ellipsis in shape :
421
- # Ellipsis short-hands all implied dimensions. Therefore
422
- # we don't know how many dimensions to expect.
423
- ndim_expected = ndim_batch = None
424
- # Create the RV with its implied shape and resize later
425
- create_size = None
426
- else :
427
- ndim_expected = len (tuple (shape ))
428
- ndim_batch = ndim_expected - ndim_supp
429
- create_size = tuple (shape )[:ndim_batch ]
430
- elif size is not None :
431
- ndim_expected = ndim_supp + len (tuple (size ))
432
- ndim_batch = ndim_expected - ndim_supp
433
- create_size = size
434
276
277
+ shape = convert_shape (shape )
278
+ size = convert_size (size )
279
+
280
+ create_size , ndim_expected , ndim_batch , ndim_supp = find_size (
281
+ shape = shape , size = size , ndim_supp = cls .rv_op .ndim_supp
282
+ )
435
283
# Create the RV with a `size` right away.
436
284
# This is not necessarily the final result.
437
285
rv_out = cls .rv_op (* dist_params , size = create_size , ** kwargs )
438
- ndim_actual = rv_out .ndim
439
- ndims_unexpected = ndim_actual != ndim_expected
440
-
441
- if shape is not None and ndims_unexpected :
442
- if Ellipsis in shape :
443
- # Resize and we're done!
444
- rv_out = change_rv_size (rv_var = rv_out , new_size = shape [:- 1 ], expand = True )
445
- else :
446
- # This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
447
- # Recreate the RV without passing `size` to created it with just the implied dimensions.
448
- rv_out = cls .rv_op (* dist_params , size = None , ** kwargs )
449
-
450
- # Now resize by any remaining "extra" dimensions that were not implied from support and parameters
451
- if rv_out .ndim < ndim_expected :
452
- expand_shape = shape [: ndim_expected - rv_out .ndim ]
453
- rv_out = change_rv_size (rv_var = rv_out , new_size = expand_shape , expand = True )
454
- if not rv_out .ndim == ndim_expected :
455
- raise ShapeError (
456
- f"Failed to create the RV with the expected dimensionality. "
457
- f"This indicates a severe problem. Please open an issue." ,
458
- actual = ndim_actual ,
459
- expected = ndim_batch + ndim_supp ,
460
- )
461
-
462
- # Warn about the edge cases where the RV Op creates more dimensions than
463
- # it should based on `size` and `RVOp.ndim_supp`.
464
- if size is not None and ndims_unexpected :
465
- warnings .warn (
466
- f"You may have expected a ({ len (tuple (size ))} +{ ndim_supp } )-dimensional RV, but the resulting RV will be { ndim_actual } -dimensional."
467
- ' To silence this warning use `warnings.simplefilter("ignore", pm.ShapeWarning)`.' ,
468
- ShapeWarning ,
469
- )
286
+ rv_out = maybe_resize (
287
+ rv_out ,
288
+ cls .rv_op ,
289
+ dist_params ,
290
+ ndim_expected ,
291
+ ndim_batch ,
292
+ ndim_supp ,
293
+ shape ,
294
+ size ,
295
+ ** kwargs ,
296
+ )
470
297
471
298
rng = kwargs .pop ("rng" , None )
472
299
if (
0 commit comments