Skip to content

Commit a0fe340

Browse files
jaycee-licopybara-github
authored andcommitted
chore: fix that sklearn will be imported when vizer hptuning is not used
PiperOrigin-RevId: 566442578
1 parent 1d15f82 commit a0fe340

File tree

1 file changed

+58
-35
lines changed

1 file changed

+58
-35
lines changed

vertexai/preview/hyperparameter_tuning/vizier_hyperparameter_tuner.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -69,42 +69,32 @@
6969
_RMSLE_METRIC_ID = "rmsle"
7070
_MSE_METRIC_ID = "mse"
7171

72-
try: # Only used by local tuning loop
73-
import sklearn.metrics
74-
from sklearn.model_selection import train_test_split
75-
76-
_SUPPORTED_METRIC_FUNCTIONS = {
77-
_ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score,
78-
_F1_METRIC_ID: sklearn.metrics.f1_score,
79-
_PRECISION_METRIC_ID: sklearn.metrics.precision_score,
80-
_RECALL_METRIC_ID: sklearn.metrics.recall_score,
81-
_ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score,
82-
_MAE_METRIC_ID: sklearn.metrics.mean_absolute_error,
83-
_MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error,
84-
_R2_METRIC_ID: sklearn.metrics.r2_score,
85-
_RMSE_METRIC_ID: functools.partial(
86-
sklearn.metrics.mean_squared_error, squared=False
87-
),
88-
_RMSLE_METRIC_ID: functools.partial(
89-
sklearn.metrics.mean_squared_log_error, squared=False
90-
),
91-
_MSE_METRIC_ID: sklearn.metrics.mean_squared_error,
92-
}
93-
_SUPPORTED_METRIC_IDS = frozenset(_SUPPORTED_METRIC_FUNCTIONS.keys()).union(
94-
frozenset([_CUSTOM_METRIC_ID])
95-
)
96-
_SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset(
97-
[
98-
_ROC_AUC_METRIC_ID,
99-
_F1_METRIC_ID,
100-
_PRECISION_METRIC_ID,
101-
_RECALL_METRIC_ID,
102-
_ACCURACY_METRIC_ID,
103-
]
104-
)
72+
_SUPPORTED_METRIC_IDS = frozenset(
73+
[
74+
_CUSTOM_METRIC_ID,
75+
_ROC_AUC_METRIC_ID,
76+
_F1_METRIC_ID,
77+
_PRECISION_METRIC_ID,
78+
_RECALL_METRIC_ID,
79+
_ACCURACY_METRIC_ID,
80+
_MAE_METRIC_ID,
81+
_MAPE_METRIC_ID,
82+
_R2_METRIC_ID,
83+
_RMSE_METRIC_ID,
84+
_RMSLE_METRIC_ID,
85+
_MSE_METRIC_ID,
86+
]
87+
)
88+
_SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset(
89+
[
90+
_ROC_AUC_METRIC_ID,
91+
_F1_METRIC_ID,
92+
_PRECISION_METRIC_ID,
93+
_RECALL_METRIC_ID,
94+
_ACCURACY_METRIC_ID,
95+
]
96+
)
10597

106-
except ImportError:
107-
pass
10898

10999
# Vizier client constnats
110100
_STUDY_NAME_PREFIX = "vizier_hyperparameter_tuner_study"
@@ -366,6 +356,13 @@ def _create_train_and_test_splits(
366356
"test_fraction must be greater than 0 and less than 1 but was "
367357
f"{test_fraction}."
368358
)
359+
try:
360+
from sklearn.model_selection import train_test_split
361+
except ImportError:
362+
raise ImportError(
363+
"scikit-learn must be installed to create train and test splits. "
364+
"Please call `pip install scikit-learn>=0.24`"
365+
) from None
369366

370367
if isinstance(y, str):
371368
try:
@@ -414,6 +411,32 @@ def score(x_test, y_test):
414411
Returns:
415412
A tuple containing the model and the corresponding metric value.
416413
"""
414+
try: # Only used by local tuning loop
415+
import sklearn.metrics
416+
417+
_SUPPORTED_METRIC_FUNCTIONS = {
418+
_ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score,
419+
_F1_METRIC_ID: sklearn.metrics.f1_score,
420+
_PRECISION_METRIC_ID: sklearn.metrics.precision_score,
421+
_RECALL_METRIC_ID: sklearn.metrics.recall_score,
422+
_ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score,
423+
_MAE_METRIC_ID: sklearn.metrics.mean_absolute_error,
424+
_MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error,
425+
_R2_METRIC_ID: sklearn.metrics.r2_score,
426+
_RMSE_METRIC_ID: functools.partial(
427+
sklearn.metrics.mean_squared_error, squared=False
428+
),
429+
_RMSLE_METRIC_ID: functools.partial(
430+
sklearn.metrics.mean_squared_log_error, squared=False
431+
),
432+
_MSE_METRIC_ID: sklearn.metrics.mean_squared_error,
433+
}
434+
except Exception as e:
435+
raise ImportError(
436+
"scikit-learn must be installed to evaluate models. "
437+
"Please call `pip install scikit-learn>=0.24`"
438+
) from e
439+
417440
if self.metric_id == _CUSTOM_METRIC_ID:
418441
metric_value = model.score(x_test, y_test)
419442
else:

0 commit comments

Comments
 (0)