Skip to content

Commit 7d72720

Browse files
authored
[MRG] Crash when using SGDClassifier with early stopping in a parallel grid search (scikit-learn#12122)
1 parent 6463406 commit 7d72720

File tree

3 files changed

+136
-83
lines changed

3 files changed

+136
-83
lines changed

sklearn/linear_model/sgd_fast.pyx

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
340340
double l1_ratio,
341341
SequentialDataset dataset,
342342
np.ndarray[unsigned char, ndim=1, mode='c'] validation_mask,
343-
bint early_stopping, estimator,
343+
bint early_stopping, validation_score_cb,
344344
int n_iter_no_change,
345345
int max_iter, double tol, int fit_intercept,
346346
int verbose, bint shuffle, np.uint32_t seed,
@@ -374,8 +374,9 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
374374
Equal to True on the validation set.
375375
early_stopping : boolean
376376
Whether to use a stopping criterion based on the validation set.
377-
estimator : BaseSGD
378-
A concrete object inheriting from ``BaseSGD``.
377+
validation_score_cb : callable
378+
A callable to compute a validation score given the current
379+
coefficients and intercept values.
379380
Used only if early_stopping is True.
380381
n_iter_no_change : int
381382
Number of iteration with no improvement to wait before stopping.
@@ -435,7 +436,7 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
435436
dataset,
436437
validation_mask,
437438
early_stopping,
438-
estimator,
439+
validation_score_cb,
439440
n_iter_no_change,
440441
max_iter, tol, fit_intercept,
441442
verbose, shuffle, seed,
@@ -458,7 +459,7 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
458459
double l1_ratio,
459460
SequentialDataset dataset,
460461
np.ndarray[unsigned char, ndim=1, mode='c'] validation_mask,
461-
bint early_stopping, estimator,
462+
bint early_stopping, validation_score_cb,
462463
int n_iter_no_change,
463464
int max_iter, double tol, int fit_intercept,
464465
int verbose, bint shuffle, np.uint32_t seed,
@@ -497,8 +498,9 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
497498
Equal to True on the validation set.
498499
early_stopping : boolean
499500
Whether to use a stopping criterion based on the validation set.
500-
estimator : BaseSGD
501-
A concrete object inheriting from ``BaseSGD``.
501+
validation_score_cb : callable
502+
A callable to compute a validation score given the current
503+
coefficients and intercept values.
502504
Used only if early_stopping is True.
503505
n_iter_no_change : int
504506
Number of iteration with no improvement to wait before stopping.
@@ -562,7 +564,7 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
562564
dataset,
563565
validation_mask,
564566
early_stopping,
565-
estimator,
567+
validation_score_cb,
566568
n_iter_no_change,
567569
max_iter, tol, fit_intercept,
568570
verbose, shuffle, seed,
@@ -584,7 +586,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
584586
double l1_ratio,
585587
SequentialDataset dataset,
586588
np.ndarray[unsigned char, ndim=1, mode='c'] validation_mask,
587-
bint early_stopping, estimator,
589+
bint early_stopping, validation_score_cb,
588590
int n_iter_no_change,
589591
int max_iter, double tol, int fit_intercept,
590592
int verbose, bint shuffle, np.uint32_t seed,
@@ -759,7 +761,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
759761
# evaluate the score on the validation set
760762
if early_stopping:
761763
with gil:
762-
score = estimator._validation_score(weights, intercept)
764+
score = validation_score_cb(weights, intercept)
763765
if tol > -INFINITY and score < best_score + tol:
764766
no_improvement_count += 1
765767
else:

sklearn/linear_model/stochastic_gradient.py

Lines changed: 72 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ..utils import Parallel, delayed
1313

14+
from ..base import clone, is_classifier
1415
from .base import LinearClassifierMixin, SparseCoefMixin
1516
from .base import make_dataset
1617
from ..base import BaseEstimator, RegressorMixin
@@ -20,7 +21,7 @@
2021
from ..utils.validation import check_is_fitted
2122
from ..exceptions import ConvergenceWarning
2223
from ..externals import six
23-
from ..model_selection import train_test_split
24+
from ..model_selection import StratifiedShuffleSplit, ShuffleSplit
2425

2526
from .sgd_fast import plain_sgd, average_sgd
2627
from ..utils import compute_class_weight
@@ -43,6 +44,26 @@
4344
# Default value of ``epsilon`` parameter.
4445

4546

47+
class _ValidationScoreCallback(object):
48+
"""Callback for early stopping based on validation score"""
49+
50+
def __init__(self, estimator, X_val, y_val, sample_weight_val,
51+
classes=None):
52+
self.estimator = clone(estimator)
53+
self.estimator.t_ = 1 # to pass check_is_fitted
54+
if classes is not None:
55+
self.estimator.classes_ = classes
56+
self.X_val = X_val
57+
self.y_val = y_val
58+
self.sample_weight_val = sample_weight_val
59+
60+
def __call__(self, coef, intercept):
61+
est = self.estimator
62+
est.coef_ = coef.reshape(1, -1)
63+
est.intercept_ = np.atleast_1d(intercept)
64+
return est.score(self.X_val, self.y_val, self.sample_weight_val)
65+
66+
4667
class BaseSGD(six.with_metaclass(ABCMeta, BaseEstimator, SparseCoefMixin)):
4768
"""Base class for SGD classification and regression."""
4869

@@ -248,71 +269,52 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
248269
dtype=np.float64,
249270
order="C")
250271

