Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
9 changes: 9 additions & 0 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
Metadata Routing
================

.. note::
This feature, default metadata routing, and the API related to it are **all
experimental**, and might change without the usual deprecation cycle. By
default this feature is not enabled. You can enable this feature by setting
the ``enable_metadata_routing`` flag to ``True``:

>>> import sklearn
>>> sklearn.set_config(enable_metadata_routing=True)

This guide demonstrates how metadata such as ``sample_weight`` can be routed
and passed along to estimators, scorers, and CV splitters through
meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass
Expand Down
36 changes: 1 addition & 35 deletions examples/plot_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sklearn.utils.metadata_routing import MethodMapping
from sklearn.utils.metadata_routing import process_routing
from sklearn.utils.validation import check_is_fitted
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.linear_model import LinearRegression

N, M = 100, 4
X = np.random.rand(N, M)
Expand Down Expand Up @@ -617,40 +617,6 @@ def predict(self, X):
except Exception as e:
print(e)

# %%
# You might want to give your users a period during which they see a
# ``FutureWarning`` instead in order to have time to adapt to the new API. For
# this, the :class:`~sklearn.utils.metadata_routing.MetadataRouter` provides a
# `warn_on` method:


class WarningMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
.add(estimator=self.estimator, method_mapping="one-to-one")
.warn_on(child="estimator", method="fit", params=None)
)
return router


with warnings.catch_warnings(record=True) as record:
WarningMetaRegressor(estimator=LogisticRegression()).fit(
X, y, sample_weight=my_weights
)
for w in record:
print(w.message)

# %%
# Note that in the above implementation, the value passed to ``child`` the same
# as the key passed to the ``add`` method, in this case ``"estimator"``.

