Description
Describe the bug
This is related to: #23734 (btw I love this enhancement!), when config_context
is used the Transformers created within the context do not register/memoize the transform output. This may be expected, tho I could not find that explicitly in the documentation?
Could the solution be to add default init to _SetOutputMixin
to capture transform_output
if set?
def __init__(self, **kwargs):
super().__init__(**kwargs)
transform_output = get_config()["transform_output"]
if transform_output != "default":
self.set_output(transform=transform_output)
Steps/Code to Reproduce
This works as expected:
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
X, y = load_iris(as_frame=True, return_X_y=True)
with sklearn.config_context(transform_output="pandas"):
scaler = StandardScaler()
x = scaler.fit_transform(X_test)
print(type(x)) # <class 'pandas.core.frame.DataFrame'>
But:
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
X, y = load_iris(as_frame=True, return_X_y=True)
with sklearn.config_context(transform_output="pandas"):
scaler = StandardScaler()
x = scaler.fit_transform(X_test)
print(type(x)) # <class 'numpy.ndarray'>
So when fit_transform
is not under config_context(transform_output="pandas")
the output defaults to numpy array (default output). StandardScaler()
constructor doesn't register the config at the construction time.
This is slightly confusing because:
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
X, y = load_iris(as_frame=True, return_X_y=True)
scaler = StandardScaler().set_output(transform="pandas")
x = scaler.fit_transform(X_test)
print(type(x)) # <class 'pandas.core.frame.DataFrame'>
Expected Results
As a user I would expect that config_context(transform_output="pandas")
is memoized/registered during transformer construction. Similar to explicitly calling set_output
on a transformer.
Actual Results
See above.
Versions
System:
python: 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:23:14) [GCC 10.4.0]
executable: ...
machine: Linux-5.10.0-20-cloud-amd64-x86_64-with-glibc2.31
Python dependencies:
sklearn: 1.2.0
pip: 22.3.1
setuptools: 65.6.3
numpy: 1.23.5
scipy: 1.9.3
Cython: None
pandas: 1.5.2
matplotlib: 3.6.2
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: ...
version: None
num_threads: 16
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: ...
version: 0.3.21
threading_layer: pthreads
architecture: SkylakeX
num_threads: 16
user_api: openmp
internal_api: openmp
prefix: libomp
filepath: ...
version: None
num_threads: 16