Skip to content

Commit dc2b587

Browse files
maikiaagramfort
andauthored
DEP deprecates 'normalize' in _omp.py (scikit-learn#17750)
Co-authored-by: Alexandre Gramfort <[email protected]>
1 parent c2b5c56 commit dc2b587

File tree

5 files changed

+77
-12
lines changed

5 files changed

+77
-12
lines changed

doc/whats_new/v1.0.rst

+8-2
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,17 @@ Changelog
358358
LinearRegression was deprecated in:
359359
:pr:`17743` by :user:`Maria Telenczuk <maikia>` and
360360
:user:`Alexandre Gramfort <agramfort>`.
361-
Ridge, RidgeClassifier, RidgeCV or RidgeClassifierCV were deprecated in:
361+
The ``normalize`` parameter in Ridge, RidgeClassifier, RidgeCV or
362+
RidgeClassifierCV were deprecated and will be removed in 1.2.
362363
:pr:`17772` by :user:`Maria Telenczuk <maikia>` and
363364
:user:`Alexandre Gramfort <agramfort>`.
364-
BayesianRidge, ARDRegression were deprecated in:
365+
Same for BayesianRidge, ARDRegression in:
365366
:pr:`17746` by :user:`Maria Telenczuk <maikia>`.
367+
The ``normalize`` parameter of :class:`linear_model.OrthogonalMatchingPursuit`
368+
:class:`linear_model.OrthogonalMatchingPursuitCV` will default to
369+
False in 1.2 and will be removed in 1.4.
370+
:pr:`17750` by :user:`Maria Telenczuk <maikia>` and
371+
:user:`Alexandre Gramfort <agramfort>`.
366372

367373
- |Fix| `sample_weight` are now fully taken into account in linear models
368374
when `normalize=True` for both feature centering and feature

examples/linear_model/plot_omp.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
plt.stem(idx, w[idx], use_line_collection=True)
4242

4343
# plot the noise-free reconstruction
44-
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs)
44+
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs,
45+
normalize=False)
4546
omp.fit(X, y)
4647
coef = omp.coef_
4748
idx_r, = coef.nonzero()
@@ -60,7 +61,7 @@
6061
plt.stem(idx_r, coef[idx_r], use_line_collection=True)
6162

6263
# plot the noisy reconstruction with number of non-zeros set by CV
63-
omp_cv = OrthogonalMatchingPursuitCV()
64+
omp_cv = OrthogonalMatchingPursuitCV(normalize=False)
6465
omp_cv.fit(X, y_noisy)
6566
coef = omp_cv.coef_
6667
idx_r, = coef.nonzero()

sklearn/linear_model/_omp.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from scipy.linalg.lapack import get_lapack_funcs
1414
from joblib import Parallel
1515

