@@ -369,8 +369,9 @@ def check_rv_size(self):
369
369
sizes_expected = self .sizes_expected or [(), (), (1 ,), (1 ,), (5 ,), (4 , 5 ), (2 , 4 , 2 )]
370
370
for size , expected in zip (sizes_to_check , sizes_expected ):
371
371
pymc_rv = self .pymc_dist .dist (** self .pymc_dist_params , size = size )
372
- actual = tuple (pymc_rv .shape .eval ())
373
- assert actual == expected , f"size={ size } , expected={ expected } , actual={ actual } "
372
+ expected_symbolic = tuple (pymc_rv .shape .eval ())
373
+ actual = pymc_rv .eval ().shape
374
+ assert actual == expected_symbolic == expected
374
375
375
376
# test multi-parameters sampling for univariate distributions (with univariate inputs)
376
377
if (
@@ -390,8 +391,9 @@ def check_rv_size(self):
390
391
]
391
392
for size , expected in zip (sizes_to_check , sizes_expected ):
392
393
pymc_rv = self .pymc_dist .dist (** params , size = size )
393
- actual = tuple (pymc_rv .shape .eval ())
394
- assert actual == expected
394
+ expected_symbolic = tuple (pymc_rv .shape .eval ())
395
+ actual = pymc_rv .eval ().shape
396
+ assert actual == expected_symbolic == expected
395
397
396
398
def validate_tests_list (self ):
397
399
assert len (self .checks_to_run ) == len (
@@ -417,10 +419,18 @@ class TestFlat(BaseTestDistributionRandom):
417
419
expected_rv_op_params = {}
418
420
checks_to_run = [
419
421
"check_pymc_params_match_rv_op" ,
420
- "check_rv_size " ,
422
+ "check_rv_inferred_size " ,
421
423
"check_not_implemented" ,
422
424
]
423
425
426
+ def check_rv_inferred_size (self ):
427
+ sizes_to_check = self .sizes_to_check or [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )]
428
+ sizes_expected = self .sizes_expected or [(), (), (1 ,), (1 ,), (5 ,), (4 , 5 ), (2 , 4 , 2 )]
429
+ for size , expected in zip (sizes_to_check , sizes_expected ):
430
+ pymc_rv = self .pymc_dist .dist (** self .pymc_dist_params , size = size )
431
+ expected_symbolic = tuple (pymc_rv .shape .eval ())
432
+ assert expected_symbolic == expected
433
+
424
434
def check_not_implemented (self ):
425
435
with pytest .raises (NotImplementedError ):
426
436
self .pymc_rv .eval ()
@@ -432,10 +442,18 @@ class TestHalfFlat(BaseTestDistributionRandom):
432
442
expected_rv_op_params = {}
433
443
checks_to_run = [
434
444
"check_pymc_params_match_rv_op" ,
435
- "check_rv_size " ,
445
+ "check_rv_inferred_size " ,
436
446
"check_not_implemented" ,
437
447
]
438
448
449
+ def check_rv_inferred_size (self ):
450
+ sizes_to_check = self .sizes_to_check or [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )]
451
+ sizes_expected = self .sizes_expected or [(), (), (1 ,), (1 ,), (5 ,), (4 , 5 ), (2 , 4 , 2 )]
452
+ for size , expected in zip (sizes_to_check , sizes_expected ):
453
+ pymc_rv = self .pymc_dist .dist (** self .pymc_dist_params , size = size )
454
+ expected_symbolic = tuple (pymc_rv .shape .eval ())
455
+ assert expected_symbolic == expected
456
+
439
457
def check_not_implemented (self ):
440
458
with pytest .raises (NotImplementedError ):
441
459
self .pymc_rv .eval ()
0 commit comments