@@ -379,14 +379,10 @@ def check_rv_size(self):
379
379
sizes_to_check = self .sizes_to_check or [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )]
380
380
sizes_expected = self .sizes_expected or [(), (), (1 ,), (1 ,), (5 ,), (4 , 5 ), (2 , 4 , 2 )]
381
381
for size , expected in zip (sizes_to_check , sizes_expected ):
382
- actual = change_rv_size (self .pymc_rv , size ).eval ().shape
382
+ pymc_rv = self .pymc_dist .dist (** self .pymc_dist_params , size = size )
383
+ actual = tuple (pymc_rv .shape .eval ())
383
384
assert actual == expected , f"size={ size } , expected={ expected } , actual={ actual } "
384
385
385
- # test negative sizes raise
386
- for size in [- 2 , (3 , - 2 )]:
387
- with pytest .raises (ValueError ):
388
- change_rv_size (self .pymc_rv , size ).eval ()
389
-
390
386
# test multi-parameters sampling for univariate distributions (with univariate inputs)
391
387
if self .pymc_dist .rv_op .ndim_supp == 0 and sum (self .pymc_dist .rv_op .ndims_params ) == 0 :
392
388
params = {
@@ -400,7 +396,8 @@ def check_rv_size(self):
400
396
(5 , self .repeated_params_shape ),
401
397
]
402
398
for size , expected in zip (sizes_to_check , sizes_expected ):
403
- actual = change_rv_size (self .pymc_rv , size ).eval ().shape
399
+ pymc_rv = self .pymc_dist .dist (** params , size = size )
400
+ actual = tuple (pymc_rv .shape .eval ())
404
401
assert actual == expected
405
402
406
403
def validate_tests_list (self ):
0 commit comments