Skip to content

Commit 75ea2a8

Browse files
Add function to find closest prior within lower/upper bounds (#5231)
Co-authored-by: Ricardo <[email protected]>
1 parent 600fe90 commit 75ea2a8

File tree

4 files changed

+287
-1
lines changed

4 files changed

+287
-1
lines changed

Diff for: RELEASE-NOTES.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
113113
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
114114
- Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)).
115115
- Improve sampling, increase default number of particles [5229](https://github.com/pymc-devs/pymc3/pull/5229).
116+
- The new `pm.find_constrained_prior` function can be used to find optimized prior parameters of a distribution under some
117+
constraints (e.g lower and upper bound). See [#5231](https://github.com/pymc-devs/pymc/pull/5231).
116118
- New features for `pm.Data` containers:
117-
- With `pm.Data(..., mutable=True/False)`, or by using `pm.MutableData` vs. `pm.ConstantData` one can now create `TensorConstant` data variables. They can be more performant and compatible in situtations where a variable doesn't need to be changed via `pm.set_data()`. See [#5295](https://github.com/pymc-devs/pymc/pull/5295).
119+
- With `pm.Data(..., mutable=True/False)`, or by using `pm.MutableData` vs. `pm.ConstantData` one can now create `TensorConstant` data variables. They can be more performant and compatible in situations where a variable doesn't need to be changed via `pm.set_data()`. See [#5295](https://github.com/pymc-devs/pymc/pull/5295).
118120
- New named dimensions can be introduced to the model via `pm.Data(..., dims=...)`. For mutable data variables (see above) the lengths of these dimensions are symbolic, so they can be re-sized via `pm.set_data()`.
119121
- `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098).
120122
- ...

Diff for: pymc/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __set_compiler_flags():
8181
from pymc.distributions import *
8282
from pymc.distributions import transforms
8383
from pymc.exceptions import *
84+
from pymc.func_utils import find_constrained_prior
8485
from pymc.math import (
8586
expand_packed_triangular,
8687
invlogit,

Diff for: pymc/func_utils.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2021 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import warnings
15+
16+
from typing import Dict, Optional
17+
18+
import aesara.tensor as aet
19+
import numpy as np
20+
21+
from aesara.gradient import NullTypeGradError
22+
from scipy import optimize
23+
24+
import pymc as pm
25+
26+
__all__ = ["find_constrained_prior"]
27+
28+
29+
def find_constrained_prior(
30+
distribution: pm.Distribution,
31+
lower: float,
32+
upper: float,
33+
init_guess: Dict[str, float],
34+
mass: float = 0.95,
35+
fixed_params: Optional[Dict[str, float]] = None,
36+
) -> Dict[str, float]:
37+
"""
38+
Find optimal parameters to get `mass` % of probability
39+
of `pm_dist` between `lower` and `upper`.
40+
Note: only works for one- and two-parameter distributions, as there
41+
are exactly two constraints. Fix some combination of parameters
42+
if you want to use it on >=3-parameter distributions.
43+
44+
Parameters
45+
----------
46+
distribution : pm.Distribution
47+
PyMC distribution you want to set a prior on.
48+
Needs to have a ``logcdf`` method implemented in PyMC.
49+
lower : float
50+
Lower bound to get `mass` % of probability of `pm_dist`.
51+
upper : float
52+
Upper bound to get `mass` % of probability of `pm_dist`.
53+
init_guess: Dict[str, float]
54+
Initial guess for ``scipy.optimize.least_squares`` to find the
55+
optimal parameters of `pm_dist` fitting the interval constraint.
56+
Must be a dictionary with the name of the PyMC distribution's
57+
parameter as keys and the initial guess as values.
58+
mass: float, default to 0.95
59+
Share of the probability mass we want between ``lower`` and ``upper``.
60+
Defaults to 95%.
61+
fixed_params: Dict[str, float], Optional, default None
62+
Only used when `pm_dist` has at least three parameters.
63+
Dictionary of fixed parameters, so that there are only 2 to optimize.
64+
For instance, for a StudenT, you fix nu to a constant and get the optimized
65+
mu and sigma.
66+
67+
Returns
68+
-------
69+
The optimized distribution parameters as a dictionary with the parameters'
70+
name as key and the optimized value as value.
71+
72+
Examples
73+
--------
74+
.. code-block:: python
75+
76+
# get parameters obeying constraints
77+
opt_params = pm.find_constrained_prior(
78+
pm.Gamma, lower=0.1, upper=0.4, mass=0.75, init_guess={"alpha": 1, "beta": 10}
79+
)
80+
81+
# use these parameters to draw random samples
82+
samples = pm.Gamma.dist(**opt_params, size=100).eval()
83+
84+
# use these parameters in a model
85+
with pm.Model():
86+
x = pm.Gamma('x', **opt_params)
87+
88+
# specify fixed values before optimization
89+
opt_params = pm.find_constrained_prior(
90+
pm.StudentT,
91+
lower=0,
92+
upper=1,
93+
init_guess={"mu": 5, "sigma": 2},
94+
fixed_params={"nu": 7},
95+
)
96+
"""
97+
assert 0.01 <= mass <= 0.99, (
98+
"This function optimizes the mass of the given distribution +/- "
99+
f"1%, so `mass` has to be between 0.01 and 0.99. You provided {mass}."
100+
)
101+
102+
# exit when any parameter is not scalar:
103+
if np.any(np.asarray(distribution.rv_op.ndims_params) != 0):
104+
raise NotImplementedError(
105+
"`pm.find_constrained_prior` does not work with non-scalar parameters yet.\n"
106+
"Feel free to open a pull request on PyMC repo if you really need this feature."
107+
)
108+
109+
dist_params = aet.vector("dist_params")
110+
params_to_optim = {
111+
arg_name: dist_params[i] for arg_name, i in zip(init_guess.keys(), range(len(init_guess)))
112+
}
113+
114+
if fixed_params is not None:
115+
params_to_optim.update(fixed_params)
116+
117+
dist = distribution.dist(**params_to_optim)
118+
119+
try:
120+
logcdf_lower = pm.logcdf(dist, pm.floatX(lower))
121+
logcdf_upper = pm.logcdf(dist, pm.floatX(upper))
122+
except AttributeError:
123+
raise AttributeError(
124+
f"You cannot use `find_constrained_prior` with {distribution} -- it doesn't have a logcdf "
125+
"method yet.\nOpen an issue or, even better, a pull request on PyMC repo if you really "
126+
"need it."
127+
)
128+
129+
cdf_error = (pm.math.exp(logcdf_upper) - pm.math.exp(logcdf_lower)) - mass
130+
cdf_error_fn = pm.aesaraf.compile_pymc([dist_params], cdf_error, allow_input_downcast=True)
131+
132+
try:
133+
aesara_jac = pm.gradient(cdf_error, [dist_params])
134+
jac = pm.aesaraf.compile_pymc([dist_params], aesara_jac, allow_input_downcast=True)
135+
# when PyMC cannot compute the gradient
136+
except (NotImplementedError, NullTypeGradError):
137+
jac = "2-point"
138+
139+
opt = optimize.least_squares(cdf_error_fn, x0=list(init_guess.values()), jac=jac)
140+
if not opt.success:
141+
raise ValueError("Optimization of parameters failed.")
142+
143+
# save optimal parameters
144+
opt_params = {
145+
param_name: param_value for param_name, param_value in zip(init_guess.keys(), opt.x)
146+
}
147+
if fixed_params is not None:
148+
opt_params.update(fixed_params)
149+
150+
# check mass in interval is not too far from `mass`
151+
opt_dist = distribution.dist(**opt_params)
152+
mass_in_interval = (
153+
pm.math.exp(pm.logcdf(opt_dist, upper)) - pm.math.exp(pm.logcdf(opt_dist, lower))
154+
).eval()
155+
if (np.abs(mass_in_interval - mass)) > 0.01:
156+
warnings.warn(
157+
f"Final optimization has {(mass_in_interval if mass_in_interval.ndim < 1 else mass_in_interval[0])* 100:.0f}% of probability mass between "
158+
f"{lower} and {upper} instead of the requested {mass * 100:.0f}%.\n"
159+
"You may need to use a more flexible distribution, change the fixed parameters in the "
160+
"`fixed_params` dictionary, or provide better initial guesses."
161+
)
162+
163+
return opt_params

Diff for: pymc/tests/test_func_utils.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
import pymc as pm
19+
20+
21+
@pytest.mark.parametrize(
22+
"distribution, lower, upper, init_guess, fixed_params",
23+
[
24+
(pm.Gamma, 0.1, 0.4, {"alpha": 1, "beta": 10}, {}),
25+
(pm.Normal, 155, 180, {"mu": 170, "sigma": 3}, {}),
26+
(pm.StudentT, 0.1, 0.4, {"mu": 10, "sigma": 3}, {"nu": 7}),
27+
(pm.StudentT, 0, 1, {"mu": 5, "sigma": 2, "nu": 7}, {}),
28+
# (pm.Exponential, 0, 1, {"lam": 1}, {}), PyMC Exponential gradient is failing miserably,
29+
# need to figure out why
30+
(pm.HalfNormal, 0, 1, {"sigma": 1}, {}),
31+
(pm.Binomial, 0, 8, {"p": 0.5}, {"n": 10}),
32+
(pm.Poisson, 1, 15, {"mu": 10}, {}),
33+
(pm.Poisson, 19, 41, {"mu": 30}, {}),
34+
],
35+
)
36+
@pytest.mark.parametrize("mass", [0.5, 0.75, 0.95])
37+
def test_find_constrained_prior(distribution, lower, upper, init_guess, fixed_params, mass):
38+
with pytest.warns(None) as record:
39+
opt_params = pm.find_constrained_prior(
40+
distribution,
41+
lower=lower,
42+
upper=upper,
43+
mass=mass,
44+
init_guess=init_guess,
45+
fixed_params=fixed_params,
46+
)
47+
assert len(record) == 0
48+
49+
opt_distribution = distribution.dist(**opt_params)
50+
mass_in_interval = (
51+
pm.math.exp(pm.logcdf(opt_distribution, upper))
52+
- pm.math.exp(pm.logcdf(opt_distribution, lower))
53+
).eval()
54+
assert np.abs(mass_in_interval - mass) <= 1e-5
55+
56+
57+
@pytest.mark.parametrize(
58+
"distribution, lower, upper, init_guess, fixed_params",
59+
[
60+
(pm.Gamma, 0.1, 0.4, {"alpha": 1}, {"beta": 10}),
61+
(pm.Exponential, 0.1, 1, {"lam": 1}, {}),
62+
(pm.Binomial, 0, 2, {"p": 0.8}, {"n": 10}),
63+
],
64+
)
65+
def test_find_constrained_prior_error_too_large(
66+
distribution, lower, upper, init_guess, fixed_params
67+
):
68+
with pytest.warns(UserWarning, match="instead of the requested 95%"):
69+
pm.find_constrained_prior(
70+
distribution,
71+
lower=lower,
72+
upper=upper,
73+
mass=0.95,
74+
init_guess=init_guess,
75+
fixed_params=fixed_params,
76+
)
77+
78+
79+
def test_find_constrained_prior_input_errors():
80+
# missing param
81+
with pytest.raises(TypeError, match="required positional argument"):
82+
pm.find_constrained_prior(
83+
pm.StudentT,
84+
lower=0.1,
85+
upper=0.4,
86+
mass=0.95,
87+
init_guess={"mu": 170, "sigma": 3},
88+
)
89+
90+
# mass too high
91+
with pytest.raises(AssertionError, match="has to be between 0.01 and 0.99"):
92+
pm.find_constrained_prior(
93+
pm.StudentT,
94+
lower=0.1,
95+
upper=0.4,
96+
mass=0.995,
97+
init_guess={"mu": 170, "sigma": 3},
98+
fixed_params={"nu": 7},
99+
)
100+
101+
# mass too low
102+
with pytest.raises(AssertionError, match="has to be between 0.01 and 0.99"):
103+
pm.find_constrained_prior(
104+
pm.StudentT,
105+
lower=0.1,
106+
upper=0.4,
107+
mass=0.005,
108+
init_guess={"mu": 170, "sigma": 3},
109+
fixed_params={"nu": 7},
110+
)
111+
112+
# non-scalar params
113+
with pytest.raises(NotImplementedError, match="does not work with non-scalar parameters yet"):
114+
pm.find_constrained_prior(
115+
pm.MvNormal,
116+
lower=0,
117+
upper=1,
118+
mass=0.95,
119+
init_guess={"mu": 5, "cov": np.asarray([[1, 0.2], [0.2, 1]])},
120+
)

0 commit comments

Comments
 (0)