Skip to content

Commit 676e863

Browse files
Akshay0724lesteve
authored andcommitted
[MRG+1] Fix ug in BaseSearchCV.inverse_transform (scikit-learn#8348)
1 parent 8694278 commit 676e863

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

doc/whats_new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ Enhancements
153153
Bug fixes
154154
.........
155155

156+
- Fixed a bug where :func:`sklearn.model_selection.BaseSearchCV.inverse_transform`
157+
returns self.best_estimator_.transform() instead of self.best_estimator_.inverse_transform()
158+
:issue:`8344` by :user:`Akshay Gupta <Akshay0724>`
159+
160+
156161
- Fixed a bug where :class:`sklearn.linear_model.RandomizedLasso` and
157162
:class:`sklearn.linear_model.RandomizedLogisticRegression` breaks for
158163
sparse input.

sklearn/model_selection/_search.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def inverse_transform(self, Xt):
530530
531531
"""
532532
self._check_is_fitted('inverse_transform')
533-
return self.best_estimator_.transform(Xt)
533+
return self.best_estimator_.inverse_transform(Xt)
534534

535535
@property
536536
def classes_(self):

sklearn/model_selection/tests/test_search.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,15 @@ def fit(self, X, Y):
7979
def predict(self, T):
8080
return T.shape[0]
8181

82+
def transform(self, X):
83+
return X + self.foo_param
84+
85+
def inverse_transform(self, X):
86+
return X - self.foo_param
87+
8288
predict_proba = predict
8389
predict_log_proba = predict
8490
decision_function = predict
85-
transform = predict
86-
inverse_transform = predict
8791

8892
def score(self, X=None, Y=None):
8993
if self.foo_param > 1:
@@ -1305,3 +1309,12 @@ def _pop_time_keys(cv_results):
13051309
per_param_scores[1])
13061310
assert_array_almost_equal(per_param_scores[2],
13071311
per_param_scores[3])
1312+
1313+
1314+
def test_transform_inverse_transform_round_trip():
1315+
clf = MockClassifier()
1316+
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, verbose=3)
1317+
1318+
grid_search.fit(X, y)
1319+
X_round_trip = grid_search.inverse_transform(grid_search.transform(X))
1320+
assert_array_equal(X, X_round_trip)

0 commit comments

Comments
 (0)