Skip to content

Commit 8424c48

Browse files
committed
Merge pull request scikit-learn#5262 from andylamb/lamb-fix-class-weight-check
[MRG + 1] Fix check in `compute_class_weight`.
2 parents af56154 + 9ec03c2 commit 8424c48

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

sklearn/utils/class_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def compute_class_weight(class_weight, classes, y):
7171
" got: %r" % class_weight)
7272
for c in class_weight:
7373
i = np.searchsorted(classes, c)
74-
if classes[i] != c:
74+
if i >= len(classes) or classes[i] != c:
7575
raise ValueError("Class label %d not present." % c)
7676
else:
7777
weight[i] = class_weight[c]

sklearn/utils/tests/test_class_weight.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.utils.testing import assert_array_almost_equal
1010
from sklearn.utils.testing import assert_almost_equal
1111
from sklearn.utils.testing import assert_raises
12+
from sklearn.utils.testing import assert_raise_message
1213
from sklearn.utils.testing import assert_true
1314
from sklearn.utils.testing import assert_equal
1415
from sklearn.utils.testing import assert_warns
@@ -38,6 +39,28 @@ def test_compute_class_weight_not_present():
3839
assert_raises(ValueError, compute_class_weight, "balanced", classes, y)
3940

4041

42+
def test_compute_class_weight_dict():
43+
classes = np.arange(3)
44+
class_weights = {0: 1.0, 1: 2.0, 2: 3.0}
45+
y = np.asarray([0, 0, 1, 2])
46+
cw = compute_class_weight(class_weights, classes, y)
47+
48+
# When the user specifies class weights, compute_class_weights should just
49+
# return them.
50+
assert_array_almost_equal(np.asarray([1.0, 2.0, 3.0]), cw)
51+
52+
# When a class weight is specified that isn't in classes, a ValueError
53+
# should get raised
54+
msg = 'Class label 4 not present.'
55+
class_weights = {0: 1.0, 1: 2.0, 2: 3.0, 4: 1.5}
56+
assert_raise_message(ValueError, msg, compute_class_weight, class_weights,
57+
classes, y)
58+
msg = 'Class label -1 not present.'
59+
class_weights = {-1: 5.0, 0: 1.0, 1: 2.0, 2: 3.0}
60+
assert_raise_message(ValueError, msg, compute_class_weight, class_weights,
61+
classes, y)
62+
63+
4164
def test_compute_class_weight_invariance():
4265
# Test that results with class_weight="balanced" is invariant wrt
4366
# class imbalance if the number of samples is identical.

0 commit comments

Comments
 (0)