Skip to content

Commit 5a0db17

Browse files
committed
1. Added parameter prefit to pass in a fitted estimator.
2. Use assert_warns instead of catch_warnings 3. Remove depracation warnings in common tests.
1 parent acf5f16 commit 5a0db17

File tree

9 files changed

+80
-51
lines changed

9 files changed

+80
-51
lines changed

doc/modules/feature_selection.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ for classification::
173173
>>> X.shape
174174
(150, 4)
175175
>>> lsvc = LinearSVC(C=0.01, penalty="l1", dual=False).fit(X, y)
176-
>>> model = SelectFromModel(lsvc)
176+
>>> model = SelectFromModel(lsvc, prefit=True)
177177
>>> X_new = model.transform(X)
178178
>>> X_new.shape
179179
(150, 3)
@@ -277,7 +277,7 @@ meta-transformer)::
277277
>>> clf = clf.fit(X, y)
278278
>>> clf.feature_importances_ # doctest: +SKIP
279279
array([ 0.04..., 0.05..., 0.4..., 0.4...])
280-
>>> model = SelectFromModel(clf)
280+
>>> model = SelectFromModel(clf, prefit=True)
281281
>>> X_new = model.transform(X)
282282
>>> X_new.shape # doctest: +SKIP
283283
(150, 2)

examples/ensemble/plot_feature_transformation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@
5454
rt = RandomTreesEmbedding(max_depth=3, n_estimators=n_estimator)
5555
rt_lm = LogisticRegression()
5656
rt.fit(X_train, y_train)
57-
rt_lm.fit(SelectFromModel(rt).transform(X_train_lr), y_train_lr)
57+
rt_lm.fit(SelectFromModel(rt, prefit=True).transform(X_train_lr), y_train_lr)
5858

59-
y_pred_rt = rt_lm.predict_proba(SelectFromModel(rt).transform(X_test))[:, 1]
59+
y_pred_rt = rt_lm.predict_proba(
60+
SelectFromModel(rt, prefit=True).transform(X_test))[:, 1]
6061
fpr_rt_lm, tpr_rt_lm, _ = roc_curve(y_test, y_pred_rt)
6162

6263
# Supervised transformation based on random forests

examples/ensemble/plot_random_forest_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# use RandomTreesEmbedding to transform data
4040
hasher = RandomTreesEmbedding(n_estimators=10, random_state=0, max_depth=3)
4141
hasher.fit(X)
42-
model = SelectFromModel(hasher)
42+
model = SelectFromModel(hasher, prefit=True)
4343
X_transformed = model.transform(X)
4444

4545
# Visualize result using PCA

sklearn/ensemble/tests/test_forest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from sklearn.utils.testing import assert_greater_equal
3030
from sklearn.utils.testing import assert_raises
3131
from sklearn.utils.testing import assert_warns
32-
from sklearn.utils.testing import clean_warning_registry
3332
from sklearn.utils.testing import ignore_warnings
3433

3534
from sklearn import datasets
@@ -204,10 +203,11 @@ def check_importances(X, y, name, criterion):
204203
assert_equal(importances.shape[0], 10)
205204
assert_equal(n_important, 3)
206205

207-
clean_warning_registry()
208-
with warnings.catch_warnings(record=True) as record:
209-
X_new = est.transform(X, threshold="mean")
210-
assert_less(0 < X_new.shape[1], X.shape[1])
206+
# XXX: Remove this test in 0.19 after transform support to estimators
207+
# is removed.
208+
X_new = assert_warns(
209+
DeprecationWarning, est.transform, X, threshold="mean")
210+
assert_less(0 < X_new.shape[1], X.shape[1])
211211

