Skip to content

Commit 41cbfde

Browse files
committed
Add check for sample_weights
1 parent 80e22b3 commit 41cbfde

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

sklearn/linear_model/tests/test_logistic.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,8 @@ def test_logistic_regression_sample_weights():
586586
clf_sw_none.fit(X, y)
587587
clf_sw_ones = LR(solver=solver, fit_intercept=False)
588588
clf_sw_ones.fit(X, y, sample_weight=np.ones(y.shape[0]))
589-
assert_array_almost_equal(clf_sw_none.coef_, clf_sw_ones.coef_, decimal=4)
589+
assert_array_almost_equal(
590+
clf_sw_none.coef_, clf_sw_ones.coef_, decimal=4)
590591

591592
# Test that sample weights work the same with the lbfgs,
592593
# newton-cg, and 'sag' solvers
@@ -597,8 +598,15 @@ def test_logistic_regression_sample_weights():
597598
clf_sw_sag = LR(solver='sag', fit_intercept=False,
598599
max_iter=2000, tol=1e-7)
599600
clf_sw_sag.fit(X, y, sample_weight=y + 1)
600-
assert_array_almost_equal(clf_sw_lbfgs.coef_, clf_sw_n.coef_, decimal=4)
601-
assert_array_almost_equal(clf_sw_lbfgs.coef_, clf_sw_sag.coef_, decimal=4)
601+
clf_sw_liblinear = LR(solver='liblinear', fit_intercept=False,
602+
max_iter=50, tol=1e-7)
603+
clf_sw_liblinear.fit(X, y, sample_weight=y + 1)
604+
assert_array_almost_equal(
605+
clf_sw_lbfgs.coef_, clf_sw_n.coef_, decimal=4)
606+
assert_array_almost_equal(
607+
clf_sw_lbfgs.coef_, clf_sw_sag.coef_, decimal=4)
608+
assert_array_almost_equal(
609+
clf_sw_lbfgs.coef_, clf_sw_liblinear.coef_, decimal=4)
602610

603611
# Test that passing class_weight as [1,2] is the same as
604612
# passing class weight = [1,1] but adjusting sample weights
@@ -609,12 +617,13 @@ def test_logistic_regression_sample_weights():
609617
clf_cw_12.fit(X, y)
610618
clf_sw_12 = LR(solver=solver, fit_intercept=False)
611619
clf_sw_12.fit(X, y, sample_weight=sample_weight)
612-
assert_array_almost_equal(clf_cw_12.coef_, clf_sw_12.coef_, decimal=4)
620+
assert_array_almost_equal(
621+
clf_cw_12.coef_, clf_sw_12.coef_, decimal=4)
613622

614623
# Test the above for l1 penalty and l2 penalty with dual=True.
615624
# since the patched liblinear code is different.
616625
clf_cw = LogisticRegression(
617-
solver="liblinear", fit_intercept=False, class_weight={0:1, 1:2},
626+
solver="liblinear", fit_intercept=False, class_weight={0: 1, 1: 2},
618627
penalty="l1")
619628
clf_cw.fit(X, y)
620629
clf_sw = LogisticRegression(
@@ -623,7 +632,7 @@ def test_logistic_regression_sample_weights():
623632
assert_array_almost_equal(clf_cw.coef_, clf_sw.coef_, decimal=4)
624633

625634
clf_cw = LogisticRegression(
626-
solver="liblinear", fit_intercept=False, class_weight={0:1, 1:2},
635+
solver="liblinear", fit_intercept=False, class_weight={0: 1, 1: 2},
627636
penalty="l2", dual=True)
628637
clf_cw.fit(X, y)
629638
clf_sw = LogisticRegression(

sklearn/svm/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from ..base import BaseEstimator, ClassifierMixin
1111
from ..preprocessing import LabelEncoder
1212
from ..multiclass import _ovr_decision_function
13-
from ..utils import check_array, check_random_state, column_or_1d, check_X_y
13+
from ..utils import check_array, check_consistent_length, check_random_state
14+
from ..utils import column_or_1d, check_X_y
1415
from ..utils import compute_class_weight, deprecated
1516
from ..utils.extmath import safe_sparse_dot
1617
from ..utils.validation import check_is_fitted
@@ -891,6 +892,10 @@ def _fit_liblinear(X, y, C, fit_intercept, intercept_scaling, class_weight,
891892
y_ind = np.asarray(y_ind, dtype=np.float64).ravel()
892893
if sample_weight is None:
893894
sample_weight = np.ones(X.shape[0])
895+
else:
896+
sample_weight = np.array(sample_weight, dtype=np.float64, order='C')
897+
check_consistent_length(sample_weight, X)
898+
894899
solver_type = _get_liblinear_solver_type(multi_class, penalty, loss, dual)
895900
raw_coef_, n_iter_ = liblinear.train_wrap(
896901
X, y_ind, sp.isspmatrix(X), solver_type, tol, bias, C,

0 commit comments

Comments
 (0)