@@ -49,7 +49,7 @@ def test_histogram_split(n_bins):
4949
5050 histograms = builder .compute_histograms_brute (sample_indices )
5151 split_info = splitter .find_node_split (
52- sample_indices , histograms , sum_gradients ,
52+ sample_indices . shape [ 0 ] , histograms , sum_gradients ,
5353 sum_hessians )
5454
5555 assert split_info .bin_idx == true_bin
@@ -103,17 +103,17 @@ def test_gradient_and_hessian_sanity(constant_hessian):
103103 min_samples_leaf , min_gain_to_split , constant_hessian )
104104
105105 hists_parent = builder .compute_histograms_brute (sample_indices )
106- si_parent = splitter .find_node_split (sample_indices , hists_parent ,
106+ si_parent = splitter .find_node_split (n_samples , hists_parent ,
107107 sum_gradients , sum_hessians )
108108 sample_indices_left , sample_indices_right , _ = splitter .split_indices (
109109 si_parent , sample_indices )
110110
111111 hists_left = builder .compute_histograms_brute (sample_indices_left )
112112 hists_right = builder .compute_histograms_brute (sample_indices_right )
113- si_left = splitter .find_node_split (sample_indices_left , hists_left ,
113+ si_left = splitter .find_node_split (n_samples , hists_left ,
114114 si_parent .sum_gradient_left ,
115115 si_parent .sum_hessian_left )
116- si_right = splitter .find_node_split (sample_indices_right , hists_right ,
116+ si_right = splitter .find_node_split (n_samples , hists_right ,
117117 si_parent .sum_gradient_right ,
118118 si_parent .sum_hessian_right )
119119
@@ -203,7 +203,7 @@ def test_split_indices():
203203 assert np .all (sample_indices == splitter .partition )
204204
205205 histograms = builder .compute_histograms_brute (sample_indices )
206- si_root = splitter .find_node_split (sample_indices , histograms ,
206+ si_root = splitter .find_node_split (n_samples , histograms ,
207207 sum_gradients , sum_hessians )
208208
209209 # sanity checks for best split
@@ -256,6 +256,6 @@ def test_min_gain_to_split():
256256 hessians_are_constant )
257257
258258 histograms = builder .compute_histograms_brute (sample_indices )
259- split_info = splitter .find_node_split (sample_indices , histograms ,
259+ split_info = splitter .find_node_split (n_samples , histograms ,
260260 sum_gradients , sum_hessians )
261261 assert split_info .gain == - 1
0 commit comments