Skip to content

Commit db48ebc

Browse files
fdas3213jnothman
authored andcommitted
ENH add n_components kwarg to SpectralClustering. See scikit-learn#13698 (scikit-learn#13726)
1 parent f3a6a1a commit db48ebc

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

doc/whats_new/v0.22.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ Changelog
5757
``decision_function_shape='ovr'``, and the number of target classes > 2.
5858
:pr:`12557` by `Adrin Jalali`_.
5959

60+
61+
:mod:`sklearn.cluster`
62+
..................
63+
64+
- |Enhancement| :class:`cluster.SpectralClustering` now accepts a ``n_components``
65+
parameter. This parameter extends `SpectralClustering` class functionality to
66+
match `spectral_clustering`.
67+
:pr:`13726` by :user:`Shuzhe Xiao <fdas3213>`.
68+
69+
6070
Miscellaneous
6171
.............
6272

sklearn/cluster/spectral.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ class SpectralClustering(BaseEstimator, ClusterMixin):
307307
to be installed. It can be faster on very large, sparse problems,
308308
but may also lead to instabilities.
309309
310+
n_components : integer, optional, default=n_clusters
311+
Number of eigen vectors to use for the spectral embedding
312+
310313
random_state : int, RandomState instance or None (default)
311314
A pseudo random number generator used for the initialization of the
312315
lobpcg eigen vectors decomposition when ``eigen_solver='amg'`` and by
@@ -387,8 +390,8 @@ class SpectralClustering(BaseEstimator, ClusterMixin):
387390
>>> clustering # doctest: +NORMALIZE_WHITESPACE
388391
SpectralClustering(affinity='rbf', assign_labels='discretize', coef0=1,
389392
degree=3, eigen_solver=None, eigen_tol=0.0, gamma=1.0,
390-
kernel_params=None, n_clusters=2, n_init=10, n_jobs=None,
391-
n_neighbors=10, random_state=0)
393+
kernel_params=None, n_clusters=2, n_components=None, n_init=10,
394+
n_jobs=None, n_neighbors=10, random_state=0)
392395
393396
Notes
394397
-----
@@ -425,12 +428,13 @@ class SpectralClustering(BaseEstimator, ClusterMixin):
425428
https://www1.icsi.berkeley.edu/~stellayu/publication/doc/2003kwayICCV.pdf
426429
"""
427430

428-
def __init__(self, n_clusters=8, eigen_solver=None, random_state=None,
429-
n_init=10, gamma=1., affinity='rbf', n_neighbors=10,
430-
eigen_tol=0.0, assign_labels='kmeans', degree=3, coef0=1,
431-
kernel_params=None, n_jobs=None):
431+
def __init__(self, n_clusters=8, eigen_solver=None, n_components=None,
432+
random_state=None, n_init=10, gamma=1., affinity='rbf',
433+
n_neighbors=10, eigen_tol=0.0, assign_labels='kmeans',
434+
degree=3, coef0=1, kernel_params=None, n_jobs=None):
432435
self.n_clusters = n_clusters
433436
self.eigen_solver = eigen_solver
437+
self.n_components = n_components
434438
self.random_state = random_state
435439
self.n_init = n_init
436440
self.gamma = gamma
@@ -486,6 +490,7 @@ def fit(self, X, y=None):
486490
random_state = check_random_state(self.random_state)
487491
self.labels_ = spectral_clustering(self.affinity_matrix_,
488492
n_clusters=self.n_clusters,
493+
n_components=self.n_components,
489494
eigen_solver=self.eigen_solver,
490495
random_state=random_state,
491496
n_init=self.n_init,

sklearn/cluster/tests/test_spectral.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def test_affinities():
107107
# a dataset that yields a stable eigen decomposition both when built
108108
# on OSX and Linux
109109
X, y = make_blobs(n_samples=20, random_state=0,
110-
centers=[[1, 1], [-1, -1]], cluster_std=0.01
111-
)
110+
centers=[[1, 1], [-1, -1]], cluster_std=0.01)
112111
# nearest neighbors affinity
113112
sp = SpectralClustering(n_clusters=2, affinity='nearest_neighbors',
114113
random_state=0)
@@ -204,3 +203,23 @@ def test_spectral_clustering_with_arpack_amg_solvers():
204203
assert_raises(
205204
ValueError, spectral_clustering,
206205
graph, n_clusters=2, eigen_solver='amg', random_state=0)
206+
207+
208+
def test_n_components():
209+
# Test that after adding n_components, result is different and
210+
# n_components = n_clusters by default
211+
X, y = make_blobs(n_samples=20, random_state=0,
212+
centers=[[1, 1], [-1, -1]], cluster_std=0.01)
213+
sp = SpectralClustering(n_clusters=2, random_state=0)
214+
labels = sp.fit(X).labels_
215+
# set n_components = n_cluster and test if result is the same
216+
labels_same_ncomp = SpectralClustering(n_clusters=2, n_components=2,
217+
random_state=0).fit(X).labels_
218+
# test that n_components=n_clusters by default
219+
assert_array_equal(labels, labels_same_ncomp)
220+
221+
# test that n_components affect result
222+
# n_clusters=8 by default, and set n_components=2
223+
labels_diff_ncomp = SpectralClustering(n_components=2,
224+
random_state=0).fit(X).labels_
225+
assert not np.array_equal(labels, labels_diff_ncomp)

0 commit comments

Comments
 (0)