Skip to content

Commit 346daa1

Browse files
committed
Test expected (inferred) and actual shape of draws in TestBaseDistributionRandom
* Fixes bug in returned samples from `Wishart` when `size=1`
1 parent f534a7f commit 346daa1

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

pymc/distributions/multivariate.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -891,9 +891,13 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
891891
return dist_params[1].shape
892892

893893
@classmethod
894-
def rng_fn(cls, rng, nu, V, size=None):
895-
size = size if size else 1 # Default size for Scipy's wishart.rvs is 1
896-
return stats.wishart.rvs(np.int(nu), V, size=size, random_state=rng)
894+
def rng_fn(cls, rng, nu, V, size):
895+
scipy_size = size if size else 1 # Default size for Scipy's wishart.rvs is 1
896+
result = stats.wishart.rvs(np.int(nu), V, size=scipy_size, random_state=rng)
897+
if size == (1,):
898+
return result[np.newaxis, ...]
899+
else:
900+
return result
897901

898902

899903
wishart = WishartRV()

pymc/tests/test_distributions_random.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,9 @@ def check_rv_size(self):
369369
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
370370
for size, expected in zip(sizes_to_check, sizes_expected):
371371
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
372-
actual = tuple(pymc_rv.shape.eval())
373-
assert actual == expected, f"size={size}, expected={expected}, actual={actual}"
372+
expected_symbolic = tuple(pymc_rv.shape.eval())
373+
actual = pymc_rv.eval().shape
374+
assert actual == expected_symbolic == expected
374375

375376
# test multi-parameters sampling for univariate distributions (with univariate inputs)
376377
if (
@@ -390,8 +391,9 @@ def check_rv_size(self):
390391
]
391392
for size, expected in zip(sizes_to_check, sizes_expected):
392393
pymc_rv = self.pymc_dist.dist(**params, size=size)
393-
actual = tuple(pymc_rv.shape.eval())
394-
assert actual == expected
394+
expected_symbolic = tuple(pymc_rv.shape.eval())
395+
actual = pymc_rv.eval().shape
396+
assert actual == expected_symbolic == expected
395397

396398
def validate_tests_list(self):
397399
assert len(self.checks_to_run) == len(
@@ -417,10 +419,18 @@ class TestFlat(BaseTestDistributionRandom):
417419
expected_rv_op_params = {}
418420
checks_to_run = [
419421
"check_pymc_params_match_rv_op",
420-
"check_rv_size",
422+
"check_rv_inferred_size",
421423
"check_not_implemented",
422424
]
423425

426+
def check_rv_inferred_size(self):
427+
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
428+
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
429+
for size, expected in zip(sizes_to_check, sizes_expected):
430+
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
431+
expected_symbolic = tuple(pymc_rv.shape.eval())
432+
assert expected_symbolic == expected
433+
424434
def check_not_implemented(self):
425435
with pytest.raises(NotImplementedError):
426436
self.pymc_rv.eval()
@@ -432,10 +442,18 @@ class TestHalfFlat(BaseTestDistributionRandom):
432442
expected_rv_op_params = {}
433443
checks_to_run = [
434444
"check_pymc_params_match_rv_op",
435-
"check_rv_size",
445+
"check_rv_inferred_size",
436446
"check_not_implemented",
437447
]
438448

449+
def check_rv_inferred_size(self):
450+
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
451+
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
452+
for size, expected in zip(sizes_to_check, sizes_expected):
453+
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
454+
expected_symbolic = tuple(pymc_rv.shape.eval())
455+
assert expected_symbolic == expected
456+
439457
def check_not_implemented(self):
440458
with pytest.raises(NotImplementedError):
441459
self.pymc_rv.eval()

0 commit comments

Comments
 (0)