Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
868d0ff
FEAT allow metadata to be transformed in Pipeline
adrinjalali Apr 15, 2024
42dfe81
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Apr 26, 2024
94c8bd9
add tests
adrinjalali Apr 26, 2024
818da32
add fit_transform
adrinjalali Apr 26, 2024
067946c
fix pprint test
adrinjalali Apr 29, 2024
ed5edcd
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali May 7, 2024
85c10a4
add changelog
adrinjalali May 7, 2024
ad269ea
much more extensive tests
adrinjalali May 8, 2024
1622203
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali May 8, 2024
5268514
more fixes
adrinjalali May 24, 2024
1a4a428
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali May 25, 2024
052b13d
WIP tests improvements
adrinjalali May 26, 2024
278dc70
TST fix pipeline tests
adrinjalali May 26, 2024
75dbf5d
Christian's comments
adrinjalali Sep 2, 2024
ffcbca5
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Sep 2, 2024
07586f9
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Sep 5, 2024
52e0642
remove erronous arg passing
adrinjalali Sep 7, 2024
cdaf20f
support tupples to be transformed
adrinjalali Sep 7, 2024
24fa675
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Sep 7, 2024
399b5f1
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Oct 16, 2024
9dbb6b7
rename method
adrinjalali Oct 16, 2024
08e7415
address comments
adrinjalali Oct 18, 2024
1a7db0f
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Oct 18, 2024
486c116
changelog
adrinjalali Oct 18, 2024
06ed90b
remove TBD
adrinjalali Oct 21, 2024
0fc6800
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Oct 21, 2024
179dc88
fix tests
adrinjalali Oct 22, 2024
d1ec33c
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Oct 22, 2024
18530b9
remove debug message
adrinjalali Oct 31, 2024
c07e043
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Oct 31, 2024
f3ecedd
Update sklearn/pipeline.py
adrinjalali Nov 4, 2024
03f7dda
Merge branch 'main' into pipeline/transform
adrinjalali Nov 4, 2024
df36b33
...
adrinjalali Nov 7, 2024
7e049a6
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Nov 7, 2024
b703976
Merge branch 'pipeline/transform' of github.com:adrinjalali/scikit-le…
adrinjalali Nov 7, 2024
31a847a
Apply suggestions from code review
adrinjalali Nov 8, 2024
496d5f2
lint
adrinjalali Nov 8, 2024
568f37a
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :class:`pipeline.Pipeline` can now transform metadata up to the step requiring the
metadata, which can be set using the `transform_input` parameter.
By `Adrin Jalali`_
195 changes: 186 additions & 9 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MethodMapping,
_raise_for_params,
_routing_enabled,
get_routing_for_object,
process_routing,
)
from .utils.metaestimators import _BaseComposition, available_if
Expand Down Expand Up @@ -80,6 +81,46 @@ def check(self):
return check


def _cached_transform(
sub_pipeline, *, cache, param_name, param_value, transform_params
):
"""Transform a parameter value using a sub-pipeline and cache the result.

Parameters
----------
sub_pipeline : Pipeline
The sub-pipeline to be used for transformation.
cache : dict
The cache dictionary to store the transformed values.
param_name : str
The name of the parameter to be transformed.
param_value : object
The value of the parameter to be transformed.
transform_params : dict
The metadata to be used for transformation. This passed to the
`transform` method of the sub-pipeline.

Returns
-------
transformed_value : object
The transformed value of the parameter.
"""
if param_name not in cache:
# If the parameter is a tuple, transform each element of the
# tuple. This is needed to support the pattern present in
# `lightgbm` and `xgboost` where users can pass multiple
# validation sets.
if isinstance(param_value, tuple):
cache[param_name] = tuple(
sub_pipeline.transform(element, **transform_params)
for element in param_value
)
else:
cache[param_name] = sub_pipeline.transform(param_value, **transform_params)

return cache[param_name]


