Skip to content

Commit e67c476

Browse files
lucianopaztwiecki
authored andcommitted
Fix for #3310 and some more (#3319)
* Fixed #3310. Added broadcast_distribution_samples, which helps broadcasting multiple rvs calls with different size and distribution parameter shapes. Added shape guards to other continuous distributions. * Fixed broken continuous distributions. Did not notice that _random got a parameter shape aware size input thanks to generate_samples. * Fixed lint error * Addressed comments
1 parent 01f2444 commit e67c476

File tree

5 files changed

+60
-5
lines changed

5 files changed

+60
-5
lines changed

Diff for: RELEASE-NOTES.md

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
### Maintenance
88

9+
- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).
10+
- The `Wald`, `Kumaraswamy`, `LogNormal`, `Pareto`, `Cauchy`, `HalfCauchy`, `Weibull` and `ExGaussian` distributions `random` method used a hidden `_random` function that was written with scalars in mind. This could potentially lead to artificial correlations between random draws. Added shape guards and broadcasting of the distribution samples to prevent this (Similar to issue #3310).
11+
912
### Deprecations
1013

1114
## PyMC3 3.6 (Dec 21 2018)

Diff for: pymc3/distributions/continuous.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
alltrue_elemwise, betaln, bound, gammaln, i0e, incomplete_beta, logpow,
2020
normal_lccdf, normal_lcdf, SplineWrapper, std_cdf, zvalue,
2121
)
22-
from .distribution import Continuous, draw_values, generate_samples
22+
from .distribution import (Continuous, draw_values, generate_samples,
23+
broadcast_distribution_samples)
2324

2425
__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'TruncatedNormal', 'Beta',
2526
'Kumaraswamy', 'Exponential', 'Laplace', 'StudentT', 'Cauchy',
@@ -957,6 +958,8 @@ def random(self, point=None, size=None):
957958
"""
958959
mu, lam, alpha = draw_values([self.mu, self.lam, self.alpha],
959960
point=point, size=size)
961+
mu, lam, alpha = broadcast_distribution_samples([mu, lam, alpha],
962+
size=size)
960963
return generate_samples(self._random,
961964
mu, lam, alpha,
962965
dist_shape=self.shape,
@@ -1285,6 +1288,7 @@ def random(self, point=None, size=None):
12851288
"""
12861289
a, b = draw_values([self.a, self.b],
12871290
point=point, size=size)
1291+
a, b = broadcast_distribution_samples([a, b], size=size)
12881292
return generate_samples(self._random, a, b,
12891293
dist_shape=self.shape,
12901294
size=size)
@@ -1658,6 +1662,7 @@ def random(self, point=None, size=None):
16581662
array
16591663
"""
16601664
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
1665+
mu, tau = broadcast_distribution_samples([mu, tau], size=size)
16611666
return generate_samples(self._random, mu, tau,
16621667
dist_shape=self.shape,
16631668
size=size)
@@ -1945,6 +1950,7 @@ def random(self, point=None, size=None):
19451950
"""
19461951
alpha, m = draw_values([self.alpha, self.m],
19471952
point=point, size=size)
1953+
alpha, m = broadcast_distribution_samples([alpha, m], size=size)
19481954
return generate_samples(self._random, alpha, m,
19491955
dist_shape=self.shape,
19501956
size=size)
@@ -2069,6 +2075,7 @@ def random(self, point=None, size=None):
20692075
"""
20702076
alpha, beta = draw_values([self.alpha, self.beta],
20712077
point=point, size=size)
2078+
alpha, beta = broadcast_distribution_samples([alpha, beta], size=size)
20722079
return generate_samples(self._random, alpha, beta,
20732080
dist_shape=self.shape,
20742081
size=size)
@@ -2629,6 +2636,7 @@ def random(self, point=None, size=None):
26292636
"""
26302637
alpha, beta = draw_values([self.alpha, self.beta],
26312638
point=point, size=size)
2639+
alpha, beta = broadcast_distribution_samples([alpha, beta], size=size)
26322640

