Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f734e11

Browse files
bharatr21adrinjalalithomasjpfan
authored
ENH: Add regularization to the main NMF class (#17414)
* ENH: Add regularization to the main NMF class * Update _nmf with suggestions from code review Update `_nmf.py` with suggestion from review Co-authored-by: Adrin Jalali <[email protected]> * Refactor tests, fix linter errors * Change default value to None * Revert back to default value of "both" * Update default value documentation acc to @thomasjpfan Co-authored-by: Thomas J. Fan <[email protected]> * CLN Places regularization at the end * Add whatsnew entry * DOC Fix * DOC Fix Co-authored-by: Adrin Jalali <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 2b303b1 commit f734e11

File tree

4 files changed

+87
-48
lines changed

4 files changed

+87
-48
lines changed

doc/modules/decomposition.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -759,9 +759,9 @@ and the regularized objective function is:
759759
+ \frac{\alpha(1-\rho)}{2} ||W||_{\mathrm{Fro}} ^ 2
760760
+ \frac{\alpha(1-\rho)}{2} ||H||_{\mathrm{Fro}} ^ 2
761761
762-
:class:`NMF` regularizes both W and H. The public function
763-
:func:`non_negative_factorization` allows a finer control through the
764-
:attr:`regularization` attribute, and may regularize only W, only H, or both.
762+
:class:`NMF` regularizes both W and H by default. The :attr:`regularization`
763+
parameter allows for finer control, with which only W, only H,
764+
or both can be regularized.
765765

766766
NMF with a beta-divergence
767767
--------------------------

doc/whats_new/v0.24.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ Changelog
114114
argument `rotation`, which can take the value `None`, `'varimax'` or `'quartimax'.`
115115
:pr:`11064` by :user:`Jona Sassenhagen <jona-sassenhagen>`.
116116

117+
- |Enhancement| :class:`decomposition.NMF` now supports the optional parameter
118+
`regularization`, which can take the values `None`, `components`,
119+
`transformation` or `both`, in accordance with
120+
:func:`decomposition.NMF.non_negative_factorization`.
121+
:pr:`17414` by :user:`Bharat Raghunathan <Bharat123rox>`.
122+
117123
:mod:`sklearn.ensemble`
118124
.......................
119125

sklearn/decomposition/_nmf.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *,
10811081

10821082

10831083
class NMF(TransformerMixin, BaseEstimator):
1084-
r"""Non-Negative Matrix Factorization (NMF)
1084+
"""Non-Negative Matrix Factorization (NMF)
10851085
10861086
Find two non-negative matrices (W, H) whose product approximates the non-
10871087
negative matrix X. This factorization can be used for example for
@@ -1097,8 +1097,8 @@ class NMF(TransformerMixin, BaseEstimator):
10971097
10981098
Where::
10991099
1100-
||A||_Fro^2 = \sum_{i,j} A_{ij}^2 (Frobenius norm)
1101-
||vec(A)||_1 = \sum_{i,j} abs(A_{ij}) (Elementwise L1 norm)
1100+
||A||_Fro^2 = \\sum_{i,j} A_{ij}^2 (Frobenius norm)
1101+
||vec(A)||_1 = \\sum_{i,j} abs(A_{ij}) (Elementwise L1 norm)
11021102
11031103
For multiplicative-update ('mu') solver, the Frobenius norm
11041104
(0.5 * ||X - WH||_Fro^2) can be changed into another beta-divergence loss,
@@ -1198,6 +1198,13 @@ class NMF(TransformerMixin, BaseEstimator):
11981198
.. versionadded:: 0.17
11991199
*shuffle* parameter used in the Coordinate Descent solver.
12001200
1201+
regularization : {'both', 'components', 'transformation', None}, \
1202+
default='both'
1203+
Select whether the regularization affects the components (H), the
1204+
transformation (W), both or none of them.
1205+
1206+
.. versionadded:: 0.24
1207+
12011208
Attributes
12021209
----------
12031210
components_ : array, [n_components, n_features]
@@ -1239,7 +1246,7 @@ class NMF(TransformerMixin, BaseEstimator):
12391246
def __init__(self, n_components=None, *, init=None, solver='cd',
12401247
beta_loss='frobenius', tol=1e-4, max_iter=200,
12411248
random_state=None, alpha=0., l1_ratio=0., verbose=0,
1242-
shuffle=False):
1249+
shuffle=False, regularization='both'):
12431250
self.n_components = n_components
12441251
self.init = init
12451252
self.solver = solver
@@ -1251,6 +1258,7 @@ def __init__(self, n_components=None, *, init=None, solver='cd',
12511258
self.l1_ratio = l1_ratio
12521259
self.verbose = verbose
12531260
self.shuffle = shuffle
1261+
self.regularization = regularization
12541262

12551263
def _more_tags(self):
12561264
return {'requires_positive_X': True}
@@ -1285,7 +1293,7 @@ def fit_transform(self, X, y=None, W=None, H=None):
12851293
X=X, W=W, H=H, n_components=self.n_components, init=self.init,
12861294
update_H=True, solver=self.solver, beta_loss=self.beta_loss,
12871295
tol=self.tol, max_iter=self.max_iter, alpha=self.alpha,
1288-
l1_ratio=self.l1_ratio, regularization='both',
1296+
l1_ratio=self.l1_ratio, regularization=self.regularization,
12891297
random_state=self.random_state, verbose=self.verbose,
12901298
shuffle=self.shuffle)
12911299

