|
69 | 69 | _RMSLE_METRIC_ID = "rmsle"
|
70 | 70 | _MSE_METRIC_ID = "mse"
|
71 | 71 |
|
72 |
| -try: # Only used by local tuning loop |
73 |
| - import sklearn.metrics |
74 |
| - from sklearn.model_selection import train_test_split |
75 |
| - |
76 |
| - _SUPPORTED_METRIC_FUNCTIONS = { |
77 |
| - _ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score, |
78 |
| - _F1_METRIC_ID: sklearn.metrics.f1_score, |
79 |
| - _PRECISION_METRIC_ID: sklearn.metrics.precision_score, |
80 |
| - _RECALL_METRIC_ID: sklearn.metrics.recall_score, |
81 |
| - _ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score, |
82 |
| - _MAE_METRIC_ID: sklearn.metrics.mean_absolute_error, |
83 |
| - _MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error, |
84 |
| - _R2_METRIC_ID: sklearn.metrics.r2_score, |
85 |
| - _RMSE_METRIC_ID: functools.partial( |
86 |
| - sklearn.metrics.mean_squared_error, squared=False |
87 |
| - ), |
88 |
| - _RMSLE_METRIC_ID: functools.partial( |
89 |
| - sklearn.metrics.mean_squared_log_error, squared=False |
90 |
| - ), |
91 |
| - _MSE_METRIC_ID: sklearn.metrics.mean_squared_error, |
92 |
| - } |
93 |
| - _SUPPORTED_METRIC_IDS = frozenset(_SUPPORTED_METRIC_FUNCTIONS.keys()).union( |
94 |
| - frozenset([_CUSTOM_METRIC_ID]) |
95 |
| - ) |
96 |
| - _SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset( |
97 |
| - [ |
98 |
| - _ROC_AUC_METRIC_ID, |
99 |
| - _F1_METRIC_ID, |
100 |
| - _PRECISION_METRIC_ID, |
101 |
| - _RECALL_METRIC_ID, |
102 |
| - _ACCURACY_METRIC_ID, |
103 |
| - ] |
104 |
| - ) |
| 72 | +_SUPPORTED_METRIC_IDS = frozenset( |
| 73 | + [ |
| 74 | + _CUSTOM_METRIC_ID, |
| 75 | + _ROC_AUC_METRIC_ID, |
| 76 | + _F1_METRIC_ID, |
| 77 | + _PRECISION_METRIC_ID, |
| 78 | + _RECALL_METRIC_ID, |
| 79 | + _ACCURACY_METRIC_ID, |
| 80 | + _MAE_METRIC_ID, |
| 81 | + _MAPE_METRIC_ID, |
| 82 | + _R2_METRIC_ID, |
| 83 | + _RMSE_METRIC_ID, |
| 84 | + _RMSLE_METRIC_ID, |
| 85 | + _MSE_METRIC_ID, |
| 86 | + ] |
| 87 | +) |
| 88 | +_SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset( |
| 89 | + [ |
| 90 | + _ROC_AUC_METRIC_ID, |
| 91 | + _F1_METRIC_ID, |
| 92 | + _PRECISION_METRIC_ID, |
| 93 | + _RECALL_METRIC_ID, |
| 94 | + _ACCURACY_METRIC_ID, |
| 95 | + ] |
| 96 | +) |
105 | 97 |
|
106 |
| -except ImportError: |
107 |
| - pass |
108 | 98 |
|
109 | 99 | # Vizier client constnats
|
110 | 100 | _STUDY_NAME_PREFIX = "vizier_hyperparameter_tuner_study"
|
@@ -366,6 +356,13 @@ def _create_train_and_test_splits(
|
366 | 356 | "test_fraction must be greater than 0 and less than 1 but was "
|
367 | 357 | f"{test_fraction}."
|
368 | 358 | )
|
| 359 | + try: |
| 360 | + from sklearn.model_selection import train_test_split |
| 361 | + except ImportError: |
| 362 | + raise ImportError( |
| 363 | + "scikit-learn must be installed to create train and test splits. " |
| 364 | + "Please call `pip install scikit-learn>=0.24`" |
| 365 | + ) from None |
369 | 366 |
|
370 | 367 | if isinstance(y, str):
|
371 | 368 | try:
|
@@ -414,6 +411,32 @@ def score(x_test, y_test):
|
414 | 411 | Returns:
|
415 | 412 | A tuple containing the model and the corresponding metric value.
|
416 | 413 | """
|
| 414 | + try: # Only used by local tuning loop |
| 415 | + import sklearn.metrics |
| 416 | + |
| 417 | + _SUPPORTED_METRIC_FUNCTIONS = { |
| 418 | + _ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score, |
| 419 | + _F1_METRIC_ID: sklearn.metrics.f1_score, |
| 420 | + _PRECISION_METRIC_ID: sklearn.metrics.precision_score, |
| 421 | + _RECALL_METRIC_ID: sklearn.metrics.recall_score, |
| 422 | + _ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score, |
| 423 | + _MAE_METRIC_ID: sklearn.metrics.mean_absolute_error, |
| 424 | + _MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error, |
| 425 | + _R2_METRIC_ID: sklearn.metrics.r2_score, |
| 426 | + _RMSE_METRIC_ID: functools.partial( |
| 427 | + sklearn.metrics.mean_squared_error, squared=False |
| 428 | + ), |
| 429 | + _RMSLE_METRIC_ID: functools.partial( |
| 430 | + sklearn.metrics.mean_squared_log_error, squared=False |
| 431 | + ), |
| 432 | + _MSE_METRIC_ID: sklearn.metrics.mean_squared_error, |
| 433 | + } |
| 434 | + except Exception as e: |
| 435 | + raise ImportError( |
| 436 | + "scikit-learn must be installed to evaluate models. " |
| 437 | + "Please call `pip install scikit-learn>=0.24`" |
| 438 | + ) from e |
| 439 | + |
417 | 440 | if self.metric_id == _CUSTOM_METRIC_ID:
|
418 | 441 | metric_value = model.score(x_test, y_test)
|
419 | 442 | else:
|
|
0 commit comments