251-
def _make_validation_split(self, X, y, sample_weight):
272+
def _make_validation_split(self, y):
252273
"""Split the dataset between training set and validation set.
253274
254275
Parameters
255276
----------
256-
X : {array, sparse matrix}, shape (n_samples, n_features)
257-
Training data.
258-
259277
y : array, shape (n_samples, )
260278
Target values.
261279
262-
sample_weight : array, shape (n_samples, )
263-
Weights applied to individual samples.
264-
265280
Returns
266281
-------
267282
validation_mask : array, shape (n_samples, )
268283
Equal to 1 on the validation set, 0 on the training set.
269284
"""
270-
n_samples = X.shape[0]
285+
n_samples = y.shape[0]
271286
validation_mask = np.zeros(n_samples, dtype=np.uint8)
272287
if not self.early_stopping:
273288
# use the full set for training, with an empty validation set
274289
return validation_mask
275290

276-
tmp = train_test_split(X, y, np.arange(n_samples), sample_weight,
277-
test_size=self.validation_fraction,
278-
random_state=self.random_state)
279-
X_train, X_val, y_train, y_val = tmp[:4]
280-
idx_train, idx_val, sample_weight_train, sample_weight_val = tmp[4:8]
281-
if X_train.shape[0] == 0 or X_val.shape[0] == 0:
291+
if is_classifier(self):
292+
splitter_type = StratifiedShuffleSplit
293+
else:
294+
splitter_type = ShuffleSplit
295+
cv = splitter_type(test_size=self.validation_fraction,
296+
random_state=self.random_state)
297+
idx_train, idx_val = next(cv.split(np.zeros(shape=(y.shape[0], 1)), y))
298+
if idx_train.shape[0] == 0 or idx_val.shape[0] == 0:
282299
raise ValueError(
283300
"Splitting %d samples into a train set and a validation set "
284301
"with validation_fraction=%r led to an empty set (%d and %d "
285302
"samples). Please either change validation_fraction, increase "
286303
"number of samples, or disable early_stopping."
287-
% (n_samples, self.validation_fraction, X_train.shape[0],
288-
X_val.shape[0]))
304+
% (n_samples, self.validation_fraction, idx_train.shape[0],
305+
idx_val.shape[0]))
289306

