Skip to content

Commit 4f90efc

Browse files
committed
Merge pull request scikit-learn#610 from amueller/svm_class_weights
MRG: moved class_weight parameter in svms from fit to ``__init__``.
2 parents 9c6e044 + 1e06904 commit 4f90efc

File tree

14 files changed

+169
-181
lines changed

14 files changed

+169
-181
lines changed

doc/modules/svm.rst

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,10 @@ training samples::
8282
>>> X = [[0, 0], [1, 1]]
8383
>>> Y = [0, 1]
8484
>>> clf = svm.SVC()
85-
>>> clf.fit(X, Y)
86-
SVC(C=1.0, cache_size=200, coef0=0.0, degree=3, gamma=0.5, kernel='rbf',
87-
probability=False, scale_C=True, shrinking=True, tol=0.001)
85+
>>> clf.fit(X, Y) # doctest: +NORMALIZE_WHITESPACE
86+
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
87+
gamma=0.5, kernel='rbf', probability=False, scale_C=True, shrinking=True,
88+
tol=0.001)
8889

8990
After being fitted, the model can then be used to predict new values::
9091

@@ -120,9 +121,10 @@ classifiers are constructed and each one trains data from two classes::
120121
>>> X = [[0], [1], [2], [3]]
121122
>>> Y = [0, 1, 2, 3]
122123
>>> clf = svm.SVC()
123-
>>> clf.fit(X, Y)
124-
SVC(C=1.0, cache_size=200, coef0=0.0, degree=3, gamma=1.0, kernel='rbf',
125-
probability=False, scale_C=True, shrinking=True, tol=0.001)
124+
>>> clf.fit(X, Y) # doctest: +NORMALIZE_WHITESPACE
125+
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
126+
gamma=1.0, kernel='rbf', probability=False, scale_C=True, shrinking=True,
127+
tol=0.001)
126128
>>> dec = clf.decision_function([[1]])
127129
>>> dec.shape[1] # 4 classes: 4*3/2 = 6
128130
6
@@ -132,9 +134,10 @@ multi-class strategy, thus training n_class models. If there are only
132134
two classes, only one model is trained::
133135

134136
>>> lin_clf = svm.LinearSVC()
135-
>>> lin_clf.fit(X, Y)
136-
LinearSVC(C=1.0, dual=True, fit_intercept=True, intercept_scaling=1,
137-
loss='l2', multi_class=False, penalty='l2', scale_C=True, tol=0.0001)
137+
>>> lin_clf.fit(X, Y) # doctest: +NORMALIZE_WHITESPACE
138+
LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,
139+
intercept_scaling=1, loss='l2', multi_class=False, penalty='l2',
140+
scale_C=True, tol=0.0001)
138141
>>> dec = lin_clf.decision_function([[1]])
139142
>>> dec.shape[1]
140143
4
@@ -258,9 +261,10 @@ floating point values instead of integer values::
258261
>>> X = [[0, 0], [2, 2]]
259262
>>> y = [0.5, 2.5]
260263
>>> clf = svm.SVR()
261-
>>> clf.fit(X, y)
262-
SVR(C=1.0, cache_size=200, coef0=0.0, degree=3, epsilon=0.1, gamma=0.5,
263-
kernel='rbf', probability=False, scale_C=True, shrinking=True, tol=0.001)
264+
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
265+
SVR(C=1.0, cache_size=200, coef0=0.0, degree=3,
266+
epsilon=0.1, gamma=0.5, kernel='rbf', probability=False, scale_C=True,
267+
shrinking=True, tol=0.001)
264268
>>> clf.predict([[1, 1]])
265269
array([ 1.5])
266270

@@ -451,10 +455,10 @@ vectors and the test vectors must be provided.
451455
>>> clf = svm.SVC(kernel='precomputed')
452456
>>> # linear kernel computation
453457
>>> gram = np.dot(X, X.T)
454-
>>> clf.fit(gram, y)
455-
SVC(C=1.0, cache_size=200, coef0=0.0, degree=3, gamma=0.0,
456-
kernel='precomputed', probability=False, scale_C=True, shrinking=True,
457-
tol=0.001)
458+
>>> clf.fit(gram, y) # doctest: +NORMALIZE_WHITESPACE
459+
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
460+
gamma=0.0, kernel='precomputed', probability=False, scale_C=True,
461+
shrinking=True, tol=0.001)
458462
>>> # predict on training examples
459463
>>> clf.predict(gram)
460464
array([ 0., 1.])

doc/tutorial.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ set, let us use all the images of our dataset apart from the last
152152
one::
153153

154154
>>> clf.fit(digits.data[:-1], digits.target[:-1])
155-
SVC(C=100.0, cache_size=200, coef0=0.0, degree=3, gamma=0.001, kernel='rbf',
156-
probability=False, scale_C=True, shrinking=True, tol=0.001)
155+
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
156+
gamma=0.001, kernel='rbf', probability=False, scale_C=True,
157+
shrinking=True, tol=0.001)
157158

