@@ -60,8 +60,8 @@ def pre_process_bounds(
6060
6161
6262def _process_bounds_sequence (bounds : Sequence [tuple [float , float ]]) -> Bounds :
63- lower = np . full (len (bounds ), - np .inf )
64- upper = np . full (len (bounds ), np .inf )
63+ lower = _fast_full_array (len (bounds ), value = - np .inf )
64+ upper = _fast_full_array (len (bounds ), value = np .inf )
6565
6666 for i , (lb , ub ) in enumerate (bounds ):
6767 if lb is not None :
@@ -76,7 +76,8 @@ def get_internal_bounds(
7676 bounds : Bounds | None = None ,
7777 registry : PyTreeRegistry | None = None ,
7878 add_soft_bounds : bool = False ,
79- ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
79+ propagate_none : bool = False ,
80+ ) -> tuple [NDArray [np .float64 ] | None , NDArray [np .float64 ] | None ]:
8081 """Create consolidated and flattened bounds for params.
8182
8283 If params is a DataFrame with value column, the user provided bounds are
@@ -95,6 +96,9 @@ def get_internal_bounds(
9596 add_soft_bounds: If True, the element-wise maximum (minimum) of the lower and
9697 soft_lower (upper and soft_upper) bounds are taken. If False, the lower
9798 (upper) bounds are returned.
99+ propagate_none: If True, None values in bounds are propagated to the output.
100+ If False, None values are replaced with -np.inf for the lower bound and
101+ np.inf for the upper bound.
98102
99103 Returns:
100104 Consolidated and flattened lower_bounds.
@@ -112,6 +116,7 @@ def get_internal_bounds(
112116 return _get_fast_path_bounds (
113117 params = params ,
114118 bounds = bounds ,
119+ propagate_none = propagate_none ,
115120 )
116121
117122 registry = get_registry (extended = True ) if registry is None else registry
@@ -213,7 +218,7 @@ def _is_fast_path(params: PyTree, bounds: Bounds, add_soft_bounds: bool) -> bool
213218 if not _is_1d_array (params ):
214219 out = False
215220
216- for bound in bounds .lower , bounds .upper :
221+ for bound in ( bounds .lower , bounds .upper ) :
217222 if not (_is_1d_array (bound ) or bound is None ):
218223 out = False
219224 return out
@@ -224,22 +229,37 @@ def _is_1d_array(candidate: Any) -> bool:
224229
225230
226231def _get_fast_path_bounds (
227- params : PyTree , bounds : Bounds
228- ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
232+ params : NDArray [ np . float64 ] , bounds : Bounds , propagate_none : bool = False
233+ ) -> tuple [NDArray [np .float64 ] | None , NDArray [np .float64 ] | None ]:
229234 if bounds .lower is None :
230- # faster than np.full
231- lower_bounds = np .array ([- np .inf ] * len (params ))
235+ if propagate_none :
236+ lower_bounds = None
237+ else :
238+ lower_bounds = _fast_full_array (len (params ), value = - np .inf )
232239 else :
233240 lower_bounds = bounds .lower .astype (float )
234241
235242 if bounds .upper is None :
236- # faster than np.full
237- upper_bounds = np .array ([np .inf ] * len (params ))
243+ if propagate_none :
244+ upper_bounds = None
245+ else :
246+ upper_bounds = _fast_full_array (len (params ), value = np .inf )
238247 else :
239248 upper_bounds = bounds .upper .astype (float )
240249
241- if (lower_bounds > upper_bounds ).any ():
250+ if (
251+ lower_bounds is not None
252+ and upper_bounds is not None
253+ and (lower_bounds > upper_bounds ).any ()
254+ ):
242255 msg = "Invalid bounds. Some lower bounds are larger than upper bounds."
243256 raise InvalidBoundsError (msg )
244257
245258 return lower_bounds , upper_bounds
259+
260+
261+ def _fast_full_array (length : int , value : float ) -> NDArray [np .float64 ]:
262+ if length < 18 :
263+ return np .array ([value ] * length )
264+ else :
265+ return np .full (length , value )
0 commit comments