Skip to content

Commit d63bdd6

Browse files
committed
Split Bound into ContinuousBound and DiscreteBound
1 parent 0a5e78a commit d63bdd6

File tree

1 file changed

+112
-71
lines changed

1 file changed

+112
-71
lines changed

pymc3/distributions/bound.py

+112-71
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,30 @@
1+
from numbers import Real
2+
13
import numpy as np
24
import theano.tensor as tt
5+
import theano
36

47
from pymc3.distributions.distribution import (
5-
Distribution, Discrete, draw_values, generate_samples)
8+
Distribution, Discrete, Continuous, draw_values, generate_samples)
69
from pymc3.distributions import transforms
710
from pymc3.distributions.dist_math import bound
811

912
__all__ = ['Bound']
1013

1114

1215
class _Bounded(Distribution):
13-
R"""
14-
An upper, lower or upper+lower bounded distribution
15-
16-
Parameters
17-
----------
18-
distribution : pymc3 distribution
19-
Distribution to be transformed into a bounded distribution
20-
lower : float (optional)
21-
Lower bound of the distribution, set to -inf to disable.
22-
upper : float (optional)
23-
Upper bound of the distribibution, set to inf to disable.
24-
tranform : 'infer' or object
25-
If 'infer', infers the right transform to apply from the supplied bounds.
26-
If transform object, has to supply .forward() and .backward() methods.
27-
See pymc3.distributions.transforms for more information.
28-
"""
29-
30-
def __init__(self, distribution, lower, upper,
31-
transform='infer', *args, **kwargs):
32-
if lower == -np.inf:
33-
lower = None
34-
if upper == np.inf:
35-
upper = None
36-
37-
if lower is not None:
38-
lower = tt.as_tensor_variable(lower)
39-
if upper is not None:
40-
upper = tt.as_tensor_variable(upper)
41-
16+
def __init__(self, distribution, lower, upper, default, *args, **kwargs):
4217
self.lower = lower
4318
self.upper = upper
44-
45-
if transform == 'infer':
46-
if lower is None and upper is None:
47-
transform = None
48-
default = None
49-
elif lower is not None and upper is not None:
50-
transform = transforms.interval(lower, upper)
51-
default = 0.5 * (lower + upper)
52-
elif upper is not None:
53-
transform = transforms.upperbound(upper)
54-
default = upper - 1
55-
else:
56-
transform = transforms.lowerbound(lower)
57-
default = lower + 1
58-
else:
59-
default = None
60-
61-
# We don't use transformations for dicrete variables
62-
if issubclass(distribution, Discrete):
63-
transform = None
64-
65-
kwargs['transform'] = transform
6619
self._wrapped = distribution.dist(*args, **kwargs)
67-
self._default = default
68-
69-
if issubclass(distribution, Discrete) and default is not None:
70-
default = default.astype(str(self._wrapped.default().dtype))
7120

7221
if default is None:
7322
defaults = self._wrapped.defaults
7423
for name in defaults:
7524
setattr(self, name, getattr(self._wrapped, name))
7625
else:
7726
defaults = ('_default',)
27+
self._default = default
7828

7929
super(_Bounded, self).__init__(
8030
shape=self._wrapped.shape,
@@ -83,6 +33,18 @@ def __init__(self, distribution, lower, upper,
8333
defaults=defaults,
8434
transform=self._wrapped.transform)
8535

36+
def logp(self, value):
37+
logp = self._wrapped.logp(value)
38+
bounds = []
39+
if self.lower is not None:
40+
bounds.append(value >= self.lower)
41+
if self.upper is not None:
42+
bounds.append(value <= self.upper)
43+
if len(bounds) > 0:
44+
return bound(logp, *bounds)
45+
else:
46+
return logp
47+
8648
def _random(self, lower, upper, point=None, size=None):
8749
lower = np.asarray(lower)
8850
upper = np.asarray(upper)
@@ -121,17 +83,75 @@ def random(self, point=None, size=None, repeat=None):
12183
dist_shape=self.shape,
12284
size=size)
12385

