Skip to content

Commit 3f0b6c0

Browse files
authored
ENH Improve error message for sparse multilabel-indicator y in RandomForestClassifier (scikit-learn#15971)
1 parent 98b3c7c commit 3f0b6c0

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

sklearn/ensemble/_forest.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,16 @@ def fit(self, X, y, sample_weight=None):
292292
-------
293293
self : object
294294
"""
295-
# Validate or convert input data
295+
# Validate and convert input data
296+
if issparse(y):
297+
raise ValueError(
298+
"sparse multilabel-indicator for y is not supported."
299+
)
296300
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
297-
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
301+
y = check_array(y, ensure_2d=False, dtype=None)
298302
if sample_weight is not None:
299303
sample_weight = _check_sample_weight(sample_weight, X)
304+
300305
if issparse(X):
301306
# Pre-sort indices to avoid that each individual tree of the
302307
# ensemble sorts the indices.

sklearn/ensemble/tests/test_forest.py

+9
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,15 @@ def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg):
13361336
est.fit(X, y)
13371337

13381338

1339+
def test_forest_y_sparse():
1340+
X = [[1, 2, 3]]
1341+
y = csr_matrix([4, 5, 6])
1342+
est = RandomForestClassifier()
1343+
msg = "sparse multilabel-indicator for y is not supported."
1344+
with pytest.raises(ValueError, match=msg):
1345+
est.fit(X, y)
1346+
1347+
13391348
@pytest.mark.parametrize(
13401349
'ForestClass', [RandomForestClassifier, RandomForestRegressor]
13411350
)

0 commit comments

Comments
 (0)