1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import warnings
15-
1614from typing import Callable , Dict , Optional , Union
1715
1816import aesara .tensor as aet
@@ -33,6 +31,8 @@ def find_constrained_prior(
3331 init_guess : Dict [str , float ],
3432 mass : float = 0.95 ,
3533 fixed_params : Optional [Dict [str , float ]] = None ,
34+ mass_below_lower : Optional [float ] = None ,
35+ ** kwargs ,
3636) -> Dict [str , float ]:
3737 """
3838 Find optimal parameters to get `mass` % of probability
@@ -64,6 +64,11 @@ def find_constrained_prior(
6464 Dictionary of fixed parameters, so that there are only 2 to optimize.
6565 For instance, for a StudentT, you fix nu to a constant and get the optimized
6666 mu and sigma.
67+ mass_below_lower : float, optional, default None
68+ The probability mass below the ``lower`` bound. If ``None``,
69+ defaults to ``(1 - mass) / 2``, which implies that the probability
70+ mass below the ``lower`` value will be equal to the probability
71+ mass above the ``upper`` value.
6772
6873 Returns
6974 -------
@@ -72,6 +77,11 @@ def find_constrained_prior(
7277 Dictionary keys are the parameter names and
7378 dictionary values are the optimized parameter values.
7479
80+ Notes
81+ -----
82+ Optional keyword arguments can be passed to ``find_constrained_prior``. These will be
83+ delivered to the underlying call to :external:py:func:`scipy.optimize.minimize`.
84+
7585 Examples
7686 --------
7787 .. code-block:: python
@@ -96,11 +106,31 @@ def find_constrained_prior(
96106 init_guess={"mu": 5, "sigma": 2},
97107 fixed_params={"nu": 7},
98108 )
109+
110+ Under some circumstances, you might not want to have the same cumulative
111+ probability below the ``lower`` threshold and above the ``upper`` threshold.
112+ For example, you might want to constrain an Exponential distribution to
113+ find the parameter that yields 90% of the mass below the ``upper`` bound,
114+ and have zero mass below ``lower``. You can do that with the following call
115+ to ``find_constrained_prior``
116+
117+ .. code-block:: python
118+
119+ opt_params = pm.find_constrained_prior(
120+ pm.Exponential,
121+ lower=0,
122+ upper=3.,
123+ mass=0.9,
124+ init_guess={"lam": 1},
125+ mass_below_lower=0,
126+ )
99127 """
100128 assert 0.01 <= mass <= 0.99 , (
101129 "This function optimizes the mass of the given distribution +/- "
102130 f"1%, so `mass` has to be between 0.01 and 0.99. You provided { mass } ."
103131 )
132+ if mass_below_lower is None :
133+ mass_below_lower = (1 - mass ) / 2
104134
105135 # exit when any parameter is not scalar:
106136 if np .any (np .asarray (distribution .rv_op .ndims_params ) != 0 ):
@@ -129,39 +159,39 @@ def find_constrained_prior(
129159 "need it."
130160 )
131161
132- cdf_error = (pm .math .exp (logcdf_upper ) - pm .math .exp (logcdf_lower )) - mass
133- cdf_error_fn = pm .aesaraf .compile_pymc ([dist_params ], cdf_error , allow_input_downcast = True )
162+ target = (aet .exp (logcdf_lower ) - mass_below_lower ) ** 2
163+ target_fn = pm .aesaraf .compile_pymc ([dist_params ], target , allow_input_downcast = True )
164+
165+ constraint = aet .exp (logcdf_upper ) - aet .exp (logcdf_lower )
166+ constraint_fn = pm .aesaraf .compile_pymc ([dist_params ], constraint , allow_input_downcast = True )
134167
135168 jac : Union [str , Callable ]
169+ constraint_jac : Union [str , Callable ]
136170 try :
137- aesara_jac = pm .gradient (cdf_error , [dist_params ])
171+ aesara_jac = pm .gradient (target , [dist_params ])
138172 jac = pm .aesaraf .compile_pymc ([dist_params ], aesara_jac , allow_input_downcast = True )
173+ aesara_constraint_jac = pm .gradient (constraint , [dist_params ])
174+ constraint_jac = pm .aesaraf .compile_pymc (
175+ [dist_params ], aesara_constraint_jac , allow_input_downcast = True
176+ )
139177 # when PyMC cannot compute the gradient
140178 except (NotImplementedError , NullTypeGradError ):
141179 jac = "2-point"
180+ constraint_jac = "2-point"
181+ cons = optimize .NonlinearConstraint (constraint_fn , lb = mass , ub = mass , jac = constraint_jac )
142182
143- opt = optimize .least_squares (cdf_error_fn , x0 = list (init_guess .values ()), jac = jac )
183+ opt = optimize .minimize (
184+ target_fn , x0 = list (init_guess .values ()), jac = jac , constraints = cons , ** kwargs
185+ )
144186 if not opt .success :
145- raise ValueError ("Optimization of parameters failed." )
187+ raise ValueError (
188+ f"Optimization of parameters failed.\n Optimization termination details:\n { opt } "
189+ )
146190
147191 # save optimal parameters
148192 opt_params = {
149193 param_name : param_value for param_name , param_value in zip (init_guess .keys (), opt .x )
150194 }
151195 if fixed_params is not None :
152196 opt_params .update (fixed_params )
153-
154- # check mass in interval is not too far from `mass`
155- opt_dist = distribution .dist (** opt_params )
156- mass_in_interval = (
157- pm .math .exp (pm .logcdf (opt_dist , upper )) - pm .math .exp (pm .logcdf (opt_dist , lower ))
158- ).eval ()
159- if (np .abs (mass_in_interval - mass )) > 0.01 :
160- warnings .warn (
161- f"Final optimization has { (mass_in_interval if mass_in_interval .ndim < 1 else mass_in_interval [0 ])* 100 :.0f} % of probability mass between "
162- f"{ lower } and { upper } instead of the requested { mass * 100 :.0f} %.\n "
163- "You may need to use a more flexible distribution, change the fixed parameters in the "
164- "`fixed_params` dictionary, or provide better initial guesses."
165- )
166-
167197 return opt_params
0 commit comments