Skip to content

Commit e407a84

Browse files
committed
Compatibility with optional dependencies
1 parent 5ab70df commit e407a84

File tree

2 files changed

+59
-58
lines changed

2 files changed

+59
-58
lines changed

kernel_tuner/strategies/bayes_opt.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def predict_list(self, lst: list) -> Tuple[list, list, list]:
451451
with warnings.catch_warnings():
452452
warnings.simplefilter("ignore")
453453
mu, std = self.__model.predict(lst, return_std=True)
454-
return mu, std
454+
return list(zip(mu, std)), mu, std
455455

456456
def fit_observations_to_model(self):
457457
"""Update the model based on the current list of observations."""
@@ -540,7 +540,7 @@ def initial_sample(self):
540540
if self.is_valid(observation):
541541
collected_samples += 1
542542
self.fit_observations_to_model()
543-
_, std = self.predict_list(self.unvisited_cache)
543+
_, _, std = self.predict_list(self.unvisited_cache)
544544
self.initial_sample_mean = np.mean(self.__valid_observations)
545545
# Alternatively:
546546
# self.initial_sample_std = np.std(self.__valid_observations)
@@ -736,11 +736,11 @@ def __optimize_multi_advanced(self, max_fevals, increase_precision=False):
736736
if self.__visited_num >= self.searchspace_size or self.fevals >= max_fevals:
737737
break
738738
if increase_precision is True:
739-
predictions, _, std = self.predict_list(self.unvisited_cache)
739+
predictions = self.predict_list(self.unvisited_cache)
740740
hyperparam = self.contextual_variance(std)
741741
list_of_acquisition_values = af(predictions, hyperparam)
742742
best_af = self.argopt(list_of_acquisition_values)
743-
del predictions[best_af] # to avoid going out of bounds
743+
# del predictions[best_af] # to avoid going out of bounds
744744
candidate_params = self.unvisited_cache[best_af]
745745
candidate_index = self.find_param_config_index(candidate_params)
746746
observation = self.evaluate_objective_function(candidate_params)
@@ -855,13 +855,12 @@ def af_random(self, predictions=None, hyperparam=None) -> list:
855855
def af_probability_of_improvement(self, predictions=None, hyperparam=None) -> list:
856856
"""Acquisition function Probability of Improvement (PI)."""
857857
# prefetch required data
858-
x_mu, x_std = predictions
859858
if hyperparam is None:
860859
hyperparam = self.af_params["explorationfactor"]
861860
fplus = self.current_optimum - hyperparam
862861

863862
# precompute difference of improvement
864-
list_diff_improvement = list(-((fplus - x_mu) / (x_std + 1e-9)) for (x_mu, x_std) in predictions)
863+
list_diff_improvement = list(-((fplus - x_mu) / (x_std + 1e-9)) for x_mu, x_std in predictions[0])
865864

866865
# compute probability of improvement with CDF in bulk
867866
list_prob_improvement = norm.cdf(list_diff_improvement)
@@ -870,10 +869,15 @@ def af_probability_of_improvement(self, predictions=None, hyperparam=None) -> li
870869
def af_expected_improvement(self, predictions=None, hyperparam=None) -> list:
871870
"""Acquisition function Expected Improvement (EI)."""
872871
# prefetch required data
873-
x_mu, x_std = predictions
874872
if hyperparam is None:
875873
hyperparam = self.af_params["explorationfactor"]
876874
fplus = self.current_optimum - hyperparam
875+
if len(predictions) == 3:
876+
predictions, x_mu, x_std = predictions
877+
elif len(predictions) == 2:
878+
x_mu, x_std = predictions
879+
else:
880+
raise ValueError(f"Invalid predictions size {len(predictions)}")
877881

878882
# precompute difference of improvement, CDF and PDF in bulk
879883
list_diff_improvement = list((fplus - x_mu) / (x_std + 1e-9) for (x_mu, x_std) in predictions)
@@ -892,6 +896,7 @@ def af_lower_confidence_bound(self, predictions=None, hyperparam=None) -> list:
892896
if hyperparam is None:
893897
hyperparam = self.af_params["explorationfactor"]
894898
beta = hyperparam
899+
_, x_mu, x_std = predictions
895900

896901
# compute LCB in bulk
897902
list_lower_confidence_bound = (x_mu - beta * x_std)
@@ -900,7 +905,7 @@ def af_lower_confidence_bound(self, predictions=None, hyperparam=None) -> list:
900905
def af_lower_confidence_bound_srinivas(self, predictions=None, hyperparam=None) -> list:
901906
"""Acquisition function Lower Confidence Bound (UCB-S) after Srinivas, 2010 / Brochu, 2010."""
902907
# prefetch required data
903-
x_mu, x_std = predictions
908+
_, x_mu, x_std = predictions
904909
if hyperparam is None:
905910
hyperparam = self.af_params["explorationfactor"]
906911

0 commit comments

Comments
 (0)