Skip to content

Commit 75e6be0

Browse files
lucianopazAlexAndorraricardoV94
authored
Constrain priors with symmetric mass distribution (#5981)
* find_constrained_prior now assumes symmetric probability mass above and below upper and lower by default * Fix typo in docstring Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Alexandre Andorra <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent e64942b commit 75e6be0

File tree

2 files changed

+76
-43
lines changed

2 files changed

+76
-43
lines changed

pymc/func_utils.py

+51-21
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
15-
1614
from typing import Callable, Dict, Optional, Union
1715

1816
import aesara.tensor as aet
@@ -33,6 +31,8 @@ def find_constrained_prior(
3331
init_guess: Dict[str, float],
3432
mass: float = 0.95,
3533
fixed_params: Optional[Dict[str, float]] = None,
34+
mass_below_lower: Optional[float] = None,
35+
**kwargs,
3636
) -> Dict[str, float]:
3737
"""
3838
Find optimal parameters to get `mass` % of probability
@@ -64,6 +64,11 @@ def find_constrained_prior(
6464
Dictionary of fixed parameters, so that there are only 2 to optimize.
6565
For instance, for a StudentT, you fix nu to a constant and get the optimized
6666
mu and sigma.
67+
mass_below_lower : float, optional, default None
68+
The probability mass below the ``lower`` bound. If ``None``,
69+
defaults to ``(1 - mass) / 2``, which implies that the probability
70+
mass below the ``lower`` value will be equal to the probability
71+
mass above the ``upper`` value.
6772
6873
Returns
6974
-------
@@ -72,6 +77,11 @@ def find_constrained_prior(
7277
Dictionary keys are the parameter names and
7378
dictionary values are the optimized parameter values.
7479
80+
Notes
81+
-----
82+
Optional keyword arguments can be passed to ``find_constrained_prior``. These will be
83+
delivered to the underlying call to :external:py:func:`scipy.optimize.minimize`.
84+
7585
Examples
7686
--------
7787
.. code-block:: python
@@ -96,11 +106,31 @@ def find_constrained_prior(
96106
init_guess={"mu": 5, "sigma": 2},
97107
fixed_params={"nu": 7},
98108
)
109+
110+
Under some circumstances, you might not want to have the same cumulative
111+
probability below the ``lower`` threshold and above the ``upper`` threshold.
112+
For example, you might want to constrain an Exponential distribution to
113+
find the parameter that yields 90% of the mass below the ``upper`` bound,
114+
and have zero mass below ``lower``. You can do that with the following call
115+
to ``find_constrained_prior``
116+
117+
.. code-block:: python
118+
119+
opt_params = pm.find_constrained_prior(
120+
pm.Exponential,
121+
lower=0,
122+
upper=3.,
123+
mass=0.9,
124+
init_guess={"lam": 1},
125+
mass_below_lower=0,
126+
)
99127
"""
100128
assert 0.01 <= mass <= 0.99, (
101129
"This function optimizes the mass of the given distribution +/- "
102130
f"1%, so `mass` has to be between 0.01 and 0.99. You provided {mass}."
103131
)
132+
if mass_below_lower is None:
133+
mass_below_lower = (1 - mass) / 2
104134

105135
# exit when any parameter is not scalar:
106136
if np.any(np.asarray(distribution.rv_op.ndims_params) != 0):
@@ -129,39 +159,39 @@ def find_constrained_prior(
129159
"need it."
130160
)
131161

132-
cdf_error = (pm.math.exp(logcdf_upper) - pm.math.exp(logcdf_lower)) - mass
133-
cdf_error_fn = pm.aesaraf.compile_pymc([dist_params], cdf_error, allow_input_downcast=True)
162+
target = (aet.exp(logcdf_lower) - mass_below_lower) ** 2
163+
target_fn = pm.aesaraf.compile_pymc([dist_params], target, allow_input_downcast=True)
164+
165+
constraint = aet.exp(logcdf_upper) - aet.exp(logcdf_lower)
166+
constraint_fn = pm.aesaraf.compile_pymc([dist_params], constraint, allow_input_downcast=True)
134167

135168
jac: Union[str, Callable]
169+
constraint_jac: Union[str, Callable]
136170
try:
137-
aesara_jac = pm.gradient(cdf_error, [dist_params])
171+
aesara_jac = pm.gradient(target, [dist_params])
138172
jac = pm.aesaraf.compile_pymc([dist_params], aesara_jac, allow_input_downcast=True)
173+
aesara_constraint_jac = pm.gradient(constraint, [dist_params])
174+
constraint_jac = pm.aesaraf.compile_pymc(
175+
[dist_params], aesara_constraint_jac, allow_input_downcast=True
176+
)
139177
# when PyMC cannot compute the gradient
140178
except (NotImplementedError, NullTypeGradError):
141179
jac = "2-point"
180+
constraint_jac = "2-point"
181+
cons = optimize.NonlinearConstraint(constraint_fn, lb=mass, ub=mass, jac=constraint_jac)
142182

143-
opt = optimize.least_squares(cdf_error_fn, x0=list(init_guess.values()), jac=jac)
183+
opt = optimize.minimize(
184+
target_fn, x0=list(init_guess.values()), jac=jac, constraints=cons, **kwargs
185+
)
144186
if not opt.success:
145-
raise ValueError("Optimization of parameters failed.")
187+
raise ValueError(
188+
f"Optimization of parameters failed.\nOptimization termination details:\n{opt}"
189+
)
146190

147191
# save optimal parameters
148192
opt_params = {
149193
param_name: param_value for param_name, param_value in zip(init_guess.keys(), opt.x)
150194
}
151195
if fixed_params is not None:
152196
opt_params.update(fixed_params)
153-
154-
# check mass in interval is not too far from `mass`
155-
opt_dist = distribution.dist(**opt_params)
156-
mass_in_interval = (
157-
pm.math.exp(pm.logcdf(opt_dist, upper)) - pm.math.exp(pm.logcdf(opt_dist, lower))
158-
).eval()
159-
if (np.abs(mass_in_interval - mass)) > 0.01:
160-
warnings.warn(
161-
f"Final optimization has {(mass_in_interval if mass_in_interval.ndim < 1 else mass_in_interval[0])* 100:.0f}% of probability mass between "
162-
f"{lower} and {upper} instead of the requested {mass * 100:.0f}%.\n"
163-
"You may need to use a more flexible distribution, change the fixed parameters in the "
164-
"`fixed_params` dictionary, or provide better initial guesses."
165-
)
166-
167197
return opt_params

pymc/tests/test_func_utils.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,32 @@
1919

2020

2121
@pytest.mark.parametrize(
22-
"distribution, lower, upper, init_guess, fixed_params",
22+
"distribution, lower, upper, init_guess, fixed_params, mass_below_lower",
2323
[
24-
(pm.Gamma, 0.1, 0.4, {"alpha": 1, "beta": 10}, {}),
25-
(pm.Normal, 155, 180, {"mu": 170, "sigma": 3}, {}),
26-
(pm.StudentT, 0.1, 0.4, {"mu": 10, "sigma": 3}, {"nu": 7}),
27-
(pm.StudentT, 0, 1, {"mu": 5, "sigma": 2, "nu": 7}, {}),
28-
(pm.Exponential, 0, 1, {"lam": 1}, {}),
29-
(pm.HalfNormal, 0, 1, {"sigma": 1}, {}),
30-
(pm.Binomial, 0, 8, {"p": 0.5}, {"n": 10}),
31-
(pm.Poisson, 1, 15, {"mu": 10}, {}),
32-
(pm.Poisson, 19, 41, {"mu": 30}, {}),
24+
(pm.Gamma, 0.1, 0.4, {"alpha": 1, "beta": 10}, {}, None),
25+
(pm.Normal, 155, 180, {"mu": 170, "sigma": 3}, {}, None),
26+
(pm.StudentT, 0.1, 0.4, {"mu": 10, "sigma": 3}, {"nu": 7}, None),
27+
(pm.StudentT, 0, 1, {"mu": 5, "sigma": 2, "nu": 7}, {}, None),
28+
(pm.Exponential, 0, 1, {"lam": 1}, {}, 0),
29+
(pm.HalfNormal, 0, 1, {"sigma": 1}, {}, 0),
30+
(pm.Binomial, 0, 8, {"p": 0.5}, {"n": 10}, None),
31+
(pm.Poisson, 1, 15, {"mu": 10}, {}, None),
32+
(pm.Poisson, 19, 41, {"mu": 30}, {}, None),
3333
],
3434
)
3535
@pytest.mark.parametrize("mass", [0.5, 0.75, 0.95])
36-
def test_find_constrained_prior(distribution, lower, upper, init_guess, fixed_params, mass):
37-
with pytest.warns(None) as record:
38-
opt_params = pm.find_constrained_prior(
39-
distribution,
40-
lower=lower,
41-
upper=upper,
42-
mass=mass,
43-
init_guess=init_guess,
44-
fixed_params=fixed_params,
45-
)
46-
assert len(record) == 0
36+
def test_find_constrained_prior(
37+
distribution, lower, upper, init_guess, fixed_params, mass, mass_below_lower
38+
):
39+
opt_params = pm.find_constrained_prior(
40+
distribution,
41+
lower=lower,
42+
upper=upper,
43+
mass=mass,
44+
init_guess=init_guess,
45+
fixed_params=fixed_params,
46+
mass_below_lower=mass_below_lower,
47+
)
4748

4849
opt_distribution = distribution.dist(**opt_params)
4950
mass_in_interval = (
@@ -64,7 +65,9 @@ def test_find_constrained_prior(distribution, lower, upper, init_guess, fixed_pa
6465
def test_find_constrained_prior_error_too_large(
6566
distribution, lower, upper, init_guess, fixed_params
6667
):
67-
with pytest.warns(UserWarning, match="instead of the requested 95%"):
68+
with pytest.raises(
69+
ValueError, match="Optimization of parameters failed.\nOptimization termination details:\n"
70+
):
6871
pm.find_constrained_prior(
6972
distribution,
7073
lower=lower,

0 commit comments

Comments
 (0)