Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MNT more informative error message for UnsetMetadataPassedError #28517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,10 @@ should be passed to the estimator's scorer or not::
... ).fit(X, y, sample_weight=my_weights)
... except ValueError as e:
... print(e)
[sample_weight] are passed but are not explicitly set as requested or not for
LogisticRegression.score
[sample_weight] are passed but are not explicitly set as requested or not
requested for LogisticRegression.score, which is used within GridSearchCV.fit.
Call `LogisticRegression.set_score_request({metadata}=True/False)` for each metadata
you want to request/ignore.

The issue can be fixed by explicitly setting the request value::

Expand Down
5 changes: 5 additions & 0 deletions examples/miscellaneous/plot_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,11 @@ def predict(self, X):
for w in record:
print(w.message)

# %%
# In the end, we disable the configuration flag for metadata routing:

set_config(enable_metadata_routing=False)

# %%
# Third Party Development and scikit-learn Dependency
# ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2545,7 +2545,7 @@ def test_metadata_routing_error_for_column_transformer(method):

error_message = (
"[sample_weight, metadata] are passed but are not explicitly set as requested"
f" or not for ConsumingTransformer.{method}"
f" or not requested for ConsumingTransformer.{method}"
)
with pytest.raises(ValueError, match=re.escape(error_message)):
if method == "transform":
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def test_metadata_routing_error_for_voting_estimators(Estimator, Child):

error_message = (
"[sample_weight, metadata] are passed but are not explicitly set as requested"
f" or not for {Child.__name__}.fit"
f" or not requested for {Child.__name__}.fit"
)

with pytest.raises(ValueError, match=re.escape(error_message)):
Expand Down
46 changes: 44 additions & 2 deletions sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
BaseEstimator,
clone,
)
from sklearn.exceptions import UnsetMetadataPassedError
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.tests.metadata_routing_common import (
ConsumingClassifier,
ConsumingRegressor,
Expand Down Expand Up @@ -68,7 +70,13 @@ def enable_slep006():


class SimplePipeline(BaseEstimator):
"""A very simple pipeline, assuming the last step is always a predictor."""
"""A very simple pipeline, assuming the last step is always a predictor.

Parameters
----------
steps : iterable of objects
An iterable of transformers with the last step being a predictor.
"""

def __init__(self, steps):
self.steps = steps
Expand Down Expand Up @@ -295,7 +303,7 @@ def test_simple_metadata_routing():
clf = WeightedMetaClassifier(estimator=ConsumingClassifier())
err_message = (
"[sample_weight] are passed but are not explicitly set as requested or"
" not for ConsumingClassifier.fit"
" not requested for ConsumingClassifier.fit"
)
with pytest.raises(ValueError, match=re.escape(err_message)):
clf.fit(X, y, sample_weight=my_weights)
Expand Down Expand Up @@ -1033,6 +1041,40 @@ def fit(self, X, y, metadata=None):
MetaRegressor(estimator=Estimator()).fit(X, y, metadata=my_groups)


def test_unsetmetadatapassederror_correct():
"""Test that UnsetMetadataPassedError raises the correct error message when
set_{method}_request is not set in nested cases."""
weighted_meta = WeightedMetaClassifier(estimator=ConsumingClassifier())
pipe = SimplePipeline([weighted_meta])
msg = re.escape(
"[metadata] are passed but are not explicitly set as requested or not requested"
" for ConsumingClassifier.fit, which is used within WeightedMetaClassifier.fit."
" Call `ConsumingClassifier.set_fit_request({metadata}=True/False)` for each"
" metadata you want to request/ignore."
)

with pytest.raises(UnsetMetadataPassedError, match=msg):
pipe.fit(X, y, metadata="blah")


def test_unsetmetadatapassederror_correct_for_composite_methods():
"""Test that UnsetMetadataPassedError raises the correct error message when
composite metadata request methods are not set in nested cases."""
consuming_transformer = ConsumingTransformer()
pipe = Pipeline([("consuming_transformer", consuming_transformer)])

msg = re.escape(
"[metadata] are passed but are not explicitly set as requested or not requested"
" for ConsumingTransformer.fit_transform, which is used within"
" Pipeline.fit_transform. Call"
" `ConsumingTransformer.set_fit_request({metadata}=True/False)"
".set_transform_request({metadata}=True/False)`"
" for each metadata you want to request/ignore."
)
with pytest.raises(UnsetMetadataPassedError, match=msg):
pipe.fit_transform(X, y, metadata="blah")


def test_unbound_set_methods_work():
"""Tests that if the set_{method}_request is unbound, it still works.

Expand Down
2 changes: 1 addition & 1 deletion sklearn/tests/test_metaestimators_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator):
instance = cls(**kwargs)
msg = (
f"[{key}] are passed but are not explicitly set as requested or not"
f" for {estimator.__class__.__name__}.{method_name}"
f" requested for {estimator.__class__.__name__}.{method_name}"
)
with pytest.raises(UnsetMetadataPassedError, match=re.escape(msg)):
method = getattr(instance, method_name)
Expand Down
6 changes: 3 additions & 3 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,7 +1915,7 @@ def test_metadata_routing_error_for_pipeline(method):
pipeline = Pipeline([("estimator", est)])
error_message = (
"[sample_weight, prop] are passed but are not explicitly set as requested"
f" or not for SimpleEstimator.{method}"
f" or not requested for SimpleEstimator.{method}"
)
with pytest.raises(ValueError, match=re.escape(error_message)):
try:
Expand Down Expand Up @@ -1975,7 +1975,7 @@ def test_feature_union_metadata_routing_error():

error_message = (
"[sample_weight, metadata] are passed but are not explicitly set as requested"
f" or not for {ConsumingTransformer.__name__}.fit"
f" or not requested for {ConsumingTransformer.__name__}.fit"
)

with pytest.raises(UnsetMetadataPassedError, match=re.escape(error_message)):
Expand All @@ -1995,7 +1995,7 @@ def test_feature_union_metadata_routing_error():

error_message = (
"[sample_weight, metadata] are passed but are not explicitly set as requested "
f"or not for {ConsumingTransformer.__name__}.transform"
f"or not requested for {ConsumingTransformer.__name__}.transform"
)

with pytest.raises(UnsetMetadataPassedError, match=re.escape(error_message)):
Expand Down
76 changes: 60 additions & 16 deletions sklearn/utils/_metadata_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _check_warnings(self, *, params):
"warning, or to True to consume and use the metadata."
)

def _route_params(self, params):
def _route_params(self, params, parent, caller):
"""Prepare the given parameters to be passed to the method.

The output of this method can be used directly as the input to the
Expand All @@ -414,6 +414,12 @@ def _route_params(self, params):
params : dict
A dictionary of provided metadata.

parent : object
Parent class object, that routes the metadata.

caller : str
Method from the parent class object, where the metadata is routed from.

Returns
-------
params : Bunch
Expand All @@ -434,12 +440,26 @@ def _route_params(self, params):
elif alias in args:
res[prop] = args[alias]
if unrequested:
if self.method in COMPOSITE_METHODS:
callee_methods = COMPOSITE_METHODS[self.method]
else:
callee_methods = [self.method]
set_requests_on = "".join(
[
f".set_{method}_request({{metadata}}=True/False)"
for method in callee_methods
]
)
message = (
f"[{', '.join([key for key in unrequested])}] are passed but are not"
" explicitly set as requested or not requested for"
f" {self.owner}.{self.method}, which is used within"
f" {parent}.{caller}. Call `{self.owner}"
+ set_requests_on
+ "` for each metadata you want to request/ignore."
)
raise UnsetMetadataPassedError(
message=(
f"[{', '.join([key for key in unrequested])}] are passed but are"
" not explicitly set as requested or not for"
f" {self.owner}.{self.method}"
),
message=message,
unrequested_params=unrequested,
routed_params=res,
)
Expand Down Expand Up @@ -591,28 +611,36 @@ def _get_param_names(self, method, return_alias, ignore_self_request=None):
"""
return getattr(self, method)._get_param_names(return_alias=return_alias)

def _route_params(self, *, method, params):
def _route_params(self, *, params, method, parent, caller):
"""Prepare the given parameters to be passed to the method.

The output of this method can be used directly as the input to the
corresponding method as extra keyword arguments to pass metadata.

Parameters
----------
params : dict
A dictionary of provided metadata.

method : str
The name of the method for which the parameters are requested and
routed.

params : dict
A dictionary of provided metadata.
parent : object
Parent class object, that routes the metadata.

caller : str
Method from the parent class object, where the metadata is routed from.

Returns
-------
params : Bunch
A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the
corresponding method.
"""
return getattr(self, method)._route_params(params=params)
return getattr(self, method)._route_params(
params=params, parent=parent, caller=caller
)

def _check_warnings(self, *, method, params):
"""Check whether metadata is passed which is marked as WARN.
Expand Down Expand Up @@ -938,7 +966,7 @@ def _get_param_names(self, *, method, return_alias, ignore_self_request):
)
return res

