-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
FEAT add SLEP006 with a feature flag #26103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5fb1f50
82589be
ee03818
8c818b7
2f6abcb
f589736
2396f31
61845f2
fb0fdf2
b7fb316
54deade
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
# | ||
# License: BSD 3 clause | ||
|
||
from inspect import signature | ||
from numbers import Integral | ||
import warnings | ||
from functools import partial | ||
|
@@ -48,7 +49,13 @@ | |
from .svm import LinearSVC | ||
from .model_selection import check_cv, cross_val_predict | ||
from .metrics._base import _check_pos_label_consistency | ||
from .utils.metadata_routing import MetadataRouter, MethodMapping, process_routing | ||
from sklearn.utils import Bunch | ||
from .utils.metadata_routing import ( | ||
MetadataRouter, | ||
MethodMapping, | ||
process_routing, | ||
_routing_enabled, | ||
) | ||
|
||
|
||
class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): | ||
|
@@ -378,12 +385,34 @@ def fit(self, X, y, sample_weight=None, **fit_params): | |
self.classes_ = label_encoder_.classes_ | ||
n_classes = len(self.classes_) | ||
|
||
routed_params = process_routing( | ||
obj=self, | ||
method="fit", | ||
sample_weight=sample_weight, | ||
other_params=fit_params, | ||
) | ||
if _routing_enabled(): | ||
routed_params = process_routing( | ||
obj=self, | ||
method="fit", | ||
sample_weight=sample_weight, | ||
other_params=fit_params, | ||
) | ||
else: | ||
# sample_weight checks | ||
fit_parameters = signature(estimator.fit).parameters | ||
supports_sw = "sample_weight" in fit_parameters | ||
if sample_weight is not None and not supports_sw: | ||
estimator_name = type(estimator).__name__ | ||
warnings.warn( | ||
f"Since {estimator_name} does not appear to accept" | ||
" sample_weight, sample weights will only be used for the" | ||
" calibration itself. This can be caused by a limitation of" | ||
" the current scikit-learn API. See the following issue for" | ||
" more details:" | ||
" https://github.com/scikit-learn/scikit-learn/issues/21134." | ||
" Be warned that the result of the calibration is likely to be" | ||
" incorrect." | ||
) | ||
Comment on lines
+396
to
+410
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reviewer: This code is part of |
||
routed_params = Bunch() | ||
routed_params.splitter = Bunch(split={}) # no routing for splitter | ||
routed_params.estimator = Bunch(fit=fit_params) | ||
if sample_weight is not None and supports_sw: | ||
routed_params.estimator.fit["sample_weight"] = sample_weight | ||
|
||
# Check that each cross-validation fold can have at least one | ||
# example per class | ||
|
@@ -526,11 +555,6 @@ def get_metadata_routing(self): | |
splitter=self.cv, | ||
method_mapping=MethodMapping().add(callee="split", caller="fit"), | ||
) | ||
# the fit method already accepts everything, therefore we don't | ||
# specify parameters. The value passed to ``child`` needs to be the | ||
# same as what's passed to ``add`` above, in this case | ||
# `"estimator"`. | ||
.warn_on(child="estimator", method="fit", params=None) | ||
Comment on lines
-529
to
-533
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need to be backward compatible now when the feature flag is on, so these lines are removed. |
||
) | ||
return router | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,7 @@ | |
from .cluster import normalized_mutual_info_score | ||
from .cluster import fowlkes_mallows_score | ||
|
||
from ..utils import Bunch | ||
from ..utils.multiclass import type_of_target | ||
from ..base import is_regressor | ||
from ..utils._param_validation import validate_params | ||
|
@@ -71,6 +72,7 @@ | |
from ..utils.metadata_routing import MetadataRouter | ||
from ..utils.metadata_routing import process_routing | ||
from ..utils.metadata_routing import get_routing_for_object | ||
from ..utils.metadata_routing import _routing_enabled | ||
|
||
|
||
def _cached_call(cache, estimator, method, *args, **kwargs): | ||
|
@@ -115,16 +117,22 @@ def __call__(self, estimator, *args, **kwargs): | |
cache = {} if self._use_cache(estimator) else None | ||
cached_call = partial(_cached_call, cache) | ||
|
||
params = process_routing(self, "score", kwargs) | ||
if _routing_enabled(): | ||
routed_params = process_routing(self, "score", kwargs) | ||
else: | ||
# they all get the same args, and they all get them all | ||
routed_params = Bunch( | ||
**{name: Bunch(score=kwargs) for name in self._scorers} | ||
) | ||
|
||
for name, scorer in self._scorers.items(): | ||
try: | ||
if isinstance(scorer, _BaseScorer): | ||
score = scorer._score( | ||
cached_call, estimator, *args, **params.get(name).score | ||
cached_call, estimator, *args, **routed_params.get(name).score | ||
) | ||
else: | ||
score = scorer(estimator, *args, **params.get(name).score) | ||
score = scorer(estimator, *args, **routed_params.get(name).score) | ||
scores[name] = score | ||
except Exception as e: | ||
if self._raise_exc: | ||
|
@@ -233,7 +241,7 @@ def __repr__(self): | |
kwargs_string, | ||
) | ||
|
||
def __call__(self, estimator, X, y_true, **kwargs): | ||
def __call__(self, estimator, X, y_true, sample_weight=None, **kwargs): | ||
"""Evaluate predicted target values for X relative to y_true. | ||
|
||
Parameters | ||
|
@@ -248,17 +256,32 @@ def __call__(self, estimator, X, y_true, **kwargs): | |
y_true : array-like | ||
Gold standard target values for X. | ||
|
||
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
**kwargs : dict | ||
Other parameters passed to the scorer, e.g. sample_weight. | ||
Refer to :func:`set_score_request` for more details. | ||
|
||
.. versionadded:: 2.0 | ||
Only available if `enable_metadata_routing=True`. See the | ||
:ref:`User Guide <metadata_routing>`. | ||
|
||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
score : float | ||
Score function applied to prediction of estimator on X. | ||
""" | ||
if kwargs and not _routing_enabled(): | ||
raise ValueError( | ||
"kwargs is only supported if enable_metadata_routing=True. See" | ||
" the User Guide for more information." | ||
) | ||
|
||
if sample_weight is not None: | ||
kwargs["sample_weight"] = sample_weight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was about to say that we modify a mutable here and it should be safer to copy. However, now I see that we pass |
||
|
||
return self._score(partial(_cached_call, None), estimator, X, y_true, **kwargs) | ||
|
||
def _factory_args(self): | ||
|
@@ -284,7 +307,7 @@ def set_score_request(self, **kwargs): | |
Please see :ref:`User Guide <metadata_routing>` on how the routing | ||
mechanism works. | ||
|
||
.. versionadded:: 2.0 | ||
.. versionadded:: 1.4 | ||
|
||
Parameters | ||
---------- | ||
|
@@ -332,7 +355,7 @@ def _score(self, method_caller, estimator, X, y_true, **kwargs): | |
Other parameters passed to the scorer, e.g. sample_weight. | ||
Refer to :func:`set_score_request` for more details. | ||
|
||
.. versionadded:: 2.0 | ||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
|
@@ -377,7 +400,7 @@ def _score(self, method_caller, clf, X, y, **kwargs): | |
Other parameters passed to the scorer, e.g. sample_weight. | ||
Refer to :func:`set_score_request` for more details. | ||
|
||
.. versionadded:: 2.0 | ||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
|
@@ -404,8 +427,11 @@ def _score(self, method_caller, clf, X, y, **kwargs): | |
scoring_kwargs = {**self._kwargs, **kwargs} | ||
# this is for backward compatibility to avoid passing sample_weight | ||
# to the scorer if it's None | ||
# TODO(1.3) Probably remove | ||
if scoring_kwargs.get("sample_weight", -1) is None: | ||
# TODO: Probably remove when deprecating enable_metadata_routing | ||
if ( | ||
"sample_weight" in scoring_kwargs | ||
and scoring_kwargs["sample_weight"] is None | ||
): | ||
del scoring_kwargs["sample_weight"] | ||
|
||
return self._sign * self._score_func(y, y_pred, **scoring_kwargs) | ||
|
@@ -441,7 +467,7 @@ def _score(self, method_caller, clf, X, y, **kwargs): | |
Other parameters passed to the scorer, e.g. sample_weight. | ||
Refer to :func:`set_score_request` for more details. | ||
|
||
.. versionadded:: 2.0 | ||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
|
@@ -489,9 +515,13 @@ def _score(self, method_caller, clf, X, y, **kwargs): | |
scoring_kwargs = {**self._kwargs, **kwargs} | ||
# this is for backward compatibility to avoid passing sample_weight | ||
# to the scorer if it's None | ||
# TODO(1.3) Probably remove | ||
if scoring_kwargs.get("sample_weight", -1) is None: | ||
# TODO: Probably remove when deprecating enable_metadata_routing | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if ( | ||
"sample_weight" in scoring_kwargs | ||
and scoring_kwargs["sample_weight"] is None | ||
): | ||
del scoring_kwargs["sample_weight"] | ||
|
||
return self._sign * self._score_func(y, y_pred, **scoring_kwargs) | ||
|
||
def _factory_args(self): | ||
|
@@ -552,7 +582,7 @@ def __call__(self, estimator, *args, **kwargs): | |
def get_metadata_routing(self): | ||
"""Get requested data properties. | ||
|
||
.. versionadded:: 2.0 | ||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
|
@@ -924,9 +954,9 @@ class _DeprecatedScorers(dict): | |
|
||
def __getitem__(self, item): | ||
warnings.warn( | ||
"sklearn.metrics.SCORERS is deprecated and will be removed in v1.3. " | ||
"Please use sklearn.metrics.get_scorer_names to get a list of available " | ||
"scorers and sklearn.metrics.get_metric to get scorer.", | ||
"sklearn.metrics.SCORERS is deprecated and will be removed in v1.3." | ||
" Please use sklearn.metrics.get_scorer_names to get a list of" | ||
" available scorers and sklearn.metrics.get_metric to get scorer.", | ||
FutureWarning, | ||
) | ||
return super().__getitem__(item) | ||
|
Uh oh!
There was an error while loading. Please reload this page.