212212
# Check with parallel
213213
importances = est.feature_importances_

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.ensemble import GradientBoostingRegressor
1717
from sklearn.ensemble.gradient_boosting import ZeroEstimator
1818
from sklearn.metrics import mean_squared_error
19-
from sklearn.utils import check_random_state, tosequence, warnings
19+
from sklearn.utils import check_random_state, tosequence
2020
from sklearn.utils.testing import assert_almost_equal
2121
from sklearn.utils.testing import assert_array_almost_equal
2222
from sklearn.utils.testing import assert_array_equal
@@ -26,7 +26,6 @@
2626
from sklearn.utils.testing import assert_raises
2727
from sklearn.utils.testing import assert_true
2828
from sklearn.utils.testing import assert_warns
29-
from sklearn.utils.testing import clean_warning_registry
3029
from sklearn.utils.testing import ignore_warnings
3130
from sklearn.utils.validation import DataConversionWarning
3231
from sklearn.utils.validation import NotFittedError
@@ -297,16 +296,15 @@ def test_feature_importances():
297296
presort=presort)
298297
clf.fit(X, y)
299298
assert_true(hasattr(clf, 'feature_importances_'))
300-
clean_warning_registry()
301-
with warnings.catch_warnings(record=True) as record:
302-
X_new = clf.transform(X, threshold="mean")
303-
assert_less(X_new.shape[1], X.shape[1])
304299

305-
X_new = clf.transform(X, threshold="mean")
306-
assert_less(X_new.shape[1], X.shape[1])
307-
308-
feature_mask = clf.feature_importances_ > clf.feature_importances_.mean()
309-
assert_array_almost_equal(X_new, X[:, feature_mask])
300+
# XXX: Remove this test in 0.19 after transform support to estimators
301+
# is removed.
302+
X_new = assert_warns(
303+
DeprecationWarning, clf.transform, X, threshold="mean")
304+
assert_less(X_new.shape[1], X.shape[1])
305+
feature_mask = (
306+
clf.feature_importances_ > clf.feature_importances_.mean())
307+
assert_array_almost_equal(X_new, X[:, feature_mask])
310308

311309

312310
def test_probability_log():

sklearn/feature_selection/from_model.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def _calculate_threshold(estimator, importances, threshold):
6666
elif threshold == "mean":
6767
threshold = np.mean(importances)
6868

69+
else:
70+
raise ValueError("Expected threshold='mean' or threshold='median' "
71+
"got %s" % threshold)
72+
6973
else:
7074
threshold = float(threshold)
7175