@@ -1334,9 +1342,10 @@ def transform(self, X):
13341342
X=X, W=None, H=self.components_, n_components=self.n_components_,
13351343
init=self.init, update_H=False, solver=self.solver,
13361344
beta_loss=self.beta_loss, tol=self.tol, max_iter=self.max_iter,
1337-
alpha=self.alpha, l1_ratio=self.l1_ratio, regularization='both',
1338-
random_state=self.random_state, verbose=self.verbose,
1339-
shuffle=self.shuffle)
1345+
alpha=self.alpha, l1_ratio=self.l1_ratio,
1346+
regularization=self.regularization,
1347+
random_state=self.random_state,
1348+
verbose=self.verbose, shuffle=self.shuffle)
13401349

13411350
return W
13421351

sklearn/decomposition/tests/test_nmf.py

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020

2121

2222
@pytest.mark.parametrize('solver', ['cd', 'mu'])
23-
def test_convergence_warning(solver):
23+
@pytest.mark.parametrize('regularization',
24+
[None, 'both', 'components', 'transformation'])
25+
def test_convergence_warning(solver, regularization):
2426
convergence_warning = ("Maximum number of iterations 1 reached. "
2527
"Increase it to improve convergence.")
2628
A = np.ones((2, 2))
2729
with pytest.warns(ConvergenceWarning, match=convergence_warning):
28-
NMF(solver=solver, max_iter=1).fit(A)
30+
NMF(solver=solver, regularization=regularization, max_iter=1).fit(A)
2931

3032

3133
def test_initialize_nn_output():
@@ -44,6 +46,8 @@ def test_parameter_checking():
4446
assert_raise_message(ValueError, msg, NMF(solver=name).fit, A)
4547
msg = "Invalid init parameter: got 'spam' instead of one of"
4648
assert_raise_message(ValueError, msg, NMF(init=name).fit, A)
49+
msg = "Invalid regularization parameter: got 'spam' instead of one of"
50+
assert_raise_message(ValueError, msg, NMF(regularization=name).fit, A)
4751
msg = "Invalid beta_loss parameter: got 'spam' instead of one"
4852
assert_raise_message(ValueError, msg, NMF(solver='mu',
4953
beta_loss=name).fit, A)
@@ -97,36 +101,43 @@ def test_initialize_variants():
97101

