1111
1212from ..utils import Parallel , delayed
1313
14+ from ..base import clone , is_classifier
1415from .base import LinearClassifierMixin , SparseCoefMixin
1516from .base import make_dataset
1617from ..base import BaseEstimator , RegressorMixin
2021from ..utils .validation import check_is_fitted
2122from ..exceptions import ConvergenceWarning
2223from ..externals import six
23- from ..model_selection import train_test_split
24+ from ..model_selection import StratifiedShuffleSplit , ShuffleSplit
2425
2526from .sgd_fast import plain_sgd , average_sgd
2627from ..utils import compute_class_weight
4344# Default value of ``epsilon`` parameter.
4445
4546
47+ class _ValidationScoreCallback (object ):
48+ """Callback for early stopping based on validation score"""
49+
50+ def __init__ (self , estimator , X_val , y_val , sample_weight_val ,
51+ classes = None ):
52+ self .estimator = clone (estimator )
53+ self .estimator .t_ = 1 # to pass check_is_fitted
54+ if classes is not None :
55+ self .estimator .classes_ = classes
56+ self .X_val = X_val
57+ self .y_val = y_val
58+ self .sample_weight_val = sample_weight_val
59+
60+ def __call__ (self , coef , intercept ):
61+ est = self .estimator
62+ est .coef_ = coef .reshape (1 , - 1 )
63+ est .intercept_ = np .atleast_1d (intercept )
64+ return est .score (self .X_val , self .y_val , self .sample_weight_val )
65+
66+
4667class BaseSGD (six .with_metaclass (ABCMeta , BaseEstimator , SparseCoefMixin )):
4768 """Base class for SGD classification and regression."""
4869
@@ -248,71 +269,52 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
248269 dtype = np .float64 ,
249270 order = "C" )
250271
251- def _make_validation_split (self , X , y , sample_weight ):
272+ def _make_validation_split (self , y ):
252273 """Split the dataset between training set and validation set.
253274
254275 Parameters
255276 ----------
256- X : {array, sparse matrix}, shape (n_samples, n_features)
257- Training data.
258-
259277 y : array, shape (n_samples, )
260278 Target values.
261279
262- sample_weight : array, shape (n_samples, )
263- Weights applied to individual samples.
264-
265280 Returns
266281 -------
267282 validation_mask : array, shape (n_samples, )
268283 Equal to 1 on the validation set, 0 on the training set.
269284 """
270- n_samples = X .shape [0 ]
285+ n_samples = y .shape [0 ]
271286 validation_mask = np .zeros (n_samples , dtype = np .uint8 )
272287 if not self .early_stopping :
273288 # use the full set for training, with an empty validation set
274289 return validation_mask
275290
276- tmp = train_test_split (X , y , np .arange (n_samples ), sample_weight ,
277- test_size = self .validation_fraction ,
278- random_state = self .random_state )
279- X_train , X_val , y_train , y_val = tmp [:4 ]
280- idx_train , idx_val , sample_weight_train , sample_weight_val = tmp [4 :8 ]
281- if X_train .shape [0 ] == 0 or X_val .shape [0 ] == 0 :
291+ if is_classifier (self ):
292+ splitter_type = StratifiedShuffleSplit
293+ else :
294+ splitter_type = ShuffleSplit
295+ cv = splitter_type (test_size = self .validation_fraction ,
296+ random_state = self .random_state )
297+ idx_train , idx_val = next (cv .split (np .zeros (shape = (y .shape [0 ], 1 )), y ))
298+ if idx_train .shape [0 ] == 0 or idx_val .shape [0 ] == 0 :
282299 raise ValueError (
283300 "Splitting %d samples into a train set and a validation set "
284301 "with validation_fraction=%r led to an empty set (%d and %d "
285302 "samples). Please either change validation_fraction, increase "
286303 "number of samples, or disable early_stopping."
287- % (n_samples , self .validation_fraction , X_train .shape [0 ],
288- X_val .shape [0 ]))
304+ % (n_samples , self .validation_fraction , idx_train .shape [0 ],
305+ idx_val .shape [0 ]))
289306
290- self ._X_val = X_val
291- self ._y_val = y_val
292- self ._sample_weight_val = sample_weight_val
293307 validation_mask [idx_val ] = 1
294308 return validation_mask
295309
296- def _delete_validation_split (self ):
297- if self .early_stopping :
298- del self ._X_val
299- del self ._y_val
300- del self ._sample_weight_val
301-
302- def _validation_score (self , coef , intercept ):
303- """Compute the score on the validation set. Used for early stopping."""
304- # store attributes
305- old_coefs , old_intercept = self .coef_ , self .intercept_
306-
307- # replace them with current coefficients for scoring
308- self .coef_ = coef .reshape (1 , - 1 )
309- self .intercept_ = np .atleast_1d (intercept )
310- score = self .score (self ._X_val , self ._y_val , self ._sample_weight_val )
311-
312- # restore old attributes
313- self .coef_ , self .intercept_ = old_coefs , old_intercept
310+ def _make_validation_score_cb (self , validation_mask , X , y , sample_weight ,
311+ classes = None ):
312+ if not self .early_stopping :
313+ return None
314314
315- return score
315+ return _ValidationScoreCallback (
316+ self , X [validation_mask ], y [validation_mask ],
317+ sample_weight [validation_mask ], classes = classes )
316318
317319
318320def _prepare_fit_binary (est , y , i ):
@@ -348,7 +350,7 @@ def _prepare_fit_binary(est, y, i):
348350
349351
350352def fit_binary (est , i , X , y , alpha , C , learning_rate , max_iter ,
351- pos_weight , neg_weight , sample_weight ):
353+ pos_weight , neg_weight , sample_weight , validation_mask = None ):
352354 """Fit a single binary classifier.
353355
354356 The i'th class is considered the "positive" class.
@@ -388,6 +390,10 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
388390
389391 sample_weight : numpy array of shape [n_samples, ]
390392 The weight of each sample
393+
394+ validation_mask : numpy array of shape [n_samples, ] or None
395+ Precomputed validation mask in case _fit_binary is called in the
396+ context of a one-vs-rest reduction.
391397 """
392398 # if average is not true, average_coef, and average_intercept will be
393399 # unused
@@ -399,7 +405,11 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
399405 penalty_type = est ._get_penalty_type (est .penalty )
400406 learning_rate_type = est ._get_learning_rate_type (learning_rate )
401407
402- validation_mask = est ._make_validation_split (X , y , sample_weight )
408+ if validation_mask is None :
409+ validation_mask = est ._make_validation_split (y_i )
410+ classes = np .array ([- 1 , 1 ], dtype = y_i .dtype )
411+ validation_score_cb = est ._make_validation_score_cb (
412+ validation_mask , X , y_i , sample_weight , classes = classes )
403413
404414 # XXX should have random_state_!
405415 random_state = check_random_state (est .random_state )
@@ -412,8 +422,8 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
412422 if not est .average :
413423 result = plain_sgd (coef , intercept , est .loss_function_ ,
414424 penalty_type , alpha , C , est .l1_ratio ,
415- dataset , validation_mask , est .early_stopping , est ,
416- int (est .n_iter_no_change ),
425+ dataset , validation_mask , est .early_stopping ,
426+ validation_score_cb , int (est .n_iter_no_change ),
417427 max_iter , tol , int (est .fit_intercept ),
418428 int (est .verbose ), int (est .shuffle ), seed ,
419429 pos_weight , neg_weight ,
@@ -426,8 +436,8 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
426436 average_intercept , est .loss_function_ ,
427437 penalty_type , alpha , C , est .l1_ratio ,
428438 dataset , validation_mask , est .early_stopping ,
429- est , int ( est . n_iter_no_change ) ,
430- max_iter , tol ,
439+ validation_score_cb ,
440+ int ( est . n_iter_no_change ), max_iter , tol ,
431441 int (est .fit_intercept ), int (est .verbose ),
432442 int (est .shuffle ), seed , pos_weight ,
433443 neg_weight , learning_rate_type , est .eta0 ,
@@ -441,7 +451,6 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
441451
442452 result = standard_coef , standard_intercept , n_iter_
443453
444- est ._delete_validation_split ()
445454 return result
446455
447456
@@ -610,14 +619,19 @@ def _fit_multiclass(self, X, y, alpha, C, learning_rate,
610619 """Fit a multi-class classifier by combining binary classifiers
611620
612621 Each binary classifier predicts one class versus all others. This
613- strategy is called OVA: One Versus All.
622+ strategy is called OvA ( One versus All) or OvR (One versus Rest) .
614623 """
624+ # Precompute the validation split using the multiclass labels
625+ # to ensure proper balancing of the classes.
626+ validation_mask = self ._make_validation_split (y )
627+
615628 # Use joblib to fit OvA in parallel.
616629 result = Parallel (n_jobs = self .n_jobs , prefer = "threads" ,
617630 verbose = self .verbose )(
618631 delayed (fit_binary )(self , i , X , y , alpha , C , learning_rate ,
619632 max_iter , self ._expanded_class_weight [i ],
620- 1. , sample_weight )
633+ 1. , sample_weight ,
634+ validation_mask = validation_mask )
621635 for i in range (len (self .classes_ )))
622636
623637 # take the maximum of n_iter_ over every binary fit
@@ -1115,18 +1129,16 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
11151129 sample_weight = self ._validate_sample_weight (sample_weight , n_samples )
11161130
11171131 if getattr (self , "coef_" , None ) is None :
1118- self ._allocate_parameter_mem (1 , n_features ,
1119- coef_init , intercept_init )
1132+ self ._allocate_parameter_mem (1 , n_features , coef_init ,
1133+ intercept_init )
11201134 elif n_features != self .coef_ .shape [- 1 ]:
11211135 raise ValueError ("Number of features %d does not match previous "
11221136 "data %d." % (n_features , self .coef_ .shape [- 1 ]))
11231137 if self .average > 0 and getattr (self , "average_coef_" , None ) is None :
11241138 self .average_coef_ = np .zeros (n_features ,
11251139 dtype = np .float64 ,
11261140 order = "C" )
1127- self .average_intercept_ = np .zeros (1 ,
1128- dtype = np .float64 ,
1129- order = "C" )
1141+ self .average_intercept_ = np .zeros (1 , dtype = np .float64 , order = "C" )
11301142
11311143 self ._fit_regressor (X , y , alpha , C , loss , learning_rate ,
11321144 sample_weight , max_iter )
@@ -1269,7 +1281,9 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
12691281 if not hasattr (self , "t_" ):
12701282 self .t_ = 1.0
12711283
1272- validation_mask = self ._make_validation_split (X , y , sample_weight )
1284+ validation_mask = self ._make_validation_split (y )
1285+ validation_score_cb = self ._make_validation_score_cb (
1286+ validation_mask , X , y , sample_weight )
12731287
12741288 random_state = check_random_state (self .random_state )
12751289 # numpy mtrand expects a C long which is a signed 32 bit integer under
@@ -1290,7 +1304,8 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
12901304 alpha , C ,
12911305 self .l1_ratio ,
12921306 dataset ,
1293- validation_mask , self .early_stopping , self ,
1307+ validation_mask , self .early_stopping ,
1308+ validation_score_cb ,
12941309 int (self .n_iter_no_change ),
12951310 max_iter , tol ,
12961311 int (self .fit_intercept ),
@@ -1322,7 +1337,8 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
13221337 alpha , C ,
13231338 self .l1_ratio ,
13241339 dataset ,
1325- validation_mask , self .early_stopping , self ,
1340+ validation_mask , self .early_stopping ,
1341+ validation_score_cb ,
13261342 int (self .n_iter_no_change ),
13271343 max_iter , tol ,
13281344 int (self .fit_intercept ),
@@ -1337,8 +1353,6 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
13371353 self .t_ += self .n_iter_ * X .shape [0 ]
13381354 self .intercept_ = np .atleast_1d (self .intercept_ )
13391355
1340- self ._delete_validation_split ()
1341-
13421356
13431357class SGDRegressor (BaseSGDRegressor ):
13441358 """Linear model fitted by minimizing a regularized empirical loss with SGD
0 commit comments