From 2479fda9670311899b4251a7eaccc614f2fbe9b2 Mon Sep 17 00:00:00 2001 From: Thomas Unterthiner Date: Mon, 21 Jul 2014 13:06:43 +0200 Subject: [PATCH 1/7] axis argument to sparsefuncs.mean_variance_axis. --- sklearn/cluster/k_means_.py | 4 +- sklearn/decomposition/truncated_svd.py | 4 +- sklearn/linear_model/base.py | 4 +- sklearn/preprocessing/data.py | 9 ++-- sklearn/utils/sparsefuncs.py | 26 +++++++++-- sklearn/utils/tests/test_sparsefuncs.py | 59 ++++++++++++++++++++++--- 6 files changed, 85 insertions(+), 21 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 3a8131070c9f7..2fe5b2e983e2d 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -20,7 +20,7 @@ from ..metrics.pairwise import euclidean_distances from ..utils.extmath import row_norms from ..utils.sparsefuncs_fast import assign_rows_csr -from ..utils.sparsefuncs import mean_variance_axis0 +from ..utils.sparsefuncs import mean_variance_axis from ..utils.fixes import astype from ..utils import check_array from ..utils import check_random_state @@ -141,7 +141,7 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): def _tolerance(X, tol): """Return a tolerance which is independent of the dataset""" if sp.issparse(X): - variances = mean_variance_axis0(X)[1] + variances = mean_variance_axis(X, axis=0)[1] else: variances = np.var(X, axis=0) return np.mean(variances) * tol diff --git a/sklearn/decomposition/truncated_svd.py b/sklearn/decomposition/truncated_svd.py index 5e0d91dd04583..3b2033204e505 100644 --- a/sklearn/decomposition/truncated_svd.py +++ b/sklearn/decomposition/truncated_svd.py @@ -17,7 +17,7 @@ from ..base import BaseEstimator, TransformerMixin from ..utils import check_array, as_float_array, check_random_state from ..utils.extmath import randomized_svd, safe_sparse_dot, svd_flip -from ..utils.sparsefuncs import mean_variance_axis0 +from ..utils.sparsefuncs import mean_variance_axis __all__ = ["TruncatedSVD"] @@ -175,7 +175,7 @@ def fit_transform(self, X, y=None): X_transformed = np.dot(U, np.diag(Sigma)) self.explained_variance_ = exp_var = np.var(X_transformed, axis=0) if sp.issparse(X): - _, full_var = mean_variance_axis0(X) + _, full_var = mean_variance_axis(X, axis=0) full_var = full_var.sum() else: full_var = np.var(X, axis=0).sum() diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index fd82c9684e2f0..ffdb46a326af9 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -27,7 +27,7 @@ from ..base import BaseEstimator, ClassifierMixin, RegressorMixin from ..utils import as_float_array, check_array from ..utils.extmath import safe_sparse_dot -from ..utils.sparsefuncs import mean_variance_axis0, inplace_column_scale +from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale ### @@ -55,7 +55,7 @@ def sparse_center_data(X, y, fit_intercept, normalize=False): else: X = sp.csc_matrix(X, copy=normalize, dtype=np.float64) - X_mean, X_var = mean_variance_axis0(X) + X_mean, X_var = mean_variance_axis(X, axis=0) if normalize: # transform variance to std in-place # XXX: currently scaled to variance=n_samples to match center_data diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index e3c2ef34b5885..2c7bc8513c7c9 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -16,10 +16,11 @@ from ..utils import warn_if_not_float from ..utils.extmath import row_norms from ..utils.fixes import combinations_with_replacement as combinations_w_r -from ..utils.sparsefuncs_fast import inplace_csr_row_normalize_l1 -from ..utils.sparsefuncs_fast import inplace_csr_row_normalize_l2 -from ..utils.sparsefuncs import inplace_column_scale -from ..utils.sparsefuncs import mean_variance_axis0 +from ..utils import deprecated +from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, + inplace_csr_row_normalize_l2) +from ..utils.sparsefuncs import (inplace_column_scale, inplace_row_scale, + mean_variance_axis, min_max_axis) zip = six.moves.zip map = six.moves.map diff --git a/sklearn/utils/sparsefuncs.py b/sklearn/utils/sparsefuncs.py index 04fd52a082ec7..1ea6938d900f2 100644 --- a/sklearn/utils/sparsefuncs.py +++ b/sklearn/utils/sparsefuncs.py @@ -53,7 +53,7 @@ def inplace_csr_row_scale(X, scale): X.data *= np.repeat(scale, np.diff(X.indptr)) -def mean_variance_axis0(X): +def mean_variance_axis(X, axis): """Compute mean and variance along axis 0 on a CSR or CSC matrix Parameters @@ -61,6 +61,9 @@ def mean_variance_axis0(X): X: CSR or CSC sparse matrix, shape (n_samples, n_features) Input data. + axis: int (either 0 or 1) + Axis along which the axis should be computed. + Returns ------- @@ -71,14 +74,26 @@ def mean_variance_axis0(X): Feature-wise variances """ + if axis < 0: + axis += 2 + if (axis != 0) and (axis != 1): + raise ValueError("Invalid axis, use 0 for rows, or 1 for columns") + if isinstance(X, sp.csr_matrix): - return csr_mean_variance_axis0(X) + if axis == 0: + return csr_mean_variance_axis0(X) + else: + return csc_mean_variance_axis0(X.T) elif isinstance(X, sp.csc_matrix): - return csc_mean_variance_axis0(X) + if axis == 0: + return csc_mean_variance_axis0(X) + else: + return csr_mean_variance_axis0(X.T) else: _raise_typeerror(X) + def inplace_column_scale(X, scale): """Inplace column scaling of a CSC/CSR matrix. @@ -258,13 +273,16 @@ def inplace_swap_column(X, m, n): def min_max_axis(X, axis): - """Compute minimum and maximum along axis 0 on a CSR or CSC matrix + """Compute minimum and maximum along an axis on a CSR or CSC matrix Parameters ---------- X: CSR or CSC sparse matrix, shape (n_samples, n_features) Input data. + axis: int (either 0 or 1) + Axis along which the axis should be computed. + Returns ------- diff --git a/sklearn/utils/tests/test_sparsefuncs.py b/sklearn/utils/tests/test_sparsefuncs.py index 686df79aff76d..3b49c6ac2b841 100644 --- a/sklearn/utils/tests/test_sparsefuncs.py +++ b/sklearn/utils/tests/test_sparsefuncs.py @@ -5,7 +5,7 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal from sklearn.datasets import make_classification -from sklearn.utils.sparsefuncs import (mean_variance_axis0, +from sklearn.utils.sparsefuncs import (mean_variance_axis, inplace_column_scale, inplace_row_scale, inplace_swap_row, inplace_swap_column, @@ -25,27 +25,72 @@ def test_mean_variance_axis0(): X[1, 0] = 0 X_csr = sp.csr_matrix(X_lil) - X_means, X_vars = mean_variance_axis0(X_csr) + X_means, X_vars = mean_variance_axis(X_csr, axis=0) assert_array_almost_equal(X_means, np.mean(X, axis=0)) assert_array_almost_equal(X_vars, np.var(X, axis=0)) X_csc = sp.csc_matrix(X_lil) - X_means, X_vars = mean_variance_axis0(X_csc) + X_means, X_vars = mean_variance_axis(X_csc, axis=0) assert_array_almost_equal(X_means, np.mean(X, axis=0)) assert_array_almost_equal(X_vars, np.var(X, axis=0)) - assert_raises(TypeError, mean_variance_axis0, X_lil) + assert_raises(TypeError, mean_variance_axis, X_lil, axis=0) X = X.astype(np.float32) X_csr = X_csr.astype(np.float32) X_csc = X_csr.astype(np.float32) - X_means, X_vars = mean_variance_axis0(X_csr) + X_means, X_vars = mean_variance_axis(X_csr, axis=0) assert_array_almost_equal(X_means, np.mean(X, axis=0)) assert_array_almost_equal(X_vars, np.var(X, axis=0)) - X_means, X_vars = mean_variance_axis0(X_csc) + X_means, X_vars = mean_variance_axis(X_csc, axis=0) assert_array_almost_equal(X_means, np.mean(X, axis=0)) assert_array_almost_equal(X_vars, np.var(X, axis=0)) - assert_raises(TypeError, mean_variance_axis0, X_lil) + assert_raises(TypeError, mean_variance_axis, X_lil, axis=0) + + +def test_mean_variance_illegal_axis(): + X, _ = make_classification(5, 4, random_state=0) + # Sparsify the array a little bit + X[0, 0] = 0 + X[2, 1] = 0 + X[4, 3] = 0 + X_csr = sp.csr_matrix(X) + assert_raises(ValueError, mean_variance_axis, X_csr, axis=-3) + assert_raises(ValueError, mean_variance_axis, X_csr, axis=2) + + +def test_mean_variance_axis1(): + X, _ = make_classification(5, 4, random_state=0) + # Sparsify the array a little bit + X[0, 0] = 0 + X[2, 1] = 0 + X[4, 3] = 0 + X_lil = sp.lil_matrix(X) + X_lil[1, 0] = 0 + X[1, 0] = 0 + X_csr = sp.csr_matrix(X_lil) + + X_means, X_vars = mean_variance_axis(X_csr, axis=1) + assert_array_almost_equal(X_means, np.mean(X, axis=1)) + assert_array_almost_equal(X_vars, np.var(X, axis=1)) + + X_csc = sp.csc_matrix(X_lil) + X_means, X_vars = mean_variance_axis(X_csc, axis=1) + + assert_array_almost_equal(X_means, np.mean(X, axis=1)) + assert_array_almost_equal(X_vars, np.var(X, axis=1)) + assert_raises(TypeError, mean_variance_axis, X_lil, axis=1) + + X = X.astype(np.float32) + X_csr = X_csr.astype(np.float32) + X_csc = X_csr.astype(np.float32) + X_means, X_vars = mean_variance_axis(X_csr, axis=1) + assert_array_almost_equal(X_means, np.mean(X, axis=1)) + assert_array_almost_equal(X_vars, np.var(X, axis=1)) + X_means, X_vars = mean_variance_axis(X_csc, axis=1) + assert_array_almost_equal(X_means, np.mean(X, axis=1)) + assert_array_almost_equal(X_vars, np.var(X, axis=1)) + assert_raises(TypeError, mean_variance_axis, X_lil, axis=1) def test_densify_rows(): From 06ecf4868cff141eacc91132f2754534f6f3743e Mon Sep 17 00:00:00 2001 From: Thomas Unterthiner Date: Mon, 21 Jul 2014 13:27:48 +0200 Subject: [PATCH 2/7] BaseScaler abstraction. --- doc/modules/preprocessing.rst | 11 +- sklearn/pipeline.py | 7 +- sklearn/preprocessing/data.py | 471 ++++++++++++++++++++-------------- 3 files changed, 287 insertions(+), 202 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 4d3b04ade3c7b..1627d46295f5a 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -72,13 +72,14 @@ This class is hence suitable for use in the early steps of a :class:`sklearn.pipeline.Pipeline`:: >>> scaler = preprocessing.StandardScaler().fit(X) - >>> scaler - StandardScaler(copy=True, with_mean=True, with_std=True) + >>> scaler # doctest: +NORMALIZE_WHITESPACE + StandardScaler(axis=0, copy=True, with_centering=True, with_mean=None, + with_scaling=True, with_std=None) - >>> scaler.mean_ # doctest: +ELLIPSIS + >>> scaler.center_ # doctest: +ELLIPSIS array([ 1. ..., 0. ..., 0.33...]) - >>> scaler.std_ # doctest: +ELLIPSIS + >>> scaler.scale_ # doctest: +ELLIPSIS array([ 0.81..., 0.81..., 1.24...]) >>> scaler.transform(X) # doctest: +ELLIPSIS @@ -94,7 +95,7 @@ same way it did on the training set:: array([[-2.44..., 1.22..., -0.26...]]) It is possible to disable either centering or scaling by either -passing ``with_mean=False`` or ``with_std=False`` to the constructor +passing ``with_centering=False`` or ``with_scaling=False`` to the constructor of :class:`StandardScaler`. diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 8c6483c8ac14d..193398627c1c4 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -237,10 +237,11 @@ def make_pipeline(*steps): -------- >>> from sklearn.naive_bayes import GaussianNB >>> from sklearn.preprocessing import StandardScaler + >>> from sklearn.pipeline import make_pipeline >>> make_pipeline(StandardScaler(), GaussianNB()) # doctest: +NORMALIZE_WHITESPACE - Pipeline(steps=[('standardscaler', - StandardScaler(copy=True, with_mean=True, with_std=True)), - ('gaussiannb', GaussianNB())]) + Pipeline(steps=[('standardscaler', StandardScaler(axis=0, copy=True, + with_centering=True, with_mean=None, with_scaling=True, + with_std=None)), ('gaussiannb', GaussianNB())]) Returns ------- diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 2c7bc8513c7c9..373dee8f6a946 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -6,15 +6,19 @@ from itertools import chain, combinations import numbers +import warnings +from abc import ABCMeta, abstractmethod import numpy as np from scipy import sparse from ..base import BaseEstimator, TransformerMixin from ..externals import six from ..utils import check_array +from ..utils import as_float_array from ..utils import warn_if_not_float from ..utils.extmath import row_norms + from ..utils.fixes import combinations_with_replacement as combinations_w_r from ..utils import deprecated from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, @@ -65,91 +69,120 @@ def _mean_and_std(X, axis=0, with_mean=True, with_std=True): return mean_, std_ -def scale(X, axis=0, with_mean=True, with_std=True, copy=True): - """Standardize a dataset along any axis +class BaseScaler(six.with_metaclass(ABCMeta, BaseEstimator, TransformerMixin)): + """Base class for all Scale transformers.""" - Center to the mean and component wise scale to unit variance. + def __init__(self, copy=True, with_centering=True, with_scaling=True, + axis=0): + self.with_centering = with_centering + self.with_scaling = with_scaling + self.axis = axis + self.copy = copy - Parameters - ---------- - X : array-like or CSR matrix. - The data to center and scale. + def _check_array(self, X, copy): + """Makes sure centering is not enabled for sparse matrices.""" + X = check_array(X, accept_sparse=['csr', 'csc'], + copy=copy, ensure_2d=False) + if warn_if_not_float(X, estimator=self): + X = X.astype(np.float) + if sparse.issparse(X): + if not (sparse.isspmatrix_csc(X) or sparse.isspmatrix_csr(X)): + raise TypeError("Scaling only supports CSC and CSR " + "sparse matrix formats.") + if self.with_centering: + raise ValueError( + "Cannot center sparse matrices: use `with_centering=False`" + " instead. See docstring for motivation and alternatives.") + return X - axis : int (0 by default) - axis used to compute the means and standard deviations along. If 0, - independently standardize each feature, otherwise (if 1) standardize - each sample. + def _handle_zeros_in_scale(self, scale): + ''' Makes sure that whenever scale is zero, we handle it correctly. - with_mean : boolean, True by default - If True, center the data before scaling. + This happens in most scalers when we have constant features.''' + # if we are fitting on 1D arrays, scale might be a scalar + if np.isscalar(scale): + if scale == 0: + scale = 1. + elif isinstance(scale, np.ndarray): + scale[scale == 0.0] = 1.0 + scale[-np.isfinite(scale)] = 1.0 + return scale - with_std : boolean, True by default - If True, scale the data to unit variance (or equivalently, - unit standard deviation). + @abstractmethod + def fit(self, X, y=None): + """Compute the statistics to be used for later scaling. - copy : boolean, optional, default is True - set to False to perform inplace row normalization and avoid a - copy (if the input is already a numpy array or a scipy.sparse - CSR matrix and if axis is 1). + Parameters + ---------- + X : array-like or CSR matrix. + The data used to compute the mean and standard deviation + used for later scaling along the features axis. + """ - Notes - ----- - This implementation will refuse to center scipy.sparse matrices - since it would make them non-sparse and would potentially crash the - program with memory exhaustion problems. + def transform(self, X, y=None, copy=None): + """Perform standardization by centering and scaling - Instead the caller is expected to either set explicitly - `with_mean=False` (in that case, only variance scaling will be - performed on the features of the CSR matrix) or to call `X.toarray()` - if he/she expects the materialized dense array to fit in memory. + Parameters + ---------- + X : array-like or CSR matrix. + The data used to scale along the specified axis. + """ + if copy is None: + copy = self.copy + X = self._check_array(X, copy) + if sparse.issparse(X): + if self.with_scaling: + if self.axis == 1 or X.shape[0] == 1: + inplace_row_scale(X, 1.0 / self.scale_) + elif self.axis == 0: + inplace_column_scale(X, 1.0 / self.scale_) + else: + if copy: + X = X.copy() + # Xr is a view on the original array that enables easy use of + # broadcasting on the axis in which we are interested in + Xr = np.rollaxis(X, self.axis) + if self.with_centering: + Xr -= self.center_ + if self.with_scaling: + Xr /= self.scale_ + return X - To avoid memory copy the caller should pass a CSR matrix. + def inverse_transform(self, X, copy=None): + """Scale back the data to the original representation - See also - -------- - :class:`sklearn.preprocessing.StandardScaler` to perform centering and - scaling using the ``Transformer`` API (e.g. as part of a preprocessing - :class:`sklearn.pipeline.Pipeline`) - """ - if sparse.issparse(X): - if with_mean: - raise ValueError( - "Cannot center sparse matrices: pass `with_mean=False` instead" - " See docstring for motivation and alternatives.") - if axis != 0: - raise ValueError("Can only scale sparse matrix on axis=0, " - " got axis=%d" % axis) - warn_if_not_float(X, estimator='The scale function') - if not sparse.isspmatrix_csr(X): - X = X.tocsr() - copy = False - if copy: - X = X.copy() - _, var = mean_variance_axis0(X) - var[var == 0.0] = 1.0 - inplace_column_scale(X, 1 / np.sqrt(var)) - else: - X = np.asarray(X) - warn_if_not_float(X, estimator='The scale function') - mean_, std_ = _mean_and_std( - X, axis, with_mean=with_mean, with_std=with_std) - if copy: - X = X.copy() - # Xr is a view on the original array that enables easy use of - # broadcasting on the axis in which we are interested in - Xr = np.rollaxis(X, axis) - if with_mean: - Xr -= mean_ - if with_std: - Xr /= std_ - return X + Parameters + ---------- + X : array-like or CSR matrix. + The data used to scale along the specified axis. + """ + if copy is None: + copy = self.copy + X = self._check_array(X, copy) + if sparse.issparse(X): + if self.with_scaling: + if self.axis == 1 or X.shape[0] == 1: + inplace_row_scale(X, self.scale_) + elif self.axis == 0: + inplace_column_scale(X, self.scale_) + else: + if copy: + X = X.copy() + # Xr is a view on the original array that enables easy use of + # broadcasting on the axis in which we are interested in + Xr = np.rollaxis(X, self.axis) + if self.with_scaling: + Xr *= self.scale_ + if self.with_centering: + Xr += self.center_ + return X -class MinMaxScaler(BaseEstimator, TransformerMixin): +class MinMaxScaler(BaseScaler): """Standardizes features by scaling each feature to a given range. This estimator scales and translates each feature individually such - that it is in the given range on the training set, i.e. between + that it is in the given range on the training set, e.g. between zero and one. The standardization is given by:: @@ -161,6 +194,10 @@ class MinMaxScaler(BaseEstimator, TransformerMixin): This standardization is often used as an alternative to zero mean, unit variance scaling. + Note that if future input exceeds the maximal/minimal values seen + during `fit`, the return values of `transform` might lie outside + of the specified `feature_range`. + Parameters ---------- feature_range: tuple (min, max), default=(0, 1) @@ -170,20 +207,28 @@ class MinMaxScaler(BaseEstimator, TransformerMixin): Set to False to perform inplace row normalization and avoid a copy (if the input is already a numpy array). + axis : int (0 by default) + axis used to compute the scaling statistics along. If 0, + independently scale each feature, otherwise (if 1) scale + each sample. + Attributes ---------- - min_ : ndarray, shape (n_features,) - Per feature adjustment for minimum. + `center_` : ndarray, shape (n_features,) + Per feature center. - scale_ : ndarray, shape (n_features,) + `scale_` : ndarray, shape (n_features,) Per feature relative scaling of the data. """ - def __init__(self, feature_range=(0, 1), copy=True): + def __init__(self, feature_range=(0, 1), copy=True, axis=0): + super(MinMaxScaler, self).__init__(with_centering=True, + with_scaling=True, + copy=copy, axis=axis) self.feature_range = feature_range self.copy = copy - def fit(self, X, y=None): + def fit(self, X, y=None, copy=None): """Compute the minimum and maximum to be used for later scaling. Parameters @@ -192,57 +237,45 @@ def fit(self, X, y=None): The data used to compute the per-feature minimum and maximum used for later scaling along the features axis. """ - X = check_array(X, copy=self.copy, ensure_2d=False) - warn_if_not_float(X, estimator=self) + + if sparse.issparse(X): + raise TypeError("MinMaxScaler cannot be fitted on sparse inputs") + + if copy is None: + copy = self.copy + + X = self._check_array(X, copy) + feature_range = self.feature_range if feature_range[0] >= feature_range[1]: raise ValueError("Minimum of desired feature range must be smaller" " than maximum. Got %s." % str(feature_range)) - data_min = np.min(X, axis=0) - data_range = np.max(X, axis=0) - data_min - # Do not scale constant features - if isinstance(data_range, np.ndarray): - data_range[data_range == 0.0] = 1.0 - elif data_range == 0.: - data_range = 1. - self.scale_ = (feature_range[1] - feature_range[0]) / data_range - self.min_ = feature_range[0] - data_min * self.scale_ - self.data_range = data_range - self.data_min = data_min + data_min = np.min(X, axis=self.axis) + data_range = np.max(X, axis=self.axis) - data_min + data_range = self._handle_zeros_in_scale(data_range) + self.scale_ = data_range / (feature_range[1] - feature_range[0]) + self.center_ = data_min - feature_range[0] * self.scale_ return self - def transform(self, X): - """Scaling features of X according to feature_range. + @property + @deprecated("Attribute min_ is deprecated and " + "will be removed in 0.17. Use 'center_' instead") + def min_(self): + return self.center_ - Parameters - ---------- - X : array-like with shape [n_samples, n_features] - Input data that will be transformed. - """ - X = check_array(X, copy=self.copy, ensure_2d=False) - X *= self.scale_ - X += self.min_ - return X - def inverse_transform(self, X): - """Undo the scaling of X according to feature_range. Parameters ---------- - X : array-like with shape [n_samples, n_features] - Input data that will be transformed. """ - X = check_array(X, copy=self.copy, ensure_2d=False) - X -= self.min_ - X /= self.scale_ - return X -class StandardScaler(BaseEstimator, TransformerMixin): +class StandardScaler(BaseScaler): """Standardize features by removing the mean and scaling to unit variance - Centering and scaling happen independently on each feature by computing - the relevant statistics on the samples in the training set. Mean and + Centering and scaling happen independently on each feature (or each + sample, depending on the `axis` argument) by computing the relevant + statistics on the samples in the training set. Mean and standard deviation are then stored to be used on later data using the `transform` method. @@ -261,14 +294,14 @@ class StandardScaler(BaseEstimator, TransformerMixin): Parameters ---------- - with_mean : boolean, True by default + with_centering : boolean, True by default If True, center the data before scaling. This does not work (and will raise an exception) when attempted on sparse matrices, because centering them entails building a dense matrix which in common use cases is likely to be too large to fit in memory. - with_std : boolean, True by default + with_scaling : boolean, True by default If True, scale the data to unit variance (or equivalently, unit standard deviation). @@ -278,12 +311,25 @@ class StandardScaler(BaseEstimator, TransformerMixin): not a NumPy array or scipy.sparse CSR matrix, a copy may still be returned. + axis : int (0 by default) + axis used to compute the scaling statistics along. If 0, + independently standardize each feature, otherwise (if 1) standardize + each sample. + + with_mean : boolean + Old name for parameter `with_centering`. + WARNING : will be deprecated in 0.17 + + with_std : boolean + Old name for parameter `with_scaling`. + WARNING : will be deprecated in 0.17 + Attributes ---------- - mean_ : array of floats with shape [n_features] + `center_` : array of floats with shape [n_features] The mean value for each feature in the training set. - std_ : array of floats with shape [n_features] + `scale_` : array of floats with shape [n_features] The standard deviation for each feature in the training set. See also @@ -295,98 +341,65 @@ class StandardScaler(BaseEstimator, TransformerMixin): to further remove the linear correlation across features. """ - def __init__(self, copy=True, with_mean=True, with_std=True): - self.with_mean = with_mean - self.with_std = with_std - self.copy = copy + def __init__(self, copy=True, with_centering=True, with_scaling=True, + axis=0, with_mean=None, with_std=None): + if with_mean is not None: + with_centering = with_mean + warnings.warn("with_mean was renamed to with_centering and will be" + " removed in 0.17", DeprecationWarning) - def fit(self, X, y=None): + if with_std is not None: + with_scaling = with_std + warnings.warn("with_std was renamed to with_centering and will be" + " removed in 0.17", DeprecationWarning) + + super(StandardScaler, self).__init__(with_centering=with_centering, + with_scaling=with_scaling, + copy=copy, axis=axis) + + def fit(self, X, y=None, copy=None): """Compute the mean and std to be used for later scaling. Parameters ---------- - X : array-like or CSR matrix with shape [n_samples, n_features] + X : array-like or CSR matrix The data used to compute the mean and standard deviation - used for later scaling along the features axis. + used for later scaling along the specified axis. """ - X = check_array(X, accept_sparse='csr', copy=self.copy, ensure_2d=False) - if warn_if_not_float(X, estimator=self): - X = X.astype(np.float) + self.center_ = None + self.scale_ = None + if copy is None: + copy = self.copy + X = self._check_array(X, copy) if sparse.issparse(X): - if self.with_mean: - raise ValueError( - "Cannot center sparse matrices: pass `with_mean=False` " - "instead. See docstring for motivation and alternatives.") - self.mean_ = None - - if self.with_std: - var = mean_variance_axis0(X)[1] - self.std_ = np.sqrt(var) - self.std_[var == 0.0] = 1.0 - else: - self.std_ = None - return self + if self.with_scaling: + var = mean_variance_axis(X, axis=self.axis)[1] + self.scale_ = np.sqrt(var) else: - self.mean_, self.std_ = _mean_and_std( - X, axis=0, with_mean=self.with_mean, with_std=self.with_std) - return self + self.center_, self.scale_ = _mean_and_std( + X, axis=self.axis, with_mean=self.with_centering, + with_std=self.with_scaling) + self.scale_ = self._handle_zeros_in_scale(self.scale_) + return self - def transform(self, X, y=None, copy=None): - """Perform standardization by centering and scaling + @property + @deprecated("Attribute mean_ is deprecated and " + "will be removed in 0.17. Use 'center_' instead") + def mean_(self): + return self.center_ - Parameters - ---------- - X : array-like with shape [n_samples, n_features] - The data used to scale along the features axis. - """ - copy = copy if copy is not None else self.copy - X = check_array(X, accept_sparse='csr', copy=copy, ensure_2d=False) - if warn_if_not_float(X, estimator=self): - X = X.astype(np.float) - if sparse.issparse(X): - if self.with_mean: - raise ValueError( - "Cannot center sparse matrices: pass `with_mean=False` " - "instead. See docstring for motivation and alternatives.") - if self.std_ is not None: - inplace_column_scale(X, 1 / self.std_) - else: - if self.with_mean: - X -= self.mean_ - if self.with_std: - X /= self.std_ - return X + @property + @deprecated("Attribute std_ is deprecated and " + "will be removed in 0.17. Use 'scale_' instead") + def std_(self): + return self.scale_ - def inverse_transform(self, X, copy=None): - """Scale back the data to the original representation Parameters ---------- - X : array-like with shape [n_samples, n_features] - The data used to scale along the features axis. """ - copy = copy if copy is not None else self.copy if sparse.issparse(X): - if self.with_mean: - raise ValueError( - "Cannot uncenter sparse matrices: pass `with_mean=False` " - "instead See docstring for motivation and alternatives.") - if not sparse.isspmatrix_csr(X): - X = X.tocsr() - copy = False - if copy: - X = X.copy() - if self.std_ is not None: - inplace_column_scale(X, self.std_) else: - X = np.asarray(X) - if copy: - X = X.copy() - if self.with_std: - X *= self.std_ - if self.with_mean: - X += self.mean_ - return X class PolynomialFeatures(BaseEstimator, TransformerMixin): @@ -433,7 +446,7 @@ class PolynomialFeatures(BaseEstimator, TransformerMixin): Attributes ---------- - powers_ : + `powers_`: powers_[i, j] is the exponent of the jth input in the ith output. Notes @@ -493,6 +506,74 @@ def transform(self, X, y=None): return (X[:, None, :] ** self.powers_).prod(-1) +def scale(X, axis=0, with_centering=True, with_scaling=True, copy=True, + with_mean=None, with_std=None): + """Standardize a dataset along any axis + + Center to the mean and component wise scale to unit variance. + + Parameters + ---------- + X : array-like or CSR matrix. + The data to center and scale. + + axis : int (0 by default) + axis used to compute the means and standard deviations along. If 0, + independently standardize each feature, otherwise (if 1) standardize + each sample. + + with_centering : boolean, True by default + If True, center the data before scaling. + + with_scaling : boolean, True by default + If True, scale the data to unit variance (or equivalently, + unit standard deviation). + + copy : boolean, optional, default is True + set to False to perform inplace row normalization and avoid a + copy (if the input is already a numpy array or a scipy.sparse + CSR matrix and if axis is 1). + + with_mean : boolean + Old name for parameter `with_centering`. + WARNING : will be deprecated in 0.17 + + with_std : boolean + Old name for parameter `with_scaling`. + WARNING : will be deprecated in 0.17 + + Notes + ----- + This implementation will refuse to center scipy.sparse matrices + since it would make them non-sparse and would potentially crash the + program with memory exhaustion problems. + + Instead the caller is expected to either set explicitly + `with_centering=False` (in that case, only variance scaling will be + performed on the features of the CSR matrix) or to call `X.toarray()` + if he/she expects the materialized dense array to fit in memory. + + To avoid memory copy the caller should pass a CSR matrix. + + See also + -------- + :class:`sklearn.preprocessing.StandardScaler` to perform centering and + scaling using the ``Transformer`` API (e.g. as part of a preprocessing + :class:`sklearn.pipeline.Pipeline`) + """ + if with_mean is not None: + with_centering = with_mean + warnings.warn("with_mean was renamed to with_centering and will be" + " removed in 0.17", DeprecationWarning) + + if with_std is not None: + with_scaling = with_std + warnings.warn("with_std was renamed to with_centering and will be" + " removed in 0.17", DeprecationWarning) + + s = StandardScaler(with_centering=with_centering, + with_scaling=with_scaling, copy=copy, axis=axis) + return s.fit_transform(X) def normalize(X, norm='l2', axis=1, copy=True): """Scale input vectors individually to unit norm (vector length). @@ -617,7 +698,8 @@ def transform(self, X, y=None, copy=None): The data to normalize, row by row. scipy.sparse matrices should be in CSR format to avoid an un-necessary copy. """ - copy = copy if copy is not None else self.copy + if copy is None: + copy = self.copy X = check_array(X, accept_sparse='csr') return normalize(X, norm=self.norm, axis=1, copy=copy) @@ -722,7 +804,8 @@ def transform(self, X, y=None, copy=None): scipy.sparse matrices should be in CSR format to avoid an un-necessary copy. """ - copy = copy if copy is not None else self.copy + if copy is None: + copy = self.copy return binarize(X, threshold=self.threshold, copy=copy) @@ -929,17 +1012,17 @@ class OneHotEncoder(BaseEstimator, TransformerMixin): Attributes ---------- - active_features_ : array + `active_features_` : array Indices for active features, meaning values that actually occur in the training set. Only available when n_values is ``'auto'``. - feature_indices_ : array of shape (n_features,) + `feature_indices_` : array of shape (n_features,) Indices to feature ranges. Feature ``i`` in the original data is mapped to features from ``feature_indices_[i]`` to ``feature_indices_[i+1]`` (and then potentially masked by `active_features_` afterwards) - n_values_ : array of shape (n_features,) + `n_values_` : array of shape (n_features,) Maximum number of values per feature. Examples From a5d24a80bf5225b8e3265fc6e6994b99e4972ab7 Mon Sep 17 00:00:00 2001 From: Thomas Unterthiner Date: Mon, 21 Jul 2014 13:29:53 +0200 Subject: [PATCH 3/7] minmax_scale --- sklearn/preprocessing/__init__.py | 2 ++ sklearn/preprocessing/data.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index c8c71570dfee1..f3a644b9546d6 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -12,6 +12,7 @@ from .data import binarize from .data import normalize from .data import scale +from .data import minmax_scale from .data import OneHotEncoder from .data import PolynomialFeatures @@ -39,5 +40,6 @@ 'binarize', 'normalize', 'scale', + 'minmax_scale', 'label_binarize', ] diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 373dee8f6a946..74ef1a0a7a8d7 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -41,6 +41,7 @@ 'binarize', 'normalize', 'scale', + 'minmax_scale' ] @@ -574,6 +575,53 @@ def scale(X, axis=0, with_centering=True, with_scaling=True, copy=True, s = StandardScaler(with_centering=with_centering, with_scaling=with_scaling, copy=copy, axis=axis) return s.fit_transform(X) + + +def minmax_scale(X, feature_range=(0, 1), axis=0, with_centering=True, + with_scaling=True, copy=True): + """Standardizes features by scaling each feature to a given range. + + This estimator scales and translates each feature individually such + that it is in the given range on the training set, i.e. between + zero and one. + + The standardization is given by:: + X_std = (X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0)) + X_scaled = X_std * (max - min) + min + + where min, max = feature_range. + + This standardization is often used as an alternative to zero mean, + unit variance scaling. + + Note that if future input exceeds the maximal/minimal values seen + during `fit`, the return values of `transform` might lie outside + of the specified `feature_range`. + + Parameters + ---------- + feature_range: tuple (min, max), default=(0, 1) + Desired range of transformed data. + + copy : boolean, optional, default is True + Set to False to perform inplace row normalization and avoid a + copy (if the input is already a numpy array). + + axis : int (0 by default) + axis used to compute the scaling statistics along. If 0, + independently scale each feature, otherwise (if 1) scale + each sample. + + Attributes + ---------- + `center_` : ndarray, shape (n_features,) + Per feature adjustment for minimum. + + `scale_` : ndarray, shape (n_features,) + Per feature relative scaling of the data. + """ + s = MinMaxScaler(feature_range=feature_range, copy=copy, axis=axis) + return s.fit_transform(X) def normalize(X, norm='l2', axis=1, copy=True): """Scale input vectors individually to unit norm (vector length). From 8bdc83e2749481453796842346d7e1a7bd617f3e Mon Sep 17 00:00:00 2001 From: Thomas Unterthiner Date: Mon, 21 Jul 2014 13:35:44 +0200 Subject: [PATCH 4/7] Refactor scaling tests --- sklearn/preprocessing/tests/test_data.py | 627 +++++++++++++---------- 1 file changed, 367 insertions(+), 260 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 55e85a86dce8c..08ea2614e0a9a 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2,6 +2,9 @@ import numpy as np import numpy.linalg as la from scipy import sparse +from functools import partial + +from scipy.stats.mstats import mquantiles from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal @@ -13,8 +16,9 @@ from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import ignore_warnings -from sklearn.utils.sparsefuncs import mean_variance_axis0 +from sklearn.utils.sparsefuncs import mean_variance_axis from sklearn.preprocessing.data import _transform_selected from sklearn.preprocessing.data import Binarizer from sklearn.preprocessing.data import KernelCenterer @@ -24,83 +28,306 @@ from sklearn.preprocessing.data import StandardScaler from sklearn.preprocessing.data import scale from sklearn.preprocessing.data import MinMaxScaler +from sklearn.preprocessing.data import minmax_scale from sklearn.preprocessing.data import add_dummy_feature from sklearn.preprocessing.data import PolynomialFeatures from sklearn import datasets + iris = datasets.load_iris() -def toarray(a): - if hasattr(a, "toarray"): - a = a.toarray() - return a +SPARSE_SCALERS = { # support sparse or dense input + 'uncentered StandardScaler': partial(StandardScaler, with_centering=False), +} +NON_SPARSE_SCALERS = { # only support dense input + 'centered StandardScaler': partial(StandardScaler, with_centering=True), + 'MinMaxScaler[0,1]': partial(MinMaxScaler, feature_range=(0, 1)), + 'MinMaxScaler[-1,1]': partial(MinMaxScaler, feature_range=(-1, 1)), + 'MinMaxScaler[-3,5]': partial(MinMaxScaler, feature_range=(-3, 5)) +} -def test_polynomial_features(): - """Test Polynomial Features""" - X1 = np.arange(6)[:, np.newaxis] - P1 = np.hstack([np.ones_like(X1), - X1, X1 ** 2, X1 ** 3]) - deg1 = 3 - X2 = np.arange(6).reshape((3, 2)) - x1 = X2[:, :1] - x2 = X2[:, 1:] - P2 = np.hstack([x1 ** 0 * x2 ** 0, - x1 ** 1 * x2 ** 0, - x1 ** 0 * x2 ** 1, - x1 ** 2 * x2 ** 0, - x1 ** 1 * x2 ** 1, - x1 ** 0 * x2 ** 2]) - deg2 = 2 +ALL_SCALERS = {} +ALL_SCALERS.update(SPARSE_SCALERS) +ALL_SCALERS.update(NON_SPARSE_SCALERS) - for (deg, X, P) in [(deg1, X1, P1), (deg2, X2, P2)]: - P_test = PolynomialFeatures(deg, include_bias=True).fit_transform(X) - assert_array_almost_equal(P_test, P) - P_test = PolynomialFeatures(deg, include_bias=False).fit_transform(X) - assert_array_almost_equal(P_test, P[:, 1:]) +SCALER_FUNCTIONS = { + 'scale': scale, + 'minmax_scale': minmax_scale, +} - interact = PolynomialFeatures(2, interaction_only=True, include_bias=True) - X_poly = interact.fit_transform(X) - assert_array_almost_equal(X_poly, P2[:, [0, 1, 2, 4]]) - assert_raises(ValueError, interact.transform, X[:, 1:]) +def test_scaler_2d_axis0(): + """Test robust scaling of 2d array along axis0""" + rng = np.random.RandomState(0) + X = rng.randn(4, 5) + X[:, 0] = 0.0 # first feature is always of zero + for fn, ScalerClass in ALL_SCALERS.items(): + scaler = ScalerClass(axis=0) + X_scaled = scaler.fit(X).transform(X, copy=True) + assert_false(np.any(np.isnan(X_scaled))) -def test_scaler_1d(): - """Test scaling of dataset along single axis""" - rng = np.random.RandomState(0) - X = rng.randn(5) - X_orig_copy = X.copy() + assert_array_almost_equal(X_scaled.std(axis=0)[0], 0) + # Check that X has been copied + assert_true(X_scaled is not X) - scaler = StandardScaler() - X_scaled = scaler.fit(X).transform(X, copy=False) - assert_array_almost_equal(X_scaled.mean(axis=0), 0.0) - assert_array_almost_equal(X_scaled.std(axis=0), 1.0) + # check inverse transform + X_scaled_back = scaler.inverse_transform(X_scaled) + assert_true(X_scaled_back is not X) + assert_true(X_scaled_back is not X_scaled) + assert_array_almost_equal(X_scaled_back, X) - # check inverse transform - X_scaled_back = scaler.inverse_transform(X_scaled) - assert_array_almost_equal(X_scaled_back, X_orig_copy) + assert_false(np.any(np.isnan(X_scaled))) + # Check that the data hasn't been modified + assert_true(X_scaled is not X) - # Test with 1D list - X = [0., 1., 2, 0.4, 1.] - scaler = StandardScaler() - X_scaled = scaler.fit(X).transform(X, copy=False) - assert_array_almost_equal(X_scaled.mean(axis=0), 0.0) - assert_array_almost_equal(X_scaled.std(axis=0), 1.0) + X_scaled = scaler.fit(X).transform(X, copy=False) + # Check that X has not been copied + assert_true(X_scaled is X) + + +def test_scaler_2d_axis1(): + '''Check that scalers work on 2D arrays on axis=1''' + rng = np.random.RandomState(42) + X = rng.randn(4, 5) + for fn, ScalerClass in ALL_SCALERS.items(): + scaler = ScalerClass().fit(X) + scaler_trans = ScalerClass(axis=1).fit(X.T) + X_scaled = scaler.transform(X) + X_scaled_trans = scaler_trans.transform(X.T) + assert_array_almost_equal(X_scaled.T, X_scaled_trans) + X_inv = scaler.inverse_transform(X_scaled) + X_trans_inv = scaler_trans.inverse_transform(X_scaled_trans) + assert_array_almost_equal(X_inv.T, X_trans_inv) + + for fn, scale_func in SCALER_FUNCTIONS.items(): + X_scaled = scale_func(X) + X_scaled_trans = scale_func(X.T, axis=1) + assert_array_almost_equal(X_scaled.T, X_scaled_trans) + + +def test_scaler_1D(): + '''Check that scalers accept 1D input''' + X = np.array([-1, 0.0, 1.6]) + Xl = [-1, 0.0, 1.6] # 1D list + for fn, ScalerClass in ALL_SCALERS.items(): + scaler = ScalerClass() + X_trans = scaler.fit_transform(X) + X_inv = scaler.inverse_transform(X_trans) + assert_array_almost_equal(X_inv, X) + + Xl_trans = scaler.fit_transform(Xl) + assert_array_almost_equal(X_trans, Xl_trans) + X_trans2 = np.squeeze(scaler.fit_transform(np.transpose([X]))) + assert_array_almost_equal(Xl_trans, X_trans2) + X_inv = scaler.inverse_transform(Xl_trans) + assert_array_almost_equal(X_inv, Xl) + Xl_inv = scaler.inverse_transform(Xl_trans.tolist()) + assert_array_almost_equal(X_inv, Xl_inv) + + for fn, scale_func in SCALER_FUNCTIONS.items(): + scale_func(X) + + +def test_scale_allzeros(): + X = np.array([0.0, 0.0, 0.0, 0.0]) + for fn, ScalerClass in ALL_SCALERS.items(): + scaler = ScalerClass().fit(X) + X_scaled = scaler.transform(X) + if "MinMaxScaler" in fn: + continue + assert_array_almost_equal(X_scaled, X) + X_inv = scaler.inverse_transform(X_scaled) + assert_array_almost_equal(X_inv, X) + + for fn, scale_func in SCALER_FUNCTIONS.items(): + X_scaled = scale_func(X) + assert_array_almost_equal(X_scaled, X) + + +def test_scaler_sparse_data(): + """Check that the scalers works with sparse inputs.""" + X = [[0., 1., +0.5, -1], + [0., 1., -0.3, -0.5], + [0., 1., -1.5, 0], + [0., 0., +0.0, -2]] + + X_csr = sparse.csr_matrix(X) + X_csc = sparse.csc_matrix(X) + + for axis in (0, 1): + for fn, ScalerClass in SPARSE_SCALERS.items(): + scaler = ScalerClass(axis=axis) + scaler_csr = ScalerClass(axis=axis) + scaler_csc = ScalerClass(axis=axis) + X_trans = scaler.fit_transform(X) + X_trans_csr = scaler_csr.fit_transform(X_csr) + X_trans_csc = scaler_csc.fit_transform(X_csc) + assert_false(np.any(np.isnan(X_trans_csr.data))) + assert_false(np.any(np.isnan(X_trans_csc.data))) + + assert_array_almost_equal(X_trans, X_trans_csr.toarray()) + assert_array_almost_equal(X_trans, X_trans_csc.toarray()) + X_trans_inv = scaler.inverse_transform(X_trans) + X_trans_inv_csr = scaler_csr.inverse_transform(X_trans_csr) + X_trans_inv_csc = scaler_csc.inverse_transform(X_trans_csc) + assert_false(np.any(np.isnan(X_trans_inv_csr.data))) + assert_false(np.any(np.isnan(X_trans_inv_csc.data))) + assert_array_almost_equal(X_trans_inv, X_trans_inv_csr.toarray()) + assert_array_almost_equal(X_trans_inv, X_trans_inv_csc.toarray()) + + +def test_scaler_center_property(): + """Check that the center_ attribute of the Scalers is accessible""" + X = [[0., 2.0, +0.5], + [0., 0.0, -0.3]] + for fn, ScalerClass in ALL_SCALERS.items(): + scaler = ScalerClass() + if not scaler.with_centering: + continue + scaler.fit_transform(X) + assert(len(scaler.center_) == 3) + + scaler = StandardScaler().fit(X) + with warnings.catch_warnings(record=True): + assert(len(scaler.mean_) == 3) # deprecated parameter + + +def test_scaler_scale_property(): + """Check that the scale_ attribute of Scalers is accessible""" + X = [[0., 2.0, +0.5], + [0., 0.0, -0.3]] - X_scaled = scale(X) - assert_array_almost_equal(X_scaled.mean(axis=0), 0.0) - assert_array_almost_equal(X_scaled.std(axis=0), 1.0) + for fn, ScalerClass in ALL_SCALERS.items(): + scaler = ScalerClass() + scaler.fit_transform(X) + assert(len(scaler.scale_) == 3) - X = np.ones(5) - assert_array_equal(scale(X, with_mean=False), X) + scaler = StandardScaler().fit(X) + with warnings.catch_warnings(record=True): + assert(len(scaler.std_) == 3) # deprecated parameter + + +def test_scaler_matrix_copy_argument(): + '''Make sure the scalers respect the 'copy' argument on inputs.''' + + def test_impl(X, ScalerClass): + scaler = ScalerClass(copy=True).fit(X) + X_scaled = scaler.transform(X) + assert (X_scaled is not X) + X2 = scaler.inverse_transform(X_scaled) + assert (X_scaled is not X2) + + scaler = ScalerClass(copy=False).fit(X) + X_scaled = scaler.transform(X) + assert (X_scaled is X) + X2 = scaler.inverse_transform(X_scaled) + assert (X_scaled is X2) + + scaler = ScalerClass(copy=False).fit(X) + X_scaled = scaler.transform(X, copy=True) + assert (X_scaled is not X) + X2 = scaler.inverse_transform(X_scaled, copy=True) + assert (X_scaled is not X2) + + scaler = ScalerClass(copy=False).fit(X) + X_scaled = scaler.transform(X) + assert (X_scaled is X) + + rng = np.random.RandomState(42) + X = rng.randn(4, 5) + rng = np.random.RandomState(42) + for fn, ScalerClass in ALL_SCALERS.items(): + test_impl(X, ScalerClass) + + X = rng.randn(4, 5) + X[0, 0] = 0 + X_csr = sparse.csr_matrix(X) + for fn, ScalerClass in SPARSE_SCALERS.items(): + test_impl(X, ScalerClass) + + +@ignore_warnings +def test_scaler_int(): + '''Test that scaler converts integer input to float + (or at least that transform->inverse_transform works as it should)''' + rng = np.random.RandomState(42) + X = rng.randint(20, size=(4, 5)) + X[:, 0] = 0 # first feature is always of zero + X_csr = sparse.csr_matrix(X) + X_csc = sparse.csc_matrix(X) + + for fn, ScalerClass in ALL_SCALERS.items(): + try: + null_transform = ScalerClass(with_centering=False, + with_scaling=False, copy=True) + X_null = null_transform.fit_transform(X) + assert_array_equal(X_null, X) + except TypeError: + pass # some classes can't be initialized with above args + + with warnings.catch_warnings(record=True): + scaler = ScalerClass().fit(X) + X_scaled = scaler.transform(X, copy=True) + assert_false(np.any(np.isnan(X_scaled))) + + X_scaled_back = scaler.inverse_transform(X_scaled) + assert_array_almost_equal(X_scaled_back, X) + + for fn, ScalerClass in SPARSE_SCALERS.items(): + with warnings.catch_warnings(record=True): + scaler_csr = ScalerClass().fit(X_csr) + X_csr_scaled = scaler_csr.transform(X_csr, copy=True) + assert_false(np.any(np.isnan(X_csr_scaled.data))) + + with warnings.catch_warnings(record=True): + scaler_csc = ScalerClass().fit(X_csc) + X_csc_scaled = scaler_csc.transform(X_csc, copy=True) + assert_false(np.any(np.isnan(X_csc_scaled.data))) + X_csr_scaled_back = scaler_csr.inverse_transform(X_csr_scaled) + assert_array_almost_equal(X_csr_scaled_back.toarray(), X) -def test_scaler_2d_arrays(): + X_csc_scaled_back = scaler_csc.inverse_transform(X_csc_scaled.tocsc()) + assert_array_almost_equal(X_csc_scaled_back.toarray(), X) + + +def test_warning_scaling_integers(): + """Check warning when scaling integer data""" + X = np.array([[1, 2, 0], + [0, 0, 0]], dtype=np.uint8) + + for fn, ScalerClass in ALL_SCALERS.items(): + assert_warns(UserWarning, ScalerClass().fit, X) + + +def test_nonsparse_scaler_raise_exception_on_sparse(): + rng = np.random.RandomState(42) + X = rng.randn(4, 5) + X_csr = sparse.csr_matrix(X) + + nonsparse = set(ALL_SCALERS.keys()) - set(SPARSE_SCALERS.keys()) + for fn in nonsparse: + scaler = ALL_SCALERS[fn]() + + # some scalers don't except sparse matrices at all and will throw + # a TypeError, while others except them only under certain conditions + # and will throw a ValueError if that happens. Thus we test for + # BaseException + assert_raises(BaseException, scaler.fit, X_csr) + scaler.fit(X) + assert_raises(BaseException, scaler.transform, X_csr) + X_transformed_csr = sparse.csr_matrix(scaler.transform(X)) + assert_raises(BaseException, scaler.inverse_transform, + X_transformed_csr) + + +def test_standardscaler_2d_arrays(): """Test scaling of 2d array along first axis""" rng = np.random.RandomState(0) X = rng.randn(4, 5) @@ -108,45 +335,40 @@ def test_scaler_2d_arrays(): scaler = StandardScaler() X_scaled = scaler.fit(X).transform(X, copy=True) - assert_false(np.any(np.isnan(X_scaled))) assert_array_almost_equal(X_scaled.mean(axis=0), 5 * [0.0]) assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.]) - # Check that X has been copied - assert_true(X_scaled is not X) - # check inverse transform - X_scaled_back = scaler.inverse_transform(X_scaled) - assert_true(X_scaled_back is not X) - assert_true(X_scaled_back is not X_scaled) - assert_array_almost_equal(X_scaled_back, X) - - X_scaled = scale(X, axis=1, with_std=False) - assert_false(np.any(np.isnan(X_scaled))) + X_scaled = scale(X, axis=1, with_scaling=False) assert_array_almost_equal(X_scaled.mean(axis=1), 4 * [0.0]) - X_scaled = scale(X, axis=1, with_std=True) - assert_false(np.any(np.isnan(X_scaled))) + X_scaled = scale(X, axis=1, with_scaling=True) assert_array_almost_equal(X_scaled.mean(axis=1), 4 * [0.0]) assert_array_almost_equal(X_scaled.std(axis=1), 4 * [1.0]) - # Check that the data hasn't been modified - assert_true(X_scaled is not X) - X_scaled = scaler.fit(X).transform(X, copy=False) - assert_false(np.any(np.isnan(X_scaled))) - assert_array_almost_equal(X_scaled.mean(axis=0), 5 * [0.0]) - assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.]) - # Check that X has not been copied - assert_true(X_scaled is X) - X = rng.randn(4, 5) - X[:, 0] = 1.0 # first feature is a constant, non zero feature +def test_standard_scaler_zero_variance_features(): + """Check standard scaler on toy data with zero variance features""" + X = [[0., 1., +0.5], + [0., 1., -0.1], + [0., 1., +1.1]] scaler = StandardScaler() - X_scaled = scaler.fit(X).transform(X, copy=True) - assert_false(np.any(np.isnan(X_scaled))) - assert_array_almost_equal(X_scaled.mean(axis=0), 5 * [0.0]) - assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.]) - # Check that X has not been copied - assert_true(X_scaled is not X) + X_trans = scaler.fit_transform(X) + X_expected = [[0., 0., 0.], + [0., 0., -1.22474487], + [0., 0., 1.22474487]] + assert_array_almost_equal(X_trans, X_expected) + X_trans_inv = scaler.inverse_transform(X_trans) + assert_array_almost_equal(X, X_trans_inv, decimal=4) + + # make sure new data gets transformed correctly + X_new = [[+0., 2., 0.5], + [-1., 1., 0.0], + [+0., 1., 1.5]] + X_trans_new = scaler.transform(X_new) + X_expected_new = [[+0., 1., 0.], + [-1., 0., -1.02062073], + [+0., 0., 2.04124145]] + assert_array_almost_equal(X_trans_new, X_expected_new, decimal=4) def test_min_max_scaler_iris(): @@ -155,42 +377,54 @@ def test_min_max_scaler_iris(): # default params X_trans = scaler.fit_transform(X) assert_array_almost_equal(X_trans.min(axis=0), 0) - assert_array_almost_equal(X_trans.min(axis=0), 0) assert_array_almost_equal(X_trans.max(axis=0), 1) - X_trans_inv = scaler.inverse_transform(X_trans) - assert_array_almost_equal(X, X_trans_inv) # not default params: min=1, max=2 scaler = MinMaxScaler(feature_range=(1, 2)) X_trans = scaler.fit_transform(X) assert_array_almost_equal(X_trans.min(axis=0), 1) assert_array_almost_equal(X_trans.max(axis=0), 2) - X_trans_inv = scaler.inverse_transform(X_trans) - assert_array_almost_equal(X, X_trans_inv) # min=-.5, max=.6 scaler = MinMaxScaler(feature_range=(-.5, .6)) X_trans = scaler.fit_transform(X) assert_array_almost_equal(X_trans.min(axis=0), -.5) assert_array_almost_equal(X_trans.max(axis=0), .6) - X_trans_inv = scaler.inverse_transform(X_trans) - assert_array_almost_equal(X, X_trans_inv) - # raises on invalid range + # minmax_scale function + X_trans = minmax_scale(X) + assert_array_almost_equal(X_trans.min(axis=0), 0) + assert_array_almost_equal(X_trans.max(axis=0), 1) + X_trans = minmax_scale(X, feature_range=(1, 2)) + assert_array_almost_equal(X_trans.min(axis=0), 1) + assert_array_almost_equal(X_trans.max(axis=0), 2) + X_trans = minmax_scale(X, feature_range=(-0.5, 0.6)) + assert_array_almost_equal(X_trans.min(axis=0), -0.5) + assert_array_almost_equal(X_trans.max(axis=0), 0.6) + + +def test_min_max_scaler_raise_invalid_range(): + '''Check if MinMaxScaler raises an error if range is invalid''' + X = [[0., 1., +0.5], + [0., 1., -0.1], + [0., 1., +1.1]] scaler = MinMaxScaler(feature_range=(2, 1)) assert_raises(ValueError, scaler.fit, X) + # TODO: for some reason assert_raise doesn't test this correctly + did_raise = False + try: + minmax_scale(X, feature_range=(2, 1)) + except ValueError: + did_raise = True + assert_true(did_raise) def test_min_max_scaler_zero_variance_features(): - """Check min max scaler on toy data with zero variance features""" + """Check MinMaxScaler on toy data with zero variance features""" X = [[0., 1., +0.5], [0., 1., -0.1], [0., 1., +1.1]] - X_new = [[+0., 2., 0.5], - [-1., 1., 0.0], - [+0., 1., 1.5]] - # default params scaler = MinMaxScaler() X_trans = scaler.fit_transform(X) @@ -201,6 +435,10 @@ def test_min_max_scaler_zero_variance_features(): X_trans_inv = scaler.inverse_transform(X_trans) assert_array_almost_equal(X, X_trans_inv) + # make sure new data gets transformed correctly + X_new = [[+0., 2., 0.5], + [-1., 1., 0.0], + [+0., 1., 1.5]] X_trans_new = scaler.transform(X_new) X_expected_0_1_new = [[+0., 1., 0.500], [-1., 0., 0.083], @@ -246,209 +484,78 @@ def test_min_max_scaler_1d(): assert_less_equal(X_scaled.max(), 1.) -def test_scaler_without_centering(): +def test_standardscaler_nulltransform(): rng = np.random.RandomState(42) X = rng.randn(4, 5) X[:, 0] = 0.0 # first feature is always of zero X_csr = sparse.csr_matrix(X) - X_csc = sparse.csc_matrix(X) - - assert_raises(ValueError, StandardScaler().fit, X_csr) - - null_transform = StandardScaler(with_mean=False, with_std=False, copy=True) + null_transform = StandardScaler(with_centering=False, + with_scaling=False, copy=True) X_null = null_transform.fit_transform(X_csr) assert_array_equal(X_null.data, X_csr.data) X_orig = null_transform.inverse_transform(X_null) assert_array_equal(X_orig.data, X_csr.data) - scaler = StandardScaler(with_mean=False).fit(X) - X_scaled = scaler.transform(X, copy=True) - assert_false(np.any(np.isnan(X_scaled))) - - scaler_csr = StandardScaler(with_mean=False).fit(X_csr) - X_csr_scaled = scaler_csr.transform(X_csr, copy=True) - assert_false(np.any(np.isnan(X_csr_scaled.data))) - - scaler_csc = StandardScaler(with_mean=False).fit(X_csc) - X_csc_scaled = scaler_csr.transform(X_csc, copy=True) - assert_false(np.any(np.isnan(X_csc_scaled.data))) - assert_equal(scaler.mean_, scaler_csr.mean_) - assert_array_almost_equal(scaler.std_, scaler_csr.std_) - assert_equal(scaler.mean_, scaler_csc.mean_) - assert_array_almost_equal(scaler.std_, scaler_csc.std_) - assert_array_almost_equal( - X_scaled.mean(axis=0), [0., -0.01, 2.24, -0.35, -0.78], 2) - assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.]) - - X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis0(X_csr_scaled) - assert_array_almost_equal(X_csr_scaled_mean, X_scaled.mean(axis=0)) - assert_array_almost_equal(X_csr_scaled_std, X_scaled.std(axis=0)) - # Check that X has not been modified (copy) - assert_true(X_scaled is not X) - assert_true(X_csr_scaled is not X_csr) - - X_scaled_back = scaler.inverse_transform(X_scaled) - assert_true(X_scaled_back is not X) - assert_true(X_scaled_back is not X_scaled) - assert_array_almost_equal(X_scaled_back, X) - X_csr_scaled_back = scaler_csr.inverse_transform(X_csr_scaled) - assert_true(X_csr_scaled_back is not X_csr) - assert_true(X_csr_scaled_back is not X_csr_scaled) - assert_array_almost_equal(X_csr_scaled_back.toarray(), X) - X_csc_scaled_back = scaler_csr.inverse_transform(X_csc_scaled.tocsc()) - assert_true(X_csc_scaled_back is not X_csc) - assert_true(X_csc_scaled_back is not X_csc_scaled) - assert_array_almost_equal(X_csc_scaled_back.toarray(), X) -def test_scaler_int(): - # test that scaler converts integer input to floating - # for both sparse and dense matrices rng = np.random.RandomState(42) - X = rng.randint(20, size=(4, 5)) - X[:, 0] = 0 # first feature is always of zero - X_csr = sparse.csr_matrix(X) - X_csc = sparse.csc_matrix(X) - - null_transform = StandardScaler(with_mean=False, with_std=False, copy=True) - with warnings.catch_warnings(record=True): - X_null = null_transform.fit_transform(X_csr) - assert_array_equal(X_null.data, X_csr.data) - X_orig = null_transform.inverse_transform(X_null) - assert_array_equal(X_orig.data, X_csr.data) - - with warnings.catch_warnings(record=True): - scaler = StandardScaler(with_mean=False).fit(X) - X_scaled = scaler.transform(X, copy=True) - assert_false(np.any(np.isnan(X_scaled))) - - with warnings.catch_warnings(record=True): - scaler_csr = StandardScaler(with_mean=False).fit(X_csr) - X_csr_scaled = scaler_csr.transform(X_csr, copy=True) - assert_false(np.any(np.isnan(X_csr_scaled.data))) - - with warnings.catch_warnings(record=True): - scaler_csc = StandardScaler(with_mean=False).fit(X_csc) - X_csc_scaled = scaler_csr.transform(X_csc, copy=True) - assert_false(np.any(np.isnan(X_csc_scaled.data))) - assert_equal(scaler.mean_, scaler_csr.mean_) - assert_array_almost_equal(scaler.std_, scaler_csr.std_) - assert_equal(scaler.mean_, scaler_csc.mean_) - assert_array_almost_equal(scaler.std_, scaler_csc.std_) - - assert_array_almost_equal( - X_scaled.mean(axis=0), - [0., 1.109, 1.856, 21., 1.559], 2) - assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.]) - - X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis0( - X_csr_scaled.astype(np.float)) - assert_array_almost_equal(X_csr_scaled_mean, X_scaled.mean(axis=0)) - assert_array_almost_equal(X_csr_scaled_std, X_scaled.std(axis=0)) - - # Check that X has not been modified (copy) - assert_true(X_scaled is not X) - assert_true(X_csr_scaled is not X_csr) - - X_scaled_back = scaler.inverse_transform(X_scaled) - assert_true(X_scaled_back is not X) - assert_true(X_scaled_back is not X_scaled) - assert_array_almost_equal(X_scaled_back, X) - - X_csr_scaled_back = scaler_csr.inverse_transform(X_csr_scaled) - assert_true(X_csr_scaled_back is not X_csr) - assert_true(X_csr_scaled_back is not X_csr_scaled) - assert_array_almost_equal(X_csr_scaled_back.toarray(), X) - - X_csc_scaled_back = scaler_csr.inverse_transform(X_csc_scaled.tocsc()) - assert_true(X_csc_scaled_back is not X_csc) - assert_true(X_csc_scaled_back is not X_csc_scaled) - assert_array_almost_equal(X_csc_scaled_back.toarray(), X) - - -def test_scaler_without_copy(): - """Check that StandardScaler.fit does not change input""" rng = np.random.RandomState(42) X = rng.randn(4, 5) - X[:, 0] = 0.0 # first feature is always of zero - X_csr = sparse.csr_matrix(X) - - X_copy = X.copy() - StandardScaler(copy=False).fit(X) - assert_array_equal(X, X_copy) - X_csr_copy = X_csr.copy() - StandardScaler(with_mean=False, copy=False).fit(X_csr) - assert_array_equal(X_csr.toarray(), X_csr_copy.toarray()) -def test_scale_sparse_with_mean_raise_exception(): - rng = np.random.RandomState(42) - X = rng.randn(4, 5) - X_csr = sparse.csr_matrix(X) - # check scaling and fit with direct calls on sparse data - assert_raises(ValueError, scale, X_csr, with_mean=True) - assert_raises(ValueError, StandardScaler(with_mean=True).fit, X_csr) - # check transform and inverse_transform after a fit on a dense array - scaler = StandardScaler(with_mean=True).fit(X) - assert_raises(ValueError, scaler.transform, X_csr) - X_transformed_csr = sparse.csr_matrix(scaler.transform(X)) - assert_raises(ValueError, scaler.inverse_transform, X_transformed_csr) -def test_scale_function_without_centering(): - rng = np.random.RandomState(42) - X = rng.randn(4, 5) - X[:, 0] = 0.0 # first feature is always of zero - X_csr = sparse.csr_matrix(X) - X_scaled = scale(X, with_mean=False) - assert_false(np.any(np.isnan(X_scaled))) - X_csr_scaled = scale(X_csr, with_mean=False) - assert_false(np.any(np.isnan(X_csr_scaled.data))) - # test csc has same outcome - X_csc_scaled = scale(X_csr.tocsc(), with_mean=False) - assert_array_almost_equal(X_scaled, X_csc_scaled.toarray()) +def toarray(a): + if hasattr(a, "toarray"): + a = a.toarray() + return a - # raises value error on axis != 0 - assert_raises(ValueError, scale, X_csr, with_mean=False, axis=1) - assert_array_almost_equal(X_scaled.mean(axis=0), - [0., -0.01, 2.24, -0.35, -0.78], 2) - assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.]) - # Check that X has not been copied - assert_true(X_scaled is not X) +def test_polynomial_features(): + """Test Polynomial Features""" + X1 = np.arange(6)[:, np.newaxis] + P1 = np.hstack([np.ones_like(X1), + X1, X1 ** 2, X1 ** 3]) + deg1 = 3 - X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis0(X_csr_scaled) - assert_array_almost_equal(X_csr_scaled_mean, X_scaled.mean(axis=0)) - assert_array_almost_equal(X_csr_scaled_std, X_scaled.std(axis=0)) + X2 = np.arange(6).reshape((3, 2)) + x1 = X2[:, :1] + x2 = X2[:, 1:] + P2 = np.hstack([x1 ** 0 * x2 ** 0, + x1 ** 1 * x2 ** 0, + x1 ** 0 * x2 ** 1, + x1 ** 2 * x2 ** 0, + x1 ** 1 * x2 ** 1, + x1 ** 0 * x2 ** 2]) + deg2 = 2 + for (deg, X, P) in [(deg1, X1, P1), (deg2, X2, P2)]: + P_test = PolynomialFeatures(deg, include_bias=True).fit_transform(X) + assert_array_almost_equal(P_test, P) -def test_warning_scaling_integers(): - """Check warning when scaling integer data""" - X = np.array([[1, 2, 0], - [0, 0, 0]], dtype=np.uint8) + P_test = PolynomialFeatures(deg, include_bias=False).fit_transform(X) + assert_array_almost_equal(P_test, P[:, 1:]) - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - assert_warns(UserWarning, StandardScaler().fit, X) + interact = PolynomialFeatures(2, interaction_only=True, include_bias=True) + X_poly = interact.fit_transform(X) + assert_array_almost_equal(X_poly, P2[:, [0, 1, 2, 4]]) - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - assert_warns(UserWarning, MinMaxScaler().fit, X) + assert_raises(ValueError, interact.transform, X[:, 1:]) def test_normalizer_l1(): @@ -611,7 +718,7 @@ def test_center_kernel(): in feature space""" rng = np.random.RandomState(0) X_fit = rng.random_sample((5, 4)) - scaler = StandardScaler(with_std=False) + scaler = StandardScaler(with_scaling=False) scaler.fit(X_fit) X_fit_centered = scaler.transform(X_fit) K_fit = np.dot(X_fit, X_fit.T) From d3903c33a931b8edbd1f2831db629674d811c7a0 Mon Sep 17 00:00:00 2001 From: Thomas Unterthiner Date: Mon, 21 Jul 2014 13:37:04 +0200 Subject: [PATCH 5/7] DOC MinMaxScaler / minmax_scale --- doc/modules/preprocessing.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 1627d46295f5a..252a92fd45107 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -133,11 +133,8 @@ applied to be consistent with the transformation performed on the train data:: It is possible to introspect the scaler attributes to find about the exact nature of the transformation learned on the training data:: - >>> min_max_scaler.scale_ # doctest: +ELLIPSIS - array([ 0.5 , 0.5 , 0.33...]) - - >>> min_max_scaler.min_ # doctest: +ELLIPSIS - array([ 0. , 0.5 , 0.33...]) + >>> min_max_scaler.scale_ # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + array([ 2., 2., 3.]) If :class:`MinMaxScaler` is given an explicit ``feature_range=(min, max)`` the full formula is:: @@ -146,6 +143,10 @@ full formula is:: X_scaled = X_std / (max - min) + min +As with :func:`scale`, the ``preprocessing`` module further provides a +convenience function function :func:`minmax_scale` if you don't want to use +the `Transformer` API. + .. topic:: References: Further discussion on the importance of centering and scaling data is From 28df99fbf49b16da70fe0db50a81f23141ff61d2 Mon Sep 17 00:00:00 2001 From: Thomas Unterthiner Date: Mon, 21 Jul 2014 13:39:55 +0200 Subject: [PATCH 6/7] RobustScaler --- doc/modules/preprocessing.rst | 14 +- sklearn/preprocessing/__init__.py | 4 + sklearn/preprocessing/data.py | 169 +++++++++++++++++++++++ sklearn/preprocessing/tests/test_data.py | 72 ++++++++++ 4 files changed, 256 insertions(+), 3 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 252a92fd45107..9e008437e225d 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -105,7 +105,7 @@ An alternative standardization is scaling features to lie between a given minimum and maximum value, often between zero and one. This can be achieved using :class:`MinMaxScaler`. -The motivation to use this scaling include robustness to very small +The motivation to use such a scaling include robustness to very small standard deviations of features and preserving zero entries in sparse data. Here is an example to scale a toy data matrix to the ``[0, 1]`` range:: @@ -147,6 +147,15 @@ As with :func:`scale`, the ``preprocessing`` module further provides a convenience function function :func:`minmax_scale` if you don't want to use the `Transformer` API. +Scaling data with outliers +-------------------------- +If your data contains many outliers, scaling using the mean and variance +of the data does sometimes not work very well. In these cases, you can use +:func:`robust_scale` and :class:`RobustScaler` as drop-in replacements +instead, which use more robust estimates for the center and range of your +data. + + .. topic:: References: Further discussion on the importance of centering and scaling data is @@ -528,6 +537,5 @@ similarly. Note that if features have very different scaling or statistical properties, :class:`cluster.FeatureAgglomeration` maye not be able to - capture the links between related features. Using a + capture the links between related features. Using a :class:`preprocessing.StandardScaler` can be useful in these settings. - diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index f3a644b9546d6..e817232a87308 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -8,10 +8,12 @@ from .data import MinMaxScaler from .data import Normalizer from .data import StandardScaler +from .data import RobustScaler from .data import add_dummy_feature from .data import binarize from .data import normalize from .data import scale +from .data import robust_scale from .data import minmax_scale from .data import OneHotEncoder @@ -35,11 +37,13 @@ 'Normalizer', 'OneHotEncoder', 'StandardScaler', + 'RobustScaler', 'add_dummy_feature', 'PolynomialFeatures', 'binarize', 'normalize', 'scale', + 'robust_scale', 'minmax_scale', 'label_binarize', ] diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 74ef1a0a7a8d7..c3af495fa2925 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -11,6 +11,7 @@ from abc import ABCMeta, abstractmethod import numpy as np from scipy import sparse +from scipy.stats.mstats import mquantiles from ..base import BaseEstimator, TransformerMixin from ..externals import six @@ -37,10 +38,12 @@ 'Normalizer', 'OneHotEncoder', 'StandardScaler', + 'RobustScaler', 'add_dummy_feature', 'binarize', 'normalize', 'scale', + 'robust_scale', 'minmax_scale' ] @@ -396,11 +399,117 @@ def std_(self): return self.scale_ +class RobustScaler(BaseScaler): + """Standardize features by removing the median and scaling to IQR. + + Centering and scaling happen independently on each feature (or each + sample, depending on the `axis` argument) by computing the relevant + statistics on the samples in the training set. Median and interquartile + range are then stored to be used on later data using the `transform` + method. + + Standardization of a dataset is a common requirement for many + machine learning estimators. Typically this is done by removing the mean + and scaling to unit variance. However, outliers can often influence the + sample mean / variance in a negative way. In such cases, the median and + the interquartile range often give better results. + + This scaler uses `scipy.stats.mstats.mquantiles` with default parameters + to calculate the interquartile range. + + Parameters + ---------- + interquartile_scale: float or string in ["normal" (default), ], + The interquartile range is divided by this factor. If + `interquartile_scale` is "normal", the data is scaled so it + approximately reaches unit variance. This converge assumes Gaussian + input data and will need a large number of samples. + + with_centering : boolean, True by default + If True, center the data before scaling. + This does not work (and will raise an exception) when attempted on + sparse matrices, because centering them entails building a dense + matrix which in common use cases is likely to be too large to fit in + memory. + + with_scaling : boolean, True by default + If True, scale the data to interquartile range. + + copy : boolean, optional, default is True + If False, try to avoid a copy and do inplace scaling instead. + This is not guaranteed to always work inplace; e.g. if the data is + not a NumPy array or scipy.sparse CSR matrix, a copy may still be + returned. + + axis : int (0 by default) + axis used to compute the scaling statistics along. If 0, + independently scale each feature, otherwise (if 1) scale + each sample. + + Attributes + ---------- + `center_` : array of floats + The median value for each feature in the training set, unless axis=1, + in which case it contains the median value for each sample + + `scale_` : array of floats + The (scaled) interquartile range for each feature in the training set, + unless axis=1, in which case it contains the median value for each + sample. + + See also + -------- + :class:`sklearn.preprocessing.StandardScaler` to perform centering + and scaling using mean and variance. + + :class:`sklearn.decomposition.RandomizedPCA` with `whiten=True` + to further remove the linear correlation across features. + """ + + def __init__(self, interquartile_scale="normal", with_centering=True, + with_scaling=True, copy=True, axis=0): + super(RobustScaler, self).__init__(with_centering=with_centering, + with_scaling=with_scaling, + copy=copy, axis=axis) + self.interquartile_scale = interquartile_scale + + def fit(self, X, y=None, copy=None): + """Compute the mean and std to be used for later scaling. + Parameters ---------- + X : array-like or CSR matrix with shape [n_samples, n_features] + The data used to compute the mean and standard deviation + used for later scaling along the features axis. """ if sparse.issparse(X): + raise TypeError("RobustScaler cannot be fitted on sparse inputs") + + if not np.isreal(self.interquartile_scale): + if self.interquartile_scale != "normal": + raise ValueError("Unknown interquartile_scale value.") + else: + iqr_scale = 1.34898 else: + iqr_scale = self.interquartile_scale + + if copy is None: + copy = self.copy + + self.center_ = None + self.scale_ = None + X = self._check_array(X, copy) + Xr = np.rollaxis(X, self.axis) + if self.with_centering: + self.center_ = np.median(Xr, axis=0) + + if self.with_scaling: + q = as_float_array(mquantiles(Xr, prob=(0.25, 0.75), axis=0)) + if len(q.shape) == 1: + q = q.reshape(-1, 1) + self.scale_ = (q[1, :] - q[0, :]) / iqr_scale + self.scale_ = self._handle_zeros_in_scale(self.scale_) + return self class PolynomialFeatures(BaseEstimator, TransformerMixin): @@ -577,6 +686,66 @@ def scale(X, axis=0, with_centering=True, with_scaling=True, copy=True, return s.fit_transform(X) +def robust_scale(X, interquartile_scale="normal", axis=0, with_centering=True, + with_scaling=True, copy=True): + """Standardize a dataset along any axis + + Center to the median and component wise scale + according to the interquartile range. + + Parameters + ---------- + X : array-like or CSR matrix. + The data to center and scale. + + interquartile_scale: float or string in ["normal" (default), ], + The interquartile range is divided by this factor. If + `interquartile_scale` is "normal", the data is scaled so it + approximately reaches unit variance. This converge assumes Gaussian + input data and will need a large number of samples. + + axis : int (0 by default) + axis used to compute the medians and IQR along. If 0, + independently scale each feature, otherwise (if 1) scale + each sample. + + with_centering : boolean, True by default + If True, center the data before scaling. + + with_scaling : boolean, True by default + If True, scale the data to unit variance (or equivalently, + unit standard deviation). + + copy : boolean, optional, default is True + set to False to perform inplace row normalization and avoid a + copy (if the input is already a numpy array or a scipy.sparse + CSR matrix and if axis is 1). + + Notes + ----- + This implementation will refuse to center scipy.sparse matrices + since it would make them non-sparse and would potentially crash the + program with memory exhaustion problems. + + Instead the caller is expected to either set explicitly + `with_centering=False` (in that case, only variance scaling will be + performed on the features of the CSR matrix) or to call `X.toarray()` + if he/she expects the materialized dense array to fit in memory. + + To avoid memory copy the caller should pass a CSR matrix. + + See also + -------- + :class:`sklearn.preprocessing.RobustScaler` to perform centering and + scaling using the ``Transformer`` API (e.g. as part of a preprocessing + :class:`sklearn.pipeline.Pipeline`) + """ + s = RobustScaler(interquartile_scale=interquartile_scale, + with_centering=with_centering, with_scaling=with_scaling, + copy=copy, axis=axis) + return s.fit_transform(X) + + def minmax_scale(X, feature_range=(0, 1), axis=0, with_centering=True, with_scaling=True, copy=True): """Standardizes features by scaling each feature to a given range. diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 08ea2614e0a9a..7e118e7603eb8 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -27,6 +27,8 @@ from sklearn.preprocessing.data import OneHotEncoder from sklearn.preprocessing.data import StandardScaler from sklearn.preprocessing.data import scale +from sklearn.preprocessing.data import RobustScaler +from sklearn.preprocessing.data import robust_scale from sklearn.preprocessing.data import MinMaxScaler from sklearn.preprocessing.data import minmax_scale from sklearn.preprocessing.data import add_dummy_feature @@ -44,6 +46,7 @@ NON_SPARSE_SCALERS = { # only support dense input 'centered StandardScaler': partial(StandardScaler, with_centering=True), + 'RobustScaler': RobustScaler, 'MinMaxScaler[0,1]': partial(MinMaxScaler, feature_range=(0, 1)), 'MinMaxScaler[-1,1]': partial(MinMaxScaler, feature_range=(-1, 1)), 'MinMaxScaler[-3,5]': partial(MinMaxScaler, feature_range=(-3, 5)) @@ -57,6 +60,7 @@ SCALER_FUNCTIONS = { 'scale': scale, + 'robust_scale': robust_scale, 'minmax_scale': minmax_scale, } @@ -497,19 +501,87 @@ def test_standardscaler_nulltransform(): assert_array_equal(X_orig.data, X_csr.data) +def test_robust_scaler_2d_arrays(): + """Test robust scaling of 2d array along first axis""" + rng = np.random.RandomState(0) + X = rng.randn(4, 5) + X[:, 0] = 0.0 # first feature is always of zero + scaler = RobustScaler() + X_scaled = scaler.fit(X).transform(X, copy=True) + assert_array_almost_equal(np.median(X_scaled, axis=0), 5 * [0.0]) + assert_array_almost_equal(X_scaled.std(axis=0)[0], 0) +def test_robust_scaler_iris(): + X = iris.data + scaler = RobustScaler(interquartile_scale=1.0) + X_trans = scaler.fit_transform(X) + assert_array_almost_equal(np.median(X_trans, axis=0), 0) + X_trans_inv = scaler.inverse_transform(X_trans) + assert_array_almost_equal(X, X_trans_inv) + # make sure iqr is 1 + q = mquantiles(X_trans, prob=(0.25, 0.75), axis=0) + iqr = q[1, :] - q[0, :] + assert_array_almost_equal(iqr, 1) +def test_robust_scaler_iqr_scale(): + """Does iqr scaling achieve approximately std= 1 on Gaussian data?""" rng = np.random.RandomState(42) + X = rng.randn(10000, 4) # need lots of samples + scaler = RobustScaler() + X_trans = scaler.fit_transform(X) + assert_array_almost_equal(X_trans.std(axis=0), 1, decimal=2) +def test_robust_scale_iqr_errors(): + """Check that invalid arguments yield ValueError""" rng = np.random.RandomState(42) X = rng.randn(4, 5) + assert_raises(ValueError, RobustScaler(interquartile_scale="foo").fit, X) + # TODO: for some reason assert_raise doesn't test this correctly + did_raise = False + try: + robust_scale(X, interquartile_scale="foo") + except ValueError: + did_raise = True + assert(did_raise) + +def test_robust_scaler_zero_variance_features(): + """Check min max scaler on toy data with zero variance features""" + X = [[0., 1., +0.5], + [0., 1., -0.1], + [0., 1., +1.1]] + + scaler = RobustScaler(interquartile_scale=1.0) + X_trans = scaler.fit_transform(X) + + # NOTE: what we expect in the third column depends HEAVILY on the method + # used to calculate quantiles. The values here were calculated + # to fit the quantiles produces by scipy.stats.mstats.mquantiles' default + # quantile-method. Calculating quantiles with + # scipy.stats.mstats.scoreatquantile + # would yield very different results! + X_expected = [[0., 0., +0.0], + [0., 0., -0.625], + [0., 0., +0.625]] + assert_array_almost_equal(X_trans, X_expected) + X_trans_inv = scaler.inverse_transform(X_trans) + assert_array_almost_equal(X, X_trans_inv) + + # make sure new data gets transformed correctly + X_new = [[+0., 2., 0.5], + [-1., 1., 0.0], + [+0., 1., 1.5]] + X_trans_new = scaler.transform(X_new) + X_expected_new = [[+0., 1., +0.], + [-1., 0., -0.52083], + [+0., 0., +1.04166]] + assert_array_almost_equal(X_trans_new, X_expected_new, decimal=3) From b918e6627ec2329a4c4fd053ecd48af401ef025b Mon Sep 17 00:00:00 2001 From: Thomas Unterthiner Date: Mon, 21 Jul 2014 13:40:23 +0200 Subject: [PATCH 7/7] MAxAbsScaler --- doc/modules/preprocessing.rst | 82 ++++++++++++++++++------ sklearn/preprocessing/__init__.py | 4 ++ sklearn/preprocessing/data.py | 80 ++++++++++++++++++++++- sklearn/preprocessing/tests/test_data.py | 44 +++++++++++++ 4 files changed, 189 insertions(+), 21 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 9e008437e225d..b0a924d1268ed 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -102,8 +102,10 @@ of :class:`StandardScaler`. Scaling features to a range --------------------------- An alternative standardization is scaling features to -lie between a given minimum and maximum value, often between zero and one. -This can be achieved using :class:`MinMaxScaler`. +lie between a given minimum and maximum value, such as between zero and one, or +so that the maximum value of each feature is scaled to unit size. +This can be achieved using :class:`MinMaxScaler` or :class:`MaxAbsScaler`, +respectively. The motivation to use such a scaling include robustness to very small standard deviations of features and preserving zero entries in sparse data. @@ -147,6 +149,63 @@ As with :func:`scale`, the ``preprocessing`` module further provides a convenience function function :func:`minmax_scale` if you don't want to use the `Transformer` API. +:class:`MaxAbsScaler` works in a very similar fashion, but scales data so +it lies within the range ``[-1, 1]``, and is meant for data +that is already centered at zero. In particular, this scaler is very well +suited for sparse data. + +Here is how to use the toy data from the previous example with this scaler:: + + >>> X_train = np.array([[ 1., -1., 2.], + ... [ 2., 0., 0.], + ... [ 0., 1., -1.]]) + ... + >>> max_abs_scaler = preprocessing.MaxAbsScaler() + >>> X_train_maxabs = max_abs_scaler.fit_transform(X_train) + >>> X_train_maxabs #doctest +NORMALIZE_WHITESPACE^ + array([[ 0.5, -1. , 1. ], + [ 1. , 0. , 0. ], + [ 0. , 1. , -0.5]]) + >>> X_test = np.array([[ -3., -1., 4.]]) + >>> X_test_maxabs = max_abs_scaler.transform(X_test) + >>> X_test_maxabs # doctest: +NORMALIZE_WHITESPACE + array([[-1.5, -1. , 2. ]]) + >>> max_abs_scaler.scale_ # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + array([ 2., 1., 2.]) + + +As with :func:`scale`, the ``preprocessing`` module further provides a +convenience function function :func:`maxabs_scale` if you don't want to use +the `Transformer` API. + + +Scaling sparse data +------------------- +Centering sparse data would destroy the sparseness structure in the data, and +thus rarely is a sensible thing to do. However, it can make sense to scale +sparse inputs, especially if features are on different scales. + +:class:`MaxAbsScaler` and :func:`maxabs_scale` were specifically designed +for scaling sparse data, and are the recommended way to go about this. +However, :func:`scale` and :class:`StandardScaler` can accept ``scipy.sparse`` +matrices as input, as long as ``with_centering=False`` is explicitly passed +to the constructor. Otherwise a ``ValueError`` will be raised as +silently centering would break the sparsity and would often crash the +execution by allocating excessive amounts of memory unintentionally. +:class:`RobustScaler` cannot be `fit`ted to sparse inputs, but you can use the +`transform` method on sparse inputs. + +Note that the scalers accept both Compressed Sparse Rows and Compressed +Sparse Columns format (see ``scipy.sparse.csr_matrix`` and +``scipy.sparse.csc_matrix``). Any other sparse input will be **converted to +the Compressed Sparse Rows representation**. To avoid unnecessary memory +copies, it is recommended to choose the CSR or CSC representation upstream. + +Finally, if the centered data is expected to be small enough, explicitly +converting the input to an array using the ``toarray`` method of sparse matrices +is another option. + + Scaling data with outliers -------------------------- If your data contains many outliers, scaling using the mean and variance @@ -172,26 +231,9 @@ data. or :class:`sklearn.decomposition.RandomizedPCA` with ``whiten=True`` to further remove the linear correlation across features. -.. topic:: Sparse input - - :func:`scale` and :class:`StandardScaler` accept ``scipy.sparse`` matrices - as input **only when with_mean=False is explicitly passed to the - constructor**. Otherwise a ``ValueError`` will be raised as - silently centering would break the sparsity and would often crash the - execution by allocating excessive amounts of memory unintentionally. - - If the centered data is expected to be small enough, explicitly convert - the input to an array using the ``toarray`` method of sparse matrices - instead. - - For sparse input the data is **converted to the Compressed Sparse Rows - representation** (see ``scipy.sparse.csr_matrix``). - To avoid unnecessary memory copies, it is recommended to choose the CSR - representation upstream. - .. topic:: Scaling target variables in regression - :func:`scale` and :class:`StandardScaler` work out-of-the-box with 1d arrays. + All scaling functions and classes work out-of-the-box with 1d arrays. This is very useful for scaling the target / response variables used for regression. diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index e817232a87308..faa11c9425e0e 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -6,6 +6,7 @@ from .data import Binarizer from .data import KernelCenterer from .data import MinMaxScaler +from .data import MaxAbsScaler from .data import Normalizer from .data import StandardScaler from .data import RobustScaler @@ -15,6 +16,7 @@ from .data import scale from .data import robust_scale from .data import minmax_scale +from .data import maxabs_scale from .data import OneHotEncoder from .data import PolynomialFeatures @@ -34,6 +36,7 @@ 'LabelEncoder', 'MultiLabelBinarizer', 'MinMaxScaler', + 'MaxAbsScaler', 'Normalizer', 'OneHotEncoder', 'StandardScaler', @@ -45,5 +48,6 @@ 'scale', 'robust_scale', 'minmax_scale', + 'maxabs_scale', 'label_binarize', ] diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index c3af495fa2925..4f63a30f53863 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -35,6 +35,7 @@ 'Binarizer', 'KernelCenterer', 'MinMaxScaler', + 'MaxAbsScaler', 'Normalizer', 'OneHotEncoder', 'StandardScaler', @@ -44,7 +45,8 @@ 'normalize', 'scale', 'robust_scale', - 'minmax_scale' + 'minmax_scale', + 'maxabs_scale' ] @@ -268,10 +270,60 @@ def min_(self): return self.center_ +class MaxAbsScaler(BaseScaler): + """Scale each feature to the [-1, 1] range without breaking the sparsity. + + This estimator scales and translates each feature individually such + that the maximal absolute value of each feature in the + training set will be 1.0. + + This scaler can also be applied to sparse CSR or CSC matrices. + + Parameters + ---------- + copy : boolean, optional, default is True + Set to False to perform inplace scaling and avoid a copy (if the input + is already a numpy array). + + axis : int (0 by default) + axis used to compute the scaling statistics along. If 0, + independently scale each feature, otherwise (if 1) scale + each sample. + + Attributes + ---------- + `scale_` : ndarray, shape (n_features,) + Per feature relative scaling of the data. + """ + + def __init__(self, copy=True, axis=0): + super(MaxAbsScaler, self).__init__(with_centering=False, + with_scaling=True, + copy=copy, axis=axis) + + def fit(self, X, y=None, copy=None): + """Compute the minimum and maximum to be used for later scaling. Parameters ---------- + X : array-like, shape [n_samples, n_features] + The data used to compute the per-feature minimum and maximum + used for later scaling along the features axis. """ + if copy is None: + copy = self.copy + + X = self._check_array(X, copy) + if sparse.issparse(X): + mins, maxs = min_max_axis(X, axis=self.axis) + scales = np.maximum(np.abs(mins), np.abs(maxs)) + else: + scales = np.abs(X).max(axis=self.axis) + scales = np.array(scales) + scales = scales.reshape(-1) + self.scale_ = self._handle_zeros_in_scale(scales) + self.center_ = np.zeros((len(self.scale_), ), dtype=self.scale_.dtype) + return self class StandardScaler(BaseScaler): @@ -791,6 +843,32 @@ def minmax_scale(X, feature_range=(0, 1), axis=0, with_centering=True, """ s = MinMaxScaler(feature_range=feature_range, copy=copy, axis=axis) return s.fit_transform(X) + + +def maxabs_scale(X, axis=0, copy=True): + """Standardizes features by scaling each feature. + + This estimator scales and translates each feature individually such + that the maximal absoulte value of each feature in the training set + will have be 1. + + This function can also be applied to sparse CSR or CSC matrices. + + Parameters + ---------- + axis : int (0 by default) + axis used to compute the scaling statistics along. If 0, + independently scale each feature, otherwise (if 1) scale + each sample. + + copy : boolean, optional, default is True + Set to False to perform inplace row normalization and avoid a + copy (if the input is already a numpy array). + """ + s = MaxAbsScaler(copy=copy, axis=axis) + return s.fit_transform(X) + + def normalize(X, norm='l2', axis=1, copy=True): """Scale input vectors individually to unit norm (vector length). diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 7e118e7603eb8..e439d26a3a4c3 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -30,7 +30,9 @@ from sklearn.preprocessing.data import RobustScaler from sklearn.preprocessing.data import robust_scale from sklearn.preprocessing.data import MinMaxScaler +from sklearn.preprocessing.data import MaxAbsScaler from sklearn.preprocessing.data import minmax_scale +from sklearn.preprocessing.data import maxabs_scale from sklearn.preprocessing.data import add_dummy_feature from sklearn.preprocessing.data import PolynomialFeatures @@ -42,10 +44,12 @@ SPARSE_SCALERS = { # support sparse or dense input 'uncentered StandardScaler': partial(StandardScaler, with_centering=False), + 'MaxAbsScaler': MaxAbsScaler } NON_SPARSE_SCALERS = { # only support dense input 'centered StandardScaler': partial(StandardScaler, with_centering=True), + 'MaxAbsScaler': MaxAbsScaler, 'RobustScaler': RobustScaler, 'MinMaxScaler[0,1]': partial(MinMaxScaler, feature_range=(0, 1)), 'MinMaxScaler[-1,1]': partial(MinMaxScaler, feature_range=(-1, 1)), @@ -62,6 +66,7 @@ 'scale': scale, 'robust_scale': robust_scale, 'minmax_scale': minmax_scale, + 'maxabs_scale': maxabs_scale } @@ -584,12 +589,50 @@ def test_robust_scaler_zero_variance_features(): assert_array_almost_equal(X_trans_new, X_expected_new, decimal=3) +def test_maxabs_scaler_zero_variance_features(): + """Check MaxAbsScaler on toy data with zero variance features""" + X = [[0., 1., +0.5], + [0., 1., -0.3], + [0., 1., +1.5], + [0., 0., +0.0]] + # default params + scaler = MaxAbsScaler() + X_trans = scaler.fit_transform(X) + X_expected = [[0., 1., 1.0 / 3.0], + [0., 1., -0.2], + [0., 1., 1.0], + [0., 0., 0.0]] + assert_array_almost_equal(X_trans, X_expected) + X_trans_inv = scaler.inverse_transform(X_trans) + assert_array_almost_equal(X, X_trans_inv) + # make sure new data gets transformed correctly + X_new = [[+0., 2., 0.5], + [-1., 1., 0.0], + [+0., 1., 1.5]] + X_trans_new = scaler.transform(X_new) + X_expected_new = [[+0., 2.0, 1.0 / 3.0], + [-1., 1.0, 0.0], + [+0., 1.0, 1.0]] + assert_array_almost_equal(X_trans_new, X_expected_new, decimal=2) +def test_maxabs_scaler_large_negative_value(): + """Check MaxAbsScaler on toy data with a large negative value""" + X = [[0., 1., +0.5, -1.0], + [0., 1., -0.3, -0.5], + [0., 1., -100.0, 0.0], + [0., 0., +0.0, -2.0]] + scaler = MaxAbsScaler() + X_trans = scaler.fit_transform(X) + X_expected = [[0., 1., 0.005, -0.5], + [0., 1., -0.003, -0.25], + [0., 1., -1.0, 0.0], + [0., 0., 0.0, -1.0]] + assert_array_almost_equal(X_trans, X_expected) def toarray(a): @@ -899,6 +942,7 @@ def test_one_hot_encoder_sparse(): enc.fit([[0], [1]]) assert_raises(ValueError, enc.transform, [[0], [-1]]) + def test_one_hot_encoder_dense(): """check for sparse=False""" X = [[3, 2, 1], [0, 1, 1]]