98102
# ignore UserWarning raised when both solver='mu' and init='nndsvd'
99103
@ignore_warnings(category=UserWarning)
100-
def test_nmf_fit_nn_output():
104+
@pytest.mark.parametrize('solver', ('cd', 'mu'))
105+
@pytest.mark.parametrize('init',
106+
(None, 'nndsvd', 'nndsvda', 'nndsvdar', 'random'))
107+
@pytest.mark.parametrize('regularization',
108+
(None, 'both', 'components', 'transformation'))
109+
def test_nmf_fit_nn_output(solver, init, regularization):
101110
# Test that the decomposition does not contain negative values
102111
A = np.c_[5. - np.arange(1, 6),
103112
5. + np.arange(1, 6)]
104-
for solver in ('cd', 'mu'):
105-
for init in (None, 'nndsvd', 'nndsvda', 'nndsvdar', 'random'):
106-
model = NMF(n_components=2, solver=solver, init=init,
107-
random_state=0)
108-
transf = model.fit_transform(A)
109-
assert not((model.components_ < 0).any() or
110-
(transf < 0).any())
113+
model = NMF(n_components=2, solver=solver, init=init,
114+
regularization=regularization, random_state=0)
115+
transf = model.fit_transform(A)
116+
assert not((model.components_ < 0).any() or
117+
(transf < 0).any())
111118

112119

113120
@pytest.mark.parametrize('solver', ('cd', 'mu'))
114-
def test_nmf_fit_close(solver):
121+
@pytest.mark.parametrize('regularization',
122+
(None, 'both', 'components', 'transformation'))
123+
def test_nmf_fit_close(solver, regularization):
115124
rng = np.random.mtrand.RandomState(42)
116125
# Test that the fit is not too far away
117126
pnmf = NMF(5, solver=solver, init='nndsvdar', random_state=0,
118-
max_iter=600)
127+
regularization=regularization, max_iter=600)
119128
X = np.abs(rng.randn(6, 5))
120129
assert pnmf.fit(X).reconstruction_err_ < 0.1
121130

122131

123132
@pytest.mark.parametrize('solver', ('cd', 'mu'))
124-
def test_nmf_transform(solver):
133+
@pytest.mark.parametrize('regularization',
134+
(None, 'both', 'components', 'transformation'))
135+
def test_nmf_transform(solver, regularization):
125136
# Test that NMF.transform returns close values
126137
rng = np.random.mtrand.RandomState(42)
127138
A = np.abs(rng.randn(6, 5))
128139
m = NMF(solver=solver, n_components=3, init='random',
129-
random_state=0, tol=1e-5)
140+
regularization=regularization, random_state=0, tol=1e-5)
130141
ft = m.fit_transform(A)
131142
t = m.transform(A)
132143
assert_array_almost_equal(ft, t, decimal=2)
@@ -148,12 +159,14 @@ def test_nmf_transform_custom_init():
148159

149160

150161
@pytest.mark.parametrize('solver', ('cd', 'mu'))
151-
def test_nmf_inverse_transform(solver):
162+
@pytest.mark.parametrize('regularization',
163+
(None, 'both', 'components', 'transformation'))
164+
def test_nmf_inverse_transform(solver, regularization):
152165
# Test that NMF.inverse_transform returns close values
153166
random_state = np.random.RandomState(0)
154167
A = np.abs(random_state.randn(6, 4))
155168
m = NMF(solver=solver, n_components=4, init='random', random_state=0,
156-
max_iter=1000)
169+
regularization=regularization, max_iter=1000)
157170
ft = m.fit_transform(A)
158171
A_new = m.inverse_transform(ft)
159172
assert_array_almost_equal(A, A_new, decimal=2)
@@ -167,7 +180,9 @@ def test_n_components_greater_n_features():
167180

168181

169182
@pytest.mark.parametrize('solver', ['cd', 'mu'])
170-
def test_nmf_sparse_input(solver):
183+
@pytest.mark.parametrize('regularization',
184+
[None, 'both', 'components', 'transformation'])
185+
def test_nmf_sparse_input(solver, regularization):
171186
# Test that sparse matrices are accepted as input
172187
from scipy.sparse import csc_matrix
173188

@@ -177,7 +192,8 @@ def test_nmf_sparse_input(solver):
177192
A_sparse = csc_matrix(A)
178193

179194
est1 = NMF(solver=solver, n_components=5, init='random',
180-
random_state=0, tol=1e-2)
195+
regularization=regularization, random_state=0,
196+
tol=1e-2)
181197
est2 = clone(est1)
182198

183199
W1 = est1.fit_transform(A)
@@ -204,28 +220,32 @@ def test_nmf_sparse_transform():
204220
assert_array_almost_equal(A_fit_tr, A_tr, decimal=1)
205221

206222

