@@ -66,6 +66,10 @@ def _calculate_threshold(estimator, importances, threshold):
6666 elif threshold == "mean" :
6767 threshold = np .mean (importances )
6868
69+ else :
70+ raise ValueError ("Expected threshold='mean' or threshold='median' "
71+ "got %s" % threshold )
72+
6973 else :
7074 threshold = float (threshold )
7175
@@ -144,10 +148,8 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
144148 ----------
145149 estimator : object
146150 The base estimator from which the transformer is built.
147- This can be both a fitted or a non-fitted estimator.
148- If it a fitted estimator, then ``transform`` can be called directly,
149- otherwise train the model using ``fit`` and then ``transform`` to do
150- feature selection.
151+ This can be both a fitted (if ``prefit`` is set to True)
152+ or a non-fitted estimator.
151153
152154 threshold : string, float, optional
153155 The threshold value to use for feature selection. Features whose
@@ -158,26 +160,39 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
158160 available, the object attribute ``threshold`` is used. Otherwise,
159161 "mean" is used by default.
160162
163+ prefit : bool, default True
164+ Whether a prefit model is expected to be passed into the constructor
165+ directly or not. If True, ``transform`` must be called directly
166+ and SelectFromModel cannot be used with ``cross_val_score``,
167+ ``GridSearchCV`` and similar utilities that clone the estimator.
168+ Otherwise train the model using ``fit`` and then ``transform`` to do
169+ feature selection.
170+
161171 Attributes
162172 ----------
163173 `estimator_`: an estimator
164174 The base estimator from which the transformer is built.
165175 This is stored only when a non-fitted estimator is passed to the
166- ``SelectFromModel``.
176+ ``SelectFromModel``, i.e when prefit is False .
167177
168178 `threshold_`: float
169179 The threshold value used for feature selection.
170180 """
171- def __init__ (self , estimator , threshold = None ):
181+ def __init__ (self , estimator , threshold = None , prefit = False ):
172182 self .estimator = estimator
173183 self .threshold = threshold
184+ self .prefit = prefit
174185
175186 def _get_support_mask (self ):
176187 # SelectFromModel can directly call on transform.
177- if hasattr (self , "estimator_" ):
188+ if self .prefit :
189+ estimator = self .estimator
190+ elif hasattr (self , 'estimator_' ):
178191 estimator = self .estimator_
179192 else :
180- estimator = self .estimator
193+ raise ValueError (
194+ 'Either fit the model before transform or set "prefit=True"'
195+ ' while passing the fitted estimator to the constructor.' )
181196 scores = _get_feature_importances (estimator )
182197 self .threshold_ = _calculate_threshold (estimator , scores ,
183198 self .threshold )
@@ -202,6 +217,10 @@ def fit(self, X, y=None, **fit_params):
202217 self : object
203218 Returns self.
204219 """
220+ if self .prefit :
221+ raise ValueError (
222+ 'Fitting will overwrite your already fitted model. Call '
223+ 'transform directly.' )
205224 if not hasattr (self , "estimator_" ):
206225 self .estimator_ = clone (self .estimator )
207226 self .estimator_ .fit (X , y , ** fit_params )
@@ -226,6 +245,10 @@ def partial_fit(self, X, y=None, **fit_params):
226245 self : object
227246 Returns self.
228247 """
248+ if self .prefit :
249+ raise ValueError (
250+ 'Fitting will overwrite your already fitted model. Call '
251+ 'transform directly.' )
229252 if not hasattr (self , "estimator_" ):
230253 self .estimator_ = clone (self .estimator )
231254 self .estimator_ .partial_fit (X , y , ** fit_params )
0 commit comments