def _route_params(self, *, params, method):
def _route_params(self, *, params, method, parent, caller):
"""Prepare the given parameters to be passed to the method.

This is used when a router is used as a child object of another router.
Expand All @@ -950,12 +978,18 @@ def _route_params(self, *, params, method):

Parameters
----------
params : dict
A dictionary of provided metadata.

method : str
The name of the method for which the parameters are requested and
routed.

params : dict
A dictionary of provided metadata.
parent : object
Parent class object, that routes the metadata.

caller : str
Method from the parent class object, where the metadata is routed from.

Returns
-------
Expand All @@ -965,7 +999,14 @@ def _route_params(self, *, params, method):
"""
res = Bunch()
if self._self_request:
res.update(self._self_request._route_params(params=params, method=method))
res.update(
self._self_request._route_params(
params=params,
method=method,
parent=parent,
caller=caller,
)
)

param_names = self._get_param_names(
method=method, return_alias=True, ignore_self_request=True
Expand Down Expand Up @@ -1026,7 +1067,10 @@ def route_params(self, *, caller, params):
for _callee, _caller in mapping:
if _caller == caller:
res[name][_callee] = router._route_params(
params=params, method=_callee
params=params,
method=_callee,
parent=self.owner,
caller=caller,
)
return res

Expand Down Expand Up @@ -1059,7 +1103,7 @@ def validate_metadata(self, *, method, params):
if extra_keys:
raise TypeError(
f"{self.owner}.{method} got unexpected argument(s) {extra_keys}, which"
" are not requested metadata in any object."
" are not routed to any object."
)

def _serialize(self):
Expand Down