Skip to content

Commit 9667ff2

Browse files
amuellerraghavrv
authored andcommitted
[MRG + 2] Fixed parameter setting in SelectFromModel (scikit-learn#7764)
* Fixed cloning ``estimator`` again when calling fit a second time in SelectFromModel * fix link in whatsnew
1 parent f32d257 commit 9667ff2

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ Bug fixes
141141
functions were not accepting multi-label targets. :issue:`7676`
142142
by `Mohammed Affan`_
143143

144+
- Fixed setting parameters when calling ``fit`` multiple times on
145+
:class:`feature_selection.SelectFromModel`. :issue:`7756` by `Andreas Müller`_
146+
144147
- Fixes issue in ``partial_fit`` method of
145148
:class:`multiclass.OneVsRestClassifier` when number of classes used in
146149
``partial_fit`` was less than the total number of classes in the

sklearn/feature_selection/from_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ def fit(self, X, y=None, **fit_params):
232232
if self.prefit:
233233
raise NotFittedError(
234234
"Since 'prefit=True', call transform directly")
235-
if not hasattr(self, "estimator_"):
236-
self.estimator_ = clone(self.estimator)
235+
self.estimator_ = clone(self.estimator)
237236
self.estimator_.fit(X, y, **fit_params)
238237
return self
239238

sklearn/feature_selection/tests/test_from_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import scipy.sparse as sp
33

44
from sklearn.utils.testing import assert_true
5+
from sklearn.utils.testing import assert_equal
56
from sklearn.utils.testing import assert_less
67
from sklearn.utils.testing import assert_greater
78
from sklearn.utils.testing import assert_array_almost_equal
@@ -144,14 +145,13 @@ def test_partial_fit():
144145
assert_array_equal(X_transform, transformer.transform(data))
145146

146147

147-
def test_warm_start():
148-
est = PassiveAggressiveClassifier(warm_start=True, random_state=0)
148+
def test_calling_fit_reinitializes():
149+
est = LinearSVC(random_state=0)
149150
transformer = SelectFromModel(estimator=est)
150151
transformer.fit(data, y)
151-
old_model = transformer.estimator_
152+
transformer.set_params(estimator__C=100)
152153
transformer.fit(data, y)
153-
new_model = transformer.estimator_
154-
assert_true(old_model is new_model)
154+
assert_equal(transformer.estimator_.C, 100)
155155

156156

157157
def test_prefit():

0 commit comments

Comments
 (0)