26332641
def _random(a, b, size=None):
26342642
return b * (-np.log(np.random.uniform(size=size)))**(1 / a)
@@ -2913,6 +2921,8 @@ def random(self, point=None, size=None):
29132921
"""
29142922
mu, sigma, nu = draw_values([self.mu, self.sigma, self.nu],
29152923
point=point, size=size)
2924+
mu, sigma, nu = broadcast_distribution_samples([mu, sigma, nu],
2925+
size=size)
29162926

29172927
def _random(mu, sigma, nu, size=None):
29182928
return (np.random.normal(mu, sigma, size=size)

Diff for: pymc3/distributions/discrete.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from pymc3.util import get_variable_name
88
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
9-
from .distribution import Discrete, draw_values, generate_samples
9+
from .distribution import (Discrete, draw_values, generate_samples,
10+
broadcast_distribution_samples)
1011
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp
1112

1213

@@ -345,6 +346,7 @@ def _ppf(self, p):
345346

346347
def _random(self, q, beta, size=None):
347348
p = np.random.uniform(size=size)
349+
p, q, beta = broadcast_distribution_samples([p, q, beta], size=size)
348350

349351
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1. / beta)) - 1
350352

@@ -847,7 +849,8 @@ def random(self, point=None, size=None):
847849
g = generate_samples(stats.poisson.rvs, theta,
848850
dist_shape=self.shape,
849851
size=size)
850-
return g * (np.random.random(np.squeeze(g.shape)) < psi)
852+
g, psi = broadcast_distribution_samples([g, psi], size=size)
853+
return g * (np.random.random(g.shape) < psi)
851854

852855
def logp(self, value):
853856
psi = self.psi
@@ -939,7 +942,8 @@ def random(self, point=None, size=None):
939942
g = generate_samples(stats.binom.rvs, n, p,
940943
dist_shape=self.shape,
941944
size=size)
942-
return g * (np.random.random(np.squeeze(g.shape)) < psi)
945+
g, psi = broadcast_distribution_samples([g, psi], size=size)
946+
return g * (np.random.random(g.shape) < psi)
943947

944948
def logp(self, value):
945949
psi = self.psi
@@ -1057,7 +1061,8 @@ def random(self, point=None, size=None):
10571061
dist_shape=self.shape,
10581062
size=size)
10591063
g[g == 0] = np.finfo(float).eps # Just in case
1060-
return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
1064+
g, psi = broadcast_distribution_samples([g, psi], size=size)
1065+
return stats.poisson.rvs(g) * (np.random.random(g.shape) < psi)
10611066

10621067
def logp(self, value):
10631068
alpha = self.alpha

Diff for: pymc3/distributions/distribution.py

+27
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,30 @@ def generate_samples(generator, *args, **kwargs):
636636
if one_d and samples.shape[-1] == 1:
637637
samples = samples.reshape(samples.shape[:-1])
638638
return np.asarray(samples)
639+
640+
641+
def broadcast_distribution_samples(samples, size=None):
642+
if size is None:
643+
return np.broadcast_arrays(*samples)
644+
_size = to_tuple(size)
645+
try:
646+
broadcasted_samples = np.broadcast_arrays(*samples)
647+
except ValueError:
648+
# Raw samples shapes
649+
p_shapes = [p.shape for p in samples]
650+
# samples shapes without the size prepend
651+
sp_shapes = [s[len(_size):] if _size == s[:len(_size)] else s
652+
for s in p_shapes]
653+
broadcast_shape = np.broadcast(*[np.empty(s) for s in sp_shapes]).shape
654+
broadcasted_samples = []
655+
for param, p_shape, sp_shape in zip(samples, p_shapes, sp_shapes):
656+
if _size == p_shape[:len(_size)]:
657+
slicer_head = [slice(None)] * len(_size)
658+
else:
659+
slicer_head = [np.newaxis] * len(_size)
660+
slicer_tail = ([np.newaxis] * (len(broadcast_shape) -
661+
len(sp_shape)) +
662+
[slice(None)] * len(sp_shape))
663+
broadcasted_samples.append(param[tuple(slicer_head + slicer_tail)])
664+
broadcasted_samples = np.broadcast_arrays(*broadcasted_samples)
665+
return broadcasted_samples

Diff for: pymc3/tests/test_sampling.py

+10
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,13 @@ def test_shape_edgecase(self):
467467
x = pm.Normal('x', mu=mu, sd=sd, shape=5)
468468
prior = pm.sample_prior_predictive(10)
469469
assert prior['mu'].shape == (10, 5)
470+
471+
def test_zeroinflatedpoisson(self):
472+
with pm.Model():
473+
theta = pm.Beta('theta', alpha=1, beta=1)
474+
psi = pm.HalfNormal('psi', sd=1)
475+
pm.ZeroInflatedPoisson('suppliers', psi=psi, theta=theta, shape=20)
476+
gen_data = pm.sample_prior_predictive(samples=5000)
477+
assert gen_data['theta'].shape == (5000,)
478+
assert gen_data['psi'].shape == (5000,)
479+
assert gen_data['suppliers'].shape == (5000, 20)

0 commit comments

Comments
 (0)