Skip to content

Commit 0226047

Browse files
brandonwillardmichaelosthege
authored andcommitted
Fix mistaken use of change_rv_size with new_size None
1 parent 08f5847 commit 0226047

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

pymc3/tests/test_distributions_random.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -379,14 +379,10 @@ def check_rv_size(self):
379379
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
380380
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
381381
for size, expected in zip(sizes_to_check, sizes_expected):
382-
actual = change_rv_size(self.pymc_rv, size).eval().shape
382+
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
383+
actual = tuple(pymc_rv.shape.eval())
383384
assert actual == expected, f"size={size}, expected={expected}, actual={actual}"
384385

385-
# test negative sizes raise
386-
for size in [-2, (3, -2)]:
387-
with pytest.raises(ValueError):
388-
change_rv_size(self.pymc_rv, size).eval()
389-
390386
# test multi-parameters sampling for univariate distributions (with univariate inputs)
391387
if self.pymc_dist.rv_op.ndim_supp == 0 and sum(self.pymc_dist.rv_op.ndims_params) == 0:
392388
params = {
@@ -400,7 +396,8 @@ def check_rv_size(self):
400396
(5, self.repeated_params_shape),
401397
]
402398
for size, expected in zip(sizes_to_check, sizes_expected):
403-
actual = change_rv_size(self.pymc_rv, size).eval().shape
399+
pymc_rv = self.pymc_dist.dist(**params, size=size)
400+
actual = tuple(pymc_rv.shape.eval())
404401
assert actual == expected
405402

406403
def validate_tests_list(self):

0 commit comments

Comments
 (0)