Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 16adc8f

Browse files
skailasaglemaitre
authored andcommitted
ENH Add lapack driver argument to randomized svd function (scikit-learn#20617)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 4671a14 commit 16adc8f

File tree

3 files changed

+55
-14
lines changed

3 files changed

+55
-14
lines changed

doc/whats_new/v1.2.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ Changelog
9494
:pr:`23139` by `Tom Dupre la Tour`_.
9595

9696
:mod:`sklearn.preprocessing`
97-
.............................
97+
............................
9898

9999
- |Fix| :class:`preprocessing.PolynomialFeatures` with ``degree`` equal to 0 will raise
100100
error when ``include_bias`` is set to False, and outputs a single constant array when
101-
``include_bias`` is set to True. :pr:`23370` by :user:`Zhehao Liu <MaxwellLZH>`.
101+
``include_bias`` is set to True. :pr:`23370` by :user:`Zhehao Liu <MaxwellLZH>`.
102102

103103
:mod:`sklearn.tree`
104104
...................
@@ -107,6 +107,14 @@ Changelog
107107
:class:`tree.DecisionTreeRegressor` and :class:`tree.DecisionTreeClassifier`.
108108
:pr:`23273` by `Thomas Fan`_.
109109

110+
:mod:`sklearn.utils`
111+
....................
112+
113+
- |Enhancement| :func:`utils.extmath.randomized_svd` now accepts an argument,
114+
`lapack_svd_driver`, to specify the lapack driver used in the internal
115+
deterministic SVD used by the randomized SVD algorithm.
116+
:pr:`20617` by :user:`Srinath Kailasa <skailasa>`
117+
110118
Code and Documentation Contributors
111119
-----------------------------------
112120

sklearn/utils/extmath.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,12 @@ def randomized_svd(
256256
transpose="auto",
257257
flip_sign=True,
258258
random_state="warn",
259+
svd_lapack_driver="gesdd",
259260
):
260261
"""Computes a truncated randomized SVD.
261262
262-
This method solves the fixed-rank approximation problem described in the
263-
Halko et al paper (problem (1.5), p5).
263+
This method solves the fixed-rank approximation problem described in [1]_
264+
(problem (1.5), p5).
264265
265266
Parameters
266267
----------
@@ -278,8 +279,8 @@ def randomized_svd(
278279
approximation of singular vectors and singular values. Users might wish
279280
to increase this parameter up to `2*k - n_components` where k is the
280281
effective rank, for large matrices, noisy problems, matrices with
281-
slowly decaying spectrums, or to increase precision accuracy. See Halko
282-
et al (pages 5, 23 and 26).
282+
slowly decaying spectrums, or to increase precision accuracy. See [1]_
283+
(pages 5, 23 and 26).
283284
284285
n_iter : int or 'auto', default='auto'
285286
Number of power iterations. It can be used to deal with very noisy
@@ -291,7 +292,7 @@ def randomized_svd(
291292
more costly power iterations steps. When `n_components` is equal
292293
or greater to the effective matrix rank and the spectrum does not
293294
present a slow decay, `n_iter=0` or `1` should even work fine in theory
294-
(see Halko et al paper, page 9).
295+
(see [1]_ page 9).
295296
296297
.. versionchanged:: 0.18
297298
@@ -332,6 +333,14 @@ def randomized_svd(
332333
the value of `random_state` explicitly to suppress the deprecation
333334
warning.
334335
336+
svd_lapack_driver : {"gesdd", "gesvd"}, default="gesdd"
337+
Whether to use the more efficient divide-and-conquer approach
338+
(`"gesdd"`) or more general rectangular approach (`"gesvd"`) to compute
339+
the SVD of the matrix B, which is the projection of M into a low
340+
dimensional subspace, as described in [1]_.
341+
342+
.. versionadded:: 1.2
343+
335344
Notes
336345
-----
337346
This algorithm finds a (usually very good) approximate truncated
@@ -346,17 +355,16 @@ def randomized_svd(
346355
347356
References
348357
----------
349-
* :arxiv:`"Finding structure with randomness:
358+
.. [1] :arxiv:`"Finding structure with randomness:
350359
Stochastic algorithms for constructing approximate matrix decompositions"
351360
<0909.4061>`
352361
Halko, et al. (2009)
353362
354-
* A randomized algorithm for the decomposition of matrices
363+
.. [2] A randomized algorithm for the decomposition of matrices
355364
Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert
356365
357-
* An implementation of a randomized algorithm for principal component
358-
analysis
359-
A. Szlam et al. 2014
366+
.. [3] An implementation of a randomized algorithm for principal component
367+
analysis A. Szlam et al. 2014
360368
"""
361369
if isinstance(M, (sparse.lil_matrix, sparse.dok_matrix)):
362370
warnings.warn(
@@ -405,7 +413,7 @@ def randomized_svd(
405413
B = safe_sparse_dot(Q.T, M)
406414

407415
# compute the SVD on the thin matrix: (k + p) wide
408-
Uhat, s, Vt = linalg.svd(B, full_matrices=False)
416+
Uhat, s, Vt = linalg.svd(B, full_matrices=False, lapack_driver=svd_lapack_driver)
409417

410418
del B
411419
U = np.dot(Q, Uhat)

sklearn/utils/tests/test_extmath.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Denis Engemann <[email protected]>
44
#
55
# License: BSD 3 clause
6-
76
import numpy as np
87
from scipy import sparse
98
from scipy import linalg
@@ -568,6 +567,32 @@ def max_loading_is_positive(u, v):
568567
assert not v_based
569568

570569

570+
@pytest.mark.parametrize("n", [50, 100, 300])
571+
@pytest.mark.parametrize("m", [50, 100, 300])
572+
@pytest.mark.parametrize("k", [10, 20, 50])
573+
@pytest.mark.parametrize("seed", range(5))
574+
def test_randomized_svd_lapack_driver(n, m, k, seed):
575+
# Check that different SVD drivers provide consistent results
576+
577+
# Matrix being compressed
578+
rng = np.random.RandomState(seed)
579+
X = rng.rand(n, m)
580+
581+
# Number of components
582+
u1, s1, vt1 = randomized_svd(X, k, svd_lapack_driver="gesdd", random_state=0)
583+
u2, s2, vt2 = randomized_svd(X, k, svd_lapack_driver="gesvd", random_state=0)
584+
585+
# Check shape and contents
586+
assert u1.shape == u2.shape
587+
assert_allclose(u1, u2, atol=0, rtol=1e-3)
588+
589+
assert s1.shape == s2.shape
590+
assert_allclose(s1, s2, atol=0, rtol=1e-3)
591+
592+
assert vt1.shape == vt2.shape
593+
assert_allclose(vt1, vt2, atol=0, rtol=1e-3)
594+
595+
571596
def test_cartesian():
572597
# Check if cartesian product delivers the right results
573598

0 commit comments

Comments
 (0)