Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0bdc754

Browse files
API Standardize X as inverse_transform input parameter (#28756)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent 19c068f commit 0bdc754

File tree

12 files changed

+172
-73
lines changed

12 files changed

+172
-73
lines changed

doc/whats_new/v1.5.rst

+11
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ Changed models
5656
signs across all `PCA` solvers, including the new
5757
`svd_solver="covariance_eigh"` option introduced in this release.
5858

59+
Changes impacting many modules
60+
------------------------------
61+
62+
- |API| The name of the input of the `inverse_transform` method of estimators has been
63+
standardized to `X`. As a consequence, `Xt` is deprecated and will be removed in
64+
version 1.7 in the following estimators: :class:`cluster.FeatureAgglomeration`,
65+
:class:`decomposition.MiniBatchNMF`, :class:`decomposition.NMF`,
66+
:class:`model_selection.GridSearchCV`, :class:`model_selection.RandomizedSearchCV`,
67+
:class:`pipeline.Pipeline` and :class:`preprocessing.KBinsDiscretizer`.
68+
:pr:`28756` by :user:`Will Dean <wd60622>`.
69+
5970
Support for Array API
6071
---------------------
6172

sklearn/cluster/_feature_agglomeration.py

+12-25
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# Author: V. Michel, A. Gramfort
77
# License: BSD 3 clause
88

9-
import warnings
109

1110
import numpy as np
1211
from scipy.sparse import issparse
1312

1413
from ..base import TransformerMixin
1514
from ..utils import metadata_routing
15+
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform
1616
from ..utils.validation import check_is_fitted
1717

1818
###############################################################################
@@ -25,9 +25,9 @@ class AgglomerationTransform(TransformerMixin):
2525
"""
2626

2727
# This prevents ``set_split_inverse_transform`` to be generated for the
28-
# non-standard ``Xred`` arg on ``inverse_transform``.
29-
# TODO(1.5): remove when Xred is removed for inverse_transform.
30-
__metadata_request__inverse_transform = {"Xred": metadata_routing.UNUSED}
28+
# non-standard ``Xt`` arg on ``inverse_transform``.
29+
# TODO(1.7): remove when Xt is removed for inverse_transform.
30+
__metadata_request__inverse_transform = {"Xt": metadata_routing.UNUSED}
3131

3232
def transform(self, X):
3333
"""
@@ -63,43 +63,30 @@ def transform(self, X):
6363
nX = np.array(nX).T
6464
return nX
6565

66-
def inverse_transform(self, Xt=None, Xred=None):
66+
def inverse_transform(self, X=None, *, Xt=None):
6767
"""
6868
Inverse the transformation and return a vector of size `n_features`.
6969
7070
Parameters
7171
----------
72-
Xt : array-like of shape (n_samples, n_clusters) or (n_clusters,)
72+
X : array-like of shape (n_samples, n_clusters) or (n_clusters,)
7373
The values to be assigned to each cluster of samples.
7474
75-
Xred : deprecated
76-
Use `Xt` instead.
75+
Xt : array-like of shape (n_samples, n_clusters) or (n_clusters,)
76+
The values to be assigned to each cluster of samples.
7777
78-
.. deprecated:: 1.3
78+
.. deprecated:: 1.5
79+
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
7980
8081
Returns
8182
-------
8283
X : ndarray of shape (n_samples, n_features) or (n_features,)
8384
A vector of size `n_samples` with the values of `Xred` assigned to
8485
each of the cluster of samples.
8586
"""
86-
if Xt is None and Xred is None:
87-
raise TypeError("Missing required positional argument: Xt")
88-
89-
if Xred is not None and Xt is not None:
90-
raise ValueError("Please provide only `Xt`, and not `Xred`.")
91-
92-
if Xred is not None:
93-
warnings.warn(
94-
(
95-
"Input argument `Xred` was renamed to `Xt` in v1.3 and will be"
96-
" removed in v1.5."
97-
),
98-
FutureWarning,
99-
)
100-
Xt = Xred
87+
X = _deprecate_Xt_in_inverse_transform(X, Xt)
10188

10289
check_is_fitted(self)
10390

10491
unil, inverse = np.unique(self.labels_, return_inverse=True)
105-
return Xt[..., inverse]
92+
return X[..., inverse]

sklearn/cluster/tests/test_feature_agglomeration.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,23 @@ def test_feature_agglomeration_feature_names_out():
5959
)
6060

6161

62-
# TODO(1.5): remove this test
63-
def test_inverse_transform_Xred_deprecation():
62+
# TODO(1.7): remove this test
63+
def test_inverse_transform_Xt_deprecation():
6464
X = np.array([0, 0, 1]).reshape(1, 3) # (n_samples, n_features)
6565

6666
est = FeatureAgglomeration(n_clusters=1, pooling_func=np.mean)
6767
est.fit(X)
68-
Xt = est.transform(X)
68+
X = est.transform(X)
6969

