20
20
21
21
22
22
@pytest .mark .parametrize ('solver' , ['cd' , 'mu' ])
23
- def test_convergence_warning (solver ):
23
+ @pytest .mark .parametrize ('regularization' ,
24
+ [None , 'both' , 'components' , 'transformation' ])
25
+ def test_convergence_warning (solver , regularization ):
24
26
convergence_warning = ("Maximum number of iterations 1 reached. "
25
27
"Increase it to improve convergence." )
26
28
A = np .ones ((2 , 2 ))
27
29
with pytest .warns (ConvergenceWarning , match = convergence_warning ):
28
- NMF (solver = solver , max_iter = 1 ).fit (A )
30
+ NMF (solver = solver , regularization = regularization , max_iter = 1 ).fit (A )
29
31
30
32
31
33
def test_initialize_nn_output ():
@@ -44,6 +46,8 @@ def test_parameter_checking():
44
46
assert_raise_message (ValueError , msg , NMF (solver = name ).fit , A )
45
47
msg = "Invalid init parameter: got 'spam' instead of one of"
46
48
assert_raise_message (ValueError , msg , NMF (init = name ).fit , A )
49
+ msg = "Invalid regularization parameter: got 'spam' instead of one of"
50
+ assert_raise_message (ValueError , msg , NMF (regularization = name ).fit , A )
47
51
msg = "Invalid beta_loss parameter: got 'spam' instead of one"
48
52
assert_raise_message (ValueError , msg , NMF (solver = 'mu' ,
49
53
beta_loss = name ).fit , A )
@@ -97,36 +101,43 @@ def test_initialize_variants():
97
101
98
102
# ignore UserWarning raised when both solver='mu' and init='nndsvd'
99
103
@ignore_warnings (category = UserWarning )
100
- def test_nmf_fit_nn_output ():
104
+ @pytest .mark .parametrize ('solver' , ('cd' , 'mu' ))
105
+ @pytest .mark .parametrize ('init' ,
106
+ (None , 'nndsvd' , 'nndsvda' , 'nndsvdar' , 'random' ))
107
+ @pytest .mark .parametrize ('regularization' ,
108
+ (None , 'both' , 'components' , 'transformation' ))
109
+ def test_nmf_fit_nn_output (solver , init , regularization ):
101
110
# Test that the decomposition does not contain negative values
102
111
A = np .c_ [5. - np .arange (1 , 6 ),
103
112
5. + np .arange (1 , 6 )]
104
- for solver in ('cd' , 'mu' ):
105
- for init in (None , 'nndsvd' , 'nndsvda' , 'nndsvdar' , 'random' ):
106
- model = NMF (n_components = 2 , solver = solver , init = init ,
107
- random_state = 0 )
108
- transf = model .fit_transform (A )
109
- assert not ((model .components_ < 0 ).any () or
110
- (transf < 0 ).any ())
113
+ model = NMF (n_components = 2 , solver = solver , init = init ,
114
+ regularization = regularization , random_state = 0 )
115
+ transf = model .fit_transform (A )
116
+ assert not ((model .components_ < 0 ).any () or
117
+ (transf < 0 ).any ())
111
118
112
119
113
120
@pytest .mark .parametrize ('solver' , ('cd' , 'mu' ))
114
- def test_nmf_fit_close (solver ):
121
+ @pytest .mark .parametrize ('regularization' ,
122
+ (None , 'both' , 'components' , 'transformation' ))
123
+ def test_nmf_fit_close (solver , regularization ):
115
124
rng = np .random .mtrand .RandomState (42 )
116
125
# Test that the fit is not too far away
117
126
pnmf = NMF (5 , solver = solver , init = 'nndsvdar' , random_state = 0 ,
118
- max_iter = 600 )
127
+ regularization = regularization , max_iter = 600 )
119
128
X = np .abs (rng .randn (6 , 5 ))
120
129
assert pnmf .fit (X ).reconstruction_err_ < 0.1
121
130
122
131
123
132
@pytest .mark .parametrize ('solver' , ('cd' , 'mu' ))
124
- def test_nmf_transform (solver ):
133
+ @pytest .mark .parametrize ('regularization' ,
134
+ (None , 'both' , 'components' , 'transformation' ))
135
+ def test_nmf_transform (solver , regularization ):
125
136
# Test that NMF.transform returns close values
126
137
rng = np .random .mtrand .RandomState (42 )
127
138
A = np .abs (rng .randn (6 , 5 ))
128
139
m = NMF (solver = solver , n_components = 3 , init = 'random' ,
129
- random_state = 0 , tol = 1e-5 )
140
+ regularization = regularization , random_state = 0 , tol = 1e-5 )
130
141
ft = m .fit_transform (A )
131
142
t = m .transform (A )
132
143
assert_array_almost_equal (ft , t , decimal = 2 )
@@ -148,12 +159,14 @@ def test_nmf_transform_custom_init():
148
159
149
160
150
161
@pytest .mark .parametrize ('solver' , ('cd' , 'mu' ))
151
- def test_nmf_inverse_transform (solver ):
162
+ @pytest .mark .parametrize ('regularization' ,
163
+ (None , 'both' , 'components' , 'transformation' ))
164
+ def test_nmf_inverse_transform (solver , regularization ):
152
165
# Test that NMF.inverse_transform returns close values
153
166
random_state = np .random .RandomState (0 )
154
167
A = np .abs (random_state .randn (6 , 4 ))
155
168
m = NMF (solver = solver , n_components = 4 , init = 'random' , random_state = 0 ,
156
- max_iter = 1000 )
169
+ regularization = regularization , max_iter = 1000 )
157
170
ft = m .fit_transform (A )
158
171
A_new = m .inverse_transform (ft )
159
172
assert_array_almost_equal (A , A_new , decimal = 2 )
@@ -167,7 +180,9 @@ def test_n_components_greater_n_features():
167
180
168
181
169
182
@pytest .mark .parametrize ('solver' , ['cd' , 'mu' ])
170
- def test_nmf_sparse_input (solver ):
183
+ @pytest .mark .parametrize ('regularization' ,
184
+ [None , 'both' , 'components' , 'transformation' ])
185
+ def test_nmf_sparse_input (solver , regularization ):
171
186
# Test that sparse matrices are accepted as input
172
187
from scipy .sparse import csc_matrix
173
188
@@ -177,7 +192,8 @@ def test_nmf_sparse_input(solver):
177
192
A_sparse = csc_matrix (A )
178
193
179
194
est1 = NMF (solver = solver , n_components = 5 , init = 'random' ,
180
- random_state = 0 , tol = 1e-2 )
195
+ regularization = regularization , random_state = 0 ,
196
+ tol = 1e-2 )
181
197
est2 = clone (est1 )
182
198
183
199
W1 = est1 .fit_transform (A )
@@ -204,28 +220,32 @@ def test_nmf_sparse_transform():
204
220
assert_array_almost_equal (A_fit_tr , A_tr , decimal = 1 )
205
221
206
222
207
- def test_non_negative_factorization_consistency ():
223
+ @pytest .mark .parametrize ('init' , ['random' , 'nndsvd' ])
224
+ @pytest .mark .parametrize ('solver' , ('cd' , 'mu' ))
225
+ @pytest .mark .parametrize ('regularization' ,
226
+ (None , 'both' , 'components' , 'transformation' ))
227
+ def test_non_negative_factorization_consistency (init , solver , regularization ):
208
228
# Test that the function is called in the same way, either directly
209
229
# or through the NMF class
210
230
rng = np .random .mtrand .RandomState (42 )
211
231
A = np .abs (rng .randn (10 , 10 ))
212
232
A [:, 2 * np .arange (5 )] = 0
213
233
214
- for init in ['random' , 'nndsvd' ]:
215
- for solver in ('cd' , 'mu' ):
216
- W_nmf , H , _ = non_negative_factorization (
217
- A , init = init , solver = solver , random_state = 1 , tol = 1e-2 )
218
- W_nmf_2 , _ , _ = non_negative_factorization (
219
- A , H = H , update_H = False , init = init , solver = solver ,
220
- random_state = 1 , tol = 1e-2 )
234
+ W_nmf , H , _ = non_negative_factorization (
235
+ A , init = init , solver = solver ,
236
+ regularization = regularization , random_state = 1 , tol = 1e-2 )
237
+ W_nmf_2 , _ , _ = non_negative_factorization (
238
+ A , H = H , update_H = False , init = init , solver = solver ,
239
+ regularization = regularization , random_state = 1 , tol = 1e-2 )
221
240
222
- model_class = NMF (init = init , solver = solver , random_state = 1 ,
223
- tol = 1e-2 )
224
- W_cls = model_class .fit_transform (A )
225
- W_cls_2 = model_class .transform (A )
241
+ model_class = NMF (init = init , solver = solver ,
242
+ regularization = regularization ,
243
+ random_state = 1 , tol = 1e-2 )
244
+ W_cls = model_class .fit_transform (A )
245
+ W_cls_2 = model_class .transform (A )
226
246
227
- assert_array_almost_equal (W_nmf , W_cls , decimal = 10 )
228
- assert_array_almost_equal (W_nmf_2 , W_cls_2 , decimal = 10 )
247
+ assert_array_almost_equal (W_nmf , W_cls , decimal = 10 )
248
+ assert_array_almost_equal (W_nmf_2 , W_cls_2 , decimal = 10 )
229
249
230
250
231
251
def test_non_negative_factorization_checking ():
@@ -515,25 +535,29 @@ def test_nmf_underflow():
515
535
(np .int32 , np .float64 ),
516
536
(np .int64 , np .float64 )])
517
537
@pytest .mark .parametrize ("solver" , ["cd" , "mu" ])
518
- def test_nmf_dtype_match (dtype_in , dtype_out , solver ):
538
+ @pytest .mark .parametrize ("regularization" ,
539
+ (None , "both" , "components" , "transformation" ))
540
+ def test_nmf_dtype_match (dtype_in , dtype_out , solver , regularization ):
519
541
# Check that NMF preserves dtype (float32 and float64)
520
542
X = np .random .RandomState (0 ).randn (20 , 15 ).astype (dtype_in , copy = False )
521
543
np .abs (X , out = X )
522
- nmf = NMF (solver = solver )
544
+ nmf = NMF (solver = solver , regularization = regularization )
523
545
524
546
assert nmf .fit (X ).transform (X ).dtype == dtype_out
525
547
assert nmf .fit_transform (X ).dtype == dtype_out
526
548
assert nmf .components_ .dtype == dtype_out
527
549
528
550
529
551
@pytest .mark .parametrize ("solver" , ["cd" , "mu" ])
530
- def test_nmf_float32_float64_consistency (solver ):
552
+ @pytest .mark .parametrize ("regularization" ,
553
+ (None , "both" , "components" , "transformation" ))
554
+ def test_nmf_float32_float64_consistency (solver , regularization ):
531
555
# Check that the result of NMF is the same between float32 and float64
532
556
X = np .random .RandomState (0 ).randn (50 , 7 )
533
557
np .abs (X , out = X )
534
- nmf32 = NMF (solver = solver , random_state = 0 )
558
+ nmf32 = NMF (solver = solver , regularization = regularization , random_state = 0 )
535
559
W32 = nmf32 .fit_transform (X .astype (np .float32 ))
536
- nmf64 = NMF (solver = solver , random_state = 0 )
560
+ nmf64 = NMF (solver = solver , regularization = regularization , random_state = 0 )
537
561
W64 = nmf64 .fit_transform (X )
538
562
539
563
assert_allclose (W32 , W64 , rtol = 1e-6 , atol = 1e-5 )
0 commit comments