From cb5a185cf9ff45969f981da0a53221bb2afb28bf Mon Sep 17 00:00:00 2001 From: Scott Gigante Date: Mon, 27 May 2019 19:12:42 -0400 Subject: [PATCH] allow sparse input to incremental PCA --- doc/modules/decomposition.rst | 3 +- doc/whats_new/v0.22.rst | 8 +++ sklearn/decomposition/incremental_pca.py | 70 ++++++++++++++++--- .../tests/test_incremental_pca.py | 47 +++++++++++-- 4 files changed, 113 insertions(+), 15 deletions(-) diff --git a/doc/modules/decomposition.rst b/doc/modules/decomposition.rst index 445117220c12c..313229fa326a9 100644 --- a/doc/modules/decomposition.rst +++ b/doc/modules/decomposition.rst @@ -74,7 +74,8 @@ out-of-core Principal Component Analysis either by: * Using its ``partial_fit`` method on chunks of data fetched sequentially from the local hard drive or a network database. - * Calling its fit method on a memory mapped file using ``numpy.memmap``. + * Calling its fit method on a sparse matrix or a memory mapped file using + ``numpy.memmap``. :class:`IncrementalPCA` only stores estimates of component and noise variances, in order update ``explained_variance_ratio_`` incrementally. This is why diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index c29eabd7df9f2..4d13721a8f0c7 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -39,6 +39,14 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.decomposition` +.................. + +- |Enhancement| :class:`decomposition.IncrementalPCA` now accepts sparse + matrices as input, converting them to dense in batches thereby avoiding the + need to store the entire dense matrix at once. + :pr:`13960` by :user:`Scott Gigante `. + :mod:`sklearn.ensemble` ....................... diff --git a/sklearn/decomposition/incremental_pca.py b/sklearn/decomposition/incremental_pca.py index 33dc7217ee858..c6d611dcd5fea 100644 --- a/sklearn/decomposition/incremental_pca.py +++ b/sklearn/decomposition/incremental_pca.py @@ -5,7 +5,7 @@ # License: BSD 3 clause import numpy as np -from scipy import linalg +from scipy import linalg, sparse from .base import _BasePCA from ..utils import check_array, gen_batches @@ -21,11 +21,13 @@ class IncrementalPCA(_BasePCA): but not scaled for each feature before applying the SVD. Depending on the size of the input data, this algorithm can be much more - memory efficient than a PCA. + memory efficient than a PCA, and allows sparse input. This algorithm has constant memory complexity, on the order - of ``batch_size``, enabling use of np.memmap files without loading the - entire file into memory. + of ``batch_size * n_features``, enabling use of np.memmap files without + loading the entire file into memory. For sparse matrices, the input + is converted to dense in batches (in order to be able to subtract the + mean) which avoids storing the entire dense matrix at any one time. The computational overhead of each SVD is ``O(batch_size * n_features ** 2)``, but only 2 * batch_size samples @@ -104,13 +106,15 @@ class IncrementalPCA(_BasePCA): -------- >>> from sklearn.datasets import load_digits >>> from sklearn.decomposition import IncrementalPCA + >>> from scipy import sparse >>> X, _ = load_digits(return_X_y=True) >>> transformer = IncrementalPCA(n_components=7, batch_size=200) >>> # either partially fit on smaller batches of data >>> transformer.partial_fit(X[:100, :]) IncrementalPCA(batch_size=200, n_components=7) >>> # or let the fit function itself divide the data into batches - >>> X_transformed = transformer.fit_transform(X) + >>> X_sparse = sparse.csr_matrix(X) + >>> X_transformed = transformer.fit_transform(X_sparse) >>> X_transformed.shape (1797, 7) @@ -167,7 +171,7 @@ def fit(self, X, y=None): Parameters ---------- - X : array-like, shape (n_samples, n_features) + X : array-like or sparse matrix, shape (n_samples, n_features) Training data, where n_samples is the number of samples and n_features is the number of features. @@ -188,7 +192,8 @@ def fit(self, X, y=None): self.singular_values_ = None self.noise_variance_ = None - X = check_array(X, copy=self.copy, dtype=[np.float64, np.float32]) + X = check_array(X, accept_sparse=['csr', 'csc', 'lil'], + copy=self.copy, dtype=[np.float64, np.float32]) n_samples, n_features = X.shape if self.batch_size is None: @@ -198,7 +203,10 @@ def fit(self, X, y=None): for batch in gen_batches(n_samples, self.batch_size_, min_batch_size=self.n_components or 0): - self.partial_fit(X[batch], check_input=False) + X_batch = X[batch] + if sparse.issparse(X_batch): + X_batch = X_batch.toarray() + self.partial_fit(X_batch, check_input=False) return self @@ -221,6 +229,11 @@ def partial_fit(self, X, y=None, check_input=True): Returns the instance itself. """ if check_input: + if sparse.issparse(X): + raise TypeError( + "IncrementalPCA.partial_fit does not support " + "sparse input. Either convert data to dense " + "or use IncrementalPCA.fit to do so in batches.") X = check_array(X, copy=self.copy, dtype=[np.float64, np.float32]) n_samples, n_features = X.shape if not hasattr(self, 'components_'): @@ -274,7 +287,7 @@ def partial_fit(self, X, y=None, check_input=True): np.sqrt((self.n_samples_seen_ * n_samples) / n_total_samples) * (self.mean_ - col_batch_mean) X = np.vstack((self.singular_values_.reshape((-1, 1)) * - self.components_, X, mean_correction)) + self.components_, X, mean_correction)) U, S, V = linalg.svd(X, full_matrices=False) U, V = svd_flip(U, V, u_based_decision=False) @@ -295,3 +308,42 @@ def partial_fit(self, X, y=None, check_input=True): else: self.noise_variance_ = 0. return self + + def transform(self, X): + """Apply dimensionality reduction to X. + + X is projected on the first principal components previously extracted + from a training set, using minibatches of size batch_size if X is + sparse. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + New data, where n_samples is the number of samples + and n_features is the number of features. + + Returns + ------- + X_new : array-like, shape (n_samples, n_components) + + Examples + -------- + + >>> import numpy as np + >>> from sklearn.decomposition import IncrementalPCA + >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], + ... [1, 1], [2, 1], [3, 2]]) + >>> ipca = IncrementalPCA(n_components=2, batch_size=3) + >>> ipca.fit(X) + IncrementalPCA(batch_size=3, n_components=2) + >>> ipca.transform(X) # doctest: +SKIP + """ + if sparse.issparse(X): + n_samples = X.shape[0] + output = [] + for batch in gen_batches(n_samples, self.batch_size_, + min_batch_size=self.n_components or 0): + output.append(super().transform(X[batch].toarray())) + return np.vstack(output) + else: + return super().transform(X) diff --git a/sklearn/decomposition/tests/test_incremental_pca.py b/sklearn/decomposition/tests/test_incremental_pca.py index 01fe7c8889a1f..0d649ea7d75b9 100644 --- a/sklearn/decomposition/tests/test_incremental_pca.py +++ b/sklearn/decomposition/tests/test_incremental_pca.py @@ -1,5 +1,6 @@ """Tests for Incremental PCA.""" import numpy as np +import pytest from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal @@ -10,6 +11,8 @@ from sklearn import datasets from sklearn.decomposition import PCA, IncrementalPCA +from scipy import sparse + iris = datasets.load_iris() @@ -23,17 +26,51 @@ def test_incremental_pca(): X_transformed = ipca.fit_transform(X) - np.testing.assert_equal(X_transformed.shape, (X.shape[0], 2)) - assert_almost_equal(ipca.explained_variance_ratio_.sum(), - pca.explained_variance_ratio_.sum(), 1) + assert X_transformed.shape == (X.shape[0], 2) + np.testing.assert_allclose(ipca.explained_variance_ratio_.sum(), + pca.explained_variance_ratio_.sum(), rtol=1e-3) for n_components in [1, 2, X.shape[1]]: ipca = IncrementalPCA(n_components, batch_size=batch_size) ipca.fit(X) cov = ipca.get_covariance() precision = ipca.get_precision() - assert_array_almost_equal(np.dot(cov, precision), - np.eye(X.shape[1])) + np.testing.assert_allclose(np.dot(cov, precision), + np.eye(X.shape[1]), atol=1e-13) + + +@pytest.mark.parametrize( + "matrix_class", + [sparse.csc_matrix, sparse.csr_matrix, sparse.lil_matrix]) +def test_incremental_pca_sparse(matrix_class): + # Incremental PCA on sparse arrays. + X = iris.data + pca = PCA(n_components=2) + pca.fit_transform(X) + X_sparse = matrix_class(X) + batch_size = X_sparse.shape[0] // 3 + ipca = IncrementalPCA(n_components=2, batch_size=batch_size) + + X_transformed = ipca.fit_transform(X_sparse) + + assert X_transformed.shape == (X_sparse.shape[0], 2) + np.testing.assert_allclose(ipca.explained_variance_ratio_.sum(), + pca.explained_variance_ratio_.sum(), rtol=1e-3) + + for n_components in [1, 2, X.shape[1]]: + ipca = IncrementalPCA(n_components, batch_size=batch_size) + ipca.fit(X_sparse) + cov = ipca.get_covariance() + precision = ipca.get_precision() + np.testing.assert_allclose(np.dot(cov, precision), + np.eye(X_sparse.shape[1]), atol=1e-13) + + with pytest.raises( + TypeError, + match="IncrementalPCA.partial_fit does not support " + "sparse input. Either convert data to dense " + "or use IncrementalPCA.fit to do so in batches."): + ipca.partial_fit(X_sparse) def test_incremental_pca_check_projection():