290-
self._X_val = X_val
291-
self._y_val = y_val
292-
self._sample_weight_val = sample_weight_val
293307
validation_mask[idx_val] = 1
294308
return validation_mask
295309

296-
def _delete_validation_split(self):
297-
if self.early_stopping:
298-
del self._X_val
299-
del self._y_val
300-
del self._sample_weight_val
301-
302-
def _validation_score(self, coef, intercept):
303-
"""Compute the score on the validation set. Used for early stopping."""
304-
# store attributes
305-
old_coefs, old_intercept = self.coef_, self.intercept_
306-
307-
# replace them with current coefficients for scoring
308-
self.coef_ = coef.reshape(1, -1)
309-
self.intercept_ = np.atleast_1d(intercept)
310-
score = self.score(self._X_val, self._y_val, self._sample_weight_val)
311-
312-
# restore old attributes
313-
self.coef_, self.intercept_ = old_coefs, old_intercept
310+
def _make_validation_score_cb(self, validation_mask, X, y, sample_weight,
311+
classes=None):
312+
if not self.early_stopping:
313+
return None
314314

315-
return score
315+
return _ValidationScoreCallback(
316+
self, X[validation_mask], y[validation_mask],
317+
sample_weight[validation_mask], classes=classes)
316318

317319

318320
def _prepare_fit_binary(est, y, i):
@@ -348,7 +350,7 @@ def _prepare_fit_binary(est, y, i):
348350

349351

350352
def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
351-
pos_weight, neg_weight, sample_weight):
353+
pos_weight, neg_weight, sample_weight, validation_mask=None):
352354
"""Fit a single binary classifier.
353355
354356
The i'th class is considered the "positive" class.
@@ -388,6 +390,10 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
388390
389391
sample_weight : numpy array of shape [n_samples, ]
390392
The weight of each sample
393+
394+
validation_mask : numpy array of shape [n_samples, ] or None
395+
Precomputed validation mask in case _fit_binary is called in the
396+
context of a one-vs-rest reduction.
391397
"""
392398
# if average is not true, average_coef, and average_intercept will be
393399
# unused
@@ -399,7 +405,11 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
399405
penalty_type = est._get_penalty_type(est.penalty)
400406
learning_rate_type = est._get_learning_rate_type(learning_rate)
401407

402-
validation_mask = est._make_validation_split(X, y, sample_weight)
408+
if validation_mask is None:
409+
validation_mask = est._make_validation_split(y_i)
410+
classes = np.array([-1, 1], dtype=y_i.dtype)
411+
validation_score_cb = est._make_validation_score_cb(
412+
validation_mask, X, y_i, sample_weight, classes=classes)
403413

