Skip to content

Commit b580ad5

Browse files
NicolasHugthomasjpfan
authored andcommitted
BUG Fix zero division error in GBDTs (scikit-learn#14024)
1 parent a5743ed commit b580ad5

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

sklearn/ensemble/_hist_gradient_boosting/grower.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from .predictor import TreePredictor
1717
from .utils import sum_parallel
1818
from .types import PREDICTOR_RECORD_DTYPE
19+
from .types import Y_DTYPE
20+
21+
22+
EPS = np.finfo(Y_DTYPE).eps # to avoid zero division errors
1923

2024

2125
class TreeNode:
@@ -398,7 +402,7 @@ def _finalize_leaf(self, node):
398402
https://arxiv.org/abs/1603.02754
399403
"""
400404
node.value = -self.shrinkage * node.sum_gradients / (
401-
node.sum_hessians + self.splitter.l2_regularization)
405+
node.sum_hessians + self.splitter.l2_regularization + EPS)
402406
self.finalized_leaves.append(node)
403407

404408
def _finalize_splittable_nodes(self):

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,20 @@ def test_binning_train_validation_are_separated():
172172
int((1 - validation_fraction) * n_samples))
173173
assert np.all(mapper_training_data.actual_n_bins_ !=
174174
mapper_whole_data.actual_n_bins_)
175+
176+
177+
@pytest.mark.parametrize('data', [
178+
make_classification(random_state=0, n_classes=2),
179+
make_classification(random_state=0, n_classes=3, n_informative=3)
180+
], ids=['binary_crossentropy', 'categorical_crossentropy'])
181+
def test_zero_division_hessians(data):
182+
# non regression test for issue #14018
183+
# make sure we avoid zero division errors when computing the leaves values.
184+
185+
# If the learning rate is too high, the raw predictions are bad and will
186+
# saturate the softmax (or sigmoid in binary classif). This leads to
187+
# probabilities being exactly 0 or 1, gradients being constant, and
188+
# hessians being zero.
189+
X, y = data
190+
gb = HistGradientBoostingClassifier(learning_rate=100, max_iter=10)
191+
gb.fit(X, y)

sklearn/utils/estimator_checks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,8 +2401,11 @@ def check_decision_proba_consistency(name, estimator_orig):
24012401
hasattr(estimator, "predict_proba")):
24022402

24032403
estimator.fit(X, y)
2404-
a = estimator.predict_proba(X_test)[:, 1]
2405-
b = estimator.decision_function(X_test)
2404+
# Since the link function from decision_function() to predict_proba()
2405+
# is sometimes not precise enough (typically expit), we round to the
2406+
# 10th decimal to avoid numerical issues.
2407+
a = estimator.predict_proba(X_test)[:, 1].round(decimals=10)
2408+
b = estimator.decision_function(X_test).round(decimals=10)
24062409
assert_array_equal(rankdata(a), rankdata(b))
24072410

24082411

0 commit comments

Comments
 (0)