124-
def logp(self, value):
125-
logp = self._wrapped.logp(value)
126-
bounds = []
127-
if self.lower is not None:
128-
bounds.append(value >= self.lower)
129-
if self.upper is not None:
130-
bounds.append(value <= self.upper)
131-
if len(bounds) > 0:
132-
return bound(logp, *bounds)
86+
87+
class _DiscreteBounded(_Bounded, Discrete):
88+
def __init__(self, distribution, lower, upper,
89+
transform='infer', *args, **kwargs):
90+
if transform == 'infer':
91+
transform = None
92+
if transform is not None:
93+
raise ValueError('Can not transform discrete variable.')
94+
95+
if lower is None and upper is None:
96+
default = None
97+
elif lower is not None and upper is not None:
98+
default = (lower + upper) // 2
99+
if upper is not None:
100+
default = upper - 1
101+
if lower is not None:
102+
default = lower + 1
103+
104+
super(_DiscreteBounded, self).__init__(
105+
distribution=distribution, lower=lower, upper=upper,
106+
default=default, *args, **kwargs)
107+
108+
109+
class _ContinuousBounded(_Bounded, Continuous):
110+
R"""
111+
An upper, lower or upper+lower bounded distribution
112+
113+
Parameters
114+
----------
115+
distribution : pymc3 distribution
116+
Distribution to be transformed into a bounded distribution
117+
lower : float (optional)
118+
Lower bound of the distribution, set to -inf to disable.
119+
upper : float (optional)
120+
Upper bound of the distribibution, set to inf to disable.
121+
tranform : 'infer' or object
122+
If 'infer', infers the right transform to apply from the supplied bounds.
123+
If transform object, has to supply .forward() and .backward() methods.
124+
See pymc3.distributions.transforms for more information.
125+
"""
126+
127+
def __init__(self, distribution, lower, upper,
128+
transform='infer', *args, **kwargs):
129+
dtype = kwargs.get('dtype', theano.config.floatX)
130+
131+
if lower is not None:
132+
lower = tt.as_tensor_variable(lower).astype(dtype)
133+
if upper is not None:
134+
upper = tt.as_tensor_variable(upper).astype(dtype)
135+
136+
if transform == 'infer':
137+
if lower is None and upper is None:
138+
transform = None
139+
default = None
140+
elif lower is not None and upper is not None:
141+
transform = transforms.interval(lower, upper)
142+
default = 0.5 * (lower + upper)
143+
elif upper is not None:
144+
transform = transforms.upperbound(upper)
145+
default = upper - 1
146+
else:
147+
transform = transforms.lowerbound(lower)
148+
default = lower + 1
133149
else:
134-
return logp
150+
default = None
151+
152+
super(_ContinuousBounded, self).__init__(
153+
distribution=distribution, lower=lower, upper=upper,
154+
transform=transform, default=default, *args, **kwargs)
135155

136156

137157
class Bound(object):
@@ -170,22 +190,43 @@ class Bound(object):
170190
"""
171191

172192
def __init__(self, distribution, lower=None, upper=None):
193+
if isinstance(lower, Real) and lower == -np.inf:
194+
lower = None
195+
if isinstance(upper, Real) and upper == np.inf:
196+
upper = None
197+
198+
if not issubclass(distribution, Distribution):
199+
raise ValueError('"distribution" must be a Distribution subclass.')
200+
173201
self.distribution = distribution
174202
self.lower = lower
175203
self.upper = upper
176204

177205
def __call__(self, *args, **kwargs):
178206
if 'observed' in kwargs:
179-
raise ValueError('Observed Bound distributions are not allowed. '
207+
raise ValueError('Observed Bound distributions are not supported. '
180208
'If you want to model truncated data '
181209
'you can use a pm.Potential in combination '
182210
'with the cumulative probability function. See '
183211
'pymc3/examples/censored_data.py for an example.')
184212
first, args = args[0], args[1:]
185213

186-
return _Bounded(first, self.distribution, self.lower, self.upper,
187-
*args, **kwargs)
214+
if issubclass(self.distribution, Continuous):
215+
return _ContinuousBounded(first, self.distribution,
216+
self.lower, self.upper, *args, **kwargs)
217+
elif issubclass(self.distribution, Discrete):
218+
return _DiscreteBounded(first, self.distribution,
219+
self.lower, self.upper, *args, **kwargs)
220+
else:
221+
raise ValueError('Distribution is neither continuous nor discrete.')
188222

189223
def dist(self, *args, **kwargs):
190-
return _Bounded.dist(self.distribution, self.lower, self.upper,
191-
*args, **kwargs)
224+
if issubclass(self.distribution, Continuous):
225+
return _ContinuousBounded.dist(
226+
self.distribution, self.lower, self.upper, *args, **kwargs)
227+
228+
elif issubclass(self.distribution, Discrete):
229+
return _DiscreteBounded.dist(
230+
self.distribution, self.lower, self.upper, *args, **kwargs)
231+
else:
232+
raise ValueError('Distribution is neither continuous nor discrete.')

0 commit comments

Comments
 (0)