Skip to content

Commit d77ea12

Browse files
committed
Merge pull request scikit-learn#4707 from amueller/k_means_init_mismatch
[MGR] Raise error when init shape doesn't match n_clusters in KMeans
2 parents bdc39e7 + ec1bba7 commit d77ea12

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

sklearn/cluster/k_means_.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
140140
###############################################################################
141141
# K-means batch estimation by EM (expectation maximization)
142142

143+
def _validate_center_shape(X, n_centers, centers):
144+
"""Check if centers is compatible with X and n_centers"""
145+
if len(centers) != n_centers:
146+
raise ValueError('The shape of the initial centers (%s) '
147+
'does not match the number of clusters %i'
148+
% (centers.shape, n_centers))
149+
if centers.shape[1] != X.shape[1]:
150+
raise ValueError(
151+
"The number of features of the initial centers %s "
152+
"does not match the number of features of the data %s."
153+
% (centers.shape[1], X.shape[1]))
154+
143155

144156
def _tolerance(X, tol):
145157
"""Return a tolerance which is independent of the dataset"""
@@ -285,7 +297,9 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto',
285297
X -= X_mean
286298

287299
if hasattr(init, '__array__'):
288-
init = np.asarray(init).copy()
300+
init = check_array(init, dtype=np.float64, copy=True)
301+
_validate_center_shape(X, n_clusters, init)
302+
289303
init -= X_mean
290304
if n_init != 1:
291305
warnings.warn(
@@ -638,11 +652,7 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None,
638652
if sp.issparse(centers):
639653
centers = centers.toarray()
640654

641-
if len(centers) != k:
642-
raise ValueError('The shape of the initial centers (%s) '
643-
'does not match the number of clusters %i'
644-
% (centers.shape, k))
645-
655+
_validate_center_shape(X, k, centers)
646656
return centers
647657

648658

@@ -759,10 +769,6 @@ def __init__(self, n_clusters=8, init='k-means++', n_init=10, max_iter=300,
759769
tol=1e-4, precompute_distances='auto',
760770
verbose=0, random_state=None, copy_x=True, n_jobs=1):
761771

762-
if hasattr(init, '__array__'):
763-
n_clusters = init.shape[0]
764-
init = np.asarray(init, dtype=np.float64)
765-
766772
self.n_clusters = n_clusters
767773
self.init = init
768774
self.max_iter = max_iter

sklearn/cluster/tests/test_k_means.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.utils.testing import SkipTest
1111
from sklearn.utils.testing import assert_almost_equal
1212
from sklearn.utils.testing import assert_raises
13-
from sklearn.utils.testing import assert_raises_regexp
13+
from sklearn.utils.testing import assert_raises_regex
1414
from sklearn.utils.testing import assert_true
1515
from sklearn.utils.testing import assert_greater
1616
from sklearn.utils.testing import assert_less
@@ -257,8 +257,30 @@ def test_k_means_n_init():
257257

258258
# two regression tests on bad n_init argument
259259
# previous bug: n_init <= 0 threw non-informative TypeError (#3858)
260-
assert_raises_regexp(ValueError, "n_init", KMeans(n_init=0).fit, X)
261-
assert_raises_regexp(ValueError, "n_init", KMeans(n_init=-1).fit, X)
260+
assert_raises_regex(ValueError, "n_init", KMeans(n_init=0).fit, X)
261+
assert_raises_regex(ValueError, "n_init", KMeans(n_init=-1).fit, X)
262+
263+
264+
def test_k_means_explicit_init_shape():
265+
# test for sensible errors when giving explicit init
266+
# with wrong number of features or clusters
267+
rnd = np.random.RandomState(0)
268+
X = rnd.normal(size=(40, 3))
269+
for Class in [KMeans, MiniBatchKMeans]:
270+
# mismatch of number of features
271+
km = Class(n_init=1, init=X[:, :2], n_clusters=len(X))
272+
msg = "does not match the number of features of the data"
273+
assert_raises_regex(ValueError, msg, km.fit, X)
274+
# for callable init
275+
km = Class(n_init=1, init=lambda X_, k, random_state: X_[:, :2], n_clusters=len(X))
276+
assert_raises_regex(ValueError, msg, km.fit, X)
277+
# mismatch of number of clusters
278+
msg = "does not match the number of clusters"
279+
km = Class(n_init=1, init=X[:2, :], n_clusters=3)
280+
assert_raises_regex(ValueError, msg, km.fit, X)
281+
# for callable init
282+
km = Class(n_init=1, init=lambda X_, k, random_state: X_[:2, :], n_clusters=3)
283+
assert_raises_regex(ValueError, msg, km.fit, X)
262284

263285

264286
def test_k_means_fortran_aligned_data():
@@ -267,7 +289,7 @@ def test_k_means_fortran_aligned_data():
267289
centers = np.array([[0, 0], [0, 1]])
268290
labels = np.array([0, 1, 1])
269291
km = KMeans(n_init=1, init=centers, precompute_distances=False,
270-
random_state=42)
292+
random_state=42, n_clusters=2)
271293
km.fit(X)
272294
assert_array_equal(km.cluster_centers_, centers)
273295
assert_array_equal(km.labels_, labels)
@@ -437,8 +459,10 @@ def test_init(X, k, random_state):
437459

438460
# Small test to check that giving the wrong number of centers
439461
# raises a meaningful error
440-
assert_raises(ValueError,
441-
MiniBatchKMeans(init=test_init, random_state=42).fit, X_csr)
462+
msg = "does not match the number of clusters"
463+
assert_raises_regex(ValueError, msg, MiniBatchKMeans(init=test_init,
464+
random_state=42).fit,
465+
X_csr)
442466

443467
# Now check that the fit actually works
444468
mb_k_means = MiniBatchKMeans(n_clusters=3, init=test_init,

0 commit comments

Comments
 (0)