Skip to content

Commit 2a7194d

Browse files
NicolasHugthomasjpfan
authored andcommitted
FIX Bin training and validation data separately in GBDTs (scikit-learn#13933)
1 parent 5772667 commit 2a7194d

File tree

4 files changed

+74
-21
lines changed

4 files changed

+74
-21
lines changed

doc/whats_new/v0.22.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ Changelog
3939
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
4040
where 123456 is the *pull request* number, not the issue number.
4141
42+
:mod:`sklearn.ensemble`
43+
.......................
44+
45+
- |Fix| :class:`ensemble.HistGradientBoostingClassifier` and
46+
:class:`ensemble.HistGradientBoostingRegressor` now bin the training and
47+
validation data separately to avoid any data leak. :pr:`13933` by
48+
`NicolasHug`_.
49+
4250
:mod:`sklearn.linear_model`
4351
..................
4452

sklearn/ensemble/_hist_gradient_boosting/binning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def transform(self, X):
140140
Returns
141141
-------
142142
X_binned : array-like, shape (n_samples, n_features)
143-
The binned data.
143+
The binned data (fortran-aligned).
144144
"""
145145
X = check_array(X, dtype=[X_DTYPE])
146146
check_is_fitted(self, ['bin_thresholds_', 'actual_n_bins_'])

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,6 @@ def fit(self, X, y):
112112
# data.
113113
self._in_fit = True
114114

115-
# bin the data
116-
if self.verbose:
117-
print("Binning {:.3f} GB of data: ".format(X.nbytes / 1e9), end="",
118-
flush=True)
119-
tic = time()
120-
self.bin_mapper_ = _BinMapper(max_bins=self.max_bins, random_state=rng)
121-
X_binned = self.bin_mapper_.fit_transform(X)
122-
toc = time()
123-
if self.verbose:
124-
duration = toc - tic
125-
print("{:.3f} s".format(duration))
126115

127116
self.loss_ = self._get_loss()
128117

@@ -135,17 +124,20 @@ def fit(self, X, y):
135124
# stratify for classification
136125
stratify = y if hasattr(self.loss_, 'predict_proba') else None
137126

138-
X_binned_train, X_binned_val, y_train, y_val = train_test_split(
139-
X_binned, y, test_size=self.validation_fraction,
140-
stratify=stratify, random_state=rng)
127+
X_train, X_val, y_train, y_val = train_test_split(
128+
X, y, test_size=self.validation_fraction, stratify=stratify,
129+
random_state=rng)
130+
else:
131+
X_train, y_train = X, y
132+
X_val, y_val = None, None
141133

142-
# Predicting is faster of C-contiguous arrays, training is faster
143-
# on Fortran arrays.
144-
X_binned_val = np.ascontiguousarray(X_binned_val)
145-
X_binned_train = np.asfortranarray(X_binned_train)
134+
# Bin the data
135+
self.bin_mapper_ = _BinMapper(max_bins=self.max_bins, random_state=rng)
136+
X_binned_train = self._bin_data(X_train, rng, is_training_data=True)
137+
if X_val is not None:
138+
X_binned_val = self._bin_data(X_val, rng, is_training_data=False)
146139
else:
147-
X_binned_train, y_train = X_binned, y
148-
X_binned_val, y_val = None, None
140+
X_binned_val = None
149141

150142
if self.verbose:
151143
print("Fitting gradient boosted rounds:")
@@ -387,6 +379,32 @@ def _should_stop(self, scores):
387379
for score in recent_scores]
388380
return not any(recent_improvements)
389381

382+
def _bin_data(self, X, rng, is_training_data):
383+
"""Bin data X.
384+
385+
If is_training_data, then set the bin_mapper_ attribute.
386+
Else, the binned data is converted to a C-contiguous array.
387+
"""
388+
389+
description = 'training' if is_training_data else 'validation'
390+
if self.verbose:
391+
print("Binning {:.3f} GB of {} data: ".format(
392+
X.nbytes / 1e9, description), end="", flush=True)
393+
tic = time()
394+
if is_training_data:
395+
X_binned = self.bin_mapper_.fit_transform(X) # F-aligned array
396+
else:
397+
X_binned = self.bin_mapper_.transform(X) # F-aligned array
398+
# We convert the array to C-contiguous since predicting is faster
399+
# with this layout (training is faster on F-arrays though)
400+
X_binned = np.ascontiguousarray(X_binned)
401+
toc = time()
402+
if self.verbose:
403+
duration = toc - tic
404+
print("{:.3f} s".format(duration))
405+
406+
return X_binned
407+
390408
def _print_iteration_stats(self, iteration_start_time):
391409
"""Print info about the current fitting iteration."""
392410
log_msg = ''

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.experimental import enable_hist_gradient_boosting # noqa
77
from sklearn.ensemble import HistGradientBoostingRegressor
88
from sklearn.ensemble import HistGradientBoostingClassifier
9+
from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
910

1011

1112
X_classification, y_classification = make_classification(random_state=0)
@@ -145,3 +146,29 @@ def test_should_stop(scores, n_iter_no_change, tol, stopping):
145146
n_iter_no_change=n_iter_no_change, tol=tol
146147
)
147148
assert gbdt._should_stop(scores) == stopping
149+
150+
151+
def test_binning_train_validation_are_separated():
152+
# Make sure training and validation data are binned separately.
153+
# See issue 13926
154+
155+
rng = np.random.RandomState(0)
156+
validation_fraction = .2
157+
gb = HistGradientBoostingClassifier(
158+
n_iter_no_change=5,
159+
validation_fraction=validation_fraction,
160+
random_state=rng
161+
)
162+
gb.fit(X_classification, y_classification)
163+
mapper_training_data = gb.bin_mapper_
164+
165+
# Note that since the data is small there is no subsampling and the
166+
# random_state doesn't matter
167+
mapper_whole_data = _BinMapper(random_state=0)
168+
mapper_whole_data.fit(X_classification)
169+
170+
n_samples = X_classification.shape[0]
171+
assert np.all(mapper_training_data.actual_n_bins_ ==
172+
int((1 - validation_fraction) * n_samples))
173+
assert np.all(mapper_training_data.actual_n_bins_ !=
174+
mapper_whole_data.actual_n_bins_)

0 commit comments

Comments
 (0)