Skip to content

Commit

Permalink
chore: fix that sklearn will be imported when vizer hptuning is not used
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 566442578
  • Loading branch information
jaycee-li authored and copybara-github committed Sep 18, 2023
1 parent 1d15f82 commit a0fe340
Showing 1 changed file with 58 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,32 @@
_RMSLE_METRIC_ID = "rmsle"
_MSE_METRIC_ID = "mse"

try: # Only used by local tuning loop
import sklearn.metrics
from sklearn.model_selection import train_test_split

_SUPPORTED_METRIC_FUNCTIONS = {
_ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score,
_F1_METRIC_ID: sklearn.metrics.f1_score,
_PRECISION_METRIC_ID: sklearn.metrics.precision_score,
_RECALL_METRIC_ID: sklearn.metrics.recall_score,
_ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score,
_MAE_METRIC_ID: sklearn.metrics.mean_absolute_error,
_MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error,
_R2_METRIC_ID: sklearn.metrics.r2_score,
_RMSE_METRIC_ID: functools.partial(
sklearn.metrics.mean_squared_error, squared=False
),
_RMSLE_METRIC_ID: functools.partial(
sklearn.metrics.mean_squared_log_error, squared=False
),
_MSE_METRIC_ID: sklearn.metrics.mean_squared_error,
}
_SUPPORTED_METRIC_IDS = frozenset(_SUPPORTED_METRIC_FUNCTIONS.keys()).union(
frozenset([_CUSTOM_METRIC_ID])
)
_SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset(
[
_ROC_AUC_METRIC_ID,
_F1_METRIC_ID,
_PRECISION_METRIC_ID,
_RECALL_METRIC_ID,
_ACCURACY_METRIC_ID,
]
)
_SUPPORTED_METRIC_IDS = frozenset(
[
_CUSTOM_METRIC_ID,
_ROC_AUC_METRIC_ID,
_F1_METRIC_ID,
_PRECISION_METRIC_ID,
_RECALL_METRIC_ID,
_ACCURACY_METRIC_ID,
_MAE_METRIC_ID,
_MAPE_METRIC_ID,
_R2_METRIC_ID,
_RMSE_METRIC_ID,
_RMSLE_METRIC_ID,
_MSE_METRIC_ID,
]
)
_SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset(
[
_ROC_AUC_METRIC_ID,
_F1_METRIC_ID,
_PRECISION_METRIC_ID,
_RECALL_METRIC_ID,
_ACCURACY_METRIC_ID,
]
)

except ImportError:
pass

# Vizier client constnats
_STUDY_NAME_PREFIX = "vizier_hyperparameter_tuner_study"
Expand Down Expand Up @@ -366,6 +356,13 @@ def _create_train_and_test_splits(
"test_fraction must be greater than 0 and less than 1 but was "
f"{test_fraction}."
)
try:
from sklearn.model_selection import train_test_split
except ImportError:
raise ImportError(
"scikit-learn must be installed to create train and test splits. "
"Please call `pip install scikit-learn>=0.24`"
) from None

if isinstance(y, str):
try:
Expand Down Expand Up @@ -414,6 +411,32 @@ def score(x_test, y_test):
Returns:
A tuple containing the model and the corresponding metric value.
"""
try: # Only used by local tuning loop
import sklearn.metrics

_SUPPORTED_METRIC_FUNCTIONS = {
_ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score,
_F1_METRIC_ID: sklearn.metrics.f1_score,
_PRECISION_METRIC_ID: sklearn.metrics.precision_score,
_RECALL_METRIC_ID: sklearn.metrics.recall_score,
_ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score,
_MAE_METRIC_ID: sklearn.metrics.mean_absolute_error,
_MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error,
_R2_METRIC_ID: sklearn.metrics.r2_score,
_RMSE_METRIC_ID: functools.partial(
sklearn.metrics.mean_squared_error, squared=False
),
_RMSLE_METRIC_ID: functools.partial(
sklearn.metrics.mean_squared_log_error, squared=False
),
_MSE_METRIC_ID: sklearn.metrics.mean_squared_error,
}
except Exception as e:
raise ImportError(
"scikit-learn must be installed to evaluate models. "
"Please call `pip install scikit-learn>=0.24`"
) from e

if self.metric_id == _CUSTOM_METRIC_ID:
metric_value = model.score(x_test, y_test)
else:
Expand Down

0 comments on commit a0fe340

Please sign in to comment.