Skip to content

Commit 6675c9e

Browse files
NicolasHugglemaitre
authored andcommitted
MAINT pass n_samples instead of sample_indices in GBDT (scikit-learn#14017)
1 parent ccd3331 commit 6675c9e

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

sklearn/ensemble/_hist_gradient_boosting/grower.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _compute_best_split_and_push(self, node):
276276
"""
277277

278278
node.split_info = self.splitter.find_node_split(
279-
node.sample_indices, node.histograms, node.sum_gradients,
279+
node.n_samples, node.histograms, node.sum_gradients,
280280
node.sum_hessians)
281281

282282
if node.split_info.gain <= 0: # no valid split

sklearn/ensemble/_hist_gradient_boosting/splitting.pyx

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ cdef class Splitter:
319319

320320
def find_node_split(
321321
Splitter self,
322-
const unsigned int [::1] sample_indices, # IN
322+
unsigned int n_samples,
323323
hist_struct [:, ::1] histograms, # IN
324324
const Y_DTYPE_C sum_gradients,
325325
const Y_DTYPE_C sum_hessians):
@@ -329,8 +329,8 @@ cdef class Splitter:
329329
330330
Parameters
331331
----------
332-
sample_indices : ndarray of unsigned int, shape (n_samples_at_node,)
333-
The indices of the samples at the node to split.
332+
n_samples : int
333+
The number of samples at the node.
334334
histograms : ndarray of HISTOGRAM_DTYPE of \
335335
shape (n_features, max_bins)
336336
The histograms of the current node.
@@ -345,15 +345,13 @@ cdef class Splitter:
345345
The info about the best possible split among all features.
346346
"""
347347
cdef:
348-
int n_samples
349348
int feature_idx
350349
int best_feature_idx
351350
int n_features = self.n_features
352351
split_info_struct split_info
353352
split_info_struct * split_infos
354353

355354
with nogil:
356-
n_samples = sample_indices.shape[0]
357355

358356
split_infos = <split_info_struct *> malloc(
359357
self.n_features * sizeof(split_info_struct))

sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_histogram_split(n_bins):
4949

5050
histograms = builder.compute_histograms_brute(sample_indices)
5151
split_info = splitter.find_node_split(
52-
sample_indices, histograms, sum_gradients,
52+
sample_indices.shape[0], histograms, sum_gradients,
5353
sum_hessians)
5454

5555
assert split_info.bin_idx == true_bin
@@ -103,17 +103,17 @@ def test_gradient_and_hessian_sanity(constant_hessian):
103103
min_samples_leaf, min_gain_to_split, constant_hessian)
104104

105105
hists_parent = builder.compute_histograms_brute(sample_indices)
106-
si_parent = splitter.find_node_split(sample_indices, hists_parent,
106+
si_parent = splitter.find_node_split(n_samples, hists_parent,
107107
sum_gradients, sum_hessians)
108108
sample_indices_left, sample_indices_right, _ = splitter.split_indices(
109109
si_parent, sample_indices)
110110

111111
hists_left = builder.compute_histograms_brute(sample_indices_left)
112112
hists_right = builder.compute_histograms_brute(sample_indices_right)
113-
si_left = splitter.find_node_split(sample_indices_left, hists_left,
113+
si_left = splitter.find_node_split(n_samples, hists_left,
114114
si_parent.sum_gradient_left,
115115
si_parent.sum_hessian_left)
116-
si_right = splitter.find_node_split(sample_indices_right, hists_right,
116+
si_right = splitter.find_node_split(n_samples, hists_right,
117117
si_parent.sum_gradient_right,
118118
si_parent.sum_hessian_right)
119119

@@ -203,7 +203,7 @@ def test_split_indices():
203203
assert np.all(sample_indices == splitter.partition)
204204

205205
histograms = builder.compute_histograms_brute(sample_indices)
206-
si_root = splitter.find_node_split(sample_indices, histograms,
206+
si_root = splitter.find_node_split(n_samples, histograms,
207207
sum_gradients, sum_hessians)
208208

209209
# sanity checks for best split
@@ -256,6 +256,6 @@ def test_min_gain_to_split():
256256
hessians_are_constant)
257257

258258
histograms = builder.compute_histograms_brute(sample_indices)
259-
split_info = splitter.find_node_split(sample_indices, histograms,
259+
split_info = splitter.find_node_split(n_samples, histograms,
260260
sum_gradients, sum_hessians)
261261
assert split_info.gain == -1

0 commit comments

Comments
 (0)