Skip to content

Commit e50fc2a

Browse files
jeremiedbbTomDLT
authored andcommitted
TST Fix test gaussian mixture warm start (scikit-learn#12452)
1 parent 4e2da4a commit e50fc2a

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import copy
77
import warnings
8+
import pytest
89

910
import numpy as np
1011

@@ -764,8 +765,10 @@ def test_gaussian_mixture_verbose():
764765
sys.stdout = old_stdout
765766

766767

767-
def test_warm_start():
768-
random_state = 0
768+
@pytest.mark.filterwarnings('ignore:.*did not converge.*')
769+
@pytest.mark.parametrize("seed", (0, 1, 2))
770+
def test_warm_start(seed):
771+
random_state = seed
769772
rng = np.random.RandomState(random_state)
770773
n_samples, n_features, n_components = 500, 2, 2
771774
X = rng.rand(n_samples, n_features)
@@ -778,16 +781,14 @@ def test_warm_start():
778781
reg_covar=0, random_state=random_state,
779782
warm_start=True)
780783

781-
with warnings.catch_warnings():
782-
warnings.simplefilter("ignore", ConvergenceWarning)
783-
g.fit(X)
784-
score1 = h.fit(X).score(X)
785-
score2 = h.fit(X).score(X)
784+
g.fit(X)
785+
score1 = h.fit(X).score(X)
786+
score2 = h.fit(X).score(X)
786787

787788
assert_almost_equal(g.weights_, h.weights_)
788789
assert_almost_equal(g.means_, h.means_)
789790
assert_almost_equal(g.precisions_, h.precisions_)
790-
assert_greater(score2, score1)
791+
assert score2 > score1
791792

792793
# Assert that by using warm_start we can converge to a good solution
793794
g = GaussianMixture(n_components=n_components, n_init=1,
@@ -797,13 +798,18 @@ def test_warm_start():
797798
max_iter=5, reg_covar=0, random_state=random_state,
798799
warm_start=True, tol=1e-6)
799800

800-
with warnings.catch_warnings():
801-
warnings.simplefilter("ignore", ConvergenceWarning)
802-
g.fit(X)
803-
h.fit(X).fit(X)
804-
805-
assert_true(not g.converged_)
806-
assert_true(h.converged_)
801+
g.fit(X)
802+
assert not g.converged_
803+
804+
h.fit(X)
805+
# depending on the data there is large variability in the number of
806+
# refit necessary to converge due to the complete randomness of the
807+
# data
808+
for _ in range(1000):
809+
h.fit(X)
810+
if h.converged_:
811+
break
812+
assert h.converged_
807813

808814

809815
@ignore_warnings(category=ConvergenceWarning)

0 commit comments

Comments
 (0)