16-
from ._base import LinearModel, _pre_fit
16+
from ._base import LinearModel, _pre_fit, _deprecate_normalize
1717
from ..base import RegressorMixin, MultiOutputMixin
1818
from ..utils import as_float_array, check_array
1919
from ..utils.fixes import delayed
@@ -616,6 +616,10 @@ class OrthogonalMatchingPursuit(MultiOutputMixin, RegressorMixin, LinearModel):
616616
:class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
617617
on an estimator with ``normalize=False``.
618618
619+
.. deprecated:: 1.0
620+
``normalize`` was deprecated in version 1.0. It will default
621+
to False in 1.2 and be removed in 1.4.
622+
619623
precompute : 'auto' or bool, default='auto'
620624
Whether to use a precomputed Gram and Xy matrix to speed up
621625
calculations. Improves performance when :term:`n_targets` or
@@ -648,7 +652,7 @@ class OrthogonalMatchingPursuit(MultiOutputMixin, RegressorMixin, LinearModel):
648652
>>> from sklearn.linear_model import OrthogonalMatchingPursuit
649653
>>> from sklearn.datasets import make_regression
650654
>>> X, y = make_regression(noise=4, random_state=0)
651-
>>> reg = OrthogonalMatchingPursuit().fit(X, y)
655+
>>> reg = OrthogonalMatchingPursuit(normalize=False).fit(X, y)
652656
>>> reg.score(X, y)
653657
0.9991...
654658
>>> reg.predict(X[:1,])
@@ -683,7 +687,7 @@ def __init__(
683687
n_nonzero_coefs=None,
684688
tol=None,
685689
fit_intercept=True,
686-
normalize=True,
690+
normalize="deprecated",
687691
precompute="auto",
688692
):
689693
self.n_nonzero_coefs = n_nonzero_coefs
@@ -709,11 +713,15 @@ def fit(self, X, y):
709713
self : object
710714
returns an instance of self.
711715
"""
716+
_normalize = _deprecate_normalize(
717+
self.normalize, default=True, estimator_name=self.__class__.__name__
718+
)
719+
712720
X, y = self._validate_data(X, y, multi_output=True, y_numeric=True)
713721
n_features = X.shape[1]
714722

715723
X, y, X_offset, y_offset, X_scale, Gram, Xy = _pre_fit(
716-
X, y, None, self.precompute, self.normalize, self.fit_intercept, copy=True
724+
X, y, None, self.precompute, _normalize, self.fit_intercept, copy=True
717725
)
718726

719727
if y.ndim == 1:
@@ -797,6 +805,10 @@ def _omp_path_residues(
797805
:class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
798806
on an estimator with ``normalize=False``.
799807
808+
.. deprecated:: 1.0
809+
``normalize`` was deprecated in version 1.0. It will default
810+
to False in 1.2 and be removed in 1.4.
811+
800812
max_iter : int, default=100
801813
Maximum numbers of iterations to perform, therefore maximum features
802814
to include. 100 by default.
@@ -872,6 +884,10 @@ class OrthogonalMatchingPursuitCV(RegressorMixin, LinearModel):
872884
:class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
873885
on an estimator with ``normalize=False``.
874886
887+
.. deprecated:: 1.0
888+
``normalize`` was deprecated in version 1.0. It will default
889+
to False in 1.2 and be removed in 1.4.
890+
875891
max_iter : int, default=None
876892
Maximum numbers of iterations to perform, therefore maximum features
877893
to include. 10% of ``n_features`` but at least 5 if available.
@@ -929,7 +945,7 @@ class OrthogonalMatchingPursuitCV(RegressorMixin, LinearModel):
929945
>>> from sklearn.datasets import make_regression
930946
>>> X, y = make_regression(n_features=100, n_informative=10,
931947
... noise=4, random_state=0)
932-
>>> reg = OrthogonalMatchingPursuitCV(cv=5).fit(X, y)
948+
>>> reg = OrthogonalMatchingPursuitCV(cv=5, normalize=False).fit(X, y)
933949
>>> reg.score(X, y)
934950
0.9991...
935951
>>> reg.n_nonzero_coefs_
@@ -956,7 +972,7 @@ def __init__(
956972
*,
957973
copy=True,
958974
fit_intercept=True,
959-
normalize=True,
975+
normalize="deprecated",
960976
max_iter=None,
961977
cv=None,
962978
n_jobs=None,
@@ -986,6 +1002,11 @@ def fit(self, X, y):
9861002
self : object
9871003
returns an instance of self.
9881004
"""
1005+
1006+
_normalize = _deprecate_normalize(
1007+
self.normalize, default=True, estimator_name=self.__class__.__name__
1008+
)
1009+
9891010
X, y = self._validate_data(
9901011
X, y, y_numeric=True, ensure_min_features=2, estimator=self
9911012
)
@@ -1004,7 +1025,7 @@ def fit(self, X, y):
10041025
y[test],
10051026
self.copy,
10061027
self.fit_intercept,
1007-
self.normalize,
1028+
_normalize,
10081029
max_iter,
10091030
)
10101031
for train, test in cv.split(X)
@@ -1019,7 +1040,7 @@ def fit(self, X, y):
10191040
omp = OrthogonalMatchingPursuit(
10201041
n_nonzero_coefs=best_n_nonzero_coefs,
10211042
fit_intercept=self.fit_intercept,
1022-
normalize=self.normalize,
1043+
normalize=_normalize,
10231044
)
10241045
omp.fit(X, y)
10251046
self.coef_ = omp.coef_

