Skip to content

Commit 7eae66a

Browse files
vnherdeirojameslambStrikerRUS
authored
[python-package] require scikit-learn>=0.24.2, make scikit-learn estimators compatible with scikit-learn>=1.6.0dev (#6651)
Co-authored-by: James Lamb <[email protected]> Co-authored-by: Nikita Titov <[email protected]>
1 parent 0643230 commit 7eae66a

File tree

7 files changed

+309
-29
lines changed

7 files changed

+309
-29
lines changed

Diff for: .ci/test-python-latest.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ python -m pip install \
2222
'numpy>=2.0.0.dev0' \
2323
'matplotlib>=3.10.0.dev0' \
2424
'pandas>=3.0.0.dev0' \
25-
'scikit-learn==1.5.*' \
25+
'scikit-learn>=1.6.dev0' \
2626
'scipy>=1.15.0.dev0'
2727

2828
python -m pip install \

Diff for: .ci/test-python-oldest.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pip install \
1515
'numpy==1.19.0' \
1616
'pandas==1.1.3' \
1717
'pyarrow==6.0.1' \
18-
'scikit-learn==0.24.0' \
18+
'scikit-learn==0.24.2' \
1919
'scipy==1.6.0' \
2020
|| exit 1
2121
echo "done installing lightgbm's dependencies"

Diff for: .ci/test.sh

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ if [[ $TASK == "lint" ]]; then
103103
'mypy>=1.11.1' \
104104
'pre-commit>=3.8.0' \
105105
'pyarrow-core>=17.0' \
106+
'scikit-learn>=1.5.2' \
106107
'r-lintr>=3.1.2'
107108
source activate $CONDA_ENV
108109
echo "Linting Python code"

Diff for: python-package/lightgbm/compat.py

+83-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# coding: utf-8
22
"""Compatibility library."""
33

4-
from typing import Any, List
4+
from typing import TYPE_CHECKING, Any, List
55

66
# scikit-learn is intentionally imported first here,
77
# see https://github.com/microsoft/LightGBM/issues/6509
88
"""sklearn"""
99
try:
10+
from sklearn import __version__ as _sklearn_version
1011
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
1112
from sklearn.preprocessing import LabelEncoder
1213
from sklearn.utils.class_weight import compute_sample_weight
@@ -29,6 +30,74 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
2930
check_consistent_length(sample_weight, X)
3031
return sample_weight
3132

33+
try:
34+
from sklearn.utils.validation import validate_data
35+
except ImportError:
36+
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
37+
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
38+
def validate_data(
39+
_estimator,
40+
X,
41+
y="no_validation",
42+
accept_sparse: bool = True,
43+
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
44+
ensure_all_finite: bool = False,
45+
ensure_min_samples: int = 1,
46+
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
47+
**ignored_kwargs,
48+
):
49+
# it's safe to import _num_features unconditionally because:
50+
#
51+
# * it was first added in scikit-learn 0.24.2
52+
# * lightgbm cannot be used with scikit-learn versions older than that
53+
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
54+
#
55+
from sklearn.utils.validation import _num_features
56+
57+
# _num_features() raises a TypeError on 1-dimensional input. That's a problem
58+
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
59+
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
60+
#
61+
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
62+
if hasattr(X, "shape") and len(X.shape) == 1:
63+
n_features_in_ = 1
64+
else:
65+
n_features_in_ = _num_features(X)
66+
67+
no_val_y = isinstance(y, str) and y == "no_validation"
68+
69+
# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
70+
if no_val_y:
71+
X = check_array(
72+
X,
73+
accept_sparse=accept_sparse,
74+
force_all_finite=ensure_all_finite,
75+
ensure_min_samples=ensure_min_samples,
76+
)
77+
else:
78+
X, y = check_X_y(
79+
X,
80+
y,
81+
accept_sparse=accept_sparse,
82+
force_all_finite=ensure_all_finite,
83+
ensure_min_samples=ensure_min_samples,
84+
)
85+
86+
# this only needs to be updated at fit() time
87+
_estimator.n_features_in_ = n_features_in_
88+
89+
# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
90+
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
91+
raise ValueError(
92+
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
93+
f"is expecting {_estimator._n_features} features as input."
94+
)
95+
96+
if no_val_y:
97+
return X
98+
else:
99+
return X, y
100+
32101
SKLEARN_INSTALLED = True
33102
_LGBMBaseCrossValidator = BaseCrossValidator
34103
_LGBMModelBase = BaseEstimator
@@ -38,12 +107,11 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
38107
LGBMNotFittedError = NotFittedError
39108
_LGBMStratifiedKFold = StratifiedKFold
40109
_LGBMGroupKFold = GroupKFold
41-
_LGBMCheckXY = check_X_y
42-
_LGBMCheckArray = check_array
43110
_LGBMCheckSampleWeight = _check_sample_weight
44111
_LGBMAssertAllFinite = assert_all_finite
45112
_LGBMCheckClassificationTargets = check_classification_targets
46113
_LGBMComputeSampleWeight = compute_sample_weight
114+
_LGBMValidateData = validate_data
47115
except ImportError:
48116
SKLEARN_INSTALLED = False
49117

@@ -67,12 +135,22 @@ class _LGBMRegressorBase: # type: ignore
67135
LGBMNotFittedError = ValueError
68136
_LGBMStratifiedKFold = None
69137
_LGBMGroupKFold = None
70-
_LGBMCheckXY = None
71-
_LGBMCheckArray = None
72138
_LGBMCheckSampleWeight = None
73139
_LGBMAssertAllFinite = None
74140
_LGBMCheckClassificationTargets = None
75141
_LGBMComputeSampleWeight = None
142+
_LGBMValidateData = None
143+
_sklearn_version = None
144+
145+
# additional scikit-learn imports only for type hints
146+
if TYPE_CHECKING:
147+
# sklearn.utils.Tags can be imported unconditionally once
148+
# lightgbm's minimum scikit-learn version is 1.6 or higher
149+
try:
150+
from sklearn.utils import Tags as _sklearn_Tags
151+
except ImportError:
152+
_sklearn_Tags = None
153+
76154

77155
"""pandas"""
78156
try:

0 commit comments

Comments
 (0)