diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index f217d7848c9b2..5a1d5fcda6d6e 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -656,7 +656,8 @@ Kernels: impute.SimpleImputer impute.ChainedImputer - + impute.MissingIndicator + .. _kernel_approximation_ref: :mod:`sklearn.kernel_approximation` Kernel Approximation diff --git a/doc/modules/impute.rst b/doc/modules/impute.rst index 0f9089c981782..84c8538f1f05b 100644 --- a/doc/modules/impute.rst +++ b/doc/modules/impute.rst @@ -121,7 +121,6 @@ Both :class:`SimpleImputer` and :class:`ChainedImputer` can be used in a Pipelin as a way to build a composite estimator that supports imputation. See :ref:`sphx_glr_auto_examples_plot_missing_values.py`. - .. _multiple_imputation: Multiple vs. Single Imputation @@ -142,3 +141,49 @@ random seeds with the ``n_imputations`` parameter set to 1. Note that a call to the ``transform`` method of :class:`ChainedImputer` is not allowed to change the number of samples. Therefore multiple imputations cannot be achieved by a single call to ``transform``. + +.. _missing_indicator: + +Marking imputed values +====================== + +The :class:`MissingIndicator` transformer is useful to transform a dataset into +corresponding binary matrix indicating the presence of missing values in the +dataset. This transformation is useful in conjunction with imputation. When +using imputation, preserving the information about which values had been +missing can be informative. + +``NaN`` is usually used as the placeholder for missing values. However, it +enforces the data type to be float. The parameter ``missing_values`` allows to +specify other placeholder such as integer. In the following example, we will +use ``-1`` as missing values:: + + >>> from sklearn.impute import MissingIndicator + >>> X = np.array([[-1, -1, 1, 3], + ... [4, -1, 0, -1], + ... [8, -1, 1, 0]]) + >>> indicator = MissingIndicator(missing_values=-1) + >>> mask_missing_values_only = indicator.fit_transform(X) + >>> mask_missing_values_only + array([[ True, True, False], + [False, True, True], + [False, True, False]]) + +The ``features`` parameter is used to choose the features for which the mask is +constructed. By default, it is ``'missing-only'`` which returns the imputer +mask of the features containing missing values at ``fit`` time:: + + >>> indicator.features_ + array([0, 1, 3]) + +The ``features`` parameter can be set to ``'all'`` to returned all features +whether or not they contain missing values:: + + >>> indicator = MissingIndicator(missing_values=-1, features="all") + >>> mask_all = indicator.fit_transform(X) + >>> mask_all + array([[ True, True, False, False], + [False, True, False, True], + [False, True, False, False]]) + >>> indicator.features_ + array([0, 1, 2, 3]) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index d9aa052d740db..cfb7835397e38 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -149,6 +149,10 @@ Preprocessing back to the original space via an inverse transform. :issue:`9041` by `Andreas Müller`_ and :user:`Guillaume Lemaitre `. +- Added :class:`MissingIndicator` which generates a binary indicator for + missing values. :issue:`8075` by :user:`Maniteja Nandana ` and + :user:`Guillaume Lemaitre `. + - Added :class:`impute.ChainedImputer`, which is a strategy for imputing missing values by modeling each feature with missing values as a function of other features in a round-robin fashion. :issue:`8478` by diff --git a/examples/plot_missing_values.py b/examples/plot_missing_values.py index d238a16592edb..8cd20087dfb0f 100644 --- a/examples/plot_missing_values.py +++ b/examples/plot_missing_values.py @@ -4,15 +4,19 @@ ==================================================== Missing values can be replaced by the mean, the median or the most frequent -value using the basic ``SimpleImputer``. +value using the basic :func:`sklearn.impute.SimpleImputer`. The median is a more robust estimator for data with high magnitude variables which could dominate results (otherwise known as a 'long tail'). -Another option is the ``ChainedImputer``. This uses round-robin linear -regression, treating every variable as an output in turn. The version -implemented assumes Gaussian (output) variables. If your features are obviously -non-Normal, consider transforming them to look more Normal so as to improve -performance. +Another option is the :func:`sklearn.impute.ChainedImputer`. This uses +round-robin linear regression, treating every variable as an output in +turn. The version implemented assumes Gaussian (output) variables. If your +features are obviously non-Normal, consider transforming them to look more +Normal so as to improve performance. + +In addition of using an imputing method, we can also keep an indication of the +missing information using :func:`sklearn.impute.MissingIndicator` which might +carry some information. """ import numpy as np @@ -21,8 +25,8 @@ from sklearn.datasets import load_diabetes from sklearn.datasets import load_boston from sklearn.ensemble import RandomForestRegressor -from sklearn.pipeline import Pipeline -from sklearn.impute import SimpleImputer, ChainedImputer +from sklearn.pipeline import make_pipeline, make_union +from sklearn.impute import SimpleImputer, ChainedImputer, MissingIndicator from sklearn.model_selection import cross_val_score rng = np.random.RandomState(0) @@ -60,18 +64,18 @@ def get_results(dataset): X_missing = X_full.copy() X_missing[np.where(missing_samples)[0], missing_features] = 0 y_missing = y_full.copy() - estimator = Pipeline([("imputer", SimpleImputer(missing_values=0, - strategy="mean")), - ("forest", RandomForestRegressor(random_state=0, - n_estimators=100))]) + estimator = make_pipeline( + make_union(SimpleImputer(missing_values=0, strategy="mean"), + MissingIndicator(missing_values=0)), + RandomForestRegressor(random_state=0, n_estimators=100)) mean_impute_scores = cross_val_score(estimator, X_missing, y_missing, scoring='neg_mean_squared_error') # Estimate the score after chained imputation of the missing values - estimator = Pipeline([("imputer", ChainedImputer(missing_values=0, - random_state=0)), - ("forest", RandomForestRegressor(random_state=0, - n_estimators=100))]) + estimator = make_pipeline( + make_union(ChainedImputer(missing_values=0, random_state=0), + MissingIndicator(missing_values=0)), + RandomForestRegressor(random_state=0, n_estimators=100)) chained_impute_scores = cross_val_score(estimator, X_missing, y_missing, scoring='neg_mean_squared_error') diff --git a/sklearn/impute.py b/sklearn/impute.py index 72dd1ac5c24ca..fec9d8b0d7a8d 100644 --- a/sklearn/impute.py +++ b/sklearn/impute.py @@ -35,6 +35,7 @@ 'predictor']) __all__ = [ + 'MissingIndicator', 'SimpleImputer', 'ChainedImputer', ] @@ -975,3 +976,225 @@ def fit(self, X, y=None): """ self.fit_transform(X) return self + + +class MissingIndicator(BaseEstimator, TransformerMixin): + """Binary indicators for missing values. + + Parameters + ---------- + missing_values : number, string, np.nan (default) or None + The placeholder for the missing values. All occurrences of + `missing_values` will be imputed. + + features : str, optional + Whether the imputer mask should represent all or a subset of + features. + + - If "missing-only" (default), the imputer mask will only represent + features containing missing values during fit time. + - If "all", the imputer mask will represent all features. + + sparse : boolean or "auto", optional + Whether the imputer mask format should be sparse or dense. + + - If "auto" (default), the imputer mask will be of same type as + input. + - If True, the imputer mask will be a sparse matrix. + - If False, the imputer mask will be a numpy array. + + error_on_new : boolean, optional + If True (default), transform will raise an error when there are + features with missing values in transform that have no missing values + in fit This is applicable only when ``features="missing-only"``. + + Attributes + ---------- + features_ : ndarray, shape (n_missing_features,) or (n_features,) + The features indices which will be returned when calling ``transform``. + They are computed during ``fit``. For ``features='all'``, it is + to ``range(n_features)``. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.impute import MissingIndicator + >>> X1 = np.array([[np.nan, 1, 3], + ... [4, 0, np.nan], + ... [8, 1, 0]]) + >>> X2 = np.array([[5, 1, np.nan], + ... [np.nan, 2, 3], + ... [2, 4, 0]]) + >>> indicator = MissingIndicator() + >>> indicator.fit(X1) + MissingIndicator(error_on_new=True, features='missing-only', + missing_values=nan, sparse='auto') + >>> X2_tr = indicator.transform(X2) + >>> X2_tr + array([[False, True], + [ True, False], + [False, False]]) + + """ + + def __init__(self, missing_values=np.nan, features="missing-only", + sparse="auto", error_on_new=True): + self.missing_values = missing_values + self.features = features + self.sparse = sparse + self.error_on_new = error_on_new + + def _get_missing_features_info(self, X): + """Compute the imputer mask and the indices of the features + containing missing values. + + Parameters + ---------- + X : {ndarray or sparse matrix}, shape (n_samples, n_features) + The input data with missing values. Note that ``X`` has been + checked in ``fit`` and ``transform`` before to call this function. + + Returns + ------- + imputer_mask : {ndarray or sparse matrix}, shape \ +(n_samples, n_features) or (n_samples, n_features_with_missing) + The imputer mask of the original data. + + features_with_missing : ndarray, shape (n_features_with_missing) + The features containing missing values. + + """ + if sparse.issparse(X) and self.missing_values != 0: + mask = _get_mask(X.data, self.missing_values) + + # The imputer mask will be constructed with the same sparse format + # as X. + sparse_constructor = (sparse.csr_matrix if X.format == 'csr' + else sparse.csc_matrix) + imputer_mask = sparse_constructor( + (mask, X.indices.copy(), X.indptr.copy()), + shape=X.shape, dtype=bool) + + missing_values_mask = imputer_mask.copy() + missing_values_mask.eliminate_zeros() + features_with_missing = ( + np.flatnonzero(np.diff(missing_values_mask.indptr)) + if missing_values_mask.format == 'csc' + else np.unique(missing_values_mask.indices)) + + if self.sparse is False: + imputer_mask = imputer_mask.toarray() + elif imputer_mask.format == 'csr': + imputer_mask = imputer_mask.tocsc() + else: + if sparse.issparse(X): + # case of sparse matrix with 0 as missing values. Implicit and + # explicit zeros are considered as missing values. + X = X.toarray() + imputer_mask = _get_mask(X, self.missing_values) + features_with_missing = np.flatnonzero(imputer_mask.sum(axis=0)) + + if self.sparse is True: + imputer_mask = sparse.csc_matrix(imputer_mask) + + return imputer_mask, features_with_missing + + def fit(self, X, y=None): + """Fit the transformer on X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Input data, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. + + Returns + ------- + self : object + Returns self. + """ + if not is_scalar_nan(self.missing_values): + force_all_finite = True + else: + force_all_finite = "allow-nan" + X = check_array(X, accept_sparse=('csc', 'csr'), + force_all_finite=force_all_finite) + _check_inputs_dtype(X, self.missing_values) + + self._n_features = X.shape[1] + + if self.features not in ('missing-only', 'all'): + raise ValueError("'features' has to be either 'missing-only' or " + "'all'. Got {} instead.".format(self.features)) + + if not ((isinstance(self.sparse, six.string_types) and + self.sparse == "auto") or isinstance(self.sparse, bool)): + raise ValueError("'sparse' has to be a boolean or 'auto'. " + "Got {!r} instead.".format(self.sparse)) + + self.features_ = (self._get_missing_features_info(X)[1] + if self.features == 'missing-only' + else np.arange(self._n_features)) + + return self + + def transform(self, X): + """Generate missing values indicator for X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + The input data to complete. + + Returns + ------- + Xt : {ndarray or sparse matrix}, shape (n_samples, n_features) + The missing indicator for input data. The data type of ``Xt`` + will be boolean. + + """ + check_is_fitted(self, "features_") + + if not is_scalar_nan(self.missing_values): + force_all_finite = True + else: + force_all_finite = "allow-nan" + X = check_array(X, accept_sparse=('csc', 'csr'), + force_all_finite=force_all_finite) + _check_inputs_dtype(X, self.missing_values) + + if X.shape[1] != self._n_features: + raise ValueError("X has a different number of features " + "than during fitting.") + + imputer_mask, features = self._get_missing_features_info(X) + + if self.features == "missing-only": + features_diff_fit_trans = np.setdiff1d(features, self.features_) + if (self.error_on_new and features_diff_fit_trans.size > 0): + raise ValueError("The features {} have missing values " + "in transform but have no missing values " + "in fit.".format(features_diff_fit_trans)) + + if (self.features_.size > 0 and + self.features_.size < self._n_features): + imputer_mask = imputer_mask[:, self.features_] + + return imputer_mask + + def fit_transform(self, X, y=None): + """Generate missing values indicator for X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + The input data to complete. + + Returns + ------- + Xt : {ndarray or sparse matrix}, shape (n_samples, n_features) + The missing indicator for input data. The data type of ``Xt`` + will be boolean. + + """ + return self.fit(X, y).transform(X) diff --git a/sklearn/tests/test_impute.py b/sklearn/tests/test_impute.py index b286c5006d431..7fb1b0ac3280b 100644 --- a/sklearn/tests/test_impute.py +++ b/sklearn/tests/test_impute.py @@ -13,6 +13,7 @@ from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_false +from sklearn.impute import MissingIndicator from sklearn.impute import SimpleImputer, ChainedImputer from sklearn.dummy import DummyRegressor from sklearn.linear_model import BayesianRidge, ARDRegression @@ -707,6 +708,121 @@ def test_chained_imputer_additive_matrix(): assert_allclose(X_test_filled, X_test_est, atol=0.01) +@pytest.mark.parametrize( + "X_fit, X_trans, params, msg_err", + [(np.array([[-1, 1], [1, 2]]), np.array([[-1, 1], [1, -1]]), + {'features': 'missing-only', 'sparse': 'auto'}, + 'have missing values in transform but have no missing values in fit'), + (np.array([[-1, 1], [1, 2]]), np.array([[-1, 1], [1, 2]]), + {'features': 'random', 'sparse': 'auto'}, + "'features' has to be either 'missing-only' or 'all'"), + (np.array([[-1, 1], [1, 2]]), np.array([[-1, 1], [1, 2]]), + {'features': 'all', 'sparse': 'random'}, + "'sparse' has to be a boolean or 'auto'")] +) +def test_missing_indicator_error(X_fit, X_trans, params, msg_err): + indicator = MissingIndicator(missing_values=-1) + indicator.set_params(**params) + with pytest.raises(ValueError, match=msg_err): + indicator.fit(X_fit).transform(X_trans) + + +@pytest.mark.parametrize( + "missing_values, dtype", + [(np.nan, np.float64), + (0, np.int32), + (-1, np.int32)]) +@pytest.mark.parametrize( + "arr_type", + [np.array, sparse.csc_matrix, sparse.csr_matrix, sparse.coo_matrix, + sparse.lil_matrix, sparse.bsr_matrix]) +@pytest.mark.parametrize( + "param_features, n_features, features_indices", + [('missing-only', 2, np.array([0, 1])), + ('all', 3, np.array([0, 1, 2]))]) +def test_missing_indicator_new(missing_values, arr_type, dtype, param_features, + n_features, features_indices): + X_fit = np.array([[missing_values, missing_values, 1], + [4, missing_values, 2]]) + X_trans = np.array([[missing_values, missing_values, 1], + [4, 12, 10]]) + X_fit_expected = np.array([[1, 1, 0], [0, 1, 0]]) + X_trans_expected = np.array([[1, 1, 0], [0, 0, 0]]) + + # convert the input to the right array format and right dtype + X_fit = arr_type(X_fit).astype(dtype) + X_trans = arr_type(X_trans).astype(dtype) + X_fit_expected = X_fit_expected.astype(dtype) + X_trans_expected = X_trans_expected.astype(dtype) + + indicator = MissingIndicator(missing_values=missing_values, + features=param_features, + sparse=False) + X_fit_mask = indicator.fit_transform(X_fit) + X_trans_mask = indicator.transform(X_trans) + + assert X_fit_mask.shape[1] == n_features + assert X_trans_mask.shape[1] == n_features + + assert_array_equal(indicator.features_, features_indices) + assert_allclose(X_fit_mask, X_fit_expected[:, features_indices]) + assert_allclose(X_trans_mask, X_trans_expected[:, features_indices]) + + assert X_fit_mask.dtype == bool + assert X_trans_mask.dtype == bool + assert isinstance(X_fit_mask, np.ndarray) + assert isinstance(X_trans_mask, np.ndarray) + + indicator.set_params(sparse=True) + X_fit_mask_sparse = indicator.fit_transform(X_fit) + X_trans_mask_sparse = indicator.transform(X_trans) + + assert X_fit_mask_sparse.dtype == bool + assert X_trans_mask_sparse.dtype == bool + assert X_fit_mask_sparse.format == 'csc' + assert X_trans_mask_sparse.format == 'csc' + assert_allclose(X_fit_mask_sparse.toarray(), X_fit_mask) + assert_allclose(X_trans_mask_sparse.toarray(), X_trans_mask) + + +@pytest.mark.parametrize("param_sparse", [True, False, 'auto']) +@pytest.mark.parametrize("missing_values", [np.nan, 0]) +@pytest.mark.parametrize( + "arr_type", + [np.array, sparse.csc_matrix, sparse.csr_matrix, sparse.coo_matrix]) +def test_missing_indicator_sparse_param(arr_type, missing_values, + param_sparse): + # check the format of the output with different sparse parameter + X_fit = np.array([[missing_values, missing_values, 1], + [4, missing_values, 2]]) + X_trans = np.array([[missing_values, missing_values, 1], + [4, 12, 10]]) + X_fit = arr_type(X_fit).astype(np.float64) + X_trans = arr_type(X_trans).astype(np.float64) + + indicator = MissingIndicator(missing_values=missing_values, + sparse=param_sparse) + X_fit_mask = indicator.fit_transform(X_fit) + X_trans_mask = indicator.transform(X_trans) + + if param_sparse is True: + assert X_fit_mask.format == 'csc' + assert X_trans_mask.format == 'csc' + elif param_sparse == 'auto' and missing_values == 0: + assert isinstance(X_fit_mask, np.ndarray) + assert isinstance(X_trans_mask, np.ndarray) + elif param_sparse is False: + assert isinstance(X_fit_mask, np.ndarray) + assert isinstance(X_trans_mask, np.ndarray) + else: + if sparse.issparse(X_fit): + assert X_fit_mask.format == 'csc' + assert X_trans_mask.format == 'csc' + else: + assert isinstance(X_fit_mask, np.ndarray) + assert isinstance(X_trans_mask, np.ndarray) + + @pytest.mark.parametrize("imputer_constructor", [SimpleImputer, ChainedImputer]) @pytest.mark.parametrize( diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index d25abbe6377db..65112ad9d382e 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -77,7 +77,7 @@ 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] -ALLOW_NAN = ['Imputer', 'SimpleImputer', 'ChainedImputer', +ALLOW_NAN = ['Imputer', 'SimpleImputer', 'ChainedImputer', 'MissingIndicator', 'MaxAbsScaler', 'MinMaxScaler', 'RobustScaler', 'StandardScaler', 'PowerTransformer', 'QuantileTransformer']