@@ -219,3 +219,200 @@ def test_sample_generate_values(fixture_model, fixture_sizes):
219
219
prior = pm .sample_prior_predictive (samples = fixture_sizes )
220
220
for rv in RVs :
221
221
assert prior [rv .name ].shape == size + tuple (rv .distribution .shape )
222
+
223
+
224
+ class TestShapeDimsSize :
225
+ @pytest .mark .parametrize ("param_shape" , [(), (3 ,)])
226
+ @pytest .mark .parametrize ("batch_shape" , [(), (3 ,)])
227
+ @pytest .mark .parametrize (
228
+ "parametrization" ,
229
+ [
230
+ "implicit" ,
231
+ "shape" ,
232
+ "shape..." ,
233
+ "dims" ,
234
+ "dims..." ,
235
+ "size" ,
236
+ ],
237
+ )
238
+ def test_param_and_batch_shape_combos (
239
+ self , param_shape : tuple , batch_shape : tuple , parametrization : str
240
+ ):
241
+ coords = {}
242
+ param_dims = []
243
+ batch_dims = []
244
+
245
+ # Create coordinates corresponding to the parameter shape
246
+ for d in param_shape :
247
+ dname = f"param_dim_{ d } "
248
+ coords [dname ] = [f"c_{ i } " for i in range (d )]
249
+ param_dims .append (dname )
250
+ assert len (param_dims ) == len (param_shape )
251
+ # Create coordinates corresponding to the batch shape
252
+ for d in batch_shape :
253
+ dname = f"batch_dim_{ d } "
254
+ coords [dname ] = [f"c_{ i } " for i in range (d )]
255
+ batch_dims .append (dname )
256
+ assert len (batch_dims ) == len (batch_shape )
257
+
258
+ with pm .Model (coords = coords ) as pmodel :
259
+ mu = aesara .shared (np .random .normal (size = param_shape ))
260
+
261
+ with pytest .warns (None ):
262
+ if parametrization == "implicit" :
263
+ rv = pm .Normal ("rv" , mu = mu ).shape == param_shape
264
+ else :
265
+ if parametrization == "shape" :
266
+ rv = pm .Normal ("rv" , mu = mu , shape = batch_shape + param_shape )
267
+ assert rv .eval ().shape == batch_shape + param_shape
268
+ elif parametrization == "shape..." :
269
+ rv = pm .Normal ("rv" , mu = mu , shape = (* batch_shape , ...))
270
+ assert rv .eval ().shape == batch_shape + param_shape
271
+ elif parametrization == "dims" :
272
+ rv = pm .Normal ("rv" , mu = mu , dims = batch_dims + param_dims )
273
+ assert rv .eval ().shape == batch_shape + param_shape
274
+ elif parametrization == "dims..." :
275
+ rv = pm .Normal ("rv" , mu = mu , dims = (* batch_dims , ...))
276
+ n_size = len (batch_shape )
277
+ n_implied = len (param_shape )
278
+ ndim = n_size + n_implied
279
+ assert len (pmodel .RV_dims ["rv" ]) == ndim , pmodel .RV_dims
280
+ assert len (pmodel .RV_dims ["rv" ][:n_size ]) == len (batch_dims )
281
+ assert len (pmodel .RV_dims ["rv" ][n_size :]) == len (param_dims )
282
+ if n_implied > 0 :
283
+ assert pmodel .RV_dims ["rv" ][- 1 ] is None
284
+ elif parametrization == "size" :
285
+ rv = pm .Normal ("rv" , mu = mu , size = batch_shape )
286
+ assert rv .eval ().shape == batch_shape + param_shape
287
+ else :
288
+ raise NotImplementedError ("Invalid test case parametrization." )
289
+
290
+ def test_define_dims_on_the_fly (self ):
291
+ with pm .Model () as pmodel :
292
+ agedata = aesara .shared (np .array ([10 , 20 , 30 ]))
293
+
294
+ # Associate the "patient" dim with an implied dimension
295
+ age = pm .Normal ("age" , agedata , dims = ("patient" ,))
296
+ assert "patient" in pmodel .dim_lengths
297
+ assert pmodel .dim_lengths ["patient" ].eval () == 3
298
+
299
+ # Use the dim to replicate a new RV
300
+ effect = pm .Normal ("effect" , 0 , dims = ("patient" ,))
301
+ assert effect .ndim == 1
302
+ assert effect .eval ().shape == (3 ,)
303
+
304
+ # Now change the length of the implied dimension
305
+ agedata .set_value ([1 , 2 , 3 , 4 ])
306
+ # The change should propagate all the way through
307
+ assert effect .eval ().shape == (4 ,)
308
+
309
+ @pytest .mark .xfail (reason = "Simultaneous use of size and dims is not implemented" )
310
+ def test_data_defined_size_dimension_can_register_dimname (self ):
311
+ with pm .Model () as pmodel :
312
+ x = pm .Data ("x" , [[1 , 2 , 3 , 4 ]], dims = ("first" , "second" ))
313
+ assert "first" in pmodel .dim_lengths
314
+ assert "second" in pmodel .dim_lengths
315
+ # two dimensions are implied; a "third" dimension is created
316
+ y = pm .Normal ("y" , mu = x , size = 2 , dims = ("third" , "first" , "second" ))
317
+ assert "third" in pmodel .dim_lengths
318
+ assert y .eval ().shape () == (2 , 1 , 4 )
319
+
320
+ def test_can_resize_data_defined_size (self ):
321
+ with pm .Model () as pmodel :
322
+ x = pm .Data ("x" , [[1 , 2 , 3 , 4 ]], dims = ("first" , "second" ))
323
+ y = pm .Normal ("y" , mu = 0 , dims = ("first" , "second" ))
324
+ z = pm .Normal ("z" , mu = y , observed = np .ones ((1 , 4 )))
325
+ assert x .eval ().shape == (1 , 4 )
326
+ assert y .eval ().shape == (1 , 4 )
327
+ assert z .eval ().shape == (1 , 4 )
328
+ assert "first" in pmodel .dim_lengths
329
+ assert "second" in pmodel .dim_lengths
330
+ pmodel .set_data ("x" , [[1 , 2 ], [3 , 4 ], [5 , 6 ]])
331
+ assert x .eval ().shape == (3 , 2 )
332
+ assert y .eval ().shape == (3 , 2 )
333
+ assert z .eval ().shape == (3 , 2 )
334
+
335
+ @pytest .mark .xfail (
336
+ condition = sys .platform == "win32" ,
337
+ reason = "See https://github.com/pymc-devs/pymc3/issues/4652." ,
338
+ )
339
+ def test_observed_with_column_vector (self ):
340
+ with pm .Model () as model :
341
+ pm .Normal ("x1" , mu = 0 , sd = 1 , observed = np .random .normal (size = (3 , 4 )))
342
+ model .logp ()
343
+ pm .Normal ("x2" , mu = 0 , sd = 1 , observed = np .random .normal (size = (3 , 1 )))
344
+ model .logp ()
345
+
346
+ def test_dist_api_works (self ):
347
+ mu = aesara .shared (np .array ([1 , 2 , 3 ]))
348
+ with pytest .raises (NotImplementedError , match = "API is not yet supported" ):
349
+ pm .Normal .dist (mu = mu , dims = ("town" ,))
350
+ assert pm .Normal .dist (mu = mu , shape = (3 ,)).eval ().shape == (3 ,)
351
+ assert pm .Normal .dist (mu = mu , shape = (5 , 3 )).eval ().shape == (5 , 3 )
352
+ assert pm .Normal .dist (mu = mu , shape = (7 , ...)).eval ().shape == (7 , 3 )
353
+ assert pm .Normal .dist (mu = mu , size = (4 ,)).eval ().shape == (4 , 3 )
354
+
355
+ def test_auto_assert_shape (self ):
356
+ with pytest .raises (AssertionError , match = "will never match" ):
357
+ pm .Normal .dist (mu = [1 , 2 ], shape = [])
358
+
359
+ mu = at .vector (name = "mu_input" )
360
+ rv = pm .Normal .dist (mu = mu , shape = [3 , 4 ])
361
+ f = aesara .function ([mu ], rv , mode = aesara .Mode ("py" ))
362
+ assert f ([1 , 2 , 3 , 4 ]).shape == (3 , 4 )
363
+
364
+ with pytest .raises (AssertionError , match = r"Got shape \(3, 2\), expected \(3, 4\)." ):
365
+ f ([1 , 2 ])
366
+
367
+ # The `shape` can be symbolic!
368
+ s = at .vector (dtype = "int32" )
369
+ rv = pm .Uniform .dist (2 , [4 , 5 ], shape = s )
370
+ f = aesara .function ([s ], rv , mode = aesara .Mode ("py" ))
371
+ f (
372
+ [
373
+ 2 ,
374
+ ]
375
+ )
376
+ with pytest .raises (
377
+ AssertionError ,
378
+ match = r"Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(3, 4\)." ,
379
+ ):
380
+ f ([3 , 4 ])
381
+ with pytest .raises (
382
+ AssertionError ,
383
+ match = r"Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\)." ,
384
+ ):
385
+ f ([])
386
+ pass
387
+
388
+ def test_lazy_flavors (self ):
389
+
390
+ _validate_shape_dims_size (shape = 5 )
391
+ _validate_shape_dims_size (dims = "town" )
392
+ _validate_shape_dims_size (size = 7 )
393
+
394
+ assert pm .Uniform .dist (2 , [4 , 5 ], size = [3 , 4 ]).eval ().shape == (3 , 4 , 2 )
395
+ assert pm .Uniform .dist (2 , [4 , 5 ], shape = [3 , 2 ]).eval ().shape == (3 , 2 )
396
+ with pm .Model (coords = dict (town = ["Greifswald" , "Madrid" ])):
397
+ assert pm .Normal ("n2" , mu = [1 , 2 ], dims = ("town" ,)).eval ().shape == (2 ,)
398
+
399
+ def test_invalid_flavors (self ):
400
+ # redundant parametrizations
401
+ with pytest .raises (ValueError , match = "Passing both" ):
402
+ _validate_shape_dims_size (shape = (2 ,), dims = ("town" ,))
403
+ with pytest .raises (ValueError , match = "Passing both" ):
404
+ _validate_shape_dims_size (dims = ("town" ,), size = (2 ,))
405
+ with pytest .raises (ValueError , match = "Passing both" ):
406
+ _validate_shape_dims_size (shape = (3 ,), size = (3 ,))
407
+
408
+ # invalid, but not necessarly rare
409
+ with pytest .raises (ValueError , match = "must be an int, list or tuple" ):
410
+ _validate_shape_dims_size (size = "notasize" )
411
+
412
+ # invalid ellipsis positions
413
+ with pytest .raises (ValueError , match = "may only appear in the last position" ):
414
+ _validate_shape_dims_size (shape = (3 , ..., 2 ))
415
+ with pytest .raises (ValueError , match = "may only appear in the last position" ):
416
+ _validate_shape_dims_size (dims = (..., "town" ))
417
+ with pytest .raises (ValueError , match = "cannot contain" ):
418
+ _validate_shape_dims_size (size = (3 , ...))
0 commit comments