Skip to content

Commit 79ec933

Browse files
committed
BUG: ShuffleSplit should give reproducible splits
when seeded with a controled seed.
1 parent a734166 commit 79ec933

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

sklearn/cross_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def __init__(self, n, n_bootstraps=3, n_train=0.5, n_test=None,
632632
self.random_state = random_state
633633

634634
def __iter__(self):
635-
rng = self.random_state = check_random_state(self.random_state)
635+
rng = check_random_state(self.random_state)
636636
for i in range(self.n_bootstraps):
637637
# random partition
638638
permutation = rng.permutation(self.n)
@@ -743,7 +743,7 @@ def __init__(self, n, n_iterations=10, test_fraction=0.1,
743743
(train_fraction, test_fraction))
744744

745745
def __iter__(self):
746-
rng = self.random_state = check_random_state(self.random_state)
746+
rng = check_random_state(self.random_state)
747747
n_test = ceil(self.test_fraction * self.n)
748748
if self.train_fraction is None:
749749
n_train = self.n - n_test

sklearn/tests/test_cross_validation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from nose.tools import assert_true
77
from nose.tools import assert_raises
8+
from nose.tools import assert_equal
89

910
from ..base import BaseEstimator
1011
from ..datasets import make_regression
@@ -230,6 +231,13 @@ def test_shufflesplit_errors():
230231
test_fraction=0.1, train_fraction=0.95)
231232

232233

234+
def test_shufflesplit_reproducible():
235+
# Check that iterating twice on the ShuffleSplit gives the same
236+
# sequence of train-test when the random_state is given
237+
ss = cross_validation.ShuffleSplit(10, random_state=21)
238+
assert_array_equal(list(a for a, b in ss), list(a for a, b in ss))
239+
240+
233241
def test_cross_indices_exception():
234242
X = coo_matrix(np.array([[1, 2], [3, 4], [5, 6], [7, 8]]))
235243
y = np.array([1, 1, 2, 2])

0 commit comments

Comments
 (0)