158159
Now you can predict new values, in particular, we can ask to the
159160
classifier what is the digit of our last image in the `digits` dataset,
@@ -188,8 +189,8 @@ persistence model, namely `pickle <http://docs.python.org/library/pickle.html>`_
188189
>>> iris = datasets.load_iris()
189190
>>> X, y = iris.data, iris.target
190191
>>> clf.fit(X, y)
191-
SVC(C=1.0, cache_size=200, coef0=0.0, degree=3, gamma=0.25, kernel='rbf',
192-
probability=False, scale_C=True, shrinking=True, tol=0.001)
192+
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.25,
193+
kernel='rbf', probability=False, scale_C=True, shrinking=True, tol=0.001)
193194

194195
>>> import pickle
195196
>>> s = pickle.dumps(clf)

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ API changes summary
7070
objects are now deprecated.
7171
`scores_` or `pvalues_` should be used instead.
7272

73+
- In :class:`LogisticRegression`, :class:`LinearSVC`, :class:`SVC` and
74+
:class:`NuSVC`, the `class_weight` parameter is now an initialization
75+
parameter, not a parameter to fit. This makes grid searches
76+
over this parameter possible.
77+
7378
- LFW ``data`` is now always shape ``(n_samples, n_features)`` to be
7479
consistent with the Olivetti faces dataset. Use ``images`` and
7580
``pairs`` attribute to access the natural images shapes instead.

examples/applications/face_recognition.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@
110110
'C': [1, 5, 10, 50, 100],
111111
'gamma': [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1],
112112
}
113-
clf = GridSearchCV(SVC(kernel='rbf'), param_grid,
114-
fit_params={'class_weight': 'auto'})
113+
clf = GridSearchCV(SVC(kernel='rbf', class_weight='auto'), param_grid)
115114
clf = clf.fit(X_train_pca, y_train)
116115
print "done in %0.3fs" % (time() - t0)
117116
print "Best estimator found by grid search:"

examples/linear_model/plot_sgd_weighted_classes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,19 @@
3939

4040

4141
# get the separating hyperplane using weighted classes
42-
wclf = SGDClassifier(n_iter=100, alpha=0.01)
43-
wclf.fit(X, y, class_weight={1: 10})
42+
wclf = SGDClassifier(n_iter=100, alpha=0.01, class_weight={1: 10})
43+
wclf.fit(X, y)
4444

4545
ww = wclf.coef_.ravel()
4646
wa = -ww[0] / ww[1]
4747
wyy = wa * xx - wclf.intercept_ / ww[1]
4848

4949
# plot separating hyperplanes and samples
5050
pl.set_cmap(pl.cm.Paired)
51-
h0 = pl.plot(xx, yy, 'k-')
52-
h1 = pl.plot(xx, wyy, 'k--')
51+
h0 = pl.plot(xx, yy, 'k-', label='no weights')
52+
h1 = pl.plot(xx, wyy, 'k--', label='with weights')
5353
pl.scatter(X[:, 0], X[:, 1], c=y)
54-
pl.legend((h0, h1), ('no weights', 'with weights'))
54+
pl.legend()
5555

5656
pl.axis('tight')
5757
pl.show()

examples/svm/plot_separating_hyperplane_unbalanced.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,19 @@
3535

3636

3737
# get the separating hyperplane using weighted classes
38-
wclf = svm.SVC(kernel='linear')
39-
wclf.fit(X, y, class_weight={1: 10})
38+
wclf = svm.SVC(kernel='linear', class_weight={1: 10})
39+
wclf.fit(X, y)
4040

4141
ww = wclf.coef_[0]
4242
wa = -ww[0] / ww[1]
4343
wyy = wa * xx - wclf.intercept_[0] / ww[1]
4444

4545
# plot separating hyperplanes and samples
4646
pl.set_cmap(pl.cm.Paired)
47-
h0 = pl.plot(xx, yy, 'k-')
48-
h1 = pl.plot(xx, wyy, 'k--')
47+
h0 = pl.plot(xx, yy, 'k-', label='no weights')
48+
h1 = pl.plot(xx, wyy, 'k--', label='with weights')
4949
pl.scatter(X[:, 0], X[:, 1], c=y)
50-
pl.legend((h0, h1), ('no weights', 'with weights'))
50+
pl.legend()
5151

5252
pl.axis('tight')
5353
pl.show()

sklearn/linear_model/logistic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ class LogisticRegression(BaseLibLinear, ClassifierMixin, SelectorMixin):
9090

9191
def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0,
9292
fit_intercept=True, intercept_scaling=1,
93-
scale_C=True):
93+
scale_C=True, class_weight=None):
9494