@@ -144,10 +148,8 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
144148
----------
145149
estimator : object
146150
The base estimator from which the transformer is built.
147-
This can be both a fitted or a non-fitted estimator.
148-
If it a fitted estimator, then ``transform`` can be called directly,
149-
otherwise train the model using ``fit`` and then ``transform`` to do
150-
feature selection.
151+
This can be both a fitted (if ``prefit`` is set to True)
152+
or a non-fitted estimator.
151153
152154
threshold : string, float, optional
153155
The threshold value to use for feature selection. Features whose
@@ -158,26 +160,39 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
158160
available, the object attribute ``threshold`` is used. Otherwise,
159161
"mean" is used by default.
160162
163+
prefit : bool, default True
164+
Whether a prefit model is expected to be passed into the constructor
165+
directly or not. If True, ``transform`` must be called directly
166+
and SelectFromModel cannot be used with ``cross_val_score``,
167+
``GridSearchCV`` and similar utilities that clone the estimator.
168+
Otherwise train the model using ``fit`` and then ``transform`` to do
169+
feature selection.
170+
161171
Attributes
162172
----------
163173
`estimator_`: an estimator
164174
The base estimator from which the transformer is built.
165175
This is stored only when a non-fitted estimator is passed to the
166-
``SelectFromModel``.
176+
``SelectFromModel``, i.e when prefit is False.
167177
168178
`threshold_`: float
169179
The threshold value used for feature selection.
170180
"""
171-
def __init__(self, estimator, threshold=None):
181+
def __init__(self, estimator, threshold=None, prefit=False):
172182
self.estimator = estimator
173183
self.threshold = threshold
184+
self.prefit = prefit
174185

175186
def _get_support_mask(self):
176187
# SelectFromModel can directly call on transform.
177-
if hasattr(self, "estimator_"):
188+
if self.prefit:
189+
estimator = self.estimator
190+
elif hasattr(self, 'estimator_'):
178191
estimator = self.estimator_
179192
else:
180-
estimator = self.estimator
193+
raise ValueError(
194+
'Either fit the model before transform or set "prefit=True"'
195+
' while passing the fitted estimator to the constructor.')
181196
scores = _get_feature_importances(estimator)
182197
self.threshold_ = _calculate_threshold(estimator, scores,
183198
self.threshold)
@@ -202,6 +217,10 @@ def fit(self, X, y=None, **fit_params):
202217
self : object
203218
Returns self.
204219
"""
220+
if self.prefit:
221+
raise ValueError(
222+
'Fitting will overwrite your already fitted model. Call '
223+
'transform directly.')
205224
if not hasattr(self, "estimator_"):
206225
self.estimator_ = clone(self.estimator)
207226
self.estimator_.fit(X, y, **fit_params)
@@ -226,6 +245,10 @@ def partial_fit(self, X, y=None, **fit_params):
226245
self : object
227246
Returns self.
228247
"""
248+
if self.prefit:
249+
raise ValueError(
250+
'Fitting will overwrite your already fitted model. Call '
251+
'transform directly.')
229252
if not hasattr(self, "estimator_"):
230253
self.estimator_ = clone(self.estimator)
231254
self.estimator_.partial_fit(X, y, **fit_params)

sklearn/feature_selection/tests/test_from_model.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
from nose.tools import assert_raises, assert_true
55

6-
from sklearn.utils import warnings
76
from sklearn.utils.testing import assert_less
87
from sklearn.utils.testing import assert_greater
98
from sklearn.utils.testing import assert_equal
109
from sklearn.utils.testing import assert_array_almost_equal
1110
from sklearn.utils.testing import assert_array_equal
1211
from sklearn.utils.testing import assert_almost_equal
13-
from sklearn.utils.testing import clean_warning_registry
12+
from sklearn.utils.testing import assert_warns
1413

1514
from sklearn import datasets
1615
from sklearn.linear_model import LogisticRegression
@@ -33,9 +32,8 @@ def test_transform_linear_model():
3332
X = func(iris.data)
3433
clf.set_params(penalty="l1")
3534
clf.fit(X, iris.target)
36-
clean_warning_registry()
37-
with warnings.catch_warnings(record=True) as record:
38-
X_new = clf.transform(X, thresh)
35+
X_new = assert_warns(
36+
DeprecationWarning, clf.transform, X, thresh)
3937
if isinstance(clf, SGDClassifier):
4038
assert_true(X_new.shape[1] <= X.shape[1])
4139
else:
@@ -48,10 +46,10 @@ def test_transform_linear_model():
4846

4947
def test_invalid_input():
5048
clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=None)
51-
52-
clf.fit(iris.data, iris.target)
53-
assert_raises(ValueError, clf.transform, iris.data, "gobbledigook")
54-
assert_raises(ValueError, clf.transform, iris.data, ".5 * gobbledigook")
49+
for threshold in ["gobbledigook", ".5 * gobbledigook"]:
50+
model = SelectFromModel(clf, threshold=threshold)
51+
model.fit(iris.data, iris.target)
52+
assert_raises(ValueError, model.transform, iris.data)
5553

5654

5755
def test_validate_estimator():
@@ -133,7 +131,7 @@ def test_fitted_estimator():
133131
X_transform = model.transform(iris.data)
134132

135133
clf.fit(iris.data, iris.target)
136-
model = SelectFromModel(clf)
134+
model = SelectFromModel(clf, prefit=True)
137135
assert_array_equal(model.transform(iris.data), X_transform)
138136

139137

@@ -146,7 +144,7 @@ def test_threshold_string():
146144
# Calculate the threshold from the estimator directly.
147145
est.fit(iris.data, iris.target)
148146
threshold = 0.5 * np.mean(est.feature_importances_)
149-
model = SelectFromModel(est, threshold=threshold)
147+
model = SelectFromModel(est, threshold=threshold, prefit=True)
150148
assert_array_equal(X_transform, model.transform(iris.data))
151149

152150

sklearn/tree/tests/test_tree.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from sklearn.metrics import accuracy_score
1717
from sklearn.metrics import mean_squared_error
1818

19-
from sklearn.utils import warnings
2019
from sklearn.utils.testing import assert_array_equal
2120
from sklearn.utils.testing import assert_array_almost_equal
2221
from sklearn.utils.testing import assert_almost_equal
@@ -27,7 +26,7 @@
2726
from sklearn.utils.testing import assert_greater_equal
2827
from sklearn.utils.testing import assert_less
2928
from sklearn.utils.testing import assert_true
30-
from sklearn.utils.testing import clean_warning_registry
29+
from sklearn.utils.testing import assert_warns
3130
from sklearn.utils.testing import raises
3231

3332
from sklearn.utils.validation import check_random_state
@@ -380,12 +379,10 @@ def test_importances():
380379
assert_equal(importances.shape[0], 10, "Failed with {0}".format(name))
381380
assert_equal(n_important, 3, "Failed with {0}".format(name))
382381

383-
384-
clean_warning_registry()
385-
with warnings.catch_warnings(record=True) as record:
386-
X_new = clf.transform(X, threshold="mean")
387-
assert_less(0, X_new.shape[1], "Failed with {0}".format(name))
388-
assert_less(X_new.shape[1], X.shape[1], "Failed with {0}".format(name))
382+
X_new = assert_warns(
383+
DeprecationWarning, clf.transform, X, threshold="mean")
384+
assert_less(0, X_new.shape[1], "Failed with {0}".format(name))
385+
assert_less(X_new.shape[1], X.shape[1], "Failed with {0}".format(name))
389386

390387
# Check on iris that importances are the same for all builders
391388
clf = DecisionTreeClassifier(random_state=0)

sklearn/utils/estimator_checks.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@
6161
'RANSACRegressor', 'RadiusNeighborsRegressor',
6262
'RandomForestRegressor', 'Ridge', 'RidgeCV']
6363

64+
# Estimators with deprecated transform methods. Can be removed in 0.19 when
65+
# _LearntSelectorMixin is removed.
66+
DEPRECATED_TRANSFORM = [
67+
"RandomForestClassifier", "RandomForestRegressor", "ExtraTreesClassifier",
68+
"ExtraTreesRegressor", "RandomTreesEmbedding", "DecisionTreeClassifier",
69+
"DecisionTreeRegressor", "ExtraTreeClassifier", "ExtraTreeRegressor",
70+
"LinearSVC", "SGDClassifier", "SGDRegressor", "Perceptron",
71+
"LogisticRegression", "LogisticRegressionCV",
72+
"GradientBoostingClassifier", "GradientBoostingRegressor"]
73+
6474

6575
def _yield_non_meta_checks(name, Estimator):
6676
yield check_estimators_dtypes
@@ -168,8 +178,9 @@ def _yield_all_checks(name, Estimator):
168178
for check in _yield_regressor_checks(name, Estimator):
169179
yield check
170180
if issubclass(Estimator, TransformerMixin):
171-
for check in _yield_transformer_checks(name, Estimator):
172-
yield check
181+
if name not in DEPRECATED_TRANSFORM:
182+
for check in _yield_transformer_checks(name, Estimator):
183+
yield check
173184
if issubclass(Estimator, ClusterMixin):
174185
for check in _yield_clustering_checks(name, Estimator):
175186
yield check
@@ -329,7 +340,8 @@ def check_dtype_object(name, Estimator):
329340
if hasattr(estimator, "predict"):
330341
estimator.predict(X)
331342

332-
if hasattr(estimator, "transform"):
343+
if (hasattr(estimator, "transform") and
344+
name not in DEPRECATED_TRANSFORM):
333345
estimator.transform(X)
334346

335347
try:

0 commit comments

Comments
 (0)