Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 239793a

Browse files
committed
TST check underlying response method for TNR/TPR
1 parent 6985ae9 commit 239793a

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

sklearn/model_selection/_prediction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ def fit(self, X, y, sample_weight=None, **fit_params):
394394
self._response_method == "predict_proba"
395395
or self._response_method[0] == "predict_proba"
396396
):
397+
# `needs_proba=True` will first try to use `predict_proba` and then
398+
# `decision_function`
397399
params_scorer = {"needs_proba": True, "pos_label": self.pos_label}
398400
else:
399401
params_scorer = {"needs_threshold": True, "pos_label": self.pos_label}

sklearn/model_selection/tests/test_prediction.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,3 +475,47 @@ def test_cutoffclassifier_fit_params(objective_metric, fit_params_type):
475475
classifier, objective_metric=objective_metric, objective_value=0.5
476476
)
477477
model.fit(X, y, **fit_params)
478+
479+
480+
@pytest.mark.parametrize(
481+
"objective_metric, objective_value", [("tpr", 0.5), ("tnr", 0.5)]
482+
)
483+
@pytest.mark.parametrize(
484+
"response_method", ["auto", "decision_function", "predict_proba"]
485+
)
486+
def test_cutoffclassifier_response_method_scorer_tnr_tpr(
487+
objective_metric, objective_value, response_method, global_random_seed
488+
):
489+
"""Check that we use the proper scorer and forwarding the requested response method
490+
for `tnr` and `tpr`.
491+
"""
492+
X, y = make_classification(n_samples=100, random_state=global_random_seed)
493+
classifier = LogisticRegression()
494+
495+
model = CutOffClassifier(
496+
classifier,
497+
objective_metric=objective_metric,
498+
objective_value=objective_value,
499+
response_method=response_method,
500+
)
501+
model.fit(X, y)
502+
503+
# Note that optimizing TPR will increase the decision threshold while optimizing
504+
# TNR will decrease it. We therefore use the centered threshold (i.e. 0.5 for
505+
# probabilities and 0.0 for decision function) to check that the decision threshold
506+
# is properly set.
507+
if response_method in ("auto", "predict_proba"):
508+
# "auto" will fall back in priority on `predict_proba` if `estimator`
509+
# supports it.
510+
# we expect the decision threshold to be in [0, 1]
511+
if objective_metric == "tpr":
512+
assert 0.5 < model.decision_threshold_ < 1
513+
else: # "tnr"
514+
assert 0 < model.decision_threshold_ < 0.5
515+
else: # "decision_function"
516+
# we expect the decision function to be centered in 0.0 and to be larger than
517+
# -1 and 1.
518+
if objective_metric == "tpr":
519+
assert 0 < model.decision_threshold_ < 20
520+
else: # "tnr"
521+
assert -20 < model.decision_threshold_ < 0

0 commit comments

Comments
 (0)