-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Add metadata routing for FeatureUnion
#28205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Add metadata routing for FeatureUnion
#28205
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_fit_transform_request
isn't well-defined since it's a composite method, therefore it always has to be a combination of routing for fit
and transform
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the test about making sure get_metadata_routing
works on an unfitted estimator here.
sklearn/pipeline.py
Outdated
method_mapping=MethodMapping() | ||
.add(callee="fit", caller="fit") | ||
.add(callee="fit_transform", caller="fit") | ||
.add(callee="fit_transform", caller="fit_transform"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't do explicit fit_transform
here, since it's a compound method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I see:
From FeatureUnion's fit_transform
we need to prepare for routing into three directions: to fit
, to transform
and to fit_transform
. This is because of the way _fit_transform_one
is designed:
if hasattr(transformer, "fit_transform"):
res = transformer.fit_transform(X, y, **params.get("fit_transform", {}))
else:
res = transformer.fit(X, y, **params.get("fit", {})).transform(
X, **params.get("transform", {})
)
It is designed this way, because a transformer can use either fit
and transform
or fit_transform
and in fact, metadata routing is only possible, if it overwrites TransformerMixin
's fit_transform
.
But yes, we don't route from the FeaturUnion's fit to the estimator's fit_transform
.
I've changed the routing to my new understandings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually don't have a test for transformer.fit().transform(). I made one using a new class ConsumingNoFitTransformTransformer
, that is special, because it doesn't inherit from TransformerMixin
. it has a long name, but it's a speaking name. :)
We need this new class to make transformer.fit().transform() possible.
ConsumingClassifier
's fit and transform would not be used, because it already has a fit_transform
, that would be used instead.
Co-authored-by: Adrin Jalali <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
sklearn/pipeline.py
Outdated
routed_fit_transform_params = Bunch() | ||
for name, _ in self.transformer_list: | ||
routed_fit_transform_params[name] = Bunch(fit_transform={}) | ||
routed_fit_transform_params[name].fit_transform = fit_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
routed_fit_transform_params[name].fit_transform = fit_params | |
routed_fit_transform_params[name].fit_transform = fit_params | |
routed_fit_transform_params[name].fit = fit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it needs to be like that:
if hasattr(obj, "fit_transform"):
routed_params[name] = Bunch(fit_transform={})
routed_params[name].fit_transform = params
else:
routed_params[name] = Bunch(fit={})
routed_params[name] = Bunch(transform={})
routed_params[name].fit = params
routed_params[name].transform = params
It accounts for the metadata to be passed to the correct methods of the sub-transformer. I've also added a test for that (test_feature_union_fit_params_without_fit_transform
).
I've addressed all your review comments, @adrinjalali. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
Co-authored-by: Adrin Jalali <[email protected]>
Okay, I've taken care of that, @adrinjalali. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good. Just some questions reagarding transform
and thus the renaming from fit_params
to params
.
Co-authored-by: Guillaume Lemaitre <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
Co-authored-by: Adrin Jalali <[email protected]>
@StefanieSenger I think that we are missing forwarding the metadata in the |
Oh yes, @glemaitre and @adrinjalali, sorry for the oversight. |
I've also simplified the Metadata Router of FeatureUnion (by aligning if and else conditions), because they are almost identical and I think no harm would happen, by providing more possibilities than an object would use. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @StefanieSenger
Reference Issues/PRs
Towards #22893
What does this implement/fix? Explain your changes.
This PR adds metadata routing to
FeatureUnion
and the corresponding tests.I was confused about why we cannot do
set_fit_transform_request
. I know we can work around it, but from a user's perspective it's not self-explanatory.