-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
ENH Add routing to LogisticRegressionCV #26525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7952cce
66ad513
7e8b824
d7e50a6
3844706
0866c42
43f971b
db63769
a9b984f
637c18e
314bc83
9a8ef4e
5b723a0
c07a980
cc5ba48
915624a
6772e5b
49f7955
7efe941
a111bdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,17 +27,28 @@ | |
from ..preprocessing import LabelBinarizer, LabelEncoder | ||
from ..svm._base import _fit_liblinear | ||
from ..utils import ( | ||
Bunch, | ||
check_array, | ||
check_consistent_length, | ||
check_random_state, | ||
compute_class_weight, | ||
) | ||
from ..utils._param_validation import Interval, StrOptions | ||
from ..utils.extmath import row_norms, softmax | ||
from ..utils.metadata_routing import ( | ||
MetadataRouter, | ||
MethodMapping, | ||
_routing_enabled, | ||
process_routing, | ||
) | ||
from ..utils.multiclass import check_classification_targets | ||
from ..utils.optimize import _check_optimize_result, _newton_cg | ||
from ..utils.parallel import Parallel, delayed | ||
from ..utils.validation import _check_sample_weight, check_is_fitted | ||
from ..utils.validation import ( | ||
_check_method_params, | ||
_check_sample_weight, | ||
check_is_fitted, | ||
) | ||
from ._base import BaseEstimator, LinearClassifierMixin, SparseCoefMixin | ||
from ._glm.glm import NewtonCholeskySolver | ||
from ._linear_loss import LinearModelLoss | ||
|
@@ -576,23 +587,25 @@ def _log_reg_scoring_path( | |
y, | ||
train, | ||
test, | ||
pos_class=None, | ||
Cs=10, | ||
scoring=None, | ||
fit_intercept=False, | ||
max_iter=100, | ||
tol=1e-4, | ||
class_weight=None, | ||
verbose=0, | ||
solver="lbfgs", | ||
penalty="l2", | ||
dual=False, | ||
intercept_scaling=1.0, | ||
multi_class="auto", | ||
random_state=None, | ||
max_squared_sum=None, | ||
sample_weight=None, | ||
l1_ratio=None, | ||
*, | ||
pos_class, | ||
Cs, | ||
scoring, | ||
fit_intercept, | ||
max_iter, | ||
tol, | ||
class_weight, | ||
verbose, | ||
solver, | ||
penalty, | ||
dual, | ||
intercept_scaling, | ||
multi_class, | ||
random_state, | ||
max_squared_sum, | ||
sample_weight, | ||
l1_ratio, | ||
score_params, | ||
): | ||
"""Computes scores across logistic_regression_path | ||
|
||
|
@@ -704,6 +717,9 @@ def _log_reg_scoring_path( | |
to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a | ||
combination of L1 and L2. | ||
|
||
score_params : dict | ||
Parameters to pass to the `score` method of the underlying scorer. | ||
|
||
Returns | ||
------- | ||
coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1) | ||
|
@@ -784,7 +800,9 @@ def _log_reg_scoring_path( | |
if scoring is None: | ||
scores.append(log_reg.score(X_test, y_test)) | ||
else: | ||
scores.append(scoring(log_reg, X_test, y_test)) | ||
score_params = score_params or {} | ||
score_params = _check_method_params(X=X, params=score_params, indices=test) | ||
scores.append(scoring(log_reg, X_test, y_test, **score_params)) | ||
|
||
return coefs, Cs, np.array(scores), n_iter | ||
|
||
|
@@ -1747,7 +1765,7 @@ def __init__( | |
self.l1_ratios = l1_ratios | ||
|
||
@_fit_context(prefer_skip_nested_validation=True) | ||
def fit(self, X, y, sample_weight=None): | ||
def fit(self, X, y, sample_weight=None, **params): | ||
"""Fit the model according to the given training data. | ||
|
||
Parameters | ||
|
@@ -1763,11 +1781,22 @@ def fit(self, X, y, sample_weight=None): | |
Array of weights that are assigned to individual samples. | ||
If not provided, then each sample is given unit weight. | ||
|
||
**params : dict | ||
Parameters to pass to the underlying splitter and scorer. | ||
|
||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
self : object | ||
Fitted LogisticRegressionCV estimator. | ||
""" | ||
if params and not _routing_enabled(): | ||
raise ValueError( | ||
"params is only supported if enable_metadata_routing=True." | ||
" See the User Guide for more information." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be useful to get the link to the user guide. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean this page: https://scikit-learn.org/stable/metadata_routing.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it will definitely be useful. But if we add it here in one place whereas there are a number of other cases like this (this particular ValueError) that don't contain this link, won't that be a bit inconsistent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can open a subsequent PR to improve the consistency. |
||
) | ||
|
||
solver = _check_solver(self.solver, self.penalty, self.dual) | ||
|
||
if self.penalty == "elasticnet": | ||
|
@@ -1829,9 +1858,23 @@ def fit(self, X, y, sample_weight=None): | |
else: | ||
max_squared_sum = None | ||
|
||
if _routing_enabled(): | ||
routed_params = process_routing( | ||
obj=self, | ||
method="fit", | ||
sample_weight=sample_weight, | ||
other_params=params, | ||
) | ||
else: | ||
routed_params = Bunch() | ||
routed_params.splitter = Bunch(split={}) | ||
routed_params.scorer = Bunch(score=params) | ||
if sample_weight is not None: | ||
routed_params.scorer.score["sample_weight"] = sample_weight | ||
|
||
# init cross-validation generator | ||
cv = check_cv(self.cv, y, classifier=True) | ||
folds = list(cv.split(X, y)) | ||
folds = list(cv.split(X, y, **routed_params.splitter.split)) | ||
|
||
# Use the label encoded classes | ||
n_classes = len(encoded_labels) | ||
|
@@ -1898,6 +1941,7 @@ def fit(self, X, y, sample_weight=None): | |
max_squared_sum=max_squared_sum, | ||
sample_weight=sample_weight, | ||
l1_ratio=l1_ratio, | ||
score_params=routed_params.scorer.score, | ||
) | ||
for label in iter_encoded_labels | ||
for train, test in folds | ||
|
@@ -2078,7 +2122,7 @@ def fit(self, X, y, sample_weight=None): | |
|
||
return self | ||
|
||
def score(self, X, y, sample_weight=None): | ||
def score(self, X, y, sample_weight=None, **score_params): | ||
"""Score using the `scoring` option on the given test data and labels. | ||
|
||
Parameters | ||
|
@@ -2092,15 +2136,74 @@ def score(self, X, y, sample_weight=None): | |
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
**score_params : dict | ||
Parameters to pass to the `score` method of the underlying scorer. | ||
|
||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
score : float | ||
Score of self.predict(X) w.r.t. y. | ||
""" | ||
scoring = self.scoring or "accuracy" | ||
scoring = get_scorer(scoring) | ||
if score_params and not _routing_enabled(): | ||
raise ValueError( | ||
"score_params is only supported if enable_metadata_routing=True." | ||
" See the User Guide for more information." | ||
" https://scikit-learn.org/stable/metadata_routing.html" | ||
) | ||
|
||
scoring = self._get_scorer() | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if _routing_enabled(): | ||
routed_params = process_routing( | ||
obj=self, | ||
method="score", | ||
sample_weight=sample_weight, | ||
other_params=score_params, | ||
) | ||
else: | ||
routed_params = Bunch() | ||
routed_params.scorer = Bunch(score={}) | ||
if sample_weight is not None: | ||
routed_params.scorer.score["sample_weight"] = sample_weight | ||
|
||
return scoring( | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
X, | ||
y, | ||
**routed_params.scorer.score, | ||
) | ||
|
||
return scoring(self, X, y, sample_weight=sample_weight) | ||
def get_metadata_routing(self): | ||
"""Get metadata routing of this object. | ||
|
||
Please check :ref:`User Guide <metadata_routing>` on how the routing | ||
mechanism works. | ||
|
||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
routing : MetadataRouter | ||
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating | ||
routing information. | ||
""" | ||
|
||
router = ( | ||
MetadataRouter(owner=self.__class__.__name__) | ||
.add_self_request(self) | ||
.add( | ||
splitter=self.cv, | ||
method_mapping=MethodMapping().add(callee="split", caller="fit"), | ||
) | ||
.add( | ||
scorer=self._get_scorer(), | ||
method_mapping=MethodMapping() | ||
.add(callee="score", caller="score") | ||
.add(callee="score", caller="fit"), | ||
) | ||
) | ||
return router | ||
|
||
def _more_tags(self): | ||
return { | ||
|
@@ -2110,3 +2213,10 @@ def _more_tags(self): | |
), | ||
} | ||
} | ||
|
||
def _get_scorer(self): | ||
"""Get the scorer based on the scoring method specified. | ||
The default scoring method is `accuracy`. | ||
""" | ||
scoring = self.scoring or "accuracy" | ||
return get_scorer(scoring) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the removal of the default parameters, the docstring would need their removal too.