11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import warnings
15
-
16
14
from typing import Callable , Dict , Optional , Union
17
15
18
16
import aesara .tensor as aet
@@ -33,6 +31,8 @@ def find_constrained_prior(
33
31
init_guess : Dict [str , float ],
34
32
mass : float = 0.95 ,
35
33
fixed_params : Optional [Dict [str , float ]] = None ,
34
+ mass_below_lower : Optional [float ] = None ,
35
+ ** kwargs ,
36
36
) -> Dict [str , float ]:
37
37
"""
38
38
Find optimal parameters to get `mass` % of probability
@@ -64,6 +64,11 @@ def find_constrained_prior(
64
64
Dictionary of fixed parameters, so that there are only 2 to optimize.
65
65
For instance, for a StudentT, you fix nu to a constant and get the optimized
66
66
mu and sigma.
67
+ mass_below_lower : float, optional, default None
68
+ The probability mass below the ``lower`` bound. If ``None``,
69
+ defaults to ``(1 - mass) / 2``, which implies that the probability
70
+ mass below the ``lower`` value will be equal to the probability
71
+ mass above the ``upper`` value.
67
72
68
73
Returns
69
74
-------
@@ -72,6 +77,11 @@ def find_constrained_prior(
72
77
Dictionary keys are the parameter names and
73
78
dictionary values are the optimized parameter values.
74
79
80
+ Notes
81
+ -----
82
+ Optional keyword arguments can be passed to ``find_constrained_prior``. These will be
83
+ delivered to the underlying call to :external:py:func:`scipy.optimize.minimize`.
84
+
75
85
Examples
76
86
--------
77
87
.. code-block:: python
@@ -96,11 +106,31 @@ def find_constrained_prior(
96
106
init_guess={"mu": 5, "sigma": 2},
97
107
fixed_params={"nu": 7},
98
108
)
109
+
110
+ Under some circumstances, you might not want to have the same cumulative
111
+ probability below the ``lower`` threshold and above the ``upper`` threshold.
112
+ For example, you might want to constrain an Exponential distribution to
113
+ find the parameter that yields 90% of the mass below the ``upper`` bound,
114
+ and have zero mass below ``lower``. You can do that with the following call
115
+ to ``find_constrained_prior``
116
+
117
+ .. code-block:: python
118
+
119
+ opt_params = pm.find_constrained_prior(
120
+ pm.Exponential,
121
+ lower=0,
122
+ upper=3.,
123
+ mass=0.9,
124
+ init_guess={"lam": 1},
125
+ mass_below_lower=0,
126
+ )
99
127
"""
100
128
assert 0.01 <= mass <= 0.99 , (
101
129
"This function optimizes the mass of the given distribution +/- "
102
130
f"1%, so `mass` has to be between 0.01 and 0.99. You provided { mass } ."
103
131
)
132
+ if mass_below_lower is None :
133
+ mass_below_lower = (1 - mass ) / 2
104
134
105
135
# exit when any parameter is not scalar:
106
136
if np .any (np .asarray (distribution .rv_op .ndims_params ) != 0 ):
@@ -129,39 +159,39 @@ def find_constrained_prior(
129
159
"need it."
130
160
)
131
161
132
- cdf_error = (pm .math .exp (logcdf_upper ) - pm .math .exp (logcdf_lower )) - mass
133
- cdf_error_fn = pm .aesaraf .compile_pymc ([dist_params ], cdf_error , allow_input_downcast = True )
162
+ target = (aet .exp (logcdf_lower ) - mass_below_lower ) ** 2
163
+ target_fn = pm .aesaraf .compile_pymc ([dist_params ], target , allow_input_downcast = True )
164
+
165
+ constraint = aet .exp (logcdf_upper ) - aet .exp (logcdf_lower )
166
+ constraint_fn = pm .aesaraf .compile_pymc ([dist_params ], constraint , allow_input_downcast = True )
134
167
135
168
jac : Union [str , Callable ]
169
+ constraint_jac : Union [str , Callable ]
136
170
try :
137
- aesara_jac = pm .gradient (cdf_error , [dist_params ])
171
+ aesara_jac = pm .gradient (target , [dist_params ])
138
172
jac = pm .aesaraf .compile_pymc ([dist_params ], aesara_jac , allow_input_downcast = True )
173
+ aesara_constraint_jac = pm .gradient (constraint , [dist_params ])
174
+ constraint_jac = pm .aesaraf .compile_pymc (
175
+ [dist_params ], aesara_constraint_jac , allow_input_downcast = True
176
+ )
139
177
# when PyMC cannot compute the gradient
140
178
except (NotImplementedError , NullTypeGradError ):
141
179
jac = "2-point"
180
+ constraint_jac = "2-point"
181
+ cons = optimize .NonlinearConstraint (constraint_fn , lb = mass , ub = mass , jac = constraint_jac )
142
182
143
- opt = optimize .least_squares (cdf_error_fn , x0 = list (init_guess .values ()), jac = jac )
183
+ opt = optimize .minimize (
184
+ target_fn , x0 = list (init_guess .values ()), jac = jac , constraints = cons , ** kwargs
185
+ )
144
186
if not opt .success :
145
- raise ValueError ("Optimization of parameters failed." )
187
+ raise ValueError (
188
+ f"Optimization of parameters failed.\n Optimization termination details:\n { opt } "
189
+ )
146
190
147
191
# save optimal parameters
148
192
opt_params = {
149
193
param_name : param_value for param_name , param_value in zip (init_guess .keys (), opt .x )
150
194
}
151
195
if fixed_params is not None :
152
196
opt_params .update (fixed_params )
153
-
154
- # check mass in interval is not too far from `mass`
155
- opt_dist = distribution .dist (** opt_params )
156
- mass_in_interval = (
157
- pm .math .exp (pm .logcdf (opt_dist , upper )) - pm .math .exp (pm .logcdf (opt_dist , lower ))
158
- ).eval ()
159
- if (np .abs (mass_in_interval - mass )) > 0.01 :
160
- warnings .warn (
161
- f"Final optimization has { (mass_in_interval if mass_in_interval .ndim < 1 else mass_in_interval [0 ])* 100 :.0f} % of probability mass between "
162
- f"{ lower } and { upper } instead of the requested { mass * 100 :.0f} %.\n "
163
- "You may need to use a more flexible distribution, change the fixed parameters in the "
164
- "`fixed_params` dictionary, or provide better initial guesses."
165
- )
166
-
167
197
return opt_params
0 commit comments