55import sys
66import copy
77import warnings
8+ import pytest
89
910import numpy as np
1011
@@ -764,8 +765,10 @@ def test_gaussian_mixture_verbose():
764765 sys .stdout = old_stdout
765766
766767
767- def test_warm_start ():
768- random_state = 0
768+ @pytest .mark .filterwarnings ('ignore:.*did not converge.*' )
769+ @pytest .mark .parametrize ("seed" , (0 , 1 , 2 ))
770+ def test_warm_start (seed ):
771+ random_state = seed
769772 rng = np .random .RandomState (random_state )
770773 n_samples , n_features , n_components = 500 , 2 , 2
771774 X = rng .rand (n_samples , n_features )
@@ -778,16 +781,14 @@ def test_warm_start():
778781 reg_covar = 0 , random_state = random_state ,
779782 warm_start = True )
780783
781- with warnings .catch_warnings ():
782- warnings .simplefilter ("ignore" , ConvergenceWarning )
783- g .fit (X )
784- score1 = h .fit (X ).score (X )
785- score2 = h .fit (X ).score (X )
784+ g .fit (X )
785+ score1 = h .fit (X ).score (X )
786+ score2 = h .fit (X ).score (X )
786787
787788 assert_almost_equal (g .weights_ , h .weights_ )
788789 assert_almost_equal (g .means_ , h .means_ )
789790 assert_almost_equal (g .precisions_ , h .precisions_ )
790- assert_greater ( score2 , score1 )
791+ assert score2 > score1
791792
792793 # Assert that by using warm_start we can converge to a good solution
793794 g = GaussianMixture (n_components = n_components , n_init = 1 ,
@@ -797,13 +798,18 @@ def test_warm_start():
797798 max_iter = 5 , reg_covar = 0 , random_state = random_state ,
798799 warm_start = True , tol = 1e-6 )
799800
800- with warnings .catch_warnings ():
801- warnings .simplefilter ("ignore" , ConvergenceWarning )
802- g .fit (X )
803- h .fit (X ).fit (X )
804-
805- assert_true (not g .converged_ )
806- assert_true (h .converged_ )
801+ g .fit (X )
802+ assert not g .converged_
803+
804+ h .fit (X )
805+ # depending on the data there is large variability in the number of
806+ # refit necessary to converge due to the complete randomness of the
807+ # data
808+ for _ in range (1000 ):
809+ h .fit (X )
810+ if h .converged_ :
811+ break
812+ assert h .converged_
807813
808814
809815@ignore_warnings (category = ConvergenceWarning )
0 commit comments