Skip to content

Commit 6148d1b

Browse files
fonnesbeckColCarroll
authored andcommitted
Fix for shape issue in prior sampling for bounded distributions (#3451)
* Fix for shape issue in prior sampling for bounded distributions * Fixed unrelated test failure in test_normal_scalar * Fixed test_normal_scalar failure
1 parent 2295f0b commit 6148d1b

File tree

2 files changed

+338
-257
lines changed

2 files changed

+338
-257
lines changed

pymc3/distributions/bound.py

Lines changed: 81 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55
import theano
66

77
from pymc3.distributions.distribution import (
8-
Distribution, Discrete, Continuous, draw_values, generate_samples)
8+
Distribution,
9+
Discrete,
10+
Continuous,
11+
draw_values,
12+
generate_samples,
13+
)
914
from pymc3.distributions import transforms
1015
from pymc3.distributions.dist_math import bound
1116

12-
__all__ = ['Bound']
17+
__all__ = ["Bound"]
1318

1419

1520
class _Bounded(Distribution):
@@ -23,15 +28,16 @@ def __init__(self, distribution, lower, upper, default, *args, **kwargs):
2328
for name in defaults:
2429
setattr(self, name, getattr(self._wrapped, name))
2530
else:
26-
defaults = ('_default',)
31+
defaults = ("_default",)
2732
self._default = default
2833

2934
super().__init__(
3035
shape=self._wrapped.shape,
3136
dtype=self._wrapped.dtype,
3237
testval=self._wrapped.testval,
3338
defaults=defaults,
34-
transform=self._wrapped.transform)
39+
transform=self._wrapped.transform,
40+
)
3541

3642
def logp(self, value):
3743
logp = self._wrapped.logp(value)
@@ -49,49 +55,52 @@ def _random(self, lower, upper, point=None, size=None):
4955
lower = np.asarray(lower)
5056
upper = np.asarray(upper)
5157
if lower.size > 1 or upper.size > 1:
52-
raise ValueError('Drawing samples from distributions with '
53-
'array-valued bounds is not supported.')
54-
samples = np.zeros(size, dtype=self.dtype).flatten()
55-
i, n = 0, len(samples)
56-
while i < len(samples):
57-
sample = np.atleast_1d(self._wrapped.random(point=point, size=n))
58+
raise ValueError(
59+
"Drawing samples from distributions with "
60+
"array-valued bounds is not supported."
61+
)
62+
total_size = np.prod(size)
63+
samples = []
64+
s = 0
65+
while s < total_size:
66+
sample = np.atleast_1d(
67+
self._wrapped.random(point=point, size=total_size)
68+
).flatten()
5869

5970
select = sample[np.logical_and(sample >= lower, sample <= upper)]
60-
samples[i:(i + len(select))] = select[:]
61-
i += len(select)
62-
n -= len(select)
71+
samples.append(select)
72+
s += len(select)
6373
if size is not None:
64-
return np.reshape(samples, size)
74+
return np.reshape(np.concatenate(samples)[:total_size], size)
6575
else:
66-
return samples
76+
return samples[0]
6777

6878
def random(self, point=None, size=None):
6979
if self.lower is None and self.upper is None:
7080
return self._wrapped.random(point=point, size=size)
7181
elif self.lower is not None and self.upper is not None:
7282
lower, upper = draw_values([self.lower, self.upper], point=point, size=size)
73-
return generate_samples(self._random, lower, upper, point,
74-
dist_shape=self.shape,
75-
size=size)
83+
return generate_samples(
84+
self._random, lower, upper, point, dist_shape=self.shape, size=size
85+
)
7686
elif self.lower is not None:
7787
lower = draw_values([self.lower], point=point, size=size)
78-
return generate_samples(self._random, lower, np.inf, point,
79-
dist_shape=self.shape,
80-
size=size)
88+
return generate_samples(
89+
self._random, lower, np.inf, point, dist_shape=self.shape, size=size
90+
)
8191
else:
8292
upper = draw_values([self.upper], point=point, size=size)
83-
return generate_samples(self._random, -np.inf, upper, point,
84-
dist_shape=self.shape,
85-
size=size)
93+
return generate_samples(
94+
self._random, -np.inf, upper, point, dist_shape=self.shape, size=size
95+
)
8696

8797

8898
class _DiscreteBounded(_Bounded, Discrete):
89-
def __init__(self, distribution, lower, upper,
90-
transform='infer', *args, **kwargs):
91-
if transform == 'infer':
99+
def __init__(self, distribution, lower, upper, transform="infer", *args, **kwargs):
100+
if transform == "infer":
92101
transform = None
93102
if transform is not None:
94-
raise ValueError('Can not transform discrete variable.')
103+
raise ValueError("Can not transform discrete variable.")
95104

