|
9 | 9 | from sklearn.utils.testing import assert_array_almost_equal |
10 | 10 | from sklearn.utils.testing import assert_almost_equal |
11 | 11 | from sklearn.utils.testing import assert_raises |
| 12 | +from sklearn.utils.testing import assert_raise_message |
12 | 13 | from sklearn.utils.testing import assert_true |
13 | 14 | from sklearn.utils.testing import assert_equal |
14 | 15 | from sklearn.utils.testing import assert_warns |
@@ -38,6 +39,28 @@ def test_compute_class_weight_not_present(): |
38 | 39 | assert_raises(ValueError, compute_class_weight, "balanced", classes, y) |
39 | 40 |
|
40 | 41 |
|
| 42 | +def test_compute_class_weight_dict(): |
| 43 | + classes = np.arange(3) |
| 44 | + class_weights = {0: 1.0, 1: 2.0, 2: 3.0} |
| 45 | + y = np.asarray([0, 0, 1, 2]) |
| 46 | + cw = compute_class_weight(class_weights, classes, y) |
| 47 | + |
| 48 | + # When the user specifies class weights, compute_class_weights should just |
| 49 | + # return them. |
| 50 | + assert_array_almost_equal(np.asarray([1.0, 2.0, 3.0]), cw) |
| 51 | + |
| 52 | + # When a class weight is specified that isn't in classes, a ValueError |
| 53 | + # should get raised |
| 54 | + msg = 'Class label 4 not present.' |
| 55 | + class_weights = {0: 1.0, 1: 2.0, 2: 3.0, 4: 1.5} |
| 56 | + assert_raise_message(ValueError, msg, compute_class_weight, class_weights, |
| 57 | + classes, y) |
| 58 | + msg = 'Class label -1 not present.' |
| 59 | + class_weights = {-1: 5.0, 0: 1.0, 1: 2.0, 2: 3.0} |
| 60 | + assert_raise_message(ValueError, msg, compute_class_weight, class_weights, |
| 61 | + classes, y) |
| 62 | + |
| 63 | + |
41 | 64 | def test_compute_class_weight_invariance(): |
42 | 65 | # Test that results with class_weight="balanced" is invariant wrt |
43 | 66 | # class imbalance if the number of samples is identical. |
|
0 commit comments