Skip to content

Commit a5743ed

Browse files
NicolasHugthomasjpfan
authored andcommitted
TST add test for LAD-loss and quantile loss equivalence (old GBDT code) (scikit-learn#14086)
1 parent 76ce7c5 commit a5743ed

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy.testing import assert_almost_equal
77
from numpy.testing import assert_allclose
88
from numpy.testing import assert_equal
9+
import pytest
910

1011
from sklearn.utils import check_random_state
1112
from sklearn.utils.stats import _weighted_percentile
@@ -273,3 +274,24 @@ def test_init_raw_predictions_values():
273274
for k in range(n_classes):
274275
p = (y == k).mean()
275276
assert_almost_equal(raw_predictions[:, k], np.log(p))
277+
278+
279+
@pytest.mark.parametrize('seed', range(5))
280+
def test_lad_equals_quantile_50(seed):
281+
# Make sure quantile loss with alpha = .5 is equivalent to LAD
282+
lad = LeastAbsoluteError(n_classes=1)
283+
ql = QuantileLossFunction(n_classes=1, alpha=0.5)
284+
285+
n_samples = 50
286+
rng = np.random.RandomState(seed)
287+
raw_predictions = rng.normal(size=(n_samples))
288+
y_true = rng.normal(size=(n_samples))
289+
290+
lad_loss = lad(y_true, raw_predictions)
291+
ql_loss = ql(y_true, raw_predictions)
292+
assert_almost_equal(lad_loss, 2 * ql_loss)
293+
294+
weights = np.linspace(0, 1, n_samples) ** 2
295+
lad_weighted_loss = lad(y_true, raw_predictions, sample_weight=weights)
296+
ql_weighted_loss = ql(y_true, raw_predictions, sample_weight=weights)
297+
assert_almost_equal(lad_weighted_loss, 2 * ql_weighted_loss)

0 commit comments

Comments
 (0)