96105
if lower is None and upper is None:
97106
default = None
@@ -103,12 +112,12 @@ def __init__(self, distribution, lower, upper,
103112
default = lower + 1
104113

105114
super().__init__(
106-
distribution, lower, upper,
107-
default, *args, transform=transform, **kwargs)
115+
distribution, lower, upper, default, *args, transform=transform, **kwargs
116+
)
108117

109118

110119
class _ContinuousBounded(_Bounded, Continuous):
111-
R"""
120+
r"""
112121
An upper, lower or upper+lower bounded distribution
113122
114123
Parameters
@@ -125,16 +134,15 @@ class _ContinuousBounded(_Bounded, Continuous):
125134
See pymc3.distributions.transforms for more information.
126135
"""
127136

128-
def __init__(self, distribution, lower, upper,
129-
transform='infer', *args, **kwargs):
130-
dtype = kwargs.get('dtype', theano.config.floatX)
137+
def __init__(self, distribution, lower, upper, transform="infer", *args, **kwargs):
138+
dtype = kwargs.get("dtype", theano.config.floatX)
131139

132140
if lower is not None:
133141
lower = tt.as_tensor_variable(lower).astype(dtype)
134142
if upper is not None:
135143
upper = tt.as_tensor_variable(upper).astype(dtype)
136144

137-
if transform == 'infer':
145+
if transform == "infer":
138146
if lower is None and upper is None:
139147
transform = None
140148
default = None
@@ -151,12 +159,12 @@ def __init__(self, distribution, lower, upper,
151159
default = None
152160

153161
super().__init__(
154-
distribution, lower, upper,
155-
default, *args, transform=transform, **kwargs)
162+
distribution, lower, upper, default, *args, transform=transform, **kwargs
163+
)
156164

157165

158166
class Bound:
159-
R"""
167+
r"""
160168
Create a Bound variable object that can be applied to create
161169
a new upper, lower, or upper and lower bounded distribution.
162170
@@ -207,31 +215,49 @@ def __init__(self, distribution, lower=None, upper=None):
207215
self.upper = upper
208216

209217
def __call__(self, name, *args, **kwargs):
210-
if 'observed' in kwargs:
211-
raise ValueError('Observed Bound distributions are not supported. '
212-
'If you want to model truncated data '
213-
'you can use a pm.Potential in combination '
214-
'with the cumulative probability function. See '
215-
'pymc3/examples/censored_data.py for an example.')
216-
217-
transform = kwargs.pop('transform', 'infer')
218+
if "observed" in kwargs:
219+
raise ValueError(
220+
"Observed Bound distributions are not supported. "
221+
"If you want to model truncated data "
222+
"you can use a pm.Potential in combination "
223+
"with the cumulative probability function. See "
224+
"pymc3/examples/censored_data.py for an example."
225+
)
226+
227+
transform = kwargs.pop("transform", "infer")
218228
if issubclass(self.distribution, Continuous):
219-
return _ContinuousBounded(name, self.distribution, self.lower,
220-
self.upper, transform, *args, **kwargs)
229+
return _ContinuousBounded(
230+
name,
231+
self.distribution,
232+
self.lower,
233+
self.upper,
234+
transform,
235+
*args,
236+
**kwargs
237+
)
221238
elif issubclass(self.distribution, Discrete):
222-
return _DiscreteBounded(name, self.distribution, self.lower,
223-
self.upper, transform, *args, **kwargs)
239+
return _DiscreteBounded(
240+
name,
241+
self.distribution,
242+
self.lower,
243+
self.upper,
244+
transform,
245+
*args,
246+
**kwargs
247+
)
224248
else:
225-
raise ValueError(
226-
'Distribution is neither continuous nor discrete.')
249+
raise ValueError("Distribution is neither continuous nor discrete.")
227250

228251
def dist(self, *args, **kwargs):
229252
if issubclass(self.distribution, Continuous):
230253
return _ContinuousBounded.dist(
231-
self.distribution, self.lower, self.upper, *args, **kwargs)
254+
self.distribution, self.lower, self.upper, *args, **kwargs
255+
)
232256

233257
elif issubclass(self.distribution, Discrete):
234258
return _DiscreteBounded.dist(
235-
self.distribution, self.lower, self.upper, *args, **kwargs)
259+
self.distribution, self.lower, self.upper, *args, **kwargs
260+
)
236261
else:
237-
raise ValueError('Distribution is neither continuous nor discrete.')
262+
raise ValueError("Distribution is neither continuous nor discrete.")
263+

0 commit comments

Comments
 (0)