404414
# XXX should have random_state_!
405415
random_state = check_random_state(est.random_state)
@@ -412,8 +422,8 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
412422
if not est.average:
413423
result = plain_sgd(coef, intercept, est.loss_function_,
414424
penalty_type, alpha, C, est.l1_ratio,
415-
dataset, validation_mask, est.early_stopping, est,
416-
int(est.n_iter_no_change),
425+
dataset, validation_mask, est.early_stopping,
426+
validation_score_cb, int(est.n_iter_no_change),
417427
max_iter, tol, int(est.fit_intercept),
418428
int(est.verbose), int(est.shuffle), seed,
419429
pos_weight, neg_weight,
@@ -426,8 +436,8 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
426436
average_intercept, est.loss_function_,
427437
penalty_type, alpha, C, est.l1_ratio,
428438
dataset, validation_mask, est.early_stopping,
429-
est, int(est.n_iter_no_change),
430-
max_iter, tol,
439+
validation_score_cb,
440+
int(est.n_iter_no_change), max_iter, tol,
431441
int(est.fit_intercept), int(est.verbose),
432442
int(est.shuffle), seed, pos_weight,
433443
neg_weight, learning_rate_type, est.eta0,
@@ -441,7 +451,6 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
441451

442452
result = standard_coef, standard_intercept, n_iter_
443453

444-
est._delete_validation_split()
445454
return result
446455

447456

@@ -610,14 +619,19 @@ def _fit_multiclass(self, X, y, alpha, C, learning_rate,
610619
"""Fit a multi-class classifier by combining binary classifiers
611620
612621
Each binary classifier predicts one class versus all others. This
613-
strategy is called OVA: One Versus All.
622+
strategy is called OvA (One versus All) or OvR (One versus Rest).
614623
"""
624+
# Precompute the validation split using the multiclass labels
625+
# to ensure proper balancing of the classes.
626+
validation_mask = self._make_validation_split(y)
627+
615628
# Use joblib to fit OvA in parallel.
616629
result = Parallel(n_jobs=self.n_jobs, prefer="threads",
617630
verbose=self.verbose)(
618631
delayed(fit_binary)(self, i, X, y, alpha, C, learning_rate,
619632
max_iter, self._expanded_class_weight[i],
620-
1., sample_weight)
633+
1., sample_weight,
634+
validation_mask=validation_mask)
621635
for i in range(len(self.classes_)))
622636

623637
# take the maximum of n_iter_ over every binary fit
@@ -1115,18 +1129,16 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
11151129
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
11161130

11171131
if getattr(self, "coef_", None) is None:
1118-
self._allocate_parameter_mem(1, n_features,
1119-
coef_init, intercept_init)
1132+
self._allocate_parameter_mem(1, n_features, coef_init,
1133+
intercept_init)
11201134
elif n_features != self.coef_.shape[-1]:
11211135
raise ValueError("Number of features %d does not match previous "
11221136
"data %d." % (n_features, self.coef_.shape[-1]))
11231137
if self.average > 0 and getattr(self, "average_coef_", None) is None:
11241138
self.average_coef_ = np.zeros(n_features,
11251139
dtype=np.float64,
11261140
order="C")
1127-
self.average_intercept_ = np.zeros(1,
1128-
dtype=np.float64,
1129-
order="C")
1141+
self.average_intercept_ = np.zeros(1, dtype=np.float64, order="C")
11301142

11311143
self._fit_regressor(X, y, alpha, C, loss, learning_rate,
11321144
sample_weight, max_iter)
@@ -1269,7 +1281,9 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
12691281
if not hasattr(self, "t_"):
12701282
self.t_ = 1.0
12711283

1272-
validation_mask = self._make_validation_split(X, y, sample_weight)
1284+
validation_mask = self._make_validation_split(y)
1285+
validation_score_cb = self._make_validation_score_cb(
1286+
validation_mask, X, y, sample_weight)
12731287

12741288
random_state = check_random_state(self.random_state)
12751289
# numpy mtrand expects a C long which is a signed 32 bit integer under
@@ -1290,7 +1304,8 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
12901304
alpha, C,
12911305
self.l1_ratio,
12921306
dataset,
1293-
validation_mask, self.early_stopping, self,
1307+
validation_mask, self.early_stopping,
1308+
validation_score_cb,
12941309
int(self.n_iter_no_change),
12951310
max_iter, tol,
12961311
int(self.fit_intercept),
@@ -1322,7 +1337,8 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
13221337
alpha, C,
13231338
self.l1_ratio,
13241339
dataset,
1325-
validation_mask, self.early_stopping, self,
1340+
validation_mask, self.early_stopping,
1341+
validation_score_cb,
13261342
int(self.n_iter_no_change),
13271343
max_iter, tol,
13281344
int(self.fit_intercept),
@@ -1337,8 +1353,6 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
13371353
self.t_ += self.n_iter_ * X.shape[0]
13381354
self.intercept_ = np.atleast_1d(self.intercept_)
13391355

1340-
self._delete_validation_split()
1341-
13421356

13431357
class SGDRegressor(BaseSGDRegressor):
13441358
"""Linear model fitted by minimizing a regularized empirical loss with SGD

0 commit comments

Comments
 (0)