Skip to content

Commit 9a6e90a

Browse files
MarcBressonjeremiedbbOmarManzoor
authored
ENH: improve validation for SGD models to accept l1_ratio=None when penalty is not elasticnet (scikit-learn#30730)
Co-authored-by: Jérémie du Boisberranger <[email protected]> Co-authored-by: Omar Salman <[email protected]>
1 parent 4af26a7 commit 9a6e90a

File tree

3 files changed

+47
-6
lines changed

3 files changed

+47
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :class:`linear_model.SGDClassifier` and :class:`linear_model.SGDRegressor` now accept
2+
`l1_ratio=None` when `penalty` is not `"elasticnet"`.
3+
By :user:`Marc Bresson <MarcBresson>`.

sklearn/linear_model/_stochastic_gradient.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,20 @@ def _more_validate_params(self, for_partial_fit=False):
154154
"learning_rate is 'optimal'. alpha is used "
155155
"to compute the optimal learning rate."
156156
)
157+
if self.penalty == "elasticnet" and self.l1_ratio is None:
158+
raise ValueError("l1_ratio must be set when penalty is 'elasticnet'")
157159

158160
# raises ValueError if not registered
159161
self._get_penalty_type(self.penalty)
160162
self._get_learning_rate_type(self.learning_rate)
161163

164+
def _get_l1_ratio(self):
165+
if self.l1_ratio is None:
166+
# plain_sgd expects a float. Any value is fine since at this point
167+
# penalty can't be "elsaticnet" so l1_ratio is not used.
168+
return 0.0
169+
return self.l1_ratio
170+
162171
def _get_loss_function(self, loss):
163172
"""Get concrete ``LossFunction`` object for str ``loss``."""
164173
loss_ = self.loss_functions[loss]
@@ -462,7 +471,7 @@ def fit_binary(
462471
penalty_type,
463472
alpha,
464473
C,
465-
est.l1_ratio,
474+
est._get_l1_ratio(),
466475
dataset,
467476
validation_mask,
468477
est.early_stopping,
@@ -993,7 +1002,11 @@ class SGDClassifier(BaseSGDClassifier):
9931002
The Elastic Net mixing parameter, with 0 <= l1_ratio <= 1.
9941003
l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1.
9951004
Only used if `penalty` is 'elasticnet'.
996-
Values must be in the range `[0.0, 1.0]`.
1005+
Values must be in the range `[0.0, 1.0]` or can be `None` if
1006+
`penalty` is not `elasticnet`.
1007+
1008+
.. versionchanged:: 1.7
1009+
`l1_ratio` can be `None` when `penalty` is not "elasticnet".
9971010
9981011
fit_intercept : bool, default=True
9991012
Whether the intercept should be estimated or not. If False, the
@@ -1194,7 +1207,7 @@ class SGDClassifier(BaseSGDClassifier):
11941207
**BaseSGDClassifier._parameter_constraints,
11951208
"penalty": [StrOptions({"l2", "l1", "elasticnet"}), None],
11961209
"alpha": [Interval(Real, 0, None, closed="left")],
1197-
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
1210+
"l1_ratio": [Interval(Real, 0, 1, closed="both"), None],
11981211
"power_t": [Interval(Real, None, None, closed="neither")],
11991212
"epsilon": [Interval(Real, 0, None, closed="left")],
12001213
"learning_rate": [
@@ -1695,7 +1708,7 @@ def _fit_regressor(
16951708
penalty_type,
16961709
alpha,
16971710
C,
1698-
self.l1_ratio,
1711+
self._get_l1_ratio(),
16991712
dataset,
17001713
validation_mask,
17011714
self.early_stopping,
@@ -1796,7 +1809,11 @@ class SGDRegressor(BaseSGDRegressor):
17961809
The Elastic Net mixing parameter, with 0 <= l1_ratio <= 1.
17971810
l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1.
17981811
Only used if `penalty` is 'elasticnet'.
1799-
Values must be in the range `[0.0, 1.0]`.
1812+
Values must be in the range `[0.0, 1.0]` or can be `None` if
1813+
`penalty` is not `elasticnet`.
1814+
1815+
.. versionchanged:: 1.7
1816+
`l1_ratio` can be `None` when `penalty` is not "elasticnet".
18001817
18011818
fit_intercept : bool, default=True
18021819
Whether the intercept should be estimated or not. If False, the
@@ -1976,7 +1993,7 @@ class SGDRegressor(BaseSGDRegressor):
19761993
**BaseSGDRegressor._parameter_constraints,
19771994
"penalty": [StrOptions({"l2", "l1", "elasticnet"}), None],
19781995
"alpha": [Interval(Real, 0, None, closed="left")],
1979-
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
1996+
"l1_ratio": [Interval(Real, 0, 1, closed="both"), None],
19801997
"power_t": [Interval(Real, None, None, closed="neither")],
19811998
"learning_rate": [
19821999
StrOptions({"constant", "optimal", "invscaling", "adaptive"}),

sklearn/linear_model/tests/test_sgd.py

+21
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,27 @@ def test_not_enough_sample_for_early_stopping(klass):
486486
clf.fit(X3, Y3)
487487

488488

489+
@pytest.mark.parametrize("Estimator", [SGDClassifier, SGDRegressor])
490+
@pytest.mark.parametrize("l1_ratio", [0, 0.7, 1])
491+
def test_sgd_l1_ratio_not_used(Estimator, l1_ratio):
492+
"""Check that l1_ratio is not used when penalty is not 'elasticnet'"""
493+
clf1 = Estimator(penalty="l1", l1_ratio=None, random_state=0).fit(X, Y)
494+
clf2 = Estimator(penalty="l1", l1_ratio=l1_ratio, random_state=0).fit(X, Y)
495+
496+
assert_allclose(clf1.coef_, clf2.coef_)
497+
498+
499+
@pytest.mark.parametrize(
500+
"Estimator", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
501+
)
502+
def test_sgd_failing_penalty_validation(Estimator):
503+
clf = Estimator(penalty="elasticnet", l1_ratio=None)
504+
with pytest.raises(
505+
ValueError, match="l1_ratio must be set when penalty is 'elasticnet'"
506+
):
507+
clf.fit(X, Y)
508+
509+
489510
###############################################################################
490511
# Classification Test Case
491512

0 commit comments

Comments
 (0)