Description
Describe the bug
Related to #27977
Also applies to pr_auc metric.
When defining multi-metric scoring as a dictionary and passing to cross_validate()
:
scoring = {
"accuracy": make_scorer(metrics.accuracy_score),
"sensitivity": make_scorer(metrics.recall_score),
"specificity": make_scorer(metrics.recall_score, pos_label=0),
"f1": make_scorer(metrics.f1_score),
"roc_auc": make_scorer(metrics.roc_auc_score),
"pr_auc": make_scorer(metrics.average_precision_score),
"precision": make_scorer(metrics.precision_score),
}
the roc_auc
is based on class labels (predict()
) rather than scores (decision_function()
or predict_proba()
)
Trying to set response_method
in make_scorer
doesn't work:
scoring = {
"accuracy": make_scorer(metrics.accuracy_score),
"sensitivity": make_scorer(metrics.recall_score),
"specificity": make_scorer(metrics.recall_score, pos_label=0),
"f1": make_scorer(metrics.f1_score),
"roc_auc": make_scorer(metrics.roc_auc_score, response_method="decision_function"),
"pr_auc": make_scorer(metrics.average_precision_score, response_method="decision_function"),
"precision": make_scorer(metrics.precision_score),
}
because roc_auc
is still a _PredictScorer
object.
Passing roc_auc
as string will work though.
scoring = {
"accuracy": make_scorer(metrics.accuracy_score),
"sensitivity": make_scorer(metrics.recall_score),
"specificity": make_scorer(metrics.recall_score, pos_label=0),
"f1": make_scorer(metrics.f1_score),
"roc_auc": 'roc_auc',
"pr_auc": make_scorer(metrics.average_precision_score),
"precision": make_scorer(metrics.precision_score),
}
The following code also works:
roc_auc_score_dec_fnc = _ThresholdScorer(metrics.roc_auc_score, 1, {})
pr_auc_score_dec_fnc = _ThresholdScorer(metrics.average_precision_score, 1, {})
scoring = {
"accuracy": make_scorer(metrics.accuracy_score),
"sensitivity": make_scorer(metrics.recall_score),
"specificity": make_scorer(metrics.recall_score, pos_label=0),
"f1": make_scorer(metrics.f1_score),
"roc_auc": roc_auc_score_dec_fnc ,
"pr_auc": pr_auc_score_dec_fnc ,
"precision": make_scorer(metrics.precision_score),
}
but _ThresholdScorer
is an undocumented class.
I think make_scorer()
should return _ThresholdScorer
or _PredictProbaScorer
object instead of _PredictScorer
object when setting the response_method
parameter because this bug is very easy to miss and most users might think make_scorer(metrics.roc_auc_score)
and passing as 'roc_auc'
will be the same.
Update: setting needs_threshold
parameter in make_scorer()
function will work properly.
roc_auc_score_dec_fnc = make_scorer(metrics.roc_auc_score, needs_threshold=True)
pr_auc_score_dec_fnc = make_scorer(metrics.average_precision_score, needs_threshold=True)
scoring = {
"accuracy": make_scorer(metrics.accuracy_score),
"sensitivity": make_scorer(metrics.recall_score),
"specificity": make_scorer(metrics.recall_score, pos_label=0),
"f1": make_scorer(metrics.f1_score),
"roc_auc": roc_auc_score_dec_fnc,
"pr_auc": pr_auc_score_dec_fnc,
"precision": make_scorer(metrics.precision_score),
}
Steps/Code to Reproduce
from sklearn.metrics._scorer import _PredictScorer
from sklearn.metrics._scorer import _ThresholdScorer
scoring = {
"accuracy": make_scorer(metrics.accuracy_score),
"sensitivity": make_scorer(metrics.recall_score),
"specificity": make_scorer(metrics.recall_score, pos_label=0),
"f1": make_scorer(metrics.f1_score),
"roc_auc": make_scorer(metrics.roc_auc_score),
"pr_auc": make_scorer(metrics.average_precision_score),
"precision": make_scorer(metrics.precision_score),
}
scorer = _MultimetricScorer(scorers=scoring , raise_exc=(error_score == "raise"))
# Should print True
print(isinstance(scorer._scorers['roc_auc'], _PredictScorer))
# Should print False
print(isinstance(scorer._scorers['roc_auc'], _ThresholdScorer))
Expected Results
n/a
Actual Results
n/a
Versions
System:
python: 3.11.6 | packaged by conda-forge | (main, Oct 3 2023, 10:29:11) [MSC v.1935 64 bit (AMD64)]
executable: C:\Users\mning\AppData\Local\miniforge3\envs\mne\python.exe
machine: Windows-10-10.0.19045-SP0
Python dependencies:
sklearn: 1.3.2
pip: 23.3.1
setuptools: 68.2.2
numpy: 1.26.0
scipy: 1.11.3
Cython: None
pandas: 2.1.3
matplotlib: 3.8.1
joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libomp
filepath: C:\Users\mning\AppData\Local\miniforge3\envs\mne\Library\bin\libomp.dll
version: None
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libblas
filepath: C:\Users\mning\AppData\Local\miniforge3\envs\mne\Library\bin\libblas.dll
version: 0.3.24
threading_layer: pthreads
architecture: Haswell
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: vcomp
filepath: C:\Users\mning\AppData\Local\miniforge3\envs\mne\vcomp140.dll
version: None