Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 06834fc

Browse files
glemaitrejeremiedbbogrisel
authored
DEPR deprecate n_iter in MiniBatchSparsePCA (#23726)
Co-authored-by: Jérémie du Boisberranger <[email protected]> Co-authored-by: jeremie du boisberranger <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
1 parent ae7712a commit 06834fc

File tree

5 files changed

+144
-55
lines changed

5 files changed

+144
-55
lines changed

doc/whats_new/v1.2.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ Changelog
110110
its memory footprint and runtime.
111111
:pr:`22268` by :user:`MohamedBsh <Bsh>`.
112112

113+
:mod:`sklearn.decomposition`
114+
............................
115+
116+
- |API| The `n_iter` parameter of :class:`decomposition.MiniBatchSparsePCA` is
117+
deprecated and replaced by the parameters `max_iter`, `tol`, and
118+
`max_no_improvement` to be consistent with
119+
:class:`decomposition.MiniBatchDictionaryLearning`. `n_iter` will be removed
120+
in version 1.3. :pr:`23726` by :user:`Guillaume Lemaitre <glemaitre>`.
121+
122+
113123
:mod:`sklearn.ensemble`
114124
.......................
115125

sklearn/decomposition/_dict_learning.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def dict_learning_online(
794794
Number of mini-batch iterations to perform.
795795
796796
.. deprecated:: 1.1
797-
`n_iter` is deprecated in 1.1 and will be removed in 1.3. Use
797+
`n_iter` is deprecated in 1.1 and will be removed in 1.4. Use
798798
`max_iter` instead.
799799
800800
max_iter : int, default=None
@@ -1758,7 +1758,7 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
17581758
Total number of iterations over data batches to perform.
17591759
17601760
.. deprecated:: 1.1
1761-
``n_iter`` is deprecated in 1.1 and will be removed in 1.3. Use
1761+
``n_iter`` is deprecated in 1.1 and will be removed in 1.4. Use
17621762
``max_iter`` instead.
17631763
17641764
max_iter : int, default=None
@@ -2251,6 +2251,17 @@ def fit(self, X, y=None):
22512251
)
22522252

22532253
self._check_params(X)
2254+
2255+
if self.n_iter != "deprecated":
2256+
warnings.warn(
2257+
"'n_iter' is deprecated in version 1.1 and will be removed "
2258+
"in version 1.4. Use 'max_iter' and let 'n_iter' to its default "
2259+
"value instead. 'n_iter' is also ignored if 'max_iter' is "
2260+
"specified.",
2261+
FutureWarning,
2262+
)
2263+
n_iter = self.n_iter
2264+
22542265
self._random_state = check_random_state(self.random_state)
22552266

22562267
dictionary = self._initialize_dict(X, self._random_state)
@@ -2310,15 +2321,7 @@ def fit(self, X, y=None):
23102321
self.n_iter_ = np.ceil(self.n_steps_ / n_steps_per_iter)
23112322
else:
23122323
# TODO remove this branch in 1.3
2313-
if self.n_iter != "deprecated":
2314-
warnings.warn(
2315-
"'n_iter' is deprecated in version 1.1 and will be removed"
2316-
" in version 1.3. Use 'max_iter' instead.",
2317-
FutureWarning,
2318-
)
2319-
n_iter = self.n_iter
2320-
else:
2321-
n_iter = 1000
2324+
n_iter = 1000 if self.n_iter == "deprecated" else self.n_iter
23222325

23232326
batches = gen_batches(n_samples, self._batch_size)
23242327
batches = itertools.cycle(batches)

sklearn/decomposition/_sparse_pca.py

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
# Author: Vlad Niculae, Gael Varoquaux, Alexandre Gramfort
33
# License: BSD 3 clause
44

5-
import warnings
6-
75
import numpy as np
86

97
from ..utils import check_random_state
108
from ..utils.validation import check_is_fitted
119
from ..linear_model import ridge_regression
1210
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
13-
from ._dict_learning import dict_learning, dict_learning_online
11+
from ._dict_learning import dict_learning, MiniBatchDictionaryLearning
1412

1513

1614
class SparsePCA(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
@@ -275,6 +273,17 @@ class MiniBatchSparsePCA(SparsePCA):
275273
n_iter : int, default=100
276274
Number of iterations to perform for each mini batch.
277275
276+
.. deprecated:: 1.2
277+
`n_iter` is deprecated in 1.2 and will be removed in 1.4. Use
278+
`max_iter` instead.
279+
280+
max_iter : int, default=None
281+
Maximum number of iterations over the complete dataset before
282+
stopping independently of any early stopping criterion heuristics.
283+
If `max_iter` is not `None`, `n_iter` is ignored.
284+
285+
.. versionadded:: 1.2
286+
278287
callback : callable, default=None
279288
Callable that gets invoked every five iterations.
280289
@@ -307,6 +316,25 @@ class MiniBatchSparsePCA(SparsePCA):
307316
across multiple function calls.
308317
See :term:`Glossary <random_state>`.
309318
319+
tol : float, default=1e-3
320+
Control early stopping based on the norm of the differences in the
321+
dictionary between 2 steps. Used only if `max_iter` is not None.
322+
323+
To disable early stopping based on changes in the dictionary, set
324+
`tol` to 0.0.
325+
326+
.. versionadded:: 1.1
327+
328+
max_no_improvement : int, default=10
329+
Control early stopping based on the consecutive number of mini batches
330+
that does not yield an improvement on the smoothed cost function. Used only if
331+
`max_iter` is not None.
332+
333+
To disable convergence detection based on cost function, set
334+
`max_no_improvement` to `None`.
335+
336+
.. versionadded:: 1.1
337+
310338
Attributes
311339
----------
312340
components_ : ndarray of shape (n_components, n_features)
@@ -350,15 +378,15 @@ class MiniBatchSparsePCA(SparsePCA):
350378
>>> from sklearn.decomposition import MiniBatchSparsePCA
351379
>>> X, _ = make_friedman1(n_samples=200, n_features=30, random_state=0)
352380
>>> transformer = MiniBatchSparsePCA(n_components=5, batch_size=50,
353-
... random_state=0)
381+
... max_iter=10, random_state=0)
354382
>>> transformer.fit(X)
355383
MiniBatchSparsePCA(...)
356384
>>> X_transformed = transformer.transform(X)
357385
>>> X_transformed.shape
358386
(200, 5)
359387
>>> # most values in the components_ are zero (sparsity)
360388
>>> np.mean(transformer.components_ == 0)
361-
0.94
389+
0.9...
362390
"""
363391

364392
def __init__(
@@ -367,14 +395,17 @@ def __init__(
367395
*,
368396
alpha=1,
369397
ridge_alpha=0.01,
370-
n_iter=100,
398+
n_iter="deprecated",
399+
max_iter=None,
371400
callback=None,
372401
batch_size=3,
373402
verbose=False,
374403
shuffle=True,
375404
n_jobs=None,
376405
method="lars",
377406
random_state=None,
407+
tol=1e-3,
408+
max_no_improvement=10,
378409
):
379410
super().__init__(
380411
n_components=n_components,
@@ -386,9 +417,12 @@ def __init__(
386417
random_state=random_state,
387418
)
388419
self.n_iter = n_iter
420+
self.max_iter = max_iter
389421
self.callback = callback
390422
self.batch_size = batch_size
391423
self.shuffle = shuffle
424+
self.tol = tol
425+
self.max_no_improvement = max_no_improvement
392426

393427
def fit(self, X, y=None):
394428
"""Fit the model from data in X.
@@ -418,44 +452,27 @@ def fit(self, X, y=None):
418452
else:
419453
n_components = self.n_components
420454

421-
with warnings.catch_warnings():
422-
# return_n_iter and n_iter are deprecated. TODO Remove in 1.3
423-
warnings.filterwarnings(
424-
"ignore",
425-
message=(
426-
"'return_n_iter' is deprecated in version 1.1 and will be "
427-
"removed in version 1.3. From 1.3 'n_iter' will never be "
428-
"returned. Refer to the 'n_iter_' and 'n_steps_' attributes "
429-
"of the MiniBatchDictionaryLearning object instead."
430-
),
431-
category=FutureWarning,
432-
)
433-
warnings.filterwarnings(
434-
"ignore",
435-
message=(
436-
"'n_iter' is deprecated in version 1.1 and will be removed in "
437-
"version 1.3. Use 'max_iter' instead."
438-
),
439-
category=FutureWarning,
440-
)
441-
Vt, _, self.n_iter_ = dict_learning_online(
442-
X.T,
443-
n_components,
444-
alpha=self.alpha,
445-
n_iter=self.n_iter,
446-
return_code=True,
447-
dict_init=None,
448-
verbose=self.verbose,
449-
callback=self.callback,
450-
batch_size=self.batch_size,
451-
shuffle=self.shuffle,
452-
n_jobs=self.n_jobs,
453-
method=self.method,
454-
random_state=random_state,
455-
return_n_iter=True,
456-
)
455+
transform_algorithm = "lasso_" + self.method
456+
est = MiniBatchDictionaryLearning(
457+
n_components=n_components,
458+
alpha=self.alpha,
459+
n_iter=self.n_iter,
460+
max_iter=self.max_iter,
461+
dict_init=None,
462+
batch_size=self.batch_size,
463+
shuffle=self.shuffle,
464+
n_jobs=self.n_jobs,
465+
fit_algorithm=self.method,
466+
random_state=random_state,
467+
transform_algorithm=transform_algorithm,
468+
transform_alpha=self.alpha,
469+
verbose=self.verbose,
470+
callback=self.callback,
471+
tol=self.tol,
472+
max_no_improvement=self.max_no_improvement,
473+
).fit(X.T)
457474

458-
self.components_ = Vt.T
475+
self.components_, self.n_iter_ = est.transform(X.T).T, est.n_iter_
459476

460477
components_norm = np.linalg.norm(self.components_, axis=1)[:, np.newaxis]
461478
components_norm[components_norm == 0] = 1

sklearn/decomposition/tests/test_dict_learning.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def test_minibatch_dict_learning_n_iter_deprecated():
721721
# check the deprecation warning of n_iter
722722
# FIXME: remove in 1.3
723723
depr_msg = (
724-
"'n_iter' is deprecated in version 1.1 and will be removed in version 1.3"
724+
"'n_iter' is deprecated in version 1.1 and will be removed in version 1.4"
725725
)
726726
est = MiniBatchDictionaryLearning(
727727
n_components=2, batch_size=4, n_iter=5, random_state=0
@@ -1072,3 +1072,14 @@ def test_get_feature_names_out(estimator):
10721072
feature_names_out,
10731073
[f"{estimator_name}{i}" for i in range(n_components)],
10741074
)
1075+
1076+
1077+
# TODO(1.4) remove
1078+
def test_minibatch_dictionary_learning_warns_and_ignore_n_iter():
1079+
"""Check that we always raise a warning when `n_iter` is set even if it is
1080+
ignored if `max_iter` is set.
1081+
"""
1082+
warn_msg = "'n_iter' is deprecated in version 1.1"
1083+
with pytest.warns(FutureWarning, match=warn_msg):
1084+
model = MiniBatchDictionaryLearning(batch_size=256, n_iter=2, max_iter=2).fit(X)
1085+
assert model.n_iter_ == 2

sklearn/decomposition/tests/test_sparse_pca.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,51 @@ def test_spca_feature_names_out(SPCA):
265265

266266
estimator_name = SPCA.__name__.lower()
267267
assert_array_equal([f"{estimator_name}{i}" for i in range(4)], names)
268+
269+
270+
# TODO (1.4): remove this test
271+
def test_spca_n_iter_deprecation():
272+
"""Check that we raise a warning for the deprecation of `n_iter` and it is ignored
273+
when `max_iter` is specified.
274+
"""
275+
rng = np.random.RandomState(0)
276+
n_samples, n_features = 12, 10
277+
X = rng.randn(n_samples, n_features)
278+
279+
warn_msg = "'n_iter' is deprecated in version 1.1 and will be removed"
280+
with pytest.warns(FutureWarning, match=warn_msg):
281+
MiniBatchSparsePCA(n_iter=2).fit(X)
282+
283+
n_iter, max_iter = 1, 100
284+
with pytest.warns(FutureWarning, match=warn_msg):
285+
model = MiniBatchSparsePCA(
286+
n_iter=n_iter, max_iter=max_iter, random_state=0
287+
).fit(X)
288+
assert model.n_iter_ > 1
289+
assert model.n_iter_ <= max_iter
290+
291+
292+
def test_spca_early_stopping(global_random_seed):
293+
"""Check that `tol` and `max_no_improvement` act as early stopping."""
294+
rng = np.random.RandomState(global_random_seed)
295+
n_samples, n_features = 50, 10
296+
X = rng.randn(n_samples, n_features)
297+
298+
# vary the tolerance to force the early stopping of one of the model
299+
model_early_stopped = MiniBatchSparsePCA(
300+
max_iter=100, tol=0.5, random_state=global_random_seed
301+
).fit(X)
302+
model_not_early_stopped = MiniBatchSparsePCA(
303+
max_iter=100, tol=1e-3, random_state=global_random_seed
304+
).fit(X)
305+
assert model_early_stopped.n_iter_ < model_not_early_stopped.n_iter_
306+
307+
# force the max number of no improvement to a large value to check that
308+
# it does help to early stop
309+
model_early_stopped = MiniBatchSparsePCA(
310+
max_iter=100, tol=1e-6, max_no_improvement=2, random_state=global_random_seed
311+
).fit(X)
312+
model_not_early_stopped = MiniBatchSparsePCA(
313+
max_iter=100, tol=1e-6, max_no_improvement=100, random_state=global_random_seed
314+
).fit(X)
315+
assert model_early_stopped.n_iter_ < model_not_early_stopped.n_iter_

0 commit comments

Comments
 (0)