class Pipeline(_BaseComposition):
"""
A sequence of data transformers with an optional final predictor.
Expand Down Expand Up @@ -119,6 +160,20 @@ class Pipeline(_BaseComposition):
must define `fit`. All non-last steps must also define `transform`. See
:ref:`Combining Estimators <combining_estimators>` for more details.

transform_input : list of str, default=None
The names of the :term:`metadata` parameters that should be transformed by the
pipeline before passing it to the step consuming it.

This enables transforming some input arguments to ``fit`` (other than ``X``)
to be transformed by the steps of the pipeline up to the step which requires
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`.
For instance, this can be used to pass a validation set through the pipeline.

You can only set this if metadata routing is enabled, which you
can enable using ``sklearn.set_config(enable_metadata_routing=True)``.

.. versionadded:: 1.6

memory : str or object with the joblib.Memory interface, default=None
Used to cache the fitted transformers of the pipeline. The last step
will never be cached, even if it is a transformer. By default, no
Expand Down Expand Up @@ -184,12 +239,14 @@ class Pipeline(_BaseComposition):
# BaseEstimator interface
_parameter_constraints: dict = {
"steps": [list, Hidden(tuple)],
"transform_input": [list, None],
"memory": [None, str, HasMethods(["cache"])],
"verbose": ["boolean"],
}

def __init__(self, steps, *, memory=None, verbose=False):
def __init__(self, steps, *, transform_input=None, memory=None, verbose=False):
self.steps = steps
self.transform_input = transform_input
self.memory = memory
self.verbose = verbose

Expand Down Expand Up @@ -412,9 +469,92 @@ def _check_method_params(self, method, props, **kwargs):
fit_params_steps[step]["fit_predict"][param] = pval
return fit_params_steps

def _get_metadata_for_step(self, *, step_idx, step_params, all_params):
"""Get params (metadata) for step `name`.

This transforms the metadata up to this step if required, which is
indicated by the `transform_input` parameter.

If a param in `step_params` is included in the `transform_input` list,
it will be transformed.

Parameters
----------
step_idx : int
Index of the step in the pipeline.

step_params : dict
Parameters specific to the step. These are routed parameters, e.g.
`routed_params[name]`. If a parameter name here is included in the
`pipeline.transform_input`, then it will be transformed. Note that
these parameters are *after* routing, so the aliases are already
resolved.

all_params : dict
All parameters passed by the user. Here this is used to call
`transform` on the slice of the pipeline itself.

Returns
-------
dict
Parameters to be passed to the step. The ones which should be
transformed are transformed.
"""
if (
self.transform_input is None
or not all_params
or not step_params
or step_idx == 0
):
# we only need to process step_params if transform_input is set
# and metadata is given by the user.
return step_params

sub_pipeline = self[:step_idx]
sub_metadata_routing = get_routing_for_object(sub_pipeline)
# here we get the metadata required by sub_pipeline.transform
transform_params = {
key: value
for key, value in all_params.items()
if key
in sub_metadata_routing.consumes(
method="transform", params=all_params.keys()
)
}
transformed_params = dict() # this is to be returned
transformed_cache = dict() # used to transform each param once
# `step_params` is the output of `process_routing`, so it has a dict for each
# method (e.g. fit, transform, predict), which are the args to be passed to
# those methods. We need to transform the parameters which are in the
# `transform_input`, before returning these dicts.
for method, method_params in step_params.items():
transformed_params[method] = Bunch()
for param_name, param_value in method_params.items():
# An example of `(param_name, param_value)` is
# `('sample_weight', array([0.5, 0.5, ...]))`
if param_name in self.transform_input:
# This parameter now needs to be transformed by the sub_pipeline, to
# this step. We cache these computations to avoid repeating them.
transformed_params[method][param_name] = _cached_transform(
sub_pipeline,
cache=transformed_cache,
param_name=param_name,
param_value=param_value,
transform_params=transform_params,
)
else:
transformed_params[method][param_name] = param_value
return transformed_params

# Estimator interface

def _fit(self, X, y=None, routed_params=None):
def _fit(self, X, y=None, routed_params=None, raw_params=None):
"""Fit the pipeline except the last step.