# %%
# Third Party Development and scikit-learn Dependency
Expand Down
30 changes: 30 additions & 0 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"enable_cython_pairwise_dist": True,
"array_api_dispatch": False,
"transform_output": "default",
"enable_metadata_routing": False,
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -54,6 +55,7 @@ def set_config(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
enable_metadata_routing=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -134,6 +136,18 @@ def set_config(

.. versionadded:: 1.2

enable_metadata_routing : bool, default=None
Enable metadata routing. By default this feature is disabled.

Refer to :ref:`metadata routing user guide <metadata_routing>` for more
details.

- `True`: Metadata routing is enabled
- `False`: Metadata routing is disabled, use the old syntax.
- `None`: Configuration is unchanged

.. versionadded:: 1.4

See Also
--------
config_context : Context manager for global scikit-learn configuration.
Expand All @@ -157,6 +171,8 @@ def set_config(
local_config["array_api_dispatch"] = array_api_dispatch
if transform_output is not None:
local_config["transform_output"] = transform_output
if enable_metadata_routing is not None:
local_config["enable_metadata_routing"] = enable_metadata_routing


@contextmanager
Expand All @@ -170,6 +186,7 @@ def config_context(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
enable_metadata_routing=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -249,6 +266,18 @@ def config_context(

.. versionadded:: 1.2

enable_metadata_routing : bool, default=None
Enable metadata routing. By default this feature is disabled.

Refer to :ref:`metadata routing user guide <metadata_routing>` for more
details.

- `True`: Metadata routing is enabled
- `False`: Metadata routing is disabled, use the old syntax.
- `None`: Configuration is unchanged

.. versionadded:: 1.4

Yields
------
None.
Expand Down Expand Up @@ -286,6 +315,7 @@ def config_context(
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
array_api_dispatch=array_api_dispatch,
transform_output=transform_output,
enable_metadata_routing=enable_metadata_routing,
)

try:
Expand Down
48 changes: 36 additions & 12 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#
# License: BSD 3 clause

from inspect import signature
from numbers import Integral
import warnings
from functools import partial
Expand Down Expand Up @@ -48,7 +49,13 @@
from .svm import LinearSVC
from .model_selection import check_cv, cross_val_predict
from .metrics._base import _check_pos_label_consistency
from .utils.metadata_routing import MetadataRouter, MethodMapping, process_routing
from sklearn.utils import Bunch
from .utils.metadata_routing import (
MetadataRouter,
MethodMapping,
process_routing,
_routing_enabled,
)


class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down Expand Up @@ -378,12 +385,34 @@ def fit(self, X, y, sample_weight=None, **fit_params):
self.classes_ = label_encoder_.classes_
n_classes = len(self.classes_)

routed_params = process_routing(
obj=self,
method="fit",
sample_weight=sample_weight,
other_params=fit_params,
)
if _routing_enabled():
routed_params = process_routing(
obj=self,
method="fit",
sample_weight=sample_weight,
other_params=fit_params,
)
else:
# sample_weight checks
fit_parameters = signature(estimator.fit).parameters
supports_sw = "sample_weight" in fit_parameters
if sample_weight is not None and not supports_sw:
estimator_name = type(estimator).__name__
warnings.warn(
f"Since {estimator_name} does not appear to accept"
" sample_weight, sample weights will only be used for the"
" calibration itself. This can be caused by a limitation of"
" the current scikit-learn API. See the following issue for"
" more details:"
" https://github.com/scikit-learn/scikit-learn/issues/21134."
" Be warned that the result of the calibration is likely to be"
" incorrect."
)
Comment on lines +396 to +410
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reviewer: This code is part of main

routed_params = Bunch()
routed_params.splitter = Bunch(split={}) # no routing for splitter
routed_params.estimator = Bunch(fit=fit_params)
if sample_weight is not None and supports_sw:
routed_params.estimator.fit["sample_weight"] = sample_weight

# Check that each cross-validation fold can have at least one
# example per class
Expand Down Expand Up @@ -526,11 +555,6 @@ def get_metadata_routing(self):
splitter=self.cv,
method_mapping=MethodMapping().add(callee="split", caller="fit"),
)
# the fit method already accepts everything, therefore we don't
# specify parameters. The value passed to ``child`` needs to be the
# same as what's passed to ``add`` above, in this case
# `"estimator"`.
.warn_on(child="estimator", method="fit", params=None)
Comment on lines -529 to -533
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to be backward compatible now when the feature flag is on, so these lines are removed.

)
return router

Expand Down
64 changes: 47 additions & 17 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from .cluster import normalized_mutual_info_score
from .cluster import fowlkes_mallows_score

from ..utils import Bunch
from ..utils.multiclass import type_of_target
from ..base import is_regressor
from ..utils._param_validation import validate_params
Expand All @@ -71,6 +72,7 @@
from ..utils.metadata_routing import MetadataRouter
from ..utils.metadata_routing import process_routing
from ..utils.metadata_routing import get_routing_for_object
from ..utils.metadata_routing import _routing_enabled


def _cached_call(cache, estimator, method, *args, **kwargs):
Expand Down Expand Up @@ -115,16 +117,22 @@ def __call__(self, estimator, *args, **kwargs):
cache = {} if self._use_cache(estimator) else None
cached_call = partial(_cached_call, cache)

params = process_routing(self, "score", kwargs)
if _routing_enabled():
routed_params = process_routing(self, "score", kwargs)
else:
# they all get the same args, and they all get them all
routed_params = Bunch(
**{name: Bunch(score=kwargs) for name in self._scorers}
)

for name, scorer in self._scorers.items():
try:
if isinstance(scorer, _BaseScorer):
score = scorer._score(
cached_call, estimator, *args, **params.get(name).score
cached_call, estimator, *args, **routed_params.get(name).score
)
else:
score = scorer(estimator, *args, **params.get(name).score)
score = scorer(estimator, *args, **routed_params.get(name).score)
scores[name] = score
except Exception as e:
if self._raise_exc:
Expand Down Expand Up @@ -233,7 +241,7 @@ def __repr__(self):
kwargs_string,
)

def __call__(self, estimator, X, y_true, **kwargs):
def __call__(self, estimator, X, y_true, sample_weight=None, **kwargs):
"""Evaluate predicted target values for X relative to y_true.

Parameters
Expand All @@ -248,17 +256,32 @@ def __call__(self, estimator, X, y_true, **kwargs):
y_true : array-like
Gold standard target values for X.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

**kwargs : dict
Other parameters passed to the scorer, e.g. sample_weight.
Refer to :func:`set_score_request` for more details.

.. versionadded:: 2.0
Only available if `enable_metadata_routing=True`. See the
:ref:`User Guide <metadata_routing>`.

.. versionadded:: 1.4

Returns
-------
score : float
Score function applied to prediction of estimator on X.
"""
if kwargs and not _routing_enabled():
raise ValueError(
"kwargs is only supported if enable_metadata_routing=True. See"
" the User Guide for more information."
)

if sample_weight is not None:
kwargs["sample_weight"] = sample_weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to say that we modify a mutable here and it should be safer to copy. However, now I see that we pass **kwargs and not kwargs so it should be fine.


return self._score(partial(_cached_call, None), estimator, X, y_true, **kwargs)

def _factory_args(self):
Expand All @@ -284,7 +307,7 @@ def set_score_request(self, **kwargs):
Please see :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

.. versionadded:: 2.0
.. versionadded:: 1.4

Parameters
----------
Expand Down Expand Up @@ -332,7 +355,7 @@ def _score(self, method_caller, estimator, X, y_true, **kwargs):
Other parameters passed to the scorer, e.g. sample_weight.
Refer to :func:`set_score_request` for more details.

.. versionadded:: 2.0
.. versionadded:: 1.4

Returns
-------
Expand Down Expand Up @@ -377,7 +400,7 @@ def _score(self, method_caller, clf, X, y, **kwargs):
Other parameters passed to the scorer, e.g. sample_weight.
Refer to :func:`set_score_request` for more details.

.. versionadded:: 2.0
.. versionadded:: 1.4

Returns
-------
Expand All @@ -404,8 +427,11 @@ def _score(self, method_caller, clf, X, y, **kwargs):
scoring_kwargs = {**self._kwargs, **kwargs}
# this is for backward compatibility to avoid passing sample_weight
# to the scorer if it's None
# TODO(1.3) Probably remove
if scoring_kwargs.get("sample_weight", -1) is None:
# TODO: Probably remove when deprecating enable_metadata_routing
if (
"sample_weight" in scoring_kwargs
and scoring_kwargs["sample_weight"] is None
):
del scoring_kwargs["sample_weight"]

return self._sign * self._score_func(y, y_pred, **scoring_kwargs)
Expand Down Expand Up @@ -441,7 +467,7 @@ def _score(self, method_caller, clf, X, y, **kwargs):
Other parameters passed to the scorer, e.g. sample_weight.
Refer to :func:`set_score_request` for more details.

.. versionadded:: 2.0
.. versionadded:: 1.4

Returns
-------
Expand Down Expand Up @@ -489,9 +515,13 @@ def _score(self, method_caller, clf, X, y, **kwargs):
scoring_kwargs = {**self._kwargs, **kwargs}
# this is for backward compatibility to avoid passing sample_weight
# to the scorer if it's None
# TODO(1.3) Probably remove
if scoring_kwargs.get("sample_weight", -1) is None:
# TODO: Probably remove when deprecating enable_metadata_routing
if (
"sample_weight" in scoring_kwargs
and scoring_kwargs["sample_weight"] is None
):
del scoring_kwargs["sample_weight"]

return self._sign * self._score_func(y, y_pred, **scoring_kwargs)

def _factory_args(self):
Expand Down Expand Up @@ -552,7 +582,7 @@ def __call__(self, estimator, *args, **kwargs):
def get_metadata_routing(self):
"""Get requested data properties.

.. versionadded:: 2.0
.. versionadded:: 1.4

Returns
-------
Expand Down Expand Up @@ -924,9 +954,9 @@ class _DeprecatedScorers(dict):

def __getitem__(self, item):
warnings.warn(
"sklearn.metrics.SCORERS is deprecated and will be removed in v1.3. "
"Please use sklearn.metrics.get_scorer_names to get a list of available "
"scorers and sklearn.metrics.get_metric to get scorer.",
"sklearn.metrics.SCORERS is deprecated and will be removed in v1.3."
" Please use sklearn.metrics.get_scorer_names to get a list of"
" available scorers and sklearn.metrics.get_metric to get scorer.",
FutureWarning,
)
return super().__getitem__(item)
Expand Down
Loading