@@ -219,3 +219,200 @@ def test_sample_generate_values(fixture_model, fixture_sizes):
219219 prior = pm .sample_prior_predictive (samples = fixture_sizes )
220220 for rv in RVs :
221221 assert prior [rv .name ].shape == size + tuple (rv .distribution .shape )
222+
223+
224+ class TestShapeDimsSize :
225+ @pytest .mark .parametrize ("param_shape" , [(), (3 ,)])
226+ @pytest .mark .parametrize ("batch_shape" , [(), (3 ,)])
227+ @pytest .mark .parametrize (
228+ "parametrization" ,
229+ [
230+ "implicit" ,
231+ "shape" ,
232+ "shape..." ,
233+ "dims" ,
234+ "dims..." ,
235+ "size" ,
236+ ],
237+ )
238+ def test_param_and_batch_shape_combos (
239+ self , param_shape : tuple , batch_shape : tuple , parametrization : str
240+ ):
241+ coords = {}
242+ param_dims = []
243+ batch_dims = []
244+
245+ # Create coordinates corresponding to the parameter shape
246+ for d in param_shape :
247+ dname = f"param_dim_{ d } "
248+ coords [dname ] = [f"c_{ i } " for i in range (d )]
249+ param_dims .append (dname )
250+ assert len (param_dims ) == len (param_shape )
251+ # Create coordinates corresponding to the batch shape
252+ for d in batch_shape :
253+ dname = f"batch_dim_{ d } "
254+ coords [dname ] = [f"c_{ i } " for i in range (d )]
255+ batch_dims .append (dname )
256+ assert len (batch_dims ) == len (batch_shape )
257+
258+ with pm .Model (coords = coords ) as pmodel :
259+ mu = aesara .shared (np .random .normal (size = param_shape ))
260+
261+ with pytest .warns (None ):
262+ if parametrization == "implicit" :
263+ rv = pm .Normal ("rv" , mu = mu ).shape == param_shape
264+ else :
265+ if parametrization == "shape" :
266+ rv = pm .Normal ("rv" , mu = mu , shape = batch_shape + param_shape )
267+ assert rv .eval ().shape == batch_shape + param_shape
268+ elif parametrization == "shape..." :
269+ rv = pm .Normal ("rv" , mu = mu , shape = (* batch_shape , ...))
270+ assert rv .eval ().shape == batch_shape + param_shape
271+ elif parametrization == "dims" :
272+ rv = pm .Normal ("rv" , mu = mu , dims = batch_dims + param_dims )
273+ assert rv .eval ().shape == batch_shape + param_shape
274+ elif parametrization == "dims..." :
275+ rv = pm .Normal ("rv" , mu = mu , dims = (* batch_dims , ...))
276+ n_size = len (batch_shape )
277+ n_implied = len (param_shape )
278+ ndim = n_size + n_implied
279+ assert len (pmodel .RV_dims ["rv" ]) == ndim , pmodel .RV_dims
280+ assert len (pmodel .RV_dims ["rv" ][:n_size ]) == len (batch_dims )
281+ assert len (pmodel .RV_dims ["rv" ][n_size :]) == len (param_dims )
282+ if n_implied > 0 :
283+ assert pmodel .RV_dims ["rv" ][- 1 ] is None
284+ elif parametrization == "size" :
285+ rv = pm .Normal ("rv" , mu = mu , size = batch_shape )
286+ assert rv .eval ().shape == batch_shape + param_shape
287+ else :
288+ raise NotImplementedError ("Invalid test case parametrization." )
289+
290+ def test_define_dims_on_the_fly (self ):
291+ with pm .Model () as pmodel :
292+ agedata = aesara .shared (np .array ([10 , 20 , 30 ]))
293+
294+ # Associate the "patient" dim with an implied dimension
295+ age = pm .Normal ("age" , agedata , dims = ("patient" ,))
296+ assert "patient" in pmodel .dim_lengths
297+ assert pmodel .dim_lengths ["patient" ].eval () == 3
298+
299+ # Use the dim to replicate a new RV
300+ effect = pm .Normal ("effect" , 0 , dims = ("patient" ,))
301+ assert effect .ndim == 1
302+ assert effect .eval ().shape == (3 ,)
303+
304+ # Now change the length of the implied dimension
305+ agedata .set_value ([1 , 2 , 3 , 4 ])
306+ # The change should propagate all the way through
307+ assert effect .eval ().shape == (4 ,)
308+
309+ @pytest .mark .xfail (reason = "Simultaneous use of size and dims is not implemented" )
310+ def test_data_defined_size_dimension_can_register_dimname (self ):
311+ with pm .Model () as pmodel :
312+ x = pm .Data ("x" , [[1 , 2 , 3 , 4 ]], dims = ("first" , "second" ))
313+ assert "first" in pmodel .dim_lengths
314+ assert "second" in pmodel .dim_lengths
315+ # two dimensions are implied; a "third" dimension is created
316+ y = pm .Normal ("y" , mu = x , size = 2 , dims = ("third" , "first" , "second" ))
317+ assert "third" in pmodel .dim_lengths
318+ assert y .eval ().shape () == (2 , 1 , 4 )
319+
320+ def test_can_resize_data_defined_size (self ):
321+ with pm .Model () as pmodel :
322+ x = pm .Data ("x" , [[1 , 2 , 3 , 4 ]], dims = ("first" , "second" ))
323+ y = pm .Normal ("y" , mu = 0 , dims = ("first" , "second" ))
324+ z = pm .Normal ("z" , mu = y , observed = np .ones ((1 , 4 )))
325+ assert x .eval ().shape == (1 , 4 )
326+ assert y .eval ().shape == (1 , 4 )
327+ assert z .eval ().shape == (1 , 4 )
328+ assert "first" in pmodel .dim_lengths
329+ assert "second" in pmodel .dim_lengths
330+ pmodel .set_data ("x" , [[1 , 2 ], [3 , 4 ], [5 , 6 ]])
331+ assert x .eval ().shape == (3 , 2 )
332+ assert y .eval ().shape == (3 , 2 )
333+ assert z .eval ().shape == (3 , 2 )
334+
335+ @pytest .mark .xfail (
336+ condition = sys .platform == "win32" ,
337+ reason = "See https://github.com/pymc-devs/pymc3/issues/4652." ,
338+ )
339+ def test_observed_with_column_vector (self ):
340+ with pm .Model () as model :
341+ pm .Normal ("x1" , mu = 0 , sd = 1 , observed = np .random .normal (size = (3 , 4 )))
342+ model .logp ()
343+ pm .Normal ("x2" , mu = 0 , sd = 1 , observed = np .random .normal (size = (3 , 1 )))
344+ model .logp ()
345+
346+ def test_dist_api_works (self ):
347+ mu = aesara .shared (np .array ([1 , 2 , 3 ]))
348+ with pytest .raises (NotImplementedError , match = "API is not yet supported" ):
349+ pm .Normal .dist (mu = mu , dims = ("town" ,))
350+ assert pm .Normal .dist (mu = mu , shape = (3 ,)).eval ().shape == (3 ,)
351+ assert pm .Normal .dist (mu = mu , shape = (5 , 3 )).eval ().shape == (5 , 3 )
352+ assert pm .Normal .dist (mu = mu , shape = (7 , ...)).eval ().shape == (7 , 3 )
353+ assert pm .Normal .dist (mu = mu , size = (4 ,)).eval ().shape == (4 , 3 )
354+
355+ def test_auto_assert_shape (self ):
356+ with pytest .raises (AssertionError , match = "will never match" ):
357+ pm .Normal .dist (mu = [1 , 2 ], shape = [])
358+
359+ mu = at .vector (name = "mu_input" )
360+ rv = pm .Normal .dist (mu = mu , shape = [3 , 4 ])
361+ f = aesara .function ([mu ], rv , mode = aesara .Mode ("py" ))
362+ assert f ([1 , 2 , 3 , 4 ]).shape == (3 , 4 )
363+
364+ with pytest .raises (AssertionError , match = r"Got shape \(3, 2\), expected \(3, 4\)." ):
365+ f ([1 , 2 ])
366+
367+ # The `shape` can be symbolic!
368+ s = at .vector (dtype = "int32" )
369+ rv = pm .Uniform .dist (2 , [4 , 5 ], shape = s )
370+ f = aesara .function ([s ], rv , mode = aesara .Mode ("py" ))
371+ f (
372+ [
373+ 2 ,
374+ ]
375+ )
376+ with pytest .raises (
377+ AssertionError ,
378+ match = r"Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(3, 4\)." ,
379+ ):
380+ f ([3 , 4 ])
381+ with pytest .raises (
382+ AssertionError ,
383+ match = r"Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\)." ,
384+ ):
385+ f ([])
386+ pass
387+
388+ def test_lazy_flavors (self ):
389+
390+ _validate_shape_dims_size (shape = 5 )
391+ _validate_shape_dims_size (dims = "town" )
392+ _validate_shape_dims_size (size = 7 )
393+
394+ assert pm .Uniform .dist (2 , [4 , 5 ], size = [3 , 4 ]).eval ().shape == (3 , 4 , 2 )
395+ assert pm .Uniform .dist (2 , [4 , 5 ], shape = [3 , 2 ]).eval ().shape == (3 , 2 )
396+ with pm .Model (coords = dict (town = ["Greifswald" , "Madrid" ])):
397+ assert pm .Normal ("n2" , mu = [1 , 2 ], dims = ("town" ,)).eval ().shape == (2 ,)
398+
399+ def test_invalid_flavors (self ):
400+ # redundant parametrizations
401+ with pytest .raises (ValueError , match = "Passing both" ):
402+ _validate_shape_dims_size (shape = (2 ,), dims = ("town" ,))
403+ with pytest .raises (ValueError , match = "Passing both" ):
404+ _validate_shape_dims_size (dims = ("town" ,), size = (2 ,))
405+ with pytest .raises (ValueError , match = "Passing both" ):
406+ _validate_shape_dims_size (shape = (3 ,), size = (3 ,))
407+
408+ # invalid, but not necessarly rare
409+ with pytest .raises (ValueError , match = "must be an int, list or tuple" ):
410+ _validate_shape_dims_size (size = "notasize" )
411+
412+ # invalid ellipsis positions
413+ with pytest .raises (ValueError , match = "may only appear in the last position" ):
414+ _validate_shape_dims_size (shape = (3 , ..., 2 ))
415+ with pytest .raises (ValueError , match = "may only appear in the last position" ):
416+ _validate_shape_dims_size (dims = (..., "town" ))
417+ with pytest .raises (ValueError , match = "cannot contain" ):
418+ _validate_shape_dims_size (size = (3 , ...))
0 commit comments