sklearn/linear_model/tests/test_omp.py

+30
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,30 @@
3535
# and y (n_samples, 3)
3636

3737

38+
# FIXME: 'normalize' to set to False in 1.2 and removed in 1.4
39+
@pytest.mark.parametrize(
40+
"OmpModel", [OrthogonalMatchingPursuit, OrthogonalMatchingPursuitCV]
41+
)
42+
@pytest.mark.parametrize(
43+
"normalize, n_warnings", [(True, 0), (False, 0), ("deprecated", 1)]
44+
)
45+
def test_assure_warning_when_normalize(OmpModel, normalize, n_warnings):
46+
# check that we issue a FutureWarning when normalize was set
47+
rng = check_random_state(0)
48+
n_samples = 200
49+
n_features = 2
50+
X = rng.randn(n_samples, n_features)
51+
X[X < 0.1] = 0.0
52+
y = rng.rand(n_samples)
53+
54+
model = OmpModel(normalize=normalize)
55+
with pytest.warns(None) as record:
56+
model.fit(X, y)
57+
58+
record = [r for r in record if r.category == FutureWarning]
59+
assert len(record) == n_warnings
60+
61+
3862
def test_correct_shapes():
3963
assert orthogonal_mp(X, y[:, 0], n_nonzero_coefs=5).shape == (n_features,)
4064
assert orthogonal_mp(X, y, n_nonzero_coefs=5).shape == (n_features, 3)
@@ -125,6 +149,8 @@ def test_orthogonal_mp_gram_readonly():
125149
assert_array_almost_equal(gamma[:, 0], gamma_gram, decimal=2)
126150

127151

152+
# FIXME: 'normalize' to be removed in 1.4
153+
@pytest.mark.filterwarnings("ignore:The default of 'normalize'")
128154
def test_estimator():
129155
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs)
130156
omp.fit(X, y[:, 0])
@@ -211,6 +237,8 @@ def test_omp_return_path_prop_with_gram():
211237
assert_array_almost_equal(path[:, :, -1], last)
212238

213239

240+
# FIXME: 'normalize' to be removed in 1.4
241+
@pytest.mark.filterwarnings("ignore:The default of 'normalize'")
214242
def test_omp_cv():
215243
y_ = y[:, 0]
216244
gamma_ = gamma[:, 0]
@@ -227,6 +255,8 @@ def test_omp_cv():
227255
assert_array_almost_equal(ompcv.coef_, omp.coef_)
228256

229257

258+
# FIXME: 'normalize' to be removed in 1.4
259+
@pytest.mark.filterwarnings("ignore:The default of 'normalize'")
230260
def test_omp_reaches_least_squares():
231261
# Use small simple data; it's a sanity check but OMP can stop early
232262
rng = check_random_state(0)

sklearn/tests/test_docstring_parameters.py

+7
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ def test_fit_docstring_attributes(name, Estimator):
241241
# default="auto" raises an error with the shape of `X`
242242
est.set_params(n_components=2)
243243

244+
# FIXME: TO BE REMOVED in 1.4 (avoid FutureWarning)
245+
if Estimator.__name__ in (
246+
"OrthogonalMatchingPursuit",
247+
"OrthogonalMatchingPursuitCV",
248+
):
249+
est.set_params(normalize=False)
250+
244251
# FIXME: TO BE REMOVED for 1.1 (avoid FutureWarning)
245252
if Estimator.__name__ == "NMF":
246253
est.set_params(init="nndsvda")

0 commit comments

Comments
 (0)