9595
super(LogisticRegression, self).__init__(penalty=penalty,
9696
dual=dual, loss='lr', tol=tol, C=C,
9797
fit_intercept=fit_intercept, intercept_scaling=intercept_scaling,
98-
scale_C=scale_C)
98+
scale_C=scale_C, class_weight=class_weight)
9999

100100
def predict_proba(self, X):
101101
"""Probability estimates.
@@ -118,8 +118,8 @@ def predict_proba(self, X):
118118
prob_wrap = (csr_predict_prob_wrap if self._sparse else
119119
predict_prob_wrap)
120120
probas = prob_wrap(X, self.raw_coef_, self._get_solver_type(),
121-
self.tol, self.C, self.class_weight_label,
122-
self.class_weight, self.label_, self._get_bias())
121+
self.tol, self.C, self.class_weight_label_,
122+
self.class_weight_, self.label_, self._get_bias())
123123
return probas[:, np.argsort(self.label_)]
124124

125125
def predict_log_proba(self, X):

sklearn/linear_model/stochastic_gradient.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import scipy.sparse as sp
9+
import warnings
910

1011
from ..externals.joblib import Parallel, delayed
1112

@@ -222,8 +223,7 @@ def _set_class_weight(self, class_weight, classes, y):
222223

223224
self._expanded_class_weight = weight
224225

225-
def _partial_fit(self, X, y, n_iter, classes=None,
226-
class_weight=None, sample_weight=None):
226+
def _partial_fit(self, X, y, n_iter, classes=None, sample_weight=None):
227227
X = safe_asarray(X, dtype=np.float64, order="C")
228228
y = np.asarray(y)
229229

@@ -243,7 +243,7 @@ def _partial_fit(self, X, y, n_iter, classes=None,
243243
n_classes = self.classes_.shape[0]
244244

245245
# Allocate datastructures from input arguments
246-
self._set_class_weight(class_weight, self.classes_, y)
246+
self._set_class_weight(self.class_weight, self.classes_, y)
247247
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
248248

249249
if self.coef_ is None:
@@ -283,16 +283,6 @@ def partial_fit(self, X, y, classes=None,
283283
and can be omitted in the subsequent calls.
284284
Note that y doesn't need to contain all labels in `classes`.
285285
286-
class_weight : dict, {class_label : weight} or "auto"
287-
Weights associated with classes.
288-
289-
The "auto" mode uses the values of y to automatically adjust
290-
weights inversely proportional to class frequencies.
291-
292-
If None, values defined in the previous call to partial_fit
293-
will be used. If partial_fit was never called before,
294-
uniform weights are assumed.
295-
296286
sample_weight : array-like, shape = [n_samples], optional
297287
Weights applied to individual samples.
298288
If not provided, uniform weights are assumed.
@@ -301,8 +291,12 @@ def partial_fit(self, X, y, classes=None,
301291
-------
302292
self : returns an instance of self.
303293
"""
294+
if class_weight != None:
295+
warnings.warn("Using 'class_weight' as a parameter to the 'fit'"
296+
"method is deprecated. Set it on initialization instead.",
297+
DeprecationWarning)
298+
self.class_weight = class_weight
304299
return self._partial_fit(X, y, n_iter=1, classes=classes,
305-
class_weight=class_weight,
306300
sample_weight=sample_weight)
307301

308302
def fit(self, X, y, coef_init=None, intercept_init=None,
@@ -323,13 +317,6 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
323317
intercept_init : array, shape = [n_classes]
324318
The initial intercept to warm-start the optimization.
325319
326-
class_weight : dict, {class_label : weight} or "auto"
327-
Weights associated with classes. If not given, all classes
328-
are supposed to have weight one.
329-
330-
The "auto" mode uses the values of y to automatically adjust
331-
weights inversely proportional to class frequencies.
332-
333320
sample_weight : array-like, shape = [n_samples], optional
334321
Weights applied to individual samples.
335322
If not provided, uniform weights are assumed.
@@ -338,6 +325,11 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
338325
-------
339326
self : returns an instance of self.
340327
"""
328+
if class_weight != None:
329+
warnings.warn("Using 'class_weight' as a parameter to the 'fit'"
330+
"method is deprecated. Set it on initialization instead.",
331+
DeprecationWarning)
332+
self.class_weight = class_weight
341333
X = safe_asarray(X, dtype=np.float64, order="C")
342334
y = np.asarray(y)
343335

@@ -363,8 +355,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
363355

364356
self._partial_fit(X, y, self.n_iter,
365357
classes=classes,
366-
sample_weight=sample_weight,
367-
class_weight=class_weight)
358+
sample_weight=sample_weight)
368359

369360
# fitting is over, we can now transform coef_ to fortran order
370361
# for faster predictions

0 commit comments

Comments
 (0)