@@ -451,7 +451,7 @@ def predict_list(self, lst: list) -> Tuple[list, list, list]:
451
451
with warnings .catch_warnings ():
452
452
warnings .simplefilter ("ignore" )
453
453
mu , std = self .__model .predict (lst , return_std = True )
454
- return mu , std
454
+ return list ( zip ( mu , std )), mu , std
455
455
456
456
def fit_observations_to_model (self ):
457
457
"""Update the model based on the current list of observations."""
@@ -540,7 +540,7 @@ def initial_sample(self):
540
540
if self .is_valid (observation ):
541
541
collected_samples += 1
542
542
self .fit_observations_to_model ()
543
- _ , std = self .predict_list (self .unvisited_cache )
543
+ _ , _ , std = self .predict_list (self .unvisited_cache )
544
544
self .initial_sample_mean = np .mean (self .__valid_observations )
545
545
# Alternatively:
546
546
# self.initial_sample_std = np.std(self.__valid_observations)
@@ -736,11 +736,11 @@ def __optimize_multi_advanced(self, max_fevals, increase_precision=False):
736
736
if self .__visited_num >= self .searchspace_size or self .fevals >= max_fevals :
737
737
break
738
738
if increase_precision is True :
739
- predictions , _ , std = self .predict_list (self .unvisited_cache )
739
+ predictions = self .predict_list (self .unvisited_cache )
740
740
hyperparam = self .contextual_variance (std )
741
741
list_of_acquisition_values = af (predictions , hyperparam )
742
742
best_af = self .argopt (list_of_acquisition_values )
743
- del predictions [best_af ] # to avoid going out of bounds
743
+ # del predictions[best_af] # to avoid going out of bounds
744
744
candidate_params = self .unvisited_cache [best_af ]
745
745
candidate_index = self .find_param_config_index (candidate_params )
746
746
observation = self .evaluate_objective_function (candidate_params )
@@ -855,13 +855,12 @@ def af_random(self, predictions=None, hyperparam=None) -> list:
855
855
def af_probability_of_improvement (self , predictions = None , hyperparam = None ) -> list :
856
856
"""Acquisition function Probability of Improvement (PI)."""
857
857
# prefetch required data
858
- x_mu , x_std = predictions
859
858
if hyperparam is None :
860
859
hyperparam = self .af_params ["explorationfactor" ]
861
860
fplus = self .current_optimum - hyperparam
862
861
863
862
# precompute difference of improvement
864
- list_diff_improvement = list (- ((fplus - x_mu ) / (x_std + 1e-9 )) for ( x_mu , x_std ) in predictions )
863
+ list_diff_improvement = list (- ((fplus - x_mu ) / (x_std + 1e-9 )) for x_mu , x_std in predictions [ 0 ] )
865
864
866
865
# compute probability of improvement with CDF in bulk
867
866
list_prob_improvement = norm .cdf (list_diff_improvement )
@@ -870,10 +869,15 @@ def af_probability_of_improvement(self, predictions=None, hyperparam=None) -> li
870
869
def af_expected_improvement (self , predictions = None , hyperparam = None ) -> list :
871
870
"""Acquisition function Expected Improvement (EI)."""
872
871
# prefetch required data
873
- x_mu , x_std = predictions
874
872
if hyperparam is None :
875
873
hyperparam = self .af_params ["explorationfactor" ]
876
874
fplus = self .current_optimum - hyperparam
875
+ if len (predictions ) == 3 :
876
+ predictions , x_mu , x_std = predictions
877
+ elif len (predictions ) == 2 :
878
+ x_mu , x_std = predictions
879
+ else :
880
+ raise ValueError (f"Invalid predictions size { len (predictions )} " )
877
881
878
882
# precompute difference of improvement, CDF and PDF in bulk
879
883
list_diff_improvement = list ((fplus - x_mu ) / (x_std + 1e-9 ) for (x_mu , x_std ) in predictions )
@@ -892,6 +896,7 @@ def af_lower_confidence_bound(self, predictions=None, hyperparam=None) -> list:
892
896
if hyperparam is None :
893
897
hyperparam = self .af_params ["explorationfactor" ]
894
898
beta = hyperparam
899
+ _ , x_mu , x_std = predictions
895
900
896
901
# compute LCB in bulk
897
902
list_lower_confidence_bound = (x_mu - beta * x_std )
@@ -900,7 +905,7 @@ def af_lower_confidence_bound(self, predictions=None, hyperparam=None) -> list:
900
905
def af_lower_confidence_bound_srinivas (self , predictions = None , hyperparam = None ) -> list :
901
906
"""Acquisition function Lower Confidence Bound (UCB-S) after Srinivas, 2010 / Brochu, 2010."""
902
907
# prefetch required data
903
- x_mu , x_std = predictions
908
+ _ , x_mu , x_std = predictions
904
909
if hyperparam is None :
905
910
hyperparam = self .af_params ["explorationfactor" ]
906
911
0 commit comments