-
-
Notifications
You must be signed in to change notification settings - Fork 26k
FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold #26120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b44dd9d
516f62f
d2fbee0
29e5e87
b645ade
3397c56
092689a
200ec31
e871558
31aa1c0
74614e8
27713af
ed1d9b3
8410317
c7d1fe4
c9d7a22
9981f3a
b9c9d5e
588f1c4
243d173
883e929
69333ed
8616da1
99a10b3
0f6dce2
d6fb9f7
7ff3d0d
6985ae9
239793a
92083ed
55d0844
7dfc4a6
729c9a8
787be21
146b170
03b1f7f
8a09a5f
d56f57f
cf164c5
c943f5e
bf1462b
fa89431
862519d
aa520da
5403cf6
cd37743
e7d07af
bc20a47
bba2f97
f925503
d539235
c0acd44
f87baa7
e4dac09
4da7cef
bd86595
45e6e5a
402a1a7
6745afc
4d557cc
2c6ee7e
91c8222
9a96ae1
3d4ce81
d7d8dac
aa3e83d
ab97d63
acb6af8
486a2bd
6d4c4aa
1d12e1f
21e20e0
7952cce
66ad513
7e8b824
d7e50a6
3844706
0866c42
43f971b
db63769
a9b984f
97105a4
52f5921
637c18e
314bc83
9a8ef4e
5b723a0
9ce463d
bba8f55
1c5487d
d302678
2ade221
dca5770
8897533
75bd7ac
c07a980
cc5ba48
66c4c7f
378930e
c88ed94
4715e67
b3bb39f
915624a
b72a72a
5108e43
95150b0
5b66ab8
b4e67fb
005126a
767a05f
080ba5c
05ec85d
63c32bd
1584c5b
1a5a247
4cc61b9
8f36235
9e6b384
5490ce4
8dad0a4
b918708
5e23523
d5578f9
f3f844e
44ad195
e489eab
1cf5528
26dc94e
6a1a6c7
b17b59e
43c1da8
41a6d07
ca06717
ab8b466
235abf5
d9ec528
69f60a6
45a8504
8c4c88d
b97ebf4
e37f831
383937f
d4ce3fb
b6b3548
759d680
23e65e6
6904817
bee1ebe
4d86a36
48fd7cd
0854cd4
ac75300
2df616e
b14225c
98dcefd
7e3d7aa
b958bb0
e7722f6
98a1db8
c28a3e1
076fd29
e728f1d
c73b205
a4890df
f8a5a79
5dfa435
d45a71b
a32c151
7592437
dc5346b
843ca04
51ed9a8
3c89ab3
8cd5582
a48487c
dd18549
27515ca
8a87b26
811dec9
c83b4e1
d4e232f
92f6e05
1c5c3f4
94160ba
85c8484
4f86e9d
d747098
6d0f418
5671dd6
f04085d
d179b5f
2c375f8
66ba8da
a6b19c1
b3b99ff
dda0d2c
48e7829
17839e8
3f02bc3
553cfce
6ae6d27
d010096
8bb8ca6
fd971c7
bf57dac
0409932
66ea575
1c97dd4
c8c1d0c
8a52bc6
ef668cf
fdbf68e
9c0c13d
f419371
42eafe5
581133f
0f803d9
eb0defc
ffd5669
18abafe
ce9464c
89d67cf
db3360b
1789cc0
e7c31b9
0fd667c
07e4387
9bd68e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
.. currentmodule:: sklearn.model_selection | ||
|
||
.. _TunedThresholdClassifierCV: | ||
|
||
================================================== | ||
Tuning the decision threshold for class prediction | ||
================================================== | ||
|
||
Classification is best divided into two parts: | ||
|
||
* the statistical problem of learning a model to predict, ideally, class probabilities; | ||
* the decision problem to take concrete action based on those probability predictions. | ||
|
||
Let's take a straightforward example related to weather forecasting: the first point is | ||
related to answering "what is the chance that it will rain tomorrow?" while the second | ||
point is related to answering "should I take an umbrella tomorrow?". | ||
|
||
When it comes to the scikit-learn API, the first point is addressed providing scores | ||
using :term:`predict_proba` or :term:`decision_function`. The former returns conditional | ||
probability estimates :math:`P(y|X)` for each class, while the latter returns a decision | ||
score for each class. | ||
|
||
The decision corresponding to the labels are obtained with :term:`predict`. In binary | ||
classification, a decision rule or action is then defined by thresholding the scores, | ||
leading to the prediction of a single class label for each sample. For binary | ||
classification in scikit-learn, class labels predictions are obtained by hard-coded | ||
cut-off rules: a positive class is predicted when the conditional probability | ||
:math:`P(y|X)` is greater than 0.5 (obtained with :term:`predict_proba`) or if the | ||
decision score is greater than 0 (obtained with :term:`decision_function`). | ||
|
||
Here, we show an example that illustrates the relation between conditional | ||
probability estimates :math:`P(y|X)` and class labels:: | ||
|
||
>>> from sklearn.datasets import make_classification | ||
>>> from sklearn.tree import DecisionTreeClassifier | ||
>>> X, y = make_classification(random_state=0) | ||
>>> classifier = DecisionTreeClassifier(max_depth=2, random_state=0).fit(X, y) | ||
>>> classifier.predict_proba(X[:4]) | ||
array([[0.94 , 0.06 ], | ||
[0.94 , 0.06 ], | ||
[0.0416..., 0.9583...], | ||
[0.0416..., 0.9583...]]) | ||
>>> classifier.predict(X[:4]) | ||
array([0, 0, 1, 1]) | ||
|
||
While these hard-coded rules might at first seem reasonable as default behavior, they | ||
are most certainly not ideal for most use cases. Let's illustrate with an example. | ||
|
||
Consider a scenario where a predictive model is being deployed to assist | ||
physicians in detecting tumors. In this setting, physicians will most likely be | ||
interested in identifying all patients with cancer and not missing anyone with cancer so | ||
that they can provide them with the right treatment. In other words, physicians | ||
prioritize achieving a high recall rate. This emphasis on recall comes, of course, with | ||
the trade-off of potentially more false-positive predictions, reducing the precision of | ||
the model. That is a risk physicians are willing to take because the cost of a missed | ||
cancer is much higher than the cost of further diagnostic tests. Consequently, when it | ||
comes to deciding whether to classify a patient as having cancer or not, it may be more | ||
beneficial to classify them as positive for cancer when the conditional probability | ||
estimate is much lower than 0.5. | ||
|
||
Post-tuning the decision threshold | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I mentioned in the other example, I would introduce the roc curve earlier in the explanation. |
||
================================== | ||
|
||
One solution to address the problem stated in the introduction is to tune the decision | ||
threshold of the classifier once the model has been trained. The | ||
:class:`~sklearn.model_selection.TunedThresholdClassifierCV` tunes this threshold using | ||
an internal cross-validation. The optimum threshold is chosen to maximize a given | ||
metric. | ||
|
||
The following image illustrates the tuning of the decision threshold for a gradient | ||
boosting classifier. While the vanilla and tuned classifiers provide the same | ||
:term:`predict_proba` outputs and thus the same Receiver Operating Characteristic (ROC) | ||
and Precision-Recall curves, the class label predictions differ because of the tuned | ||
decision threshold. The vanilla classifier predicts the class of interest for a | ||
conditional probability greater than 0.5 while the tuned classifier predicts the class | ||
of interest for a very low probability (around 0.02). This decision threshold optimizes | ||
a utility metric defined by the business (in this case an insurance company). | ||
|
||
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cost_sensitive_learning_002.png | ||
:target: ../auto_examples/model_selection/plot_cost_sensitive_learning.html | ||
:align: center | ||
|
||
Options to tune the decision threshold | ||
-------------------------------------- | ||
|
||
The decision threshold can be tuned through different strategies controlled by the | ||
parameter `scoring`. | ||
|
||
One way to tune the threshold is by maximizing a pre-defined scikit-learn metric. These | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would maybe mention that a common tuning is picking the top right point on the ROC curve which is the same as picking f2 score (I think?) here. Or maybe mention that that has a nice geometric explanation but doesn't really consider the application. |
||
metrics can be found by calling the function :func:`~sklearn.metrics.get_scorer_names`. | ||
By default, the balanced accuracy is the metric used but be aware that one should choose | ||
a meaningful metric for their use case. | ||
|
||
.. note:: | ||
|
||
It is important to notice that these metrics come with default parameters, notably | ||
the label of the class of interest (i.e. `pos_label`). Thus, if this label is not | ||
the right one for your application, you need to define a scorer and pass the right | ||
`pos_label` (and additional parameters) using the | ||
:func:`~sklearn.metrics.make_scorer`. Refer to :ref:`scoring` to get | ||
information to define your own scoring function. For instance, we show how to pass | ||
the information to the scorer that the label of interest is `0` when maximizing the | ||
:func:`~sklearn.metrics.f1_score`:: | ||
|
||
>>> from sklearn.linear_model import LogisticRegression | ||
>>> from sklearn.model_selection import TunedThresholdClassifierCV | ||
>>> from sklearn.metrics import make_scorer, f1_score | ||
>>> X, y = make_classification( | ||
... n_samples=1_000, weights=[0.1, 0.9], random_state=0) | ||
>>> pos_label = 0 | ||
>>> scorer = make_scorer(f1_score, pos_label=pos_label) | ||
>>> base_model = LogisticRegression() | ||
>>> model = TunedThresholdClassifierCV(base_model, scoring=scorer) | ||
>>> scorer(model.fit(X, y), X, y) | ||
0.88... | ||
>>> # compare it with the internal score found by cross-validation | ||
>>> model.best_score_ | ||
0.86... | ||
|
||
Important notes regarding the internal cross-validation | ||
------------------------------------------------------- | ||
|
||
By default :class:`~sklearn.model_selection.TunedThresholdClassifierCV` uses a 5-fold | ||
stratified cross-validation to tune the decision threshold. The parameter `cv` allows to | ||
control the cross-validation strategy. It is possible to bypass cross-validation by | ||
setting `cv="prefit"` and providing a fitted classifier. In this case, the decision | ||
threshold is tuned on the data provided to the `fit` method. | ||
|
||
However, you should be extremely careful when using this option. You should never use | ||
the same data for training the classifier and tuning the decision threshold due to the | ||
risk of overfitting. Refer to the following example section for more details (cf. | ||
:ref:`TunedThresholdClassifierCV_no_cv`). If you have limited resources, consider using | ||
a float number for `cv` to limit to an internal single train-test split. | ||
|
||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The option `cv="prefit"` should only be used when the provided classifier was already | ||
trained, and you just want to find the best decision threshold using a new validation | ||
set. | ||
|
||
.. _FixedThresholdClassifier: | ||
|
||
Manually setting the decision threshold | ||
--------------------------------------- | ||
|
||
The previous sections discussed strategies to find an optimal decision threshold. It is | ||
also possible to manually set the decision threshold using the class | ||
:class:`~sklearn.model_selection.FixedThresholdClassifier`. | ||
|
||
Examples | ||
-------- | ||
|
||
- See the example entitled | ||
:ref:`sphx_glr_auto_examples_model_selection_plot_tuned_decision_threshold.py`, | ||
to get insights on the post-tuning of the decision threshold. | ||
- See the example entitled | ||
:ref:`sphx_glr_auto_examples_model_selection_plot_cost_sensitive_learning.py`, | ||
to learn about cost-sensitive learning and decision threshold tuning. |
Uh oh!
There was an error while loading. Please reload this page.