20
20
21
21
from abc import ABCMeta
22
22
from copy import copy
23
- from typing import TYPE_CHECKING
23
+ from typing import Any , Optional , Sequence , Tuple , Union
24
24
25
+ import aesara
26
+ import aesara .tensor as at
25
27
import dill
26
28
29
+ from aesara .graph .basic import Variable
27
30
from aesara .tensor .random .op import RandomVariable
28
31
32
+ from pymc3 .aesaraf import change_rv_size , pandas_to_array
29
33
from pymc3 .distributions import _logcdf , _logp
30
-
31
- if TYPE_CHECKING :
32
- from typing import Optional , Callable
33
-
34
- import aesara
35
- import aesara .graph .basic
36
- import aesara .tensor as at
37
-
38
34
from pymc3 .util import UNSET , get_repr_for_variable
39
35
from pymc3 .vartypes import string_types
40
36
52
48
53
49
PLATFORM = sys .platform
54
50
51
+ Shape = Union [int , Sequence [Union [str , type (Ellipsis )]], Variable ]
52
+ Dims = Union [str , Sequence [Union [str , None , type (Ellipsis )]]]
53
+ Size = Union [int , Tuple [int , ...]]
54
+
55
55
56
56
class _Unpickling :
57
57
pass
@@ -122,13 +122,111 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
122
122
return new_cls
123
123
124
124
125
+ def _valid_ellipsis_position (items : Union [None , Shape , Dims , Size ]) -> bool :
126
+ if items is not None and not isinstance (items , Variable ) and Ellipsis in items :
127
+ if any (i == Ellipsis for i in items [:- 1 ]):
128
+ return False
129
+ return True
130
+
131
+
132
+ def _validate_shape_dims_size (
133
+ shape : Any = None , dims : Any = None , size : Any = None
134
+ ) -> Tuple [Optional [Shape ], Optional [Dims ], Optional [Size ]]:
135
+ # Raise on unsupported parametrization
136
+ if shape is not None and dims is not None :
137
+ raise ValueError ("Passing both `shape` ({shape}) and `dims` ({dims}) is not supported!" )
138
+ if dims is not None and size is not None :
139
+ raise ValueError ("Passing both `dims` ({dims}) and `size` ({size}) is not supported!" )
140
+ if shape is not None and size is not None :
141
+ raise ValueError ("Passing both `shape` ({shape}) and `size` ({size}) is not supported!" )
142
+
143
+ # Raise on invalid types
144
+ if not isinstance (shape , (type (None ), int , list , tuple , Variable )):
145
+ raise ValueError ("The `shape` parameter must be an int, list or tuple." )
146
+ if not isinstance (dims , (type (None ), str , list , tuple )):
147
+ raise ValueError ("The `dims` parameter must be a str, list or tuple." )
148
+ if not isinstance (size , (type (None ), int , list , tuple )):
149
+ raise ValueError ("The `size` parameter must be an int, list or tuple." )
150
+
151
+ # Auto-convert non-tupled parameters
152
+ if isinstance (shape , int ):
153
+ shape = (shape ,)
154
+ if isinstance (dims , str ):
155
+ dims = (dims ,)
156
+ if isinstance (size , int ):
157
+ size = (size ,)
158
+
159
+ # Convert to actual tuples
160
+ if not isinstance (shape , (type (None ), tuple , Variable )):
161
+ shape = tuple (shape )
162
+ if not isinstance (dims , (type (None ), tuple )):
163
+ dims = tuple (dims )
164
+ if not isinstance (size , (type (None ), tuple )):
165
+ size = tuple (size )
166
+
167
+ if not _valid_ellipsis_position (shape ):
168
+ raise ValueError (
169
+ f"Ellipsis in `shape` may only appear in the last position. Actual: { shape } "
170
+ )
171
+ if not _valid_ellipsis_position (dims ):
172
+ raise ValueError (f"Ellipsis in `dims` may only appear in the last position. Actual: { dims } " )
173
+ if size is not None and Ellipsis in size :
174
+ raise ValueError ("The `size` parameter cannot contain an Ellipsis. Actual: {size}" )
175
+ return shape , dims , size
176
+
177
+
125
178
class Distribution (metaclass = DistributionMeta ):
126
179
"""Statistical distribution"""
127
180
128
181
rv_class = None
129
182
rv_op = None
130
183
131
- def __new__ (cls , name , * args , ** kwargs ):
184
+ def __new__ (
185
+ cls ,
186
+ name : str ,
187
+ * args ,
188
+ rng = None ,
189
+ dims : Optional [Dims ] = None ,
190
+ testval = None ,
191
+ observed = None ,
192
+ total_size = None ,
193
+ transform = UNSET ,
194
+ ** kwargs ,
195
+ ) -> RandomVariable :
196
+ """Adds a RandomVariable corresponding to a PyMC3 distribution to the current model.
197
+
198
+ Note that all remaining kwargs must be compatible with ``.dist()``
199
+
200
+ Parameters
201
+ ----------
202
+ cls : type
203
+ A PyMC3 distribution.
204
+ name : str
205
+ Name for the new model variable.
206
+ rng : optional
207
+ Random number generator to use with the RandomVariable.
208
+ dims : tuple, optional
209
+ A tuple of dimension names known to the model.
210
+ testval : optional
211
+ Test value to be attached to the output RV.
212
+ Must match its shape exactly.
213
+ observed : optional
214
+ Observed data to be passed when registering the random variable in the model.
215
+ See ``Model.register_rv``.
216
+ total_size : float, optional
217
+ See ``Model.register_rv``.
218
+ transform : optional
219
+ See ``Model.register_rv``.
220
+ **kwargs
221
+ Keyword arguments that will be forwarded to ``.dist()``.
222
+ Most prominently: ``shape`` and ``size``
223
+
224
+ Returns
225
+ -------
226
+ rv : RandomVariable
227
+ The created RV, registered in the Model.
228
+ """
229
+
132
230
try :
133
231
from pymc3 .model import Model
134
232
@@ -141,40 +239,125 @@ def __new__(cls, name, *args, **kwargs):
141
239
"for a standalone distribution."
142
240
)
143
241
144
- rng = kwargs .pop ("rng" , None )
242
+ if not isinstance (name , string_types ):
243
+ raise TypeError (f"Name needs to be a string but got: { name } " )
145
244
146
245
if rng is None :
147
246
rng = model .default_rng
148
247
149
- if not isinstance ( name , string_types ):
150
- raise TypeError ( f"Name needs to be a string but got: { name } " )
248
+ _ , dims , _ = _validate_shape_dims_size ( dims = dims )
249
+ resize = None
151
250
152
- data = kwargs .pop ("observed" , None )
251
+ # Create the RV without specifying testval, because the testval may have a shape
252
+ # that only matches after replicating with a size implied by dims (see below).
253
+ rv_out = cls .dist (* args , rng = rng , testval = None , ** kwargs )
254
+ n_implied = rv_out .ndim
153
255
154
- total_size = kwargs .pop ("total_size" , None )
256
+ # `dims` are only available with this API, because `.dist()` can be used
257
+ # without a modelcontext and dims are not tracked at the Aesara level.
258
+ if dims is not None :
259
+ if Ellipsis in dims :
260
+ # Auto-complete the dims tuple to the full length
261
+ dims = (* dims [:- 1 ], * [None ] * rv_out .ndim )
155
262
156
- dims = kwargs . pop ( " dims" , None )
263
+ n_resize = len ( dims ) - n_implied
157
264
158
- if "shape" in kwargs :
159
- raise DeprecationWarning ("The `shape` keyword is deprecated; use `size`." )
265
+ # All resize dims must be known already (numerically or symbolically).
266
+ unknown_resize_dims = set (dims [:n_resize ]) - set (model .dim_lengths )
267
+ if unknown_resize_dims :
268
+ raise KeyError (
269
+ f"Dimensions { unknown_resize_dims } are unknown to the model and cannot be used to specify a `size`."
270
+ )
160
271
161
- transform = kwargs .pop ("transform" , UNSET )
272
+ # The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
273
+ resize = tuple (model .dim_lengths [dname ] for dname in dims [:n_resize ])
274
+ elif observed is not None :
275
+ if not hasattr (observed , "shape" ):
276
+ observed = pandas_to_array (observed )
277
+ n_resize = observed .ndim - n_implied
278
+ resize = tuple (observed .shape [d ] for d in range (n_resize ))
279
+
280
+ if resize :
281
+ # A batch size was specified through `dims`, or implied by `observed`.
282
+ rv_out = change_rv_size (rv_var = rv_out , new_size = resize , expand = True )
283
+
284
+ if dims is not None :
285
+ # Now that we have a handle on the output RV, we can register named implied dimensions that
286
+ # were not yet known to the model, such that they can be used for size further downstream.
287
+ for di , dname in enumerate (dims [n_resize :]):
288
+ if not dname in model .dim_lengths :
289
+ model .add_coord (dname , values = None , length = rv_out .shape [n_resize + di ])
162
290
163
- rv_out = cls .dist (* args , rng = rng , ** kwargs )
291
+ if testval is not None :
292
+ # Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
293
+ rv_out .tag .test_value = testval
164
294
165
- return model .register_rv (rv_out , name , data , total_size , dims = dims , transform = transform )
295
+ return model .register_rv (rv_out , name , observed , total_size , dims = dims , transform = transform )
166
296
167
297
@classmethod
168
- def dist (cls , dist_params , ** kwargs ):
298
+ def dist (
299
+ cls ,
300
+ dist_params ,
301
+ * ,
302
+ shape : Optional [Shape ] = None ,
303
+ size : Optional [Size ] = None ,
304
+ testval = None ,
305
+ ** kwargs ,
306
+ ) -> RandomVariable :
307
+ """Creates a RandomVariable corresponding to the `cls` distribution.
169
308
170
- testval = kwargs .pop ("testval" , None )
309
+ Parameters
310
+ ----------
311
+ dist_params
312
+ shape : tuple, optional
313
+ A tuple of sizes for each dimension of the new RV.
314
+
315
+ Ellipsis (...) may be used in the last position of the tuple,
316
+ and automatically expand to the shape implied by RV inputs.
317
+ size : int, tuple, Variable, optional
318
+ A scalar or tuple for replicating the RV in addition
319
+ to its implied shape/dimensionality.
320
+ testval : optional
321
+ Test value to be attached to the output RV.
322
+ Must match its shape exactly.
323
+
324
+ Returns
325
+ -------
326
+ rv : RandomVariable
327
+ The created RV.
328
+ """
329
+ if "dims" in kwargs :
330
+ raise NotImplementedError ("The use of a `.dist(dims=...)` API is not yet supported." )
331
+
332
+ shape , _ , size = _validate_shape_dims_size (shape = shape , size = size )
333
+
334
+ # Create the RV without specifying size or testval.
335
+ # The size will be expanded later (if necessary) and only then the testval fits.
336
+ rv_native = cls .rv_op (* dist_params , size = None , ** kwargs )
171
337
172
- rv_var = cls .rv_op (* dist_params , ** kwargs )
338
+ if shape is None and size is None :
339
+ size = ()
340
+ elif shape is not None :
341
+ if isinstance (shape , Variable ):
342
+ size = ()
343
+ else :
344
+ if Ellipsis in shape :
345
+ size = tuple (shape [:- 1 ])
346
+ else :
347
+ size = tuple (shape [: len (shape ) - rv_native .ndim ])
348
+ # no-op conditions:
349
+ # `elif size is not None` (User already specified how to expand the RV)
350
+ # `else` (Unreachable)
351
+
352
+ if size :
353
+ rv_out = change_rv_size (rv_var = rv_native , new_size = size , expand = True )
354
+ else :
355
+ rv_out = rv_native
173
356
174
357
if testval is not None :
175
- rv_var .tag .test_value = testval
358
+ rv_out .tag .test_value = testval
176
359
177
- return rv_var
360
+ return rv_out
178
361
179
362
def _distr_parameters_for_repr (self ):
180
363
"""Return the names of the parameters for this distribution (e.g. "mu"
0 commit comments