diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 0e4b8258476da..4a5b347e0f728 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -635,6 +635,35 @@ instantiated with an instance of ``LogisticRegression`` (or of these two models is somewhat idiosyncratic but both should provide robust closed-form solutions. +.. _developer_api_set_output: + +Developer API for `set_output` +============================== + +With +`SLEP018 `__, +scikit-learn introduces the `set_output` API for configuring transformers to +output pandas DataFrames. The `set_output` API is automatically defined if the +transformer defines :term:`get_feature_names_out` and subclasses +:class:`base.TransformerMixin`. :term:`get_feature_names_out` is used to get the +column names of pandas output. You can opt-out of the `set_output` API by +setting `auto_wrap_output_keys=None` when defining a custom subclass:: + + class MyTransformer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None): + + def fit(self, X, y=None): + return self + def transform(self, X, y=None): + return X + def get_feature_names_out(self, input_features=None): + ... + +For transformers that return multiple arrays in `transform`, auto wrapping will +only wrap the first array and not alter the other arrays. + +See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py` +for an example on how to use the API. + .. _coding-guidelines: Coding guidelines diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 6a1b05badd4a8..aefc046186b9a 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -52,6 +52,13 @@ random sampling procedures. Changes impacting all modules ----------------------------- +- |MajorFeature| The `set_output` API has been adopted by all transformers. + Meta-estimators that contain transformers such as :class:`pipeline.Pipeline` + or :class:`compose.ColumnTransformer` also define a `set_output`. + For details, see + `SLEP018 `__. + :pr:`23734` by `Thomas Fan`_. + - |Enhancement| Finiteness checks (detection of NaN and infinite values) in all estimators are now significantly more efficient for float32 data by leveraging NumPy's SIMD optimized primitives. diff --git a/examples/miscellaneous/plot_set_output.py b/examples/miscellaneous/plot_set_output.py new file mode 100644 index 0000000000000..12f2c822753e8 --- /dev/null +++ b/examples/miscellaneous/plot_set_output.py @@ -0,0 +1,111 @@ +""" +================================ +Introducing the `set_output` API +================================ + +.. currentmodule:: sklearn + +This example will demonstrate the `set_output` API to configure transformers to +output pandas DataFrames. `set_output` can be configured per estimator by calling +the `set_output` method or globally by setting `set_config(transform_output="pandas")`. +For details, see +`SLEP018 `__. +""" # noqa + +# %% +# First, we load the iris dataset as a DataFrame to demonstrate the `set_output` API. +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split + +X, y = load_iris(as_frame=True, return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) +X_train.head() + +# %% +# To configure an estimator such as :class:`preprocessing.StandardScalar` to return +# DataFrames, call `set_output`. This feature requires pandas to be installed. + +from sklearn.preprocessing import StandardScaler + +scaler = StandardScaler().set_output(transform="pandas") + +scaler.fit(X_train) +X_test_scaled = scaler.transform(X_test) +X_test_scaled.head() + +# %% +# `set_output` can be called after `fit` to configure `transform` after the fact. +scaler2 = StandardScaler() + +scaler2.fit(X_train) +X_test_np = scaler2.transform(X_test) +print(f"Default output type: {type(X_test_np).__name__}") + +scaler2.set_output(transform="pandas") +X_test_df = scaler2.transform(X_test) +print(f"Configured pandas output type: {type(X_test_df).__name__}") + +# %% +# In a :class:`pipeline.Pipeline`, `set_output` configures all steps to output +# DataFrames. +from sklearn.pipeline import make_pipeline +from sklearn.linear_model import LogisticRegression +from sklearn.feature_selection import SelectPercentile + +clf = make_pipeline( + StandardScaler(), SelectPercentile(percentile=75), LogisticRegression() +) +clf.set_output(transform="pandas") +clf.fit(X_train, y_train) + +# %% +# Each transformer in the pipeline is configured to return DataFrames. This +# means that the final logistic regression step contain the feature names. +clf[-1].feature_names_in_ + +# %% +# Next we load the titanic dataset to demonstrate `set_output` with +# :class:`compose.ColumnTransformer` and heterogenous data. +from sklearn.datasets import fetch_openml + +X, y = fetch_openml( + "titanic", version=1, as_frame=True, return_X_y=True, parser="pandas" +) +X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y) + +# %% +# The `set_output` API can be configured globally by using :func:`set_config` and +# setting the `transform_output` to `"pandas"`. +from sklearn.compose import ColumnTransformer +from sklearn.preprocessing import OneHotEncoder, StandardScaler +from sklearn.impute import SimpleImputer +from sklearn import set_config + +set_config(transform_output="pandas") + +num_pipe = make_pipeline(SimpleImputer(), StandardScaler()) +ct = ColumnTransformer( + ( + ("numerical", num_pipe, ["age", "fare"]), + ( + "categorical", + OneHotEncoder( + sparse_output=False, drop="if_binary", handle_unknown="ignore" + ), + ["embarked", "sex", "pclass"], + ), + ), + verbose_feature_names_out=False, +) +clf = make_pipeline(ct, SelectPercentile(percentile=50), LogisticRegression()) +clf.fit(X_train, y_train) +clf.score(X_test, y_test) + +# %% +# With the global configuration, all transformers output DataFrames. This allows us to +# easily plot the logistic regression coefficients with the corresponding feature names. +import pandas as pd + +log_reg = clf[-1] +coef = pd.Series(log_reg.coef_.ravel(), index=log_reg.feature_names_in_) +_ = coef.sort_values().plot.barh() diff --git a/sklearn/_config.py b/sklearn/_config.py index c358b7ea38584..ea5c47499b5b4 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -14,6 +14,7 @@ ), "enable_cython_pairwise_dist": True, "array_api_dispatch": False, + "transform_output": "default", } _threadlocal = threading.local() @@ -52,6 +53,7 @@ def set_config( pairwise_dist_chunk_size=None, enable_cython_pairwise_dist=None, array_api_dispatch=None, + transform_output=None, ): """Set global scikit-learn configuration @@ -120,6 +122,11 @@ def set_config( .. versionadded:: 1.2 + transform_output : str, default=None + Configure the output container for transform. + + .. versionadded:: 1.2 + See Also -------- config_context : Context manager for global scikit-learn configuration. @@ -141,6 +148,8 @@ def set_config( local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist if array_api_dispatch is not None: local_config["array_api_dispatch"] = array_api_dispatch + if transform_output is not None: + local_config["transform_output"] = transform_output @contextmanager @@ -153,6 +162,7 @@ def config_context( pairwise_dist_chunk_size=None, enable_cython_pairwise_dist=None, array_api_dispatch=None, + transform_output=None, ): """Context manager for global scikit-learn configuration. @@ -220,6 +230,11 @@ def config_context( .. versionadded:: 1.2 + transform_output : str, default=None + Configure the output container for transform. + + .. versionadded:: 1.2 + Yields ------ None. @@ -256,6 +271,7 @@ def config_context( pairwise_dist_chunk_size=pairwise_dist_chunk_size, enable_cython_pairwise_dist=enable_cython_pairwise_dist, array_api_dispatch=array_api_dispatch, + transform_output=transform_output, ) try: diff --git a/sklearn/base.py b/sklearn/base.py index ca94a45dfcc87..0bd9327c38166 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -15,6 +15,7 @@ from . import __version__ from ._config import get_config from .utils import _IS_32BIT +from .utils._set_output import _SetOutputMixin from .utils._tags import ( _DEFAULT_TAGS, ) @@ -98,6 +99,13 @@ def clone(estimator, *, safe=True): "Cannot clone object %s, as the constructor " "either does not set or modifies parameter %s" % (estimator, name) ) + + # _sklearn_output_config is used by `set_output` to configure the output + # container of an estimator. + if hasattr(estimator, "_sklearn_output_config"): + new_object._sklearn_output_config = copy.deepcopy( + estimator._sklearn_output_config + ) return new_object @@ -798,8 +806,13 @@ def get_submatrix(self, i, data): return data[row_ind[:, np.newaxis], col_ind] -class TransformerMixin: - """Mixin class for all transformers in scikit-learn.""" +class TransformerMixin(_SetOutputMixin): + """Mixin class for all transformers in scikit-learn. + + If :term:`get_feature_names_out` is defined and `auto_wrap_output` is True, + then `BaseEstimator` will automatically wrap `transform` and `fit_transform` to + follow the `set_output` API. See the :ref:`developer_api_set_output` for details. + """ def fit_transform(self, X, y=None, **fit_params): """ diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index f21616ac6684a..db7a7016c83ab 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -20,6 +20,8 @@ from ..utils import Bunch from ..utils import _safe_indexing from ..utils import _get_column_indices +from ..utils._set_output import _get_output_config, _safe_set_output +from ..utils import check_pandas_support from ..utils.metaestimators import _BaseComposition from ..utils.validation import check_array, check_is_fitted, _check_feature_names_in from ..utils.fixes import delayed @@ -252,6 +254,35 @@ def _transformers(self, value): except (TypeError, ValueError): self.transformers = value + def set_output(self, transform=None): + """Set the output container when `"transform`" and `"fit_transform"` are called. + + Calling `set_output` will set the output of all estimators in `transformers` + and `transformers_`. + + Parameters + ---------- + transform : {"default", "pandas"}, default=None + Configure output of `transform` and `fit_transform`. + + Returns + ------- + self : estimator instance + Estimator instance. + """ + super().set_output(transform=transform) + transformers = ( + trans + for _, trans, _ in chain( + self.transformers, getattr(self, "transformers_", []) + ) + if trans not in {"passthrough", "drop"} + ) + for trans in transformers: + _safe_set_output(trans, transform=transform) + + return self + def get_params(self, deep=True): """Get parameters for this estimator. @@ -302,7 +333,19 @@ def _iter(self, fitted=False, replace_strings=False, column_as_strings=False): """ if fitted: - transformers = self.transformers_ + if replace_strings: + # Replace "passthrough" with the fitted version in + # _name_to_fitted_passthrough + def replace_passthrough(name, trans, columns): + if name not in self._name_to_fitted_passthrough: + return name, trans, columns + return name, self._name_to_fitted_passthrough[name], columns + + transformers = [ + replace_passthrough(*trans) for trans in self.transformers_ + ] + else: + transformers = self.transformers_ else: # interleave the validated column specifiers transformers = [ @@ -314,12 +357,17 @@ def _iter(self, fitted=False, replace_strings=False, column_as_strings=False): transformers = chain(transformers, [self._remainder]) get_weight = (self.transformer_weights or {}).get + output_config = _get_output_config("transform", self) for name, trans, columns in transformers: if replace_strings: # replace 'passthrough' with identity transformer and # skip in case of 'drop' if trans == "passthrough": - trans = FunctionTransformer(accept_sparse=True, check_inverse=False) + trans = FunctionTransformer( + accept_sparse=True, + check_inverse=False, + feature_names_out="one-to-one", + ).set_output(transform=output_config["dense"]) elif trans == "drop": continue elif _is_empty_column_selection(columns): @@ -505,6 +553,7 @@ def _update_fitted_transformers(self, transformers): # transformers are fitted; excludes 'drop' cases fitted_transformers = iter(transformers) transformers_ = [] + self._name_to_fitted_passthrough = {} for name, old, column, _ in self._iter(): if old == "drop": @@ -512,8 +561,12 @@ def _update_fitted_transformers(self, transformers): elif old == "passthrough": # FunctionTransformer is present in list of transformers, # so get next transformer, but save original string - next(fitted_transformers) + func_transformer = next(fitted_transformers) trans = "passthrough" + + # The fitted FunctionTransformer is saved in another attribute, + # so it can be used during transform for set_output. + self._name_to_fitted_passthrough[name] = func_transformer elif _is_empty_column_selection(column): trans = old else: @@ -765,6 +818,10 @@ def _hstack(self, Xs): return sparse.hstack(converted_Xs).tocsr() else: Xs = [f.toarray() if sparse.issparse(f) else f for f in Xs] + config = _get_output_config("transform", self) + if config["dense"] == "pandas" and all(hasattr(X, "iloc") for X in Xs): + pd = check_pandas_support("transform") + return pd.concat(Xs, axis=1) return np.hstack(Xs) def _sk_visual_block_(self): diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 803cc2d542011..bade093b89c1f 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -13,7 +13,7 @@ from sklearn.utils._testing import assert_allclose_dense_sparse from sklearn.utils._testing import assert_almost_equal -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.compose import ( ColumnTransformer, make_column_transformer, @@ -24,7 +24,7 @@ from sklearn.preprocessing import StandardScaler, Normalizer, OneHotEncoder -class Trans(BaseEstimator): +class Trans(TransformerMixin, BaseEstimator): def fit(self, X, y=None): return self @@ -1940,3 +1940,116 @@ def test_verbose_feature_names_out_false_errors( ) with pytest.raises(ValueError, match=msg): ct.get_feature_names_out() + + +@pytest.mark.parametrize("verbose_feature_names_out", [True, False]) +@pytest.mark.parametrize("remainder", ["drop", "passthrough"]) +def test_column_transformer_set_output(verbose_feature_names_out, remainder): + """Check column transformer behavior with set_output.""" + pd = pytest.importorskip("pandas") + df = pd.DataFrame([[1, 2, 3, 4]], columns=["a", "b", "c", "d"], index=[10]) + ct = ColumnTransformer( + [("first", TransWithNames(), ["a", "c"]), ("second", TransWithNames(), ["d"])], + remainder=remainder, + verbose_feature_names_out=verbose_feature_names_out, + ) + X_trans = ct.fit_transform(df) + assert isinstance(X_trans, np.ndarray) + + ct.set_output(transform="pandas") + + df_test = pd.DataFrame([[1, 2, 3, 4]], columns=df.columns, index=[20]) + X_trans = ct.transform(df_test) + assert isinstance(X_trans, pd.DataFrame) + + feature_names_out = ct.get_feature_names_out() + assert_array_equal(X_trans.columns, feature_names_out) + assert_array_equal(X_trans.index, df_test.index) + + +@pytest.mark.parametrize("remainder", ["drop", "passthrough"]) +@pytest.mark.parametrize("fit_transform", [True, False]) +def test_column_transform_set_output_mixed(remainder, fit_transform): + """Check ColumnTransformer outputs mixed types correctly.""" + pd = pytest.importorskip("pandas") + df = pd.DataFrame( + { + "pet": pd.Series(["dog", "cat", "snake"], dtype="category"), + "color": pd.Series(["green", "blue", "red"], dtype="object"), + "age": [1.4, 2.1, 4.4], + "height": [20, 40, 10], + "distance": pd.Series([20, pd.NA, 100], dtype="Int32"), + } + ) + ct = ColumnTransformer( + [ + ( + "color_encode", + OneHotEncoder(sparse_output=False, dtype="int8"), + ["color"], + ), + ("age", StandardScaler(), ["age"]), + ], + remainder=remainder, + verbose_feature_names_out=False, + ).set_output(transform="pandas") + if fit_transform: + X_trans = ct.fit_transform(df) + else: + X_trans = ct.fit(df).transform(df) + + assert isinstance(X_trans, pd.DataFrame) + assert_array_equal(X_trans.columns, ct.get_feature_names_out()) + + expected_dtypes = { + "color_blue": "int8", + "color_green": "int8", + "color_red": "int8", + "age": "float64", + "pet": "category", + "height": "int64", + "distance": "Int32", + } + for col, dtype in X_trans.dtypes.items(): + assert dtype == expected_dtypes[col] + + +@pytest.mark.parametrize("remainder", ["drop", "passthrough"]) +def test_column_transform_set_output_after_fitting(remainder): + pd = pytest.importorskip("pandas") + df = pd.DataFrame( + { + "pet": pd.Series(["dog", "cat", "snake"], dtype="category"), + "age": [1.4, 2.1, 4.4], + "height": [20, 40, 10], + } + ) + ct = ColumnTransformer( + [ + ( + "color_encode", + OneHotEncoder(sparse_output=False, dtype="int16"), + ["pet"], + ), + ("age", StandardScaler(), ["age"]), + ], + remainder=remainder, + verbose_feature_names_out=False, + ) + + # fit without calling set_output + X_trans = ct.fit_transform(df) + assert isinstance(X_trans, np.ndarray) + assert X_trans.dtype == "float64" + + ct.set_output(transform="pandas") + X_trans_df = ct.transform(df) + expected_dtypes = { + "pet_cat": "int16", + "pet_dog": "int16", + "pet_snake": "int16", + "height": "int64", + "age": "float64", + } + for col, dtype in X_trans_df.dtypes.items(): + assert dtype == expected_dtypes[col] diff --git a/sklearn/cross_decomposition/tests/test_pls.py b/sklearn/cross_decomposition/tests/test_pls.py index dc2cc64a54cf9..aff2b76034b0b 100644 --- a/sklearn/cross_decomposition/tests/test_pls.py +++ b/sklearn/cross_decomposition/tests/test_pls.py @@ -620,3 +620,16 @@ def test_pls_feature_names_out(Klass): dtype=object, ) assert_array_equal(names_out, expected_names_out) + + +@pytest.mark.parametrize("Klass", [CCA, PLSSVD, PLSRegression, PLSCanonical]) +def test_pls_set_output(Klass): + """Check `set_output` in cross_decomposition module.""" + pd = pytest.importorskip("pandas") + X, Y = load_linnerud(return_X_y=True, as_frame=True) + + est = Klass().set_output(transform="pandas").fit(X, Y) + X_trans, y_trans = est.transform(X, Y) + assert isinstance(y_trans, np.ndarray) + assert isinstance(X_trans, pd.DataFrame) + assert_array_equal(X_trans.columns, est.get_feature_names_out()) diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 18a91901251a2..70aa6e7714149 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -1636,3 +1636,12 @@ def test_nonnegative_hashing_vectorizer_result_indices(): hashing = HashingVectorizer(n_features=1000000, ngram_range=(2, 3)) indices = hashing.transform(["22pcs efuture"]).indices assert indices[0] >= 0 + + +@pytest.mark.parametrize( + "Estimator", [CountVectorizer, TfidfVectorizer, TfidfTransformer, HashingVectorizer] +) +def test_vectorizers_do_not_have_set_output(Estimator): + """Check that vectorizers do not define set_output.""" + est = Estimator() + assert not hasattr(est, "set_output") diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 16c3999d771d6..9d0d847c1d35c 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -566,7 +566,9 @@ def _warn_for_unused_params(self): ) -class HashingVectorizer(TransformerMixin, _VectorizerMixin, BaseEstimator): +class HashingVectorizer( + TransformerMixin, _VectorizerMixin, BaseEstimator, auto_wrap_output_keys=None +): r"""Convert a collection of text documents to a matrix of token occurrences. It turns a collection of text documents into a scipy.sparse matrix holding @@ -1483,7 +1485,9 @@ def _make_int_array(): return array.array(str("i")) -class TfidfTransformer(_OneToOneFeatureMixin, TransformerMixin, BaseEstimator): +class TfidfTransformer( + _OneToOneFeatureMixin, TransformerMixin, BaseEstimator, auto_wrap_output_keys=None +): """Transform a count matrix to a normalized tf or tf-idf representation. Tf means term-frequency while tf-idf means term-frequency times inverse diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index a5f7ff503ec74..3f74acda1fc29 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -27,6 +27,8 @@ from .utils._tags import _safe_tags from .utils.validation import check_memory from .utils.validation import check_is_fitted +from .utils import check_pandas_support +from .utils._set_output import _safe_set_output, _get_output_config from .utils.fixes import delayed from .exceptions import NotFittedError @@ -146,6 +148,25 @@ def __init__(self, steps, *, memory=None, verbose=False): self.memory = memory self.verbose = verbose + def set_output(self, transform=None): + """Set the output container when `"transform`" and `"fit_transform"` are called. + + Calling `set_output` will set the output of all estimators in `steps`. + + Parameters + ---------- + transform : {"default", "pandas"}, default=None + Configure output of `transform` and `fit_transform`. + + Returns + ------- + self : estimator instance + Estimator instance. + """ + for _, _, step in self._iter(): + _safe_set_output(step, transform=transform) + return self + def get_params(self, deep=True): """Get parameters for this estimator. @@ -968,6 +989,26 @@ def __init__( self.transformer_weights = transformer_weights self.verbose = verbose + def set_output(self, transform=None): + """Set the output container when `"transform`" and `"fit_transform"` are called. + + `set_output` will set the output of all estimators in `transformer_list`. + + Parameters + ---------- + transform : {"default", "pandas"}, default=None + Configure output of `transform` and `fit_transform`. + + Returns + ------- + self : estimator instance + Estimator instance. + """ + super().set_output(transform=transform) + for _, step, _ in self._iter(): + _safe_set_output(step, transform=transform) + return self + def get_params(self, deep=True): """Get parameters for this estimator. @@ -1189,6 +1230,11 @@ def transform(self, X): return self._hstack(Xs) def _hstack(self, Xs): + config = _get_output_config("transform", self) + if config["dense"] == "pandas" and all(hasattr(X, "iloc") for X in Xs): + pd = check_pandas_support("transform") + return pd.concat(Xs, axis=1) + if any(sparse.issparse(f) for f in Xs): Xs = sparse.hstack(Xs).tocsr() else: diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index 93e51ad017369..228304bb70091 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -306,3 +306,35 @@ def __sklearn_is_fitted__(self): def _more_tags(self): return {"no_validation": not self.validate, "stateless": True} + + def set_output(self, *, transform=None): + """Set output container. + + See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py` + for an example on how to use the API. + + Parameters + ---------- + transform : {"default", "pandas"}, default=None + Configure output of the following estimator's methods: + + - `"transform"` + - `"fit_transform"` + + If `None`, this operation is a no-op. + + Returns + ------- + self : estimator instance + Estimator instance. + """ + if hasattr(super(), "set_output"): + return super().set_output(transform=transform) + + if transform == "pandas" and self.feature_names_out is None: + warnings.warn( + 'With transform="pandas", `func` should return a DataFrame to follow' + " the set_output API." + ) + + return self diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index ccc2ca6d18b67..6395bc28c7d69 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -1900,3 +1900,42 @@ def test_ordinal_encoder_unknown_missing_interaction_both_nan( assert np.isnan(val) else: assert val == expected_val + + +def test_one_hot_encoder_set_output(): + """Check OneHotEncoder works with set_output.""" + pd = pytest.importorskip("pandas") + + X_df = pd.DataFrame({"A": ["a", "b"], "B": [1, 2]}) + ohe = OneHotEncoder() + + ohe.set_output(transform="pandas") + + match = "Pandas output does not support sparse data" + with pytest.raises(ValueError, match=match): + ohe.fit_transform(X_df) + + ohe_default = OneHotEncoder(sparse_output=False).set_output(transform="default") + ohe_pandas = OneHotEncoder(sparse_output=False).set_output(transform="pandas") + + X_default = ohe_default.fit_transform(X_df) + X_pandas = ohe_pandas.fit_transform(X_df) + + assert_allclose(X_pandas.to_numpy(), X_default) + assert_array_equal(ohe_pandas.get_feature_names_out(), X_pandas.columns) + + +def test_ordinal_set_output(): + """Check OrdinalEncoder works with set_output.""" + pd = pytest.importorskip("pandas") + + X_df = pd.DataFrame({"A": ["a", "b"], "B": [1, 2]}) + + ord_default = OrdinalEncoder().set_output(transform="default") + ord_pandas = OrdinalEncoder().set_output(transform="pandas") + + X_default = ord_default.fit_transform(X_df) + X_pandas = ord_pandas.fit_transform(X_df) + + assert_allclose(X_pandas.to_numpy(), X_default) + assert_array_equal(ord_pandas.get_feature_names_out(), X_pandas.columns) diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index 256eb729e019b..b10682922acd0 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -405,3 +405,32 @@ def test_get_feature_names_out_dataframe_with_string_data( assert isinstance(names, np.ndarray) assert names.dtype == object assert_array_equal(names, expected) + + +def test_set_output_func(): + """Check behavior of set_output with different settings.""" + pd = pytest.importorskip("pandas") + + X = pd.DataFrame({"a": [1, 2, 3], "b": [10, 20, 100]}) + + ft = FunctionTransformer(np.log, feature_names_out="one-to-one") + + # no warning is raised when feature_names_out is defined + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ft.set_output(transform="pandas") + + X_trans = ft.fit_transform(X) + assert isinstance(X_trans, pd.DataFrame) + assert_array_equal(X_trans.columns, ["a", "b"]) + + # If feature_names_out is not defined, then a warning is raised in + # `set_output` + ft = FunctionTransformer(lambda x: 2 * x) + msg = "should return a DataFrame to follow the set_output API" + with pytest.warns(UserWarning, match=msg): + ft.set_output(transform="pandas") + + X_trans = ft.fit_transform(X) + assert isinstance(X_trans, pd.DataFrame) + assert_array_equal(X_trans.columns, ["a", "b"]) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 13a869c10f0f3..250b2302d9a21 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -14,6 +14,8 @@ from sklearn.base import BaseEstimator, clone, is_classifier from sklearn.svm import SVC +from sklearn.preprocessing import StandardScaler +from sklearn.utils._set_output import _get_output_config from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV @@ -659,3 +661,14 @@ def transform(self, X): # transform on feature names that are mixed also warns: with pytest.raises(TypeError, match=msg): trans.transform(df_mixed) + + +def test_clone_keeps_output_config(): + """Check that clone keeps the set_output config.""" + + ss = StandardScaler().set_output(transform="pandas") + config = _get_output_config("transform", ss) + + ss_clone = clone(ss) + config_clone = _get_output_config("transform", ss_clone) + assert config == config_clone diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 15b82504552af..a5d91715cb56e 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -34,6 +34,7 @@ RadiusNeighborsClassifier, RadiusNeighborsRegressor, ) +from sklearn.preprocessing import FunctionTransformer from sklearn.semi_supervised import LabelPropagation, LabelSpreading from sklearn.utils import all_estimators @@ -45,6 +46,7 @@ import sklearn from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder from sklearn.linear_model._base import LinearClassifierMixin from sklearn.linear_model import LogisticRegression from sklearn.linear_model import Ridge @@ -72,6 +74,8 @@ check_param_validation, check_transformer_get_feature_names_out, check_transformer_get_feature_names_out_pandas, + check_set_output_transform, + check_set_output_transform_pandas, ) @@ -498,3 +502,45 @@ def test_f_contiguous_array_estimator(Estimator): if hasattr(est, "predict"): est.predict(X) + + +SET_OUTPUT_ESTIMATORS = list( + chain( + _tested_estimators("transformer"), + [ + make_pipeline(StandardScaler(), MinMaxScaler()), + OneHotEncoder(sparse_output=False), + FunctionTransformer(feature_names_out="one-to-one"), + ], + ) +) + + +@pytest.mark.parametrize( + "estimator", SET_OUTPUT_ESTIMATORS, ids=_get_check_estimator_ids +) +def test_set_output_transform(estimator): + name = estimator.__class__.__name__ + if not hasattr(estimator, "set_output"): + pytest.skip( + f"Skipping check_set_output_transform for {name}: Does not support" + " set_output API" + ) + _set_checking_parameters(estimator) + with ignore_warnings(category=(FutureWarning)): + check_set_output_transform(estimator.__class__.__name__, estimator) + + +@pytest.mark.parametrize( + "estimator", SET_OUTPUT_ESTIMATORS, ids=_get_check_estimator_ids +) +def test_set_output_transform_pandas(estimator): + name = estimator.__class__.__name__ + if not hasattr(estimator, "set_output"): + pytest.skip( + f"Skipping check_set_output_transform_pandas for {name}: Does not support" + " set_output API yet" + ) + _set_checking_parameters(estimator) + with ignore_warnings(category=(FutureWarning)): + check_set_output_transform_pandas(estimator.__class__.__name__, estimator) diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 51a5a80ebf5b4..a0b8f29662b69 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -17,6 +17,7 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "transform_output": "default", } # Not using as a context manager affects nothing @@ -32,6 +33,7 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "transform_output": "default", } assert get_config()["assume_finite"] is False @@ -64,6 +66,7 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "transform_output": "default", } # No positional arguments diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 510ec968b8f03..d09acb8d3c6c8 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -21,6 +21,7 @@ MinimalTransformer, ) from sklearn.exceptions import NotFittedError +from sklearn.model_selection import train_test_split from sklearn.utils.validation import check_is_fitted from sklearn.base import clone, is_classifier, BaseEstimator, TransformerMixin from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union @@ -1613,3 +1614,35 @@ def get_feature_names_out(self, input_features=None): feature_names_out = pipe.get_feature_names_out(input_names) assert_array_equal(feature_names_out, [f"my_prefix_{name}" for name in input_names]) + + +def test_pipeline_set_output_integration(): + """Test pipeline's set_output with feature names.""" + pytest.importorskip("pandas") + + X, y = load_iris(as_frame=True, return_X_y=True) + + pipe = make_pipeline(StandardScaler(), LogisticRegression()) + pipe.set_output(transform="pandas") + pipe.fit(X, y) + + feature_names_in_ = pipe[:-1].get_feature_names_out() + log_reg_feature_names = pipe[-1].feature_names_in_ + + assert_array_equal(feature_names_in_, log_reg_feature_names) + + +def test_feature_union_set_output(): + """Test feature union with set_output API.""" + pd = pytest.importorskip("pandas") + + X, _ = load_iris(as_frame=True, return_X_y=True) + X_train, X_test = train_test_split(X, random_state=0) + union = FeatureUnion([("scalar", StandardScaler()), ("pca", PCA())]) + union.set_output(transform="pandas") + union.fit(X_train) + + X_trans = union.transform(X_test) + assert isinstance(X_trans, pd.DataFrame) + assert_array_equal(X_trans.columns, union.get_feature_names_out()) + assert_array_equal(X_trans.index, X_test.index) diff --git a/sklearn/utils/_set_output.py b/sklearn/utils/_set_output.py new file mode 100644 index 0000000000000..525c6e0fe0118 --- /dev/null +++ b/sklearn/utils/_set_output.py @@ -0,0 +1,269 @@ +from functools import wraps + +from scipy.sparse import issparse + +from . import check_pandas_support +from .._config import get_config +from ._available_if import available_if + + +def _wrap_in_pandas_container( + data_to_wrap, + *, + columns, + index=None, +): + """Create a Pandas DataFrame. + + If `data_to_wrap` is a DataFrame, then the `columns` and `index` will be changed + inplace. If `data_to_wrap` is a ndarray, then a new DataFrame is created with + `columns` and `index`. + + Parameters + ---------- + data_to_wrap : {ndarray, dataframe} + Data to be wrapped as pandas dataframe. + + columns : callable, ndarray, or None + The column names or a callable that returns the column names. The + callable is useful if the column names require some computation. + If `None` and `data_to_wrap` is already a dataframe, then the column + names are not changed. If `None` and `data_to_wrap` is **not** a + dataframe, then columns are `range(n_features)`. + + index : array-like, default=None + Index for data. + + Returns + ------- + dataframe : DataFrame + Container with column names or unchanged `output`. + """ + if issparse(data_to_wrap): + raise ValueError("Pandas output does not support sparse data.") + + if callable(columns): + columns = columns() + + pd = check_pandas_support("Setting output container to 'pandas'") + + if isinstance(data_to_wrap, pd.DataFrame): + if columns is not None: + data_to_wrap.columns = columns + if index is not None: + data_to_wrap.index = index + return data_to_wrap + + return pd.DataFrame(data_to_wrap, index=index, columns=columns) + + +def _get_output_config(method, estimator=None): + """Get output config based on estimator and global configuration. + + Parameters + ---------- + method : {"transform"} + Estimator's method for which the output container is looked up. + + estimator : estimator instance or None + Estimator to get the output configuration from. If `None`, check global + configuration is used. + + Returns + ------- + config : dict + Dictionary with keys: + + - "dense": specifies the dense container for `method`. This can be + `"default"` or `"pandas"`. + """ + est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {}) + if method in est_sklearn_output_config: + dense_config = est_sklearn_output_config[method] + else: + dense_config = get_config()[f"{method}_output"] + + if dense_config not in {"default", "pandas"}: + raise ValueError( + f"output config must be 'default' or 'pandas' got {dense_config}" + ) + + return {"dense": dense_config} + + +def _wrap_data_with_container(method, data_to_wrap, original_input, estimator): + """Wrap output with container based on an estimator's or global config. + + Parameters + ---------- + method : {"transform"} + Estimator's method to get container output for. + + data_to_wrap : {ndarray, dataframe} + Data to wrap with container. + + original_input : {ndarray, dataframe} + Original input of function. + + estimator : estimator instance + Estimator with to get the output configuration from. + + Returns + ------- + output : {ndarray, dataframe} + If the output config is "default" or the estimator is not configured + for wrapping return `data_to_wrap` unchanged. + If the output config is "pandas", return `data_to_wrap` as a pandas + DataFrame. + """ + output_config = _get_output_config(method, estimator) + + if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator): + return data_to_wrap + + # dense_config == "pandas" + return _wrap_in_pandas_container( + data_to_wrap=data_to_wrap, + index=getattr(original_input, "index", None), + columns=estimator.get_feature_names_out, + ) + + +def _wrap_method_output(f, method): + """Wrapper used by `_SetOutputMixin` to automatically wrap methods.""" + + @wraps(f) + def wrapped(self, X, *args, **kwargs): + data_to_wrap = f(self, X, *args, **kwargs) + if isinstance(data_to_wrap, tuple): + # only wrap the first output for cross decomposition + return ( + _wrap_data_with_container(method, data_to_wrap[0], X, self), + *data_to_wrap[1:], + ) + + return _wrap_data_with_container(method, data_to_wrap, X, self) + + return wrapped + + +def _auto_wrap_is_configured(estimator): + """Return True if estimator is configured for auto-wrapping the transform method. + + `_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping + is manually disabled. + """ + auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set()) + return ( + hasattr(estimator, "get_feature_names_out") + and "transform" in auto_wrap_output_keys + ) + + +class _SetOutputMixin: + """Mixin that dynamically wraps methods to return container based on config. + + Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures + it based on `set_output` of the global configuration. + + `set_output` is only defined if `get_feature_names_out` is defined and + `auto_wrap_output` is True. + """ + + def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs): + # Dynamically wraps `transform` and `fit_transform` and configure it's + # output based on `set_output`. + if not ( + isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None + ): + raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.") + + if auto_wrap_output_keys is None: + cls._sklearn_auto_wrap_output_keys = set() + return + + # Mapping from method to key in configurations + method_to_key = { + "transform": "transform", + "fit_transform": "transform", + } + cls._sklearn_auto_wrap_output_keys = set() + + for method, key in method_to_key.items(): + if not hasattr(cls, method) or key not in auto_wrap_output_keys: + continue + cls._sklearn_auto_wrap_output_keys.add(key) + wrapped_method = _wrap_method_output(getattr(cls, method), key) + setattr(cls, method, wrapped_method) + + @available_if(_auto_wrap_is_configured) + def set_output(self, *, transform=None): + """Set output container. + + See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py` + for an example on how to use the API. + + Parameters + ---------- + transform : {"default", "pandas"}, default=None + Configure output of the following estimator's methods: + + - `"transform"` + - `"fit_transform"` + + If `None`, this operation is a no-op. + + Returns + ------- + self : estimator instance + Estimator instance. + """ + if transform is None: + return self + + if not hasattr(self, "_sklearn_output_config"): + self._sklearn_output_config = {} + + self._sklearn_output_config["transform"] = transform + return self + + +def _safe_set_output(estimator, *, transform=None): + """Safely call estimator.set_output and error if it not available. + + This is used by meta-estimators to set the output for child estimators. + + Parameters + ---------- + estimator : estimator instance + Estimator instance. + + transform : {"default", "pandas"}, default=None + Configure output of the following estimator's methods: + + - `"transform"` + - `"fit_transform"` + + If `None`, this operation is a no-op. + + Returns + ------- + estimator : estimator instance + Estimator instance. + """ + set_output_for_transform = ( + hasattr(estimator, "transform") + or hasattr(estimator, "fit_transform") + and transform is not None + ) + if not set_output_for_transform: + # If estimator can not transform, then `set_output` does not need to be + # called. + return + + if not hasattr(estimator, "set_output"): + raise ValueError( + f"Unable to configure output for {estimator} because `set_output` " + "is not available." + ) + return estimator.set_output(transform=transform) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index ed9346c0487e6..90efa06073487 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4145,3 +4145,107 @@ def check_param_validation(name, estimator_orig): getattr(estimator, method)(y) else: getattr(estimator, method)(X, y) + + +def check_set_output_transform(name, transformer_orig): + # Check transformer.set_output with the default configuration does not + # change the transform output. + tags = transformer_orig._get_tags() + if "2darray" not in tags["X_types"] or tags["no_validation"]: + return + + rng = np.random.RandomState(0) + transformer = clone(transformer_orig) + + X = rng.uniform(size=(20, 5)) + X = _pairwise_estimator_convert_X(X, transformer_orig) + y = rng.randint(0, 2, size=20) + y = _enforce_estimator_tags_y(transformer_orig, y) + set_random_state(transformer) + + def fit_then_transform(est): + if name in CROSS_DECOMPOSITION: + return est.fit(X, y).transform(X, y) + return est.fit(X, y).transform(X) + + def fit_transform(est): + return est.fit_transform(X, y) + + transform_methods = [fit_then_transform, fit_transform] + for transform_method in transform_methods: + transformer = clone(transformer) + X_trans_no_setting = transform_method(transformer) + + # Auto wrapping only wraps the first array + if name in CROSS_DECOMPOSITION: + X_trans_no_setting = X_trans_no_setting[0] + + transformer.set_output(transform="default") + X_trans_default = transform_method(transformer) + + if name in CROSS_DECOMPOSITION: + X_trans_default = X_trans_default[0] + + # Default and no setting -> returns the same transformation + assert_allclose_dense_sparse(X_trans_no_setting, X_trans_default) + + +def check_set_output_transform_pandas(name, transformer_orig): + # Check transformer.set_output configures the output of transform="pandas". + try: + import pandas as pd + except ImportError: + raise SkipTest( + "pandas is not installed: not checking column name consistency for pandas" + ) + + tags = transformer_orig._get_tags() + if "2darray" not in tags["X_types"] or tags["no_validation"]: + return + + rng = np.random.RandomState(0) + transformer = clone(transformer_orig) + + X = rng.uniform(size=(20, 5)) + X = _pairwise_estimator_convert_X(X, transformer_orig) + y = rng.randint(0, 2, size=20) + y = _enforce_estimator_tags_y(transformer_orig, y) + set_random_state(transformer) + + feature_names_in = [f"col{i}" for i in range(X.shape[1])] + df = pd.DataFrame(X, columns=feature_names_in) + + def fit_then_transform(est): + if name in CROSS_DECOMPOSITION: + return est.fit(df, y).transform(df, y) + return est.fit(df, y).transform(df) + + def fit_transform(est): + return est.fit_transform(df, y) + + transform_methods = [fit_then_transform, fit_transform] + + for transform_method in transform_methods: + transformer = clone(transformer).set_output(transform="default") + X_trans_no_setting = transform_method(transformer) + + # Auto wrapping only wraps the first array + if name in CROSS_DECOMPOSITION: + X_trans_no_setting = X_trans_no_setting[0] + + transformer.set_output(transform="pandas") + try: + X_trans_pandas = transform_method(transformer) + except ValueError as e: + # transformer does not support sparse data + assert str(e) == "Pandas output does not support sparse data.", e + return + + if name in CROSS_DECOMPOSITION: + X_trans_pandas = X_trans_pandas[0] + + assert isinstance(X_trans_pandas, pd.DataFrame) + expected_dataframe = pd.DataFrame( + X_trans_no_setting, columns=transformer.get_feature_names_out() + ) + pd.testing.assert_frame_equal(X_trans_pandas, expected_dataframe) diff --git a/sklearn/utils/tests/test_set_output.py b/sklearn/utils/tests/test_set_output.py new file mode 100644 index 0000000000000..d20a4634f885d --- /dev/null +++ b/sklearn/utils/tests/test_set_output.py @@ -0,0 +1,201 @@ +import pytest + +import numpy as np +from scipy.sparse import csr_matrix +from numpy.testing import assert_array_equal + +from sklearn._config import config_context, get_config +from sklearn.utils._set_output import _wrap_in_pandas_container +from sklearn.utils._set_output import _safe_set_output +from sklearn.utils._set_output import _SetOutputMixin +from sklearn.utils._set_output import _get_output_config + + +def test__wrap_in_pandas_container_dense(): + """Check _wrap_in_pandas_container for dense data.""" + pd = pytest.importorskip("pandas") + X = np.asarray([[1, 0, 3], [0, 0, 1]]) + columns = np.asarray(["f0", "f1", "f2"], dtype=object) + index = np.asarray([0, 1]) + + dense_named = _wrap_in_pandas_container(X, columns=lambda: columns, index=index) + assert isinstance(dense_named, pd.DataFrame) + assert_array_equal(dense_named.columns, columns) + assert_array_equal(dense_named.index, index) + + +def test__wrap_in_pandas_container_dense_update_columns_and_index(): + """Check that _wrap_in_pandas_container overrides columns and index.""" + pd = pytest.importorskip("pandas") + X_df = pd.DataFrame([[1, 0, 3], [0, 0, 1]], columns=["a", "b", "c"]) + new_columns = np.asarray(["f0", "f1", "f2"], dtype=object) + new_index = [10, 12] + + new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index) + assert_array_equal(new_df.columns, new_columns) + assert_array_equal(new_df.index, new_index) + + +def test__wrap_in_pandas_container_error_validation(): + """Check errors in _wrap_in_pandas_container.""" + X = np.asarray([[1, 0, 3], [0, 0, 1]]) + X_csr = csr_matrix(X) + match = "Pandas output does not support sparse data" + with pytest.raises(ValueError, match=match): + _wrap_in_pandas_container(X_csr, columns=["a", "b", "c"]) + + +class EstimatorWithoutSetOutputAndWithoutTransform: + pass + + +class EstimatorNoSetOutputWithTransform: + def transform(self, X, y=None): + return X # pragma: no cover + + +class EstimatorWithSetOutput(_SetOutputMixin): + def fit(self, X, y=None): + self.n_features_in_ = X.shape[1] + return self + + def transform(self, X, y=None): + return X + + def get_feature_names_out(self, input_features=None): + return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object) + + +def test__safe_set_output(): + """Check _safe_set_output works as expected.""" + + # Estimator without transform will not raise when setting set_output for transform. + est = EstimatorWithoutSetOutputAndWithoutTransform() + _safe_set_output(est, transform="pandas") + + # Estimator with transform but without set_output will raise + est = EstimatorNoSetOutputWithTransform() + with pytest.raises(ValueError, match="Unable to configure output"): + _safe_set_output(est, transform="pandas") + + est = EstimatorWithSetOutput().fit(np.asarray([[1, 2, 3]])) + _safe_set_output(est, transform="pandas") + config = _get_output_config("transform", est) + assert config["dense"] == "pandas" + + _safe_set_output(est, transform="default") + config = _get_output_config("transform", est) + assert config["dense"] == "default" + + # transform is None is a no-op, so the config remains "default" + _safe_set_output(est, transform=None) + config = _get_output_config("transform", est) + assert config["dense"] == "default" + + +class EstimatorNoSetOutputWithTransformNoFeatureNamesOut(_SetOutputMixin): + def transform(self, X, y=None): + return X # pragma: no cover + + +def test_set_output_mixin(): + """Estimator without get_feature_names_out does not define `set_output`.""" + est = EstimatorNoSetOutputWithTransformNoFeatureNamesOut() + assert not hasattr(est, "set_output") + + +def test__safe_set_output_error(): + """Check transform with invalid config.""" + X = np.asarray([[1, 0, 3], [0, 0, 1]]) + + est = EstimatorWithSetOutput() + _safe_set_output(est, transform="bad") + + msg = "output config must be 'default'" + with pytest.raises(ValueError, match=msg): + est.transform(X) + + +def test_set_output_method(): + """Check that the output is pandas.""" + pd = pytest.importorskip("pandas") + + X = np.asarray([[1, 0, 3], [0, 0, 1]]) + est = EstimatorWithSetOutput().fit(X) + + # transform=None is a no-op + est2 = est.set_output(transform=None) + assert est2 is est + X_trans_np = est2.transform(X) + assert isinstance(X_trans_np, np.ndarray) + + est.set_output(transform="pandas") + + X_trans_pd = est.transform(X) + assert isinstance(X_trans_pd, pd.DataFrame) + + +def test_set_output_method_error(): + """Check transform fails with invalid transform.""" + + X = np.asarray([[1, 0, 3], [0, 0, 1]]) + est = EstimatorWithSetOutput().fit(X) + est.set_output(transform="bad") + + msg = "output config must be 'default'" + with pytest.raises(ValueError, match=msg): + est.transform(X) + + +def test__get_output_config(): + """Check _get_output_config works as expected.""" + + # Without a configuration set, the global config is used + global_config = get_config()["transform_output"] + config = _get_output_config("transform") + assert config["dense"] == global_config + + with config_context(transform_output="pandas"): + # with estimator=None, the global config is used + config = _get_output_config("transform") + assert config["dense"] == "pandas" + + est = EstimatorNoSetOutputWithTransform() + config = _get_output_config("transform", est) + assert config["dense"] == "pandas" + + est = EstimatorWithSetOutput() + # If estimator has not config, use global config + config = _get_output_config("transform", est) + assert config["dense"] == "pandas" + + # If estimator has a config, use local config + est.set_output(transform="default") + config = _get_output_config("transform", est) + assert config["dense"] == "default" + + est.set_output(transform="pandas") + config = _get_output_config("transform", est) + assert config["dense"] == "pandas" + + +class EstimatorWithSetOutputNoAutoWrap(_SetOutputMixin, auto_wrap_output_keys=None): + def transform(self, X, y=None): + return X + + +def test_get_output_auto_wrap_false(): + """Check that auto_wrap_output_keys=None does not wrap.""" + est = EstimatorWithSetOutputNoAutoWrap() + assert not hasattr(est, "set_output") + + X = np.asarray([[1, 0, 3], [0, 0, 1]]) + assert X is est.transform(X) + + +def test_auto_wrap_output_keys_errors_with_incorrect_input(): + msg = "auto_wrap_output_keys must be None or a tuple of keys." + with pytest.raises(ValueError, match=msg): + + class BadEstimator(_SetOutputMixin, auto_wrap_output_keys="bad_parameter"): + pass