1818from sklearn .decomposition import DictionaryLearning
1919from sklearn .decomposition import MiniBatchDictionaryLearning
2020from sklearn .decomposition import SparseCoder
21+ from sklearn .decomposition import dict_learning
2122from sklearn .decomposition import dict_learning_online
2223from sklearn .decomposition import sparse_encode
2324
@@ -56,29 +57,30 @@ def test_dict_learning_overcomplete():
5657 assert dico .components_ .shape == (n_components , n_features )
5758
5859
59- # positive lars deprecated 0.22
60- @pytest .mark .filterwarnings ('ignore::DeprecationWarning' )
60+ def test_dict_learning_lars_positive_parameter ():
61+ n_components = 5
62+ alpha = 1
63+ err_msg = "Positive constraint not supported for 'lars' coding method."
64+ with pytest .raises (ValueError , match = err_msg ):
65+ dict_learning (X , n_components , alpha , positive_code = True )
66+
67+
6168@pytest .mark .parametrize ("transform_algorithm" , [
6269 "lasso_lars" ,
6370 "lasso_cd" ,
64- "lars" ,
6571 "threshold" ,
6672])
67- @pytest .mark .parametrize ("positive_code" , [
68- False ,
69- True ,
70- ])
71- @pytest .mark .parametrize ("positive_dict" , [
72- False ,
73- True ,
74- ])
73+ @pytest .mark .parametrize ("positive_code" , [False , True ])
74+ @pytest .mark .parametrize ("positive_dict" , [False , True ])
7575def test_dict_learning_positivity (transform_algorithm ,
7676 positive_code ,
7777 positive_dict ):
7878 n_components = 5
7979 dico = DictionaryLearning (
8080 n_components , transform_algorithm = transform_algorithm , random_state = 0 ,
81- positive_code = positive_code , positive_dict = positive_dict ).fit (X )
81+ positive_code = positive_code , positive_dict = positive_dict ,
82+ fit_algorithm = "cd" ).fit (X )
83+
8284 code = dico .transform (X )
8385 if positive_dict :
8486 assert (dico .components_ >= 0 ).all ()
@@ -90,6 +92,31 @@ def test_dict_learning_positivity(transform_algorithm,
9092 assert (code < 0 ).any ()
9193
9294
95+ @pytest .mark .parametrize ("positive_dict" , [False , True ])
96+ def test_dict_learning_lars_dict_positivity (positive_dict ):
97+ n_components = 5
98+ dico = DictionaryLearning (
99+ n_components , transform_algorithm = "lars" , random_state = 0 ,
100+ positive_dict = positive_dict , fit_algorithm = "cd" ).fit (X )
101+
102+ if positive_dict :
103+ assert (dico .components_ >= 0 ).all ()
104+ else :
105+ assert (dico .components_ < 0 ).any ()
106+
107+
108+ def test_dict_learning_lars_code_positivity ():
109+ n_components = 5
110+ dico = DictionaryLearning (
111+ n_components , transform_algorithm = "lars" , random_state = 0 ,
112+ positive_code = True , fit_algorithm = "cd" ).fit (X )
113+
114+ err_msg = "Positive constraint not supported for '{}' coding method."
115+ err_msg = err_msg .format ("lars" )
116+ with pytest .raises (ValueError , match = err_msg ):
117+ dico .transform (X )
118+
119+
93120def test_dict_learning_reconstruction ():
94121 n_components = 12
95122 dico = DictionaryLearning (n_components , transform_algorithm = 'omp' ,
@@ -170,31 +197,29 @@ def test_dict_learning_online_shapes():
170197 assert_equal (np .dot (code , dictionary ).shape , X .shape )
171198
172199
173- # positive lars deprecated 0.22
174- @pytest .mark .filterwarnings ('ignore::DeprecationWarning' )
200+ def test_dict_learning_online_lars_positive_parameter ():
201+ alpha = 1
202+ err_msg = "Positive constraint not supported for 'lars' coding method."
203+ with pytest .raises (ValueError , match = err_msg ):
204+ dict_learning_online (X , alpha , positive_code = True )
205+
206+
175207@pytest .mark .parametrize ("transform_algorithm" , [
176208 "lasso_lars" ,
177209 "lasso_cd" ,
178- "lars" ,
179210 "threshold" ,
180211])
181- @pytest .mark .parametrize ("positive_code" , [
182- False ,
183- True ,
184- ])
185- @pytest .mark .parametrize ("positive_dict" , [
186- False ,
187- True ,
188- ])
189- def test_dict_learning_online_positivity (transform_algorithm ,
190- positive_code ,
191- positive_dict ):
192- rng = np .random .RandomState (0 )
212+ @pytest .mark .parametrize ("positive_code" , [False , True ])
213+ @pytest .mark .parametrize ("positive_dict" , [False , True ])
214+ def test_minibatch_dictionary_learning_positivity (transform_algorithm ,
215+ positive_code ,
216+ positive_dict ):
193217 n_components = 8
194-
195218 dico = MiniBatchDictionaryLearning (
196219 n_components , transform_algorithm = transform_algorithm , random_state = 0 ,
197- positive_code = positive_code , positive_dict = positive_dict ).fit (X )
220+ positive_code = positive_code , positive_dict = positive_dict ,
221+ fit_algorithm = 'cd' ).fit (X )
222+
198223 code = dico .transform (X )
199224 if positive_dict :
200225 assert (dico .components_ >= 0 ).all ()
@@ -205,7 +230,30 @@ def test_dict_learning_online_positivity(transform_algorithm,
205230 else :
206231 assert (code < 0 ).any ()
207232
233+
234+ @pytest .mark .parametrize ("positive_dict" , [False , True ])
235+ def test_minibatch_dictionary_learning_lars (positive_dict ):
236+ n_components = 8
237+
238+ dico = MiniBatchDictionaryLearning (
239+ n_components , transform_algorithm = "lars" , random_state = 0 ,
240+ positive_dict = positive_dict , fit_algorithm = 'cd' ).fit (X )
241+
242+ if positive_dict :
243+ assert (dico .components_ >= 0 ).all ()
244+ else :
245+ assert (dico .components_ < 0 ).any ()
246+
247+
248+ @pytest .mark .parametrize ("positive_code" , [False , True ])
249+ @pytest .mark .parametrize ("positive_dict" , [False , True ])
250+ def test_dict_learning_online_positivity (positive_code ,
251+ positive_dict ):
252+ rng = np .random .RandomState (0 )
253+ n_components = 8
254+
208255 code , dictionary = dict_learning_online (X , n_components = n_components ,
256+ method = "cd" ,
209257 alpha = 1 , random_state = rng ,
210258 positive_dict = positive_dict ,
211259 positive_code = positive_code )
@@ -307,29 +355,34 @@ def test_sparse_encode_shapes():
307355 assert_equal (code .shape , (n_samples , n_components ))
308356
309357
310- # positive lars deprecated 0.22
311- @pytest .mark .filterwarnings ('ignore::DeprecationWarning' )
312- @pytest .mark .parametrize ("positive" , [
313- False ,
314- True ,
358+ @pytest .mark .parametrize ("algo" , [
359+ 'lasso_lars' ,
360+ 'lasso_cd' ,
361+ 'threshold'
315362])
316- def test_sparse_encode_positivity (positive ):
363+ @pytest .mark .parametrize ("positive" , [False , True ])
364+ def test_sparse_encode_positivity (algo , positive ):
317365 n_components = 12
318366 rng = np .random .RandomState (0 )
319367 V = rng .randn (n_components , n_features ) # random init
320368 V /= np .sum (V ** 2 , axis = 1 )[:, np .newaxis ]
321- for algo in ('lasso_lars' , 'lasso_cd' , 'lars' , 'threshold' ):
322- code = sparse_encode (X , V , algorithm = algo , positive = positive )
323- if positive :
324- assert (code >= 0 ).all ()
325- else :
326- assert (code < 0 ).any ()
369+ code = sparse_encode (X , V , algorithm = algo , positive = positive )
370+ if positive :
371+ assert (code >= 0 ).all ()
372+ else :
373+ assert (code < 0 ).any ()
327374
328- try :
329- sparse_encode (X , V , algorithm = 'omp' , positive = positive )
330- except ValueError :
331- if not positive :
332- raise
375+
376+ @pytest .mark .parametrize ("algo" , ['lars' , 'omp' ])
377+ def test_sparse_encode_unavailable_positivity (algo ):
378+ n_components = 12
379+ rng = np .random .RandomState (0 )
380+ V = rng .randn (n_components , n_features ) # random init
381+ V /= np .sum (V ** 2 , axis = 1 )[:, np .newaxis ]
382+ err_msg = "Positive constraint not supported for '{}' coding method."
383+ err_msg = err_msg .format (algo )
384+ with pytest .raises (ValueError , match = err_msg ):
385+ sparse_encode (X , V , algorithm = algo , positive = True )
333386
334387
335388def test_sparse_encode_input ():
0 commit comments