routed_params is the output of `process_routing`
raw_params is the parameters passed by the user, used when `transform_input`
is set by the user, to transform metadata using a sub-pipeline.
"""
# shallow copy of steps - this should really be steps_
self.steps = list(self.steps)
self._validate_steps()
Expand All @@ -437,14 +577,20 @@ def _fit(self, X, y=None, routed_params=None):
else:
cloned_transformer = clone(transformer)
# Fit or load from cache the current transformer
step_params = self._get_metadata_for_step(
step_idx=step_idx,
step_params=routed_params[name],
all_params=raw_params,
)

X, fitted_transformer = fit_transform_one_cached(
cloned_transformer,
X,
y,
None,
weight=None,
message_clsname="Pipeline",
message=self._log_message(step_idx),
params=routed_params[name],
params=step_params,
)
# Replace the transformer of the step with the fitted
# transformer. This is necessary when loading the transformer
Expand Down Expand Up @@ -495,11 +641,22 @@ def fit(self, X, y=None, **params):
self : object
Pipeline with fitted steps.
"""
if not _routing_enabled() and self.transform_input is not None:
raise ValueError(
"The `transform_input` parameter can only be set if metadata "
"routing is enabled. You can enable metadata routing using "
"`sklearn.set_config(enable_metadata_routing=True)`."
)

routed_params = self._check_method_params(method="fit", props=params)
Xt = self._fit(X, y, routed_params)
Xt = self._fit(X, y, routed_params, raw_params=params)
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator != "passthrough":
last_step_params = routed_params[self.steps[-1][0]]
last_step_params = self._get_metadata_for_step(
step_idx=len(self) - 1,
step_params=routed_params[self.steps[-1][0]],
all_params=params,
)
self._final_estimator.fit(Xt, y, **last_step_params["fit"])

return self
Expand Down Expand Up @@ -562,7 +719,11 @@ def fit_transform(self, X, y=None, **params):
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if last_step == "passthrough":
return Xt
last_step_params = routed_params[self.steps[-1][0]]
last_step_params = self._get_metadata_for_step(
step_idx=len(self) - 1,
step_params=routed_params[self.steps[-1][0]],
all_params=params,
)
if hasattr(last_step, "fit_transform"):
return last_step.fit_transform(
Xt, y, **last_step_params["fit_transform"]
Expand Down Expand Up @@ -1270,7 +1431,7 @@ def _name_estimators(estimators):
return list(zip(names, estimators))


def make_pipeline(*steps, memory=None, verbose=False):
def make_pipeline(*steps, memory=None, transform_input=None, verbose=False):
"""Construct a :class:`Pipeline` from the given estimators.

This is a shorthand for the :class:`Pipeline` constructor; it does not
Expand All @@ -1292,6 +1453,17 @@ def make_pipeline(*steps, memory=None, verbose=False):
or ``steps`` to inspect estimators within the pipeline. Caching the
transformers is advantageous when fitting is time consuming.

transform_input : list of str, default=None
This enables transforming some input arguments to ``fit`` (other than ``X``)
to be transformed by the steps of the pipeline up to the step which requires
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`.
This can be used to pass a validation set through the pipeline for instance.

You can only set this if metadata routing is enabled, which you
can enable using ``sklearn.set_config(enable_metadata_routing=True)``.

.. versionadded:: 1.6

verbose : bool, default=False
If True, the time elapsed while fitting each step will be printed as it
is completed.
Expand All @@ -1315,7 +1487,12 @@ def make_pipeline(*steps, memory=None, verbose=False):
Pipeline(steps=[('standardscaler', StandardScaler()),
('gaussiannb', GaussianNB())])
"""
return Pipeline(_name_estimators(steps), memory=memory, verbose=verbose)
return Pipeline(
_name_estimators(steps),
transform_input=transform_input,
memory=memory,
verbose=verbose,
)


def _transform_one(transformer, X, y, weight, params=None):
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/metadata_routing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def fit(self, X, y=None, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, sample_weight=sample_weight, metadata=metadata
)
self.fitted_ = True
return self

def transform(self, X, sample_weight="default", metadata="default"):
Expand Down
Loading