diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 0e84202f876e4..9f53afd433ffc 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -56,6 +56,17 @@ Changed models signs across all `PCA` solvers, including the new `svd_solver="covariance_eigh"` option introduced in this release. +Changes impacting many modules +------------------------------ + +- |API| The name of the input of the `inverse_transform` method of estimators has been + standardized to `X`. As a consequence, `Xt` is deprecated and will be removed in + version 1.7 in the following estimators: :class:`cluster.FeatureAgglomeration`, + :class:`decomposition.MiniBatchNMF`, :class:`decomposition.NMF`, + :class:`model_selection.GridSearchCV`, :class:`model_selection.RandomizedSearchCV`, + :class:`pipeline.Pipeline` and :class:`preprocessing.KBinsDiscretizer`. + :pr:`28756` by :user:`Will Dean `. + Support for Array API --------------------- diff --git a/sklearn/cluster/_feature_agglomeration.py b/sklearn/cluster/_feature_agglomeration.py index 218db48ad2331..c91952061a6f6 100644 --- a/sklearn/cluster/_feature_agglomeration.py +++ b/sklearn/cluster/_feature_agglomeration.py @@ -6,13 +6,13 @@ # Author: V. Michel, A. Gramfort # License: BSD 3 clause -import warnings import numpy as np from scipy.sparse import issparse from ..base import TransformerMixin from ..utils import metadata_routing +from ..utils.deprecation import _deprecate_Xt_in_inverse_transform from ..utils.validation import check_is_fitted ############################################################################### @@ -25,9 +25,9 @@ class AgglomerationTransform(TransformerMixin): """ # This prevents ``set_split_inverse_transform`` to be generated for the - # non-standard ``Xred`` arg on ``inverse_transform``. - # TODO(1.5): remove when Xred is removed for inverse_transform. - __metadata_request__inverse_transform = {"Xred": metadata_routing.UNUSED} + # non-standard ``Xt`` arg on ``inverse_transform``. + # TODO(1.7): remove when Xt is removed for inverse_transform. + __metadata_request__inverse_transform = {"Xt": metadata_routing.UNUSED} def transform(self, X): """ @@ -63,19 +63,20 @@ def transform(self, X): nX = np.array(nX).T return nX - def inverse_transform(self, Xt=None, Xred=None): + def inverse_transform(self, X=None, *, Xt=None): """ Inverse the transformation and return a vector of size `n_features`. Parameters ---------- - Xt : array-like of shape (n_samples, n_clusters) or (n_clusters,) + X : array-like of shape (n_samples, n_clusters) or (n_clusters,) The values to be assigned to each cluster of samples. - Xred : deprecated - Use `Xt` instead. + Xt : array-like of shape (n_samples, n_clusters) or (n_clusters,) + The values to be assigned to each cluster of samples. - .. deprecated:: 1.3 + .. deprecated:: 1.5 + `Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead. Returns ------- @@ -83,23 +84,9 @@ def inverse_transform(self, Xt=None, Xred=None): A vector of size `n_samples` with the values of `Xred` assigned to each of the cluster of samples. """ - if Xt is None and Xred is None: - raise TypeError("Missing required positional argument: Xt") - - if Xred is not None and Xt is not None: - raise ValueError("Please provide only `Xt`, and not `Xred`.") - - if Xred is not None: - warnings.warn( - ( - "Input argument `Xred` was renamed to `Xt` in v1.3 and will be" - " removed in v1.5." - ), - FutureWarning, - ) - Xt = Xred + X = _deprecate_Xt_in_inverse_transform(X, Xt) check_is_fitted(self) unil, inverse = np.unique(self.labels_, return_inverse=True) - return Xt[..., inverse] + return X[..., inverse] diff --git a/sklearn/cluster/tests/test_feature_agglomeration.py b/sklearn/cluster/tests/test_feature_agglomeration.py index abeb81dca50aa..488dd638ad125 100644 --- a/sklearn/cluster/tests/test_feature_agglomeration.py +++ b/sklearn/cluster/tests/test_feature_agglomeration.py @@ -59,23 +59,23 @@ def test_feature_agglomeration_feature_names_out(): ) -# TODO(1.5): remove this test -def test_inverse_transform_Xred_deprecation(): +# TODO(1.7): remove this test +def test_inverse_transform_Xt_deprecation(): X = np.array([0, 0, 1]).reshape(1, 3) # (n_samples, n_features) est = FeatureAgglomeration(n_clusters=1, pooling_func=np.mean) est.fit(X) - Xt = est.transform(X) + X = est.transform(X) with pytest.raises(TypeError, match="Missing required positional argument"): est.inverse_transform() - with pytest.raises(ValueError, match="Please provide only"): - est.inverse_transform(Xt=Xt, Xred=Xt) + with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only."): + est.inverse_transform(X=X, Xt=X) with warnings.catch_warnings(record=True): warnings.simplefilter("error") - est.inverse_transform(Xt) + est.inverse_transform(X) - with pytest.warns(FutureWarning, match="Input argument `Xred` was renamed to `Xt`"): - est.inverse_transform(Xred=Xt) + with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"): + est.inverse_transform(Xt=X) diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 75266c5f64b2b..30725c33f4df3 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -32,6 +32,7 @@ StrOptions, validate_params, ) +from ..utils.deprecation import _deprecate_Xt_in_inverse_transform from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm from ..utils.validation import ( check_is_fitted, @@ -1310,44 +1311,32 @@ def fit(self, X, y=None, **params): self.fit_transform(X, **params) return self - def inverse_transform(self, Xt=None, W=None): + def inverse_transform(self, X=None, *, Xt=None): """Transform data back to its original space. .. versionadded:: 0.18 Parameters ---------- - Xt : {ndarray, sparse matrix} of shape (n_samples, n_components) + X : {ndarray, sparse matrix} of shape (n_samples, n_components) Transformed data matrix. - W : deprecated - Use `Xt` instead. + Xt : {ndarray, sparse matrix} of shape (n_samples, n_components) + Transformed data matrix. - .. deprecated:: 1.3 + .. deprecated:: 1.5 + `Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead. Returns ------- X : ndarray of shape (n_samples, n_features) Returns a data matrix of the original shape. """ - if Xt is None and W is None: - raise TypeError("Missing required positional argument: Xt") - if W is not None and Xt is not None: - raise ValueError("Please provide only `Xt`, and not `W`.") - - if W is not None: - warnings.warn( - ( - "Input argument `W` was renamed to `Xt` in v1.3 and will be removed" - " in v1.5." - ), - FutureWarning, - ) - Xt = W + X = _deprecate_Xt_in_inverse_transform(X, Xt) check_is_fitted(self) - return Xt @ self.components_ + return X @ self.components_ @property def _n_features_out(self): diff --git a/sklearn/decomposition/tests/test_nmf.py b/sklearn/decomposition/tests/test_nmf.py index 2112b59129e25..b6eb4f9b1becc 100644 --- a/sklearn/decomposition/tests/test_nmf.py +++ b/sklearn/decomposition/tests/test_nmf.py @@ -933,30 +933,31 @@ def test_minibatch_nmf_verbose(): sys.stdout = old_stdout -# TODO(1.5): remove this test -def test_NMF_inverse_transform_W_deprecation(): - rng = np.random.mtrand.RandomState(42) +# TODO(1.7): remove this test +@pytest.mark.parametrize("Estimator", [NMF, MiniBatchNMF]) +def test_NMF_inverse_transform_Xt_deprecation(Estimator): + rng = np.random.RandomState(42) A = np.abs(rng.randn(6, 5)) - est = NMF( + est = Estimator( n_components=3, init="random", random_state=0, tol=1e-6, ) - Xt = est.fit_transform(A) + X = est.fit_transform(A) with pytest.raises(TypeError, match="Missing required positional argument"): est.inverse_transform() - with pytest.raises(ValueError, match="Please provide only"): - est.inverse_transform(Xt=Xt, W=Xt) + with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"): + est.inverse_transform(X=X, Xt=X) with warnings.catch_warnings(record=True): warnings.simplefilter("error") - est.inverse_transform(Xt) + est.inverse_transform(X) - with pytest.warns(FutureWarning, match="Input argument `W` was renamed to `Xt`"): - est.inverse_transform(W=Xt) + with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"): + est.inverse_transform(Xt=X) @pytest.mark.parametrize("Estimator", [NMF, MiniBatchNMF]) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 42fde09c16bce..a26ec0786849d 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -36,6 +36,7 @@ from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import HasMethods, Interval, StrOptions from ..utils._tags import _safe_tags +from ..utils.deprecation import _deprecate_Xt_in_inverse_transform from ..utils.metadata_routing import ( MetadataRouter, MethodMapping, @@ -637,7 +638,7 @@ def transform(self, X): return self.best_estimator_.transform(X) @available_if(_estimator_has("inverse_transform")) - def inverse_transform(self, Xt): + def inverse_transform(self, X=None, Xt=None): """Call inverse_transform on the estimator with the best found params. Only available if the underlying estimator implements @@ -645,18 +646,26 @@ def inverse_transform(self, Xt): Parameters ---------- + X : indexable, length n_samples + Must fulfill the input assumptions of the + underlying estimator. + Xt : indexable, length n_samples Must fulfill the input assumptions of the underlying estimator. + .. deprecated:: 1.5 + `Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead. + Returns ------- X : {ndarray, sparse matrix} of shape (n_samples, n_features) Result of the `inverse_transform` function for `Xt` based on the estimator with the best found parameters. """ + X = _deprecate_Xt_in_inverse_transform(X, Xt) check_is_fitted(self) - return self.best_estimator_.inverse_transform(Xt) + return self.best_estimator_.inverse_transform(X) @property def n_features_in_(self): diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 9eb647df887c0..b59ed7168ff10 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -3,6 +3,7 @@ import pickle import re import sys +import warnings from collections.abc import Iterable, Sized from functools import partial from io import StringIO @@ -2553,6 +2554,28 @@ def test_search_html_repr(): assert "
LogisticRegression()
" in repr_html +# TODO(1.7): remove this test +@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV]) +def test_inverse_transform_Xt_deprecation(SearchCV): + clf = MockClassifier() + search = SearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3) + + X2 = search.fit(X, y).transform(X) + + with pytest.raises(TypeError, match="Missing required positional argument"): + search.inverse_transform() + + with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"): + search.inverse_transform(X=X2, Xt=X2) + + with warnings.catch_warnings(record=True): + warnings.simplefilter("error") + search.inverse_transform(X2) + + with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"): + search.inverse_transform(Xt=X2) + + # Metadata Routing Tests # ====================== diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 93f9ef09fc40a..b200177b8606f 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -29,6 +29,7 @@ ) from .utils._tags import _safe_tags from .utils._user_interface import _print_elapsed_time +from .utils.deprecation import _deprecate_Xt_in_inverse_transform from .utils.metadata_routing import ( MetadataRouter, MethodMapping, @@ -909,19 +910,28 @@ def _can_inverse_transform(self): return all(hasattr(t, "inverse_transform") for _, _, t in self._iter()) @available_if(_can_inverse_transform) - def inverse_transform(self, Xt, **params): + def inverse_transform(self, X=None, *, Xt=None, **params): """Apply `inverse_transform` for each step in a reverse order. All estimators in the pipeline must support `inverse_transform`. Parameters ---------- + X : array-like of shape (n_samples, n_transformed_features) + Data samples, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. Must fulfill + input requirements of last step of pipeline's + ``inverse_transform`` method. + Xt : array-like of shape (n_samples, n_transformed_features) Data samples, where ``n_samples`` is the number of samples and ``n_features`` is the number of features. Must fulfill input requirements of last step of pipeline's ``inverse_transform`` method. + .. deprecated:: 1.5 + `Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead. + **params : dict of str -> object Parameters requested and accepted by steps. Each step must have requested certain metadata for these parameters to be forwarded to @@ -940,15 +950,15 @@ def inverse_transform(self, Xt, **params): """ _raise_for_params(params, self, "inverse_transform") + X = _deprecate_Xt_in_inverse_transform(X, Xt) + # we don't have to branch here, since params is only non-empty if # enable_metadata_routing=True. routed_params = process_routing(self, "inverse_transform", **params) reverse_iter = reversed(list(self._iter())) for _, name, transform in reverse_iter: - Xt = transform.inverse_transform( - Xt, **routed_params[name].inverse_transform - ) - return Xt + X = transform.inverse_transform(X, **routed_params[name].inverse_transform) + return X @available_if(_final_estimator_has("score")) def score(self, X, y=None, sample_weight=None, **params): diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index 02d144b87f798..ee8a336a75453 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -12,6 +12,7 @@ from ..base import BaseEstimator, TransformerMixin, _fit_context from ..utils import resample from ..utils._param_validation import Interval, Options, StrOptions +from ..utils.deprecation import _deprecate_Xt_in_inverse_transform from ..utils.stats import _weighted_percentile from ..utils.validation import ( _check_feature_names_in, @@ -389,7 +390,7 @@ def transform(self, X): self._encoder.dtype = dtype_init return Xt_enc - def inverse_transform(self, Xt): + def inverse_transform(self, X=None, *, Xt=None): """ Transform discretized data back to original feature space. @@ -398,20 +399,28 @@ def inverse_transform(self, Xt): Parameters ---------- + X : array-like of shape (n_samples, n_features) + Transformed data in the binned space. + Xt : array-like of shape (n_samples, n_features) Transformed data in the binned space. + .. deprecated:: 1.5 + `Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead. + Returns ------- Xinv : ndarray, dtype={np.float32, np.float64} Data in the original feature space. """ + X = _deprecate_Xt_in_inverse_transform(X, Xt) + check_is_fitted(self) if "onehot" in self.encode: - Xt = self._encoder.inverse_transform(Xt) + X = self._encoder.inverse_transform(X) - Xinv = check_array(Xt, copy=True, dtype=(np.float64, np.float32)) + Xinv = check_array(X, copy=True, dtype=(np.float64, np.float32)) n_features = self.n_bins_.shape[0] if Xinv.shape[1] != n_features: raise ValueError( diff --git a/sklearn/preprocessing/tests/test_discretization.py b/sklearn/preprocessing/tests/test_discretization.py index 19aaa5bdba850..fd16a3db3efac 100644 --- a/sklearn/preprocessing/tests/test_discretization.py +++ b/sklearn/preprocessing/tests/test_discretization.py @@ -478,3 +478,23 @@ def test_kbinsdiscretizer_subsample(strategy, global_random_seed): assert_allclose( kbd_subsampling.bin_edges_[0], kbd_no_subsampling.bin_edges_[0], rtol=1e-2 ) + + +# TODO(1.7): remove this test +def test_KBD_inverse_transform_Xt_deprecation(): + X = np.arange(10)[:, None] + kbd = KBinsDiscretizer() + X = kbd.fit_transform(X) + + with pytest.raises(TypeError, match="Missing required positional argument"): + kbd.inverse_transform() + + with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"): + kbd.inverse_transform(X=X, Xt=X) + + with warnings.catch_warnings(record=True): + warnings.simplefilter("error") + kbd.inverse_transform(X) + + with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"): + kbd.inverse_transform(Xt=X) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 1d4cfb3dd6e2b..c7f0afe642a65 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -6,6 +6,7 @@ import re import shutil import time +import warnings from tempfile import mkdtemp import joblib @@ -1792,6 +1793,26 @@ def test_feature_union_feature_names_in_(): assert not hasattr(union, "feature_names_in_") +# TODO(1.7): remove this test +def test_pipeline_inverse_transform_Xt_deprecation(): + X = np.random.RandomState(0).normal(size=(10, 5)) + pipe = Pipeline([("pca", PCA(n_components=2))]) + X = pipe.fit_transform(X) + + with pytest.raises(TypeError, match="Missing required positional argument"): + pipe.inverse_transform() + + with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"): + pipe.inverse_transform(X=X, Xt=X) + + with warnings.catch_warnings(record=True): + warnings.simplefilter("error") + pipe.inverse_transform(X) + + with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"): + pipe.inverse_transform(Xt=X) + + # Test that metadata is routed correctly for pipelines and FeatureUnion # ===================================================================== diff --git a/sklearn/utils/deprecation.py b/sklearn/utils/deprecation.py index c46149d943431..a3225597701c7 100644 --- a/sklearn/utils/deprecation.py +++ b/sklearn/utils/deprecation.py @@ -114,3 +114,22 @@ def _is_deprecated(func): [c.cell_contents for c in closures if isinstance(c.cell_contents, str)] ) return is_deprecated + + +# TODO: remove in 1.7 +def _deprecate_Xt_in_inverse_transform(X, Xt): + """Helper to deprecate the `Xt` argument in favor of `X` in inverse_transform.""" + if X is not None and Xt is not None: + raise TypeError("Cannot use both X and Xt. Use X only.") + + if X is None and Xt is None: + raise TypeError("Missing required positional argument: X.") + + if Xt is not None: + warnings.warn( + "Xt was renamed X in version 1.5 and will be removed in 1.7.", + FutureWarning, + ) + return Xt + + return X