7070
with pytest.raises(TypeError, match="Missing required positional argument"):
7171
est.inverse_transform()
7272

73-
with pytest.raises(ValueError, match="Please provide only"):
74-
est.inverse_transform(Xt=Xt, Xred=Xt)
73+
with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only."):
74+
est.inverse_transform(X=X, Xt=X)
7575

7676
with warnings.catch_warnings(record=True):
7777
warnings.simplefilter("error")
78-
est.inverse_transform(Xt)
78+
est.inverse_transform(X)
7979

80-
with pytest.warns(FutureWarning, match="Input argument `Xred` was renamed to `Xt`"):
81-
est.inverse_transform(Xred=Xt)
80+
with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"):
81+
est.inverse_transform(Xt=X)

sklearn/decomposition/_nmf.py

+9-20
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
StrOptions,
3333
validate_params,
3434
)
35+
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform
3536
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm
3637
from ..utils.validation import (
3738
check_is_fitted,
@@ -1310,44 +1311,32 @@ def fit(self, X, y=None, **params):
13101311
self.fit_transform(X, **params)
13111312
return self
13121313

1313-
def inverse_transform(self, Xt=None, W=None):
1314+
def inverse_transform(self, X=None, *, Xt=None):
13141315
"""Transform data back to its original space.
13151316
13161317
.. versionadded:: 0.18
13171318
13181319
Parameters
13191320
----------
1320-
Xt : {ndarray, sparse matrix} of shape (n_samples, n_components)
1321+
X : {ndarray, sparse matrix} of shape (n_samples, n_components)
13211322
Transformed data matrix.
13221323
1323-
W : deprecated
1324-
Use `Xt` instead.
1324+
Xt : {ndarray, sparse matrix} of shape (n_samples, n_components)
1325+
Transformed data matrix.
13251326
1326-
.. deprecated:: 1.3
1327+
.. deprecated:: 1.5
1328+
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
13271329
13281330
Returns
13291331
-------
13301332
X : ndarray of shape (n_samples, n_features)
13311333
Returns a data matrix of the original shape.
13321334
"""
1333-
if Xt is None and W is None:
1334-
raise TypeError("Missing required positional argument: Xt")
13351335

1336-
if W is not None and Xt is not None:
1337-
raise ValueError("Please provide only `Xt`, and not `W`.")
1338-
1339-
if W is not None:
1340-
warnings.warn(
1341-
(
1342-
"Input argument `W` was renamed to `Xt` in v1.3 and will be removed"
1343-
" in v1.5."
1344-
),
1345-
FutureWarning,
1346-
)
1347-
Xt = W
1336+
X = _deprecate_Xt_in_inverse_transform(X, Xt)
13481337

13491338
check_is_fitted(self)
1350-
return Xt @ self.components_
1339+
return X @ self.components_
13511340

13521341
@property
13531342
def _n_features_out(self):

sklearn/decomposition/tests/test_nmf.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -933,30 +933,31 @@ def test_minibatch_nmf_verbose():
933933
sys.stdout = old_stdout
934934

935935

936-
# TODO(1.5): remove this test
937-
def test_NMF_inverse_transform_W_deprecation():
938-
rng = np.random.mtrand.RandomState(42)
936+
# TODO(1.7): remove this test
937+
@pytest.mark.parametrize("Estimator", [NMF, MiniBatchNMF])
938+
def test_NMF_inverse_transform_Xt_deprecation(Estimator):
939+
rng = np.random.RandomState(42)
939940
A = np.abs(rng.randn(6, 5))
940-
est = NMF(
941+
est = Estimator(
941942
n_components=3,
942943
init="random",
943944
random_state=0,
944945
tol=1e-6,
945946
)
946-
Xt = est.fit_transform(A)
947+
X = est.fit_transform(A)
947948

948949
with pytest.raises(TypeError, match="Missing required positional argument"):
949950
est.inverse_transform()
950951

951-
with pytest.raises(ValueError, match="Please provide only"):
952-
est.inverse_transform(Xt=Xt, W=Xt)
952+
with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"):
953+
est.inverse_transform(X=X, Xt=X)
953954

954955
with warnings.catch_warnings(record=True):
955956
warnings.simplefilter("error")
956-
est.inverse_transform(Xt)
957+
est.inverse_transform(X)
957958

958-
with pytest.warns(FutureWarning, match="Input argument `W` was renamed to `Xt`"):
959-
est.inverse_transform(W=Xt)
959+
with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"):
960+
est.inverse_transform(Xt=X)
960961

961962

962963
@pytest.mark.parametrize("Estimator", [NMF, MiniBatchNMF])

sklearn/model_selection/_search.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ..utils._estimator_html_repr import _VisualBlock
3737
from ..utils._param_validation import HasMethods, Interval, StrOptions
3838
from ..utils._tags import _safe_tags
39+
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform
3940
from ..utils.metadata_routing import (
4041
MetadataRouter,
4142
MethodMapping,
@@ -637,26 +638,34 @@ def transform(self, X):
637638
return self.best_estimator_.transform(X)
638639

639640
@available_if(_estimator_has("inverse_transform"))
640-
def inverse_transform(self, Xt):
641+
def inverse_transform(self, X=None, Xt=None):
641642
"""Call inverse_transform on the estimator with the best found params.
642643
643644
Only available if the underlying estimator implements
644645
``inverse_transform`` and ``refit=True``.
645646
646647
Parameters
647648
----------
649+
X : indexable, length n_samples
650+
Must fulfill the input assumptions of the
651+
underlying estimator.
652+
648653
Xt : indexable, length n_samples
649654
Must fulfill the input assumptions of the
650655
underlying estimator.
651656
657+
.. deprecated:: 1.5
658+
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
659+
652660
Returns
653661
-------
654662
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
655663
Result of the `inverse_transform` function for `Xt` based on the
656664
estimator with the best found parameters.
657665
"""
666+
X = _deprecate_Xt_in_inverse_transform(X, Xt)
658667
check_is_fitted(self)
659-
return self.best_estimator_.inverse_transform(Xt)
668+
return self.best_estimator_.inverse_transform(X)
660669

661670
@property
662671
def n_features_in_(self):

sklearn/model_selection/tests/test_search.py

+23
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pickle
44
import re
55
import sys
6+
import warnings
67
from collections.abc import Iterable, Sized
78
from functools import partial
89
from io import StringIO
@@ -2553,6 +2554,28 @@ def test_search_html_repr():
25532554
assert "<pre>LogisticRegression()</pre>" in repr_html
25542555

25552556

2557+
# TODO(1.7): remove this test
2558+
@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
2559+
def test_inverse_transform_Xt_deprecation(SearchCV):
2560+
clf = MockClassifier()
2561+
search = SearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3)
2562+
2563+
X2 = search.fit(X, y).transform(X)
2564+
2565+
with pytest.raises(TypeError, match="Missing required positional argument"):
2566+
search.inverse_transform()
2567+
2568+
with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"):
2569+
search.inverse_transform(X=X2, Xt=X2)
2570+
2571+
with warnings.catch_warnings(record=True):
2572+
warnings.simplefilter("error")
2573+
search.inverse_transform(X2)
2574+
2575+
with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"):
2576+
search.inverse_transform(Xt=X2)
2577+
2578+
25562579
# Metadata Routing Tests
25572580
# ======================
25582581

sklearn/pipeline.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from .utils._tags import _safe_tags
3131
from .utils._user_interface import _print_elapsed_time
32+
from .utils.deprecation import _deprecate_Xt_in_inverse_transform
3233
from .utils.metadata_routing import (
3334
MetadataRouter,
3435
MethodMapping,
@@ -909,19 +910,28 @@ def _can_inverse_transform(self):
909910
return all(hasattr(t, "inverse_transform") for _, _, t in self._iter())
910911

911912
@available_if(_can_inverse_transform)
912-
def inverse_transform(self, Xt, **params):
913+
def inverse_transform(self, X=None, *, Xt=None, **params):
913914
"""Apply `inverse_transform` for each step in a reverse order.
914915
915916
All estimators in the pipeline must support `inverse_transform`.
916917
917918
Parameters
918919
----------
920+
X : array-like of shape (n_samples, n_transformed_features)
921+
Data samples, where ``n_samples`` is the number of samples and
922+
``n_features`` is the number of features. Must fulfill
923+
input requirements of last step of pipeline's
924+
``inverse_transform`` method.
925+
919926
Xt : array-like of shape (n_samples, n_transformed_features)
920927
Data samples, where ``n_samples`` is the number of samples and
921928
``n_features`` is the number of features. Must fulfill
922929
input requirements of last step of pipeline's
923930
``inverse_transform`` method.
924931
932+
.. deprecated:: 1.5
933+
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
934+
925935
**params : dict of str -> object
926936
Parameters requested and accepted by steps. Each step must have
927937
requested certain metadata for these parameters to be forwarded to
@@ -940,15 +950,15 @@ def inverse_transform(self, Xt, **params):
940950
"""
941951
_raise_for_params(params, self, "inverse_transform")
942952

953+
X = _deprecate_Xt_in_inverse_transform(X, Xt)
954+
943955
# we don't have to branch here, since params is only non-empty if
944956
# enable_metadata_routing=True.
945957
routed_params = process_routing(self, "inverse_transform", **params)
946958
reverse_iter = reversed(list(self._iter()))
947959
for _, name, transform in reverse_iter:
948-
Xt = transform.inverse_transform(
949-
Xt, **routed_params[name].inverse_transform
950-
)
951-
return Xt
960+
X = transform.inverse_transform(X, **routed_params[name].inverse_transform)
961+
return X
952962

953963
@available_if(_final_estimator_has("score"))
954964
def score(self, X, y=None, sample_weight=None, **params):

0 commit comments

Comments
 (0)