Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MNT (SLEP6) remove other_params from provess_routing #26909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ Changelog
- |Enhancement| :func:`base.clone` now supports `dict` as input and creates a
copy. :pr:`26786` by `Adrin Jalali`_.

- |API|:func:`~utils.metadata_routing.process_routing` now has a different
signature. The first two (the object and the method) are positional only,
and all metadata are passed as keyword arguments. :pr:`26909` by `Adrin
Jalali`_.

:mod:`sklearn.cross_decomposition`
..................................

Expand Down
8 changes: 4 additions & 4 deletions examples/miscellaneous/plot_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def get_metadata_routing(self):
return router

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)

self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
X_transformed = self.transformer_.transform(X, **params.transformer.transform)
Expand All @@ -458,7 +458,7 @@ def fit(self, X, y, **fit_params):
return self

def predict(self, X, **predict_params):
params = process_routing(self, "predict", predict_params)
params = process_routing(self, "predict", **predict_params)

X_transformed = self.transformer_.transform(X, **params.transformer.transform)
return self.classifier_.predict(X_transformed, **params.classifier.predict)
Expand Down Expand Up @@ -543,7 +543,7 @@ def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

def get_metadata_routing(self):
Expand Down Expand Up @@ -572,7 +572,7 @@ def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, sample_weight=None, **fit_params):
params = process_routing(self, "fit", fit_params, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
check_metadata(self, sample_weight=sample_weight)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

Expand Down
6 changes: 3 additions & 3 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,10 @@ def fit(self, X, y, sample_weight=None, **fit_params):

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="fit",
self,
"fit",
sample_weight=sample_weight,
other_params=fit_params,
**fit_params,
)
else:
# sample_weight checks
Expand Down
12 changes: 6 additions & 6 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,10 +1859,10 @@ def fit(self, X, y, sample_weight=None, **params):

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="fit",
self,
"fit",
sample_weight=sample_weight,
other_params=params,
**params,
)
else:
routed_params = Bunch()
Expand Down Expand Up @@ -2150,10 +2150,10 @@ def score(self, X, y, sample_weight=None, **score_params):
scoring = self._get_scorer()
if _routing_enabled():
routed_params = process_routing(
obj=self,
method="score",
self,
"score",
sample_weight=sample_weight,
other_params=score_params,
**score_params,
)
else:
routed_params = Bunch()
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __call__(self, estimator, *args, **kwargs):
cached_call = partial(_cached_call, cache)