207-
def test_non_negative_factorization_consistency():
223+
@pytest.mark.parametrize('init', ['random', 'nndsvd'])
224+
@pytest.mark.parametrize('solver', ('cd', 'mu'))
225+
@pytest.mark.parametrize('regularization',
226+
(None, 'both', 'components', 'transformation'))
227+
def test_non_negative_factorization_consistency(init, solver, regularization):
208228
# Test that the function is called in the same way, either directly
209229
# or through the NMF class
210230
rng = np.random.mtrand.RandomState(42)
211231
A = np.abs(rng.randn(10, 10))
212232
A[:, 2 * np.arange(5)] = 0
213233

214-
for init in ['random', 'nndsvd']:
215-
for solver in ('cd', 'mu'):
216-
W_nmf, H, _ = non_negative_factorization(
217-
A, init=init, solver=solver, random_state=1, tol=1e-2)
218-
W_nmf_2, _, _ = non_negative_factorization(
219-
A, H=H, update_H=False, init=init, solver=solver,
220-
random_state=1, tol=1e-2)
234+
W_nmf, H, _ = non_negative_factorization(
235+
A, init=init, solver=solver,
236+
regularization=regularization, random_state=1, tol=1e-2)
237+
W_nmf_2, _, _ = non_negative_factorization(
238+
A, H=H, update_H=False, init=init, solver=solver,
239+
regularization=regularization, random_state=1, tol=1e-2)
221240

222-
model_class = NMF(init=init, solver=solver, random_state=1,
223-
tol=1e-2)
224-
W_cls = model_class.fit_transform(A)
225-
W_cls_2 = model_class.transform(A)
241+
model_class = NMF(init=init, solver=solver,
242+
regularization=regularization,
243+
random_state=1, tol=1e-2)
244+
W_cls = model_class.fit_transform(A)
245+
W_cls_2 = model_class.transform(A)
226246

227-
assert_array_almost_equal(W_nmf, W_cls, decimal=10)
228-
assert_array_almost_equal(W_nmf_2, W_cls_2, decimal=10)
247+
assert_array_almost_equal(W_nmf, W_cls, decimal=10)
248+
assert_array_almost_equal(W_nmf_2, W_cls_2, decimal=10)
229249

230250

231251
def test_non_negative_factorization_checking():
@@ -515,25 +535,29 @@ def test_nmf_underflow():
515535
(np.int32, np.float64),
516536
(np.int64, np.float64)])
517537
@pytest.mark.parametrize("solver", ["cd", "mu"])
518-
def test_nmf_dtype_match(dtype_in, dtype_out, solver):
538+
@pytest.mark.parametrize("regularization",
539+
(None, "both", "components", "transformation"))
540+
def test_nmf_dtype_match(dtype_in, dtype_out, solver, regularization):
519541
# Check that NMF preserves dtype (float32 and float64)
520542
X = np.random.RandomState(0).randn(20, 15).astype(dtype_in, copy=False)
521543
np.abs(X, out=X)
522-
nmf = NMF(solver=solver)
544+
nmf = NMF(solver=solver, regularization=regularization)
523545

524546
assert nmf.fit(X).transform(X).dtype == dtype_out
525547
assert nmf.fit_transform(X).dtype == dtype_out
526548
assert nmf.components_.dtype == dtype_out
527549

528550

529551
@pytest.mark.parametrize("solver", ["cd", "mu"])
530-
def test_nmf_float32_float64_consistency(solver):
552+
@pytest.mark.parametrize("regularization",
553+
(None, "both", "components", "transformation"))
554+
def test_nmf_float32_float64_consistency(solver, regularization):
531555
# Check that the result of NMF is the same between float32 and float64
532556
X = np.random.RandomState(0).randn(50, 7)
533557
np.abs(X, out=X)
534-
nmf32 = NMF(solver=solver, random_state=0)
558+
nmf32 = NMF(solver=solver, regularization=regularization, random_state=0)
535559
W32 = nmf32.fit_transform(X.astype(np.float32))
536-
nmf64 = NMF(solver=solver, random_state=0)
560+
nmf64 = NMF(solver=solver, regularization=regularization, random_state=0)
537561
W64 = nmf64.fit_transform(X)
538562

539563
assert_allclose(W32, W64, rtol=1e-6, atol=1e-5)

0 commit comments

Comments
 (0)