Skip to content

Commit 32aa82d

Browse files
authored
MNT Clean-up deprecations for 1.7: old tags (scikit-learn#31134)
1 parent 9f3ca07 commit 32aa82d

File tree

5 files changed

+59
-941
lines changed

5 files changed

+59
-941
lines changed

sklearn/base.py

-59
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,11 @@
3030
)
3131
from .utils.fixes import _IS_32BIT
3232
from .utils.validation import (
33-
_check_feature_names,
3433
_check_feature_names_in,
35-
_check_n_features,
3634
_generate_get_feature_names_out,
3735
_is_fitted,
3836
check_array,
3937
check_is_fitted,
40-
validate_data,
4138
)
4239

4340

@@ -389,33 +386,6 @@ def __setstate__(self, state):
389386
except AttributeError:
390387
self.__dict__.update(state)
391388

392-
# TODO(1.7): Remove this method
393-
def _more_tags(self):
394-
"""This code should never be reached since our `get_tags` will fallback on
395-
`__sklearn_tags__` implemented below. We keep it for backward compatibility.
396-
It is tested in `test_base_estimator_more_tags` in
397-
`sklearn/utils/testing/test_tags.py`."""
398-
from sklearn.utils._tags import _to_old_tags, default_tags
399-
400-
warnings.warn(
401-
"The `_more_tags` method is deprecated in 1.6 and will be removed in "
402-
"1.7. Please implement the `__sklearn_tags__` method.",
403-
category=DeprecationWarning,
404-
)
405-
return _to_old_tags(default_tags(self))
406-
407-
# TODO(1.7): Remove this method
408-
def _get_tags(self):
409-
from sklearn.utils._tags import _to_old_tags, get_tags
410-
411-
warnings.warn(
412-
"The `_get_tags` method is deprecated in 1.6 and will be removed in "
413-
"1.7. Please implement the `__sklearn_tags__` method.",
414-
category=DeprecationWarning,
415-
)
416-
417-
return _to_old_tags(get_tags(self))
418-
419389
def __sklearn_tags__(self):
420390
return Tags(
421391
estimator_type=None,
@@ -469,35 +439,6 @@ def _repr_mimebundle_(self, **kwargs):
469439
output["text/html"] = estimator_html_repr(self)
470440
return output
471441

472-
# TODO(1.7): Remove this method
473-
def _validate_data(self, *args, **kwargs):
474-
warnings.warn(
475-
"`BaseEstimator._validate_data` is deprecated in 1.6 and will be removed "
476-
"in 1.7. Use `sklearn.utils.validation.validate_data` instead. This "
477-
"function becomes public and is part of the scikit-learn developer API.",
478-
FutureWarning,
479-
)
480-
return validate_data(self, *args, **kwargs)
481-
482-
# TODO(1.7): Remove this method
483-
def _check_n_features(self, *args, **kwargs):
484-
warnings.warn(
485-
"`BaseEstimator._check_n_features` is deprecated in 1.6 and will be "
486-
"removed in 1.7. Use `sklearn.utils.validation._check_n_features` instead.",
487-
FutureWarning,
488-
)
489-
_check_n_features(self, *args, **kwargs)
490-
491-
# TODO(1.7): Remove this method
492-
def _check_feature_names(self, *args, **kwargs):
493-
warnings.warn(
494-
"`BaseEstimator._check_feature_names` is deprecated in 1.6 and will be "
495-
"removed in 1.7. Use `sklearn.utils.validation._check_feature_names` "
496-
"instead.",
497-
FutureWarning,
498-
)
499-
_check_feature_names(self, *args, **kwargs)
500-
501442

502443
class ClassifierMixin:
503444
"""Mixin class for all classifiers in scikit-learn.

sklearn/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ def __sklearn_tags__(self):
12211221
tags.input_tags.sparse = all(
12221222
get_tags(step).input_tags.sparse
12231223
for name, step in self.steps
1224-
if step != "passthrough"
1224+
if step is not None and step != "passthrough"
12251225
)
12261226
except (ValueError, AttributeError, TypeError):
12271227
# This happens when the `steps` is not a list of (name, estimator)

sklearn/tests/test_common.py

-35
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import sklearn
2020
from sklearn.base import BaseEstimator
2121
from sklearn.compose import ColumnTransformer
22-
from sklearn.datasets import make_classification
2322
from sklearn.exceptions import ConvergenceWarning
2423

2524
# make it possible to discover experimental estimators when calling `all_estimators`
@@ -401,37 +400,3 @@ def test_check_inplace_ensure_writeable(estimator):
401400
estimator.set_params(kernel="precomputed")
402401

403402
check_inplace_ensure_writeable(name, estimator)
404-
405-
406-
# TODO(1.7): Remove this test when the deprecation cycle is over
407-
def test_transition_public_api_deprecations():
408-
"""This test checks that we raised deprecation warning explaining how to transition
409-
to the new developer public API from 1.5 to 1.6.
410-
"""
411-
412-
class OldEstimator(BaseEstimator):
413-
def fit(self, X, y=None):
414-
X = self._validate_data(X)
415-
self._check_n_features(X, reset=True)
416-
self._check_feature_names(X, reset=True)
417-
return self
418-
419-
def transform(self, X):
420-
return X # pragma: no cover
421-
422-
X, y = make_classification(n_samples=10, n_features=5, random_state=0)
423-
424-
old_estimator = OldEstimator()
425-
with pytest.warns(FutureWarning) as warning_list:
426-
old_estimator.fit(X)
427-
428-
assert len(warning_list) == 3
429-
assert str(warning_list[0].message).startswith(
430-
"`BaseEstimator._validate_data` is deprecated"
431-
)
432-
assert str(warning_list[1].message).startswith(
433-
"`BaseEstimator._check_n_features` is deprecated"
434-
)
435-
assert str(warning_list[2].message).startswith(
436-
"`BaseEstimator._check_feature_names` is deprecated"
437-
)

0 commit comments

Comments
 (0)