if _routing_enabled():
routed_params = process_routing(self, "score", kwargs)
routed_params = process_routing(self, "score", **kwargs)
else:
# they all get the same args, and they all get them all
routed_params = Bunch(
Expand Down
16 changes: 7 additions & 9 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, **partial_fit_para

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="partial_fit",
other_params=partial_fit_params,
self,
"partial_fit",
sample_weight=sample_weight,
**partial_fit_params,
)
else:
if sample_weight is not None and not has_fit_parameter(
Expand Down Expand Up @@ -249,10 +249,10 @@ def fit(self, X, y, sample_weight=None, **fit_params):

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="fit",
other_params=fit_params,
self,
"fit",
sample_weight=sample_weight,
**fit_params,
)
else:
if sample_weight is not None and not has_fit_parameter(
Expand Down Expand Up @@ -706,9 +706,7 @@ def fit(self, X, Y, **fit_params):
del Y_pred_chain

if _routing_enabled():
routed_params = process_routing(
obj=self, method="fit", other_params=fit_params
)
routed_params = process_routing(self, "fit", **fit_params)
else:
routed_params = Bunch(estimator=Bunch(fit=fit_params))

Expand Down
18 changes: 8 additions & 10 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,7 @@ def _log_message(self, step_idx):

def _check_method_params(self, method, props, **kwargs):
if _routing_enabled():
routed_params = process_routing(
self, method=method, other_params=props, **kwargs
)
routed_params = process_routing(self, method, **props, **kwargs)
return routed_params
else:
fit_params_steps = Bunch(
Expand Down Expand Up @@ -586,7 +584,7 @@ def predict(self, X, **params):
return self.steps[-1][1].predict(Xt, **params)

# metadata routing enabled
routed_params = process_routing(self, "predict", other_params=params)
routed_params = process_routing(self, "predict", **params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict(Xt, **routed_params[self.steps[-1][0]].predict)
Expand Down Expand Up @@ -706,7 +704,7 @@ def predict_proba(self, X, **params):
return self.steps[-1][1].predict_proba(Xt, **params)

# metadata routing enabled
routed_params = process_routing(self, "predict_proba", other_params=params)
routed_params = process_routing(self, "predict_proba", **params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict_proba(
Expand Down Expand Up @@ -747,7 +745,7 @@ def decision_function(self, X, **params):

# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "decision_function", other_params=params)
routed_params = process_routing(self, "decision_function", **params)

Xt = X
for _, name, transform in self._iter(with_final=False):
Expand Down Expand Up @@ -833,7 +831,7 @@ def predict_log_proba(self, X, **params):
return self.steps[-1][1].predict_log_proba(Xt, **params)

# metadata routing enabled
routed_params = process_routing(self, "predict_log_proba", other_params=params)
routed_params = process_routing(self, "predict_log_proba", **params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict_log_proba(
Expand Down Expand Up @@ -882,7 +880,7 @@ def transform(self, X, **params):

# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "transform", other_params=params)
routed_params = process_routing(self, "transform", **params)
Xt = X
for _, name, transform in self._iter():
Xt = transform.transform(Xt, **routed_params[name].transform)
Expand Down Expand Up @@ -925,7 +923,7 @@ def inverse_transform(self, Xt, **params):

# we don't have to branch here, since params is only non-empty if
# enable_metadata_routing=True.
routed_params = process_routing(self, "inverse_transform", other_params=params)
routed_params = process_routing(self, "inverse_transform", **params)
reverse_iter = reversed(list(self._iter()))
for _, name, transform in reverse_iter:
Xt = transform.inverse_transform(
Expand Down Expand Up @@ -981,7 +979,7 @@ def score(self, X, y=None, sample_weight=None, **params):

# metadata routing is enabled.
routed_params = process_routing(
self, "score", sample_weight=sample_weight, other_params=params
self, "score", sample_weight=sample_weight, **params
)

Xt = X
Expand Down
12 changes: 6 additions & 6 deletions sklearn/tests/metadata_routing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

def get_metadata_routing(self):
Expand All @@ -345,12 +345,12 @@ def fit(self, X, y, sample_weight=None, **fit_params):
self.registry.append(self)

record_metadata(self, "fit", sample_weight=sample_weight)
params = process_routing(self, "fit", fit_params, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
return self

def predict(self, X, **predict_params):
params = process_routing(self, "predict", predict_params)
params = process_routing(self, "predict", **predict_params)
return self.estimator_.predict(X, **params.estimator.predict)

def get_metadata_routing(self):
Expand All @@ -374,7 +374,7 @@ def fit(self, X, y, sample_weight=None, **kwargs):
self.registry.append(self)

record_metadata(self, "fit", sample_weight=sample_weight)
params = process_routing(self, "fit", kwargs, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
return self

Expand All @@ -394,12 +394,12 @@ def __init__(self, transformer):
self.transformer = transformer

def fit(self, X, y=None, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
return self

def transform(self, X, y=None, **transform_params):
params = process_routing(self, "transform", transform_params)
params = process_routing(self, "transform", **transform_params)
return self.transformer_.transform(X, **params.transformer.transform)

def get_metadata_routing(self):
Expand Down
8 changes: 4 additions & 4 deletions sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, steps):

def fit(self, X, y, **fit_params):
self.steps_ = []
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
X_transformed = X
for i, step in enumerate(self.steps[:-1]):
transformer = clone(step).fit(
Expand All @@ -93,7 +93,7 @@ def fit(self, X, y, **fit_params):
def predict(self, X, **predict_params):
check_is_fitted(self)
X_transformed = X
params = process_routing(self, "predict", predict_params)
params = process_routing(self, "predict", **predict_params)
for i, step in enumerate(self.steps_[:-1]):
X_transformed = step.transform(X, **params.get(f"step_{i}").transform)

Expand Down Expand Up @@ -230,15 +230,15 @@ class OddEstimator(BaseEstimator):

def test_process_routing_invalid_method():
with pytest.raises(TypeError, match="Can only route and process input"):
process_routing(ConsumingClassifier(), "invalid_method", {})
process_routing(ConsumingClassifier(), "invalid_method", **{})


def test_process_routing_invalid_object():
class InvalidObject:
pass

with pytest.raises(AttributeError, match="has not implemented the routing"):
process_routing(InvalidObject(), "fit", {})
process_routing(InvalidObject(), "fit", **{})


def test_simple_metadata_routing():
Expand Down
41 changes: 16 additions & 25 deletions sklearn/utils/_metadata_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,34 +1412,33 @@ def get_metadata_routing(self):
# given metadata. This is to minimize the boilerplate required in routers.


def process_routing(obj, method, other_params, **kwargs):
# Here the first two arguments are positional only which makes everything
# passed as keyword argument a metadata. The first two args also have an `_`
# prefix to reduce the chances of name collisions with the passed metadata, and
# since they're positional only, users will never type those underscores.
def process_routing(_obj, _method, /, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is not the norm, may you add a comment here that explains the purpose of _obj and _method based on #26909 (comment) ?

"""Validate and route input parameters.

This function is used inside a router's method, e.g. :term:`fit`,
to validate the metadata and handle the routing.

Assuming this signature: ``fit(self, X, y, sample_weight=None, **fit_params)``,
a call to this function would be:
``process_routing(self, fit_params, sample_weight=sample_weight)``.
``process_routing(self, sample_weight=sample_weight, **fit_params)``.

.. versionadded:: 1.3

Parameters
----------
obj : object
_obj : object
An object implementing ``get_metadata_routing``. Typically a
meta-estimator.

method : str
_method : str
The name of the router's method in which this function is called.

other_params : dict
A dictionary of extra parameters passed to the router's method,
e.g. ``**fit_params`` passed to a meta-estimator's :term:`fit`.

**kwargs : dict
Parameters explicitly accepted and included in the router's method
signature.
Metadata to be routed.

Returns
-------
Expand All @@ -1449,27 +1448,19 @@ def process_routing(obj, method, other_params, **kwargs):
corresponding methods or corresponding child objects. The object names
are those defined in `obj.get_metadata_routing()`.
"""
if not hasattr(obj, "get_metadata_routing"):
if not hasattr(_obj, "get_metadata_routing"):
raise AttributeError(
f"This {repr(obj.__class__.__name__)} has not implemented the routing"
f"This {repr(_obj.__class__.__name__)} has not implemented the routing"
" method `get_metadata_routing`."
)
if method not in METHODS:
if _method not in METHODS:
raise TypeError(
f"Can only route and process input on these methods: {METHODS}, "
f"while the passed method is: {method}."
f"while the passed method is: {_method}."
)

# We take the extra params (**fit_params) which is passed as `other_params`
# and add the explicitly passed parameters (passed as **kwargs) to it. This
# is equivalent to a code such as this in a router:
# if sample_weight is not None:
# fit_params["sample_weight"] = sample_weight
all_params = other_params if other_params is not None else dict()
all_params.update(kwargs)

request_routing = get_routing_for_object(obj)
request_routing.validate_metadata(params=all_params, method=method)
routed_params = request_routing.route_params(params=all_params, caller=method)
request_routing = get_routing_for_object(_obj)
request_routing.validate_metadata(params=kwargs, method=_method)
routed_params = request_routing.route_params(params=kwargs, caller=_method)

return routed_params