2121from sklearn .utils .testing import assert_array_almost_equal
2222from sklearn .utils .testing import assert_array_equal
2323from sklearn .utils .testing import assert_warns_message
24+ from sklearn .utils .testing import assert_raise_message
2425from sklearn .utils .testing import ignore_warnings
2526from sklearn .utils .validation import _num_samples
2627from sklearn .utils .mocking import MockDataFrame
@@ -206,7 +207,7 @@ def test_kfold_valueerrors():
206207
207208 # Check that a warning is raised if the least populated class has too few
208209 # members.
209- y = np .array ([3 , 3 , - 1 , - 1 , 2 ])
210+ y = np .array ([3 , 3 , - 1 , - 1 , 3 ])
210211
211212 skf_3 = StratifiedKFold (3 )
212213 assert_warns_message (Warning , "The least populated class" ,
@@ -219,11 +220,21 @@ def test_kfold_valueerrors():
219220 warnings .simplefilter ("ignore" )
220221 check_cv_coverage (skf_3 , X2 , y , labels = None , expected_n_iter = 3 )
221222
223+ # Check that errors are raised if all n_labels for individual
224+ # classes are less than n_folds.
225+ y = np .array ([3 , 3 , - 1 , - 1 , 2 ])
226+
227+ assert_raises (ValueError , next , skf_3 .split (X2 , y ))
228+
222229 # Error when number of folds is <= 1
223230 assert_raises (ValueError , KFold , 0 )
224231 assert_raises (ValueError , KFold , 1 )
225- assert_raises (ValueError , StratifiedKFold , 0 )
226- assert_raises (ValueError , StratifiedKFold , 1 )
232+ error_string = ("k-fold cross-validation requires at least one"
233+ " train/test split" )
234+ assert_raise_message (ValueError , error_string ,
235+ StratifiedKFold , 0 )
236+ assert_raise_message (ValueError , error_string ,
237+ StratifiedKFold , 1 )
227238
228239 # When n_folds is not integer:
229240 assert_raises (ValueError , KFold , 1.5 )
0 commit comments