Skip to content

Commit 114d8b1

Browse files
raghavrvjnothman
authored andcommitted
ENH Better error message when refit=True. (scikit-learn#7234)
1 parent 234d256 commit 114d8b1

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

sklearn/model_selection/_search.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..base import MetaEstimatorMixin
2525
from ._split import check_cv
2626
from ._validation import _fit_and_score
27+
from ..exceptions import NotFittedError
2728
from ..externals.joblib import Parallel, delayed
2829
from ..externals import six
2930
from ..utils import check_random_state
@@ -414,6 +415,15 @@ def score(self, X, y=None):
414415
% self.best_estimator_)
415416
return self.scorer_(self.best_estimator_, X, y)
416417

418+
def _check_is_fitted(self, method_name):
419+
if not self.refit:
420+
raise NotFittedError(('This GridSearchCV instance was initialized '
421+
'with refit=False. %s is '
422+
'available only after refitting on the best '
423+
'parameters. ') % method_name)
424+
else:
425+
check_is_fitted(self, 'best_estimator_')
426+
417427
@if_delegate_has_method(delegate='estimator')
418428
def predict(self, X):
419429
"""Call predict on the estimator with the best found parameters.
@@ -428,6 +438,7 @@ def predict(self, X):
428438
underlying estimator.
429439
430440
"""
441+
self._check_is_fitted('predict')
431442
return self.best_estimator_.predict(X)
432443

433444
@if_delegate_has_method(delegate='estimator')
@@ -444,6 +455,7 @@ def predict_proba(self, X):
444455
underlying estimator.
445456
446457
"""
458+
self._check_is_fitted('predict_proba')
447459
return self.best_estimator_.predict_proba(X)
448460

449461
@if_delegate_has_method(delegate='estimator')
@@ -460,6 +472,7 @@ def predict_log_proba(self, X):
460472
underlying estimator.
461473
462474
"""
475+
self._check_is_fitted('predict_log_proba')
463476
return self.best_estimator_.predict_log_proba(X)
464477

465478
@if_delegate_has_method(delegate='estimator')
@@ -476,6 +489,7 @@ def decision_function(self, X):
476489
underlying estimator.
477490
478491
"""
492+
self._check_is_fitted('decision_function')
479493
return self.best_estimator_.decision_function(X)
480494

481495
@if_delegate_has_method(delegate='estimator')
@@ -492,11 +506,12 @@ def transform(self, X):
492506
underlying estimator.
493507
494508
"""
509+
self._check_is_fitted('transform')
495510
return self.best_estimator_.transform(X)
496511

497512
@if_delegate_has_method(delegate='estimator')
498513
def inverse_transform(self, Xt):
499-
"""Call inverse_transform on the estimator with the best found parameters.
514+
"""Call inverse_transform on the estimator with the best found params.
500515
501516
Only available if the underlying estimator implements
502517
``inverse_transform`` and ``refit=True``.
@@ -508,6 +523,7 @@ def inverse_transform(self, Xt):
508523
underlying estimator.
509524
510525
"""
526+
self._check_is_fitted('inverse_transform')
511527
return self.best_estimator_.transform(Xt)
512528

513529
def _fit(self, X, y, labels, parameter_iterable):

sklearn/model_selection/tests/test_search.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from sklearn.externals.six.moves import zip
3030
from sklearn.base import BaseEstimator
31+
from sklearn.exceptions import NotFittedError
3132
from sklearn.datasets import make_classification
3233
from sklearn.datasets import make_blobs
3334
from sklearn.datasets import make_multilabel_classification
@@ -73,8 +74,10 @@ def predict(self, T):
7374
return T.shape[0]
7475

7576
predict_proba = predict
77+
predict_log_proba = predict
7678
decision_function = predict
7779
transform = predict
80+
inverse_transform = predict
7881

7982
def score(self, X=None, Y=None):
8083
if self.foo_param > 1:
@@ -268,6 +271,14 @@ def test_no_refit():
268271
hasattr(grid_search, "best_index_") and
269272
hasattr(grid_search, "best_params_"))
270273

274+
# Make sure the predict/transform etc fns raise meaningfull error msg
275+
for fn_name in ('predict', 'predict_proba', 'predict_log_proba',
276+
'transform', 'inverse_transform'):
277+
assert_raise_message(NotFittedError,
278+
('refit=False. %s is available only after '
279+
'refitting on the best parameters' % fn_name),
280+
getattr(grid_search, fn_name), X)
281+
271282

272283
def test_grid_search_error():
273284
# Test that grid search will capture errors on data with different length

0 commit comments

Comments
 (0)