1
1
# coding: utf-8
2
2
"""Compatibility library."""
3
3
4
- from typing import Any , List
4
+ from typing import TYPE_CHECKING , Any , List
5
5
6
6
# scikit-learn is intentionally imported first here,
7
7
# see https://github.com/microsoft/LightGBM/issues/6509
8
8
"""sklearn"""
9
9
try :
10
+ from sklearn import __version__ as _sklearn_version
10
11
from sklearn .base import BaseEstimator , ClassifierMixin , RegressorMixin
11
12
from sklearn .preprocessing import LabelEncoder
12
13
from sklearn .utils .class_weight import compute_sample_weight
@@ -29,6 +30,74 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
29
30
check_consistent_length (sample_weight , X )
30
31
return sample_weight
31
32
33
+ try :
34
+ from sklearn .utils .validation import validate_data
35
+ except ImportError :
36
+ # validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
37
+ # It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
38
+ def validate_data (
39
+ _estimator ,
40
+ X ,
41
+ y = "no_validation" ,
42
+ accept_sparse : bool = True ,
43
+ # 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
44
+ ensure_all_finite : bool = False ,
45
+ ensure_min_samples : int = 1 ,
46
+ # trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
47
+ ** ignored_kwargs ,
48
+ ):
49
+ # it's safe to import _num_features unconditionally because:
50
+ #
51
+ # * it was first added in scikit-learn 0.24.2
52
+ # * lightgbm cannot be used with scikit-learn versions older than that
53
+ # * this validate_data() re-implementation will not be called in scikit-learn>=1.6
54
+ #
55
+ from sklearn .utils .validation import _num_features
56
+
57
+ # _num_features() raises a TypeError on 1-dimensional input. That's a problem
58
+ # because scikit-learn's 'check_fit1d' estimator check sets that expectation that
59
+ # estimators must raise a ValueError when a 1-dimensional input is passed to fit().
60
+ #
61
+ # So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
62
+ if hasattr (X , "shape" ) and len (X .shape ) == 1 :
63
+ n_features_in_ = 1
64
+ else :
65
+ n_features_in_ = _num_features (X )
66
+
67
+ no_val_y = isinstance (y , str ) and y == "no_validation"
68
+
69
+ # NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
70
+ if no_val_y :
71
+ X = check_array (
72
+ X ,
73
+ accept_sparse = accept_sparse ,
74
+ force_all_finite = ensure_all_finite ,
75
+ ensure_min_samples = ensure_min_samples ,
76
+ )
77
+ else :
78
+ X , y = check_X_y (
79
+ X ,
80
+ y ,
81
+ accept_sparse = accept_sparse ,
82
+ force_all_finite = ensure_all_finite ,
83
+ ensure_min_samples = ensure_min_samples ,
84
+ )
85
+
86
+ # this only needs to be updated at fit() time
87
+ _estimator .n_features_in_ = n_features_in_
88
+
89
+ # raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
90
+ if _estimator .__sklearn_is_fitted__ () and _estimator ._n_features != n_features_in_ :
91
+ raise ValueError (
92
+ f"X has { n_features_in_ } features, but { _estimator .__class__ .__name__ } "
93
+ f"is expecting { _estimator ._n_features } features as input."
94
+ )
95
+
96
+ if no_val_y :
97
+ return X
98
+ else :
99
+ return X , y
100
+
32
101
SKLEARN_INSTALLED = True
33
102
_LGBMBaseCrossValidator = BaseCrossValidator
34
103
_LGBMModelBase = BaseEstimator
@@ -38,12 +107,11 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
38
107
LGBMNotFittedError = NotFittedError
39
108
_LGBMStratifiedKFold = StratifiedKFold
40
109
_LGBMGroupKFold = GroupKFold
41
- _LGBMCheckXY = check_X_y
42
- _LGBMCheckArray = check_array
43
110
_LGBMCheckSampleWeight = _check_sample_weight
44
111
_LGBMAssertAllFinite = assert_all_finite
45
112
_LGBMCheckClassificationTargets = check_classification_targets
46
113
_LGBMComputeSampleWeight = compute_sample_weight
114
+ _LGBMValidateData = validate_data
47
115
except ImportError :
48
116
SKLEARN_INSTALLED = False
49
117
@@ -67,12 +135,22 @@ class _LGBMRegressorBase: # type: ignore
67
135
LGBMNotFittedError = ValueError
68
136
_LGBMStratifiedKFold = None
69
137
_LGBMGroupKFold = None
70
- _LGBMCheckXY = None
71
- _LGBMCheckArray = None
72
138
_LGBMCheckSampleWeight = None
73
139
_LGBMAssertAllFinite = None
74
140
_LGBMCheckClassificationTargets = None
75
141
_LGBMComputeSampleWeight = None
142
+ _LGBMValidateData = None
143
+ _sklearn_version = None
144
+
145
+ # additional scikit-learn imports only for type hints
146
+ if TYPE_CHECKING :
147
+ # sklearn.utils.Tags can be imported unconditionally once
148
+ # lightgbm's minimum scikit-learn version is 1.6 or higher
149
+ try :
150
+ from sklearn .utils import Tags as _sklearn_Tags
151
+ except ImportError :
152
+ _sklearn_Tags = None
153
+
76
154
77
155
"""pandas"""
78
156
try :
0 commit comments