Open
Description
Describe the bug
I get a ValueError
for pos_label=1
default argument value to f1_score
metric with argument average='micro'
for the iris flower classification problem:
ValueError: pos_label=1 is not a valid label: It should be one of ['setosa' 'versicolor' 'virginica']
According to the documentation, the pos_label
argument should be ignored for the multiclass problem:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#f1-score
The class to report if average='binary'
and the data is binary, otherwise this parameter is ignored.
Setting pos_label
explicitly to None solves the problem and produces the expected output, see below.
Steps/Code to Reproduce
# Import necessary libraries
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import make_scorer, f1_score
# Load the Iris dataset
data = load_iris()
X = data.data # Features
y = data.target # Labels
# Convert labels to string type
y = np.array([data.target_names[label] for label in data.target])
# Initialize the Linear Discriminant Analysis classifier
classifier = LinearDiscriminantAnalysis()
# Define a custom scorer using F1 score with average='micro'
f1_scorer = make_scorer(f1_score, average='micro', pos_label=1)
# Perform cross-validation with cross_val_score
try:
scores = cross_val_score(classifier, X, y, cv=5, scoring=f1_scorer)
print(f"Cross-validated F1 Scores (micro average): {scores}")
print(f"Mean F1 Score: {np.mean(scores)}")
except ValueError as e:
print(f"Error: {e}")
Expected Results
Cross-validated F1 Scores (micro average): [1. 1. 0.96666667 0.93333333 1. ]
Mean F1 Score: 0.9800000000000001
Actual Results
Cross-validated F1 Scores (micro average): [nan nan nan nan nan]
Mean F1 Score: nan
[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection\_validation.py:1000](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/model_selection/_validation.py#line=999): UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details:
Traceback (most recent call last):
File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 139](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=138), in __call__
score = scorer._score(
^^^^^^^^^^^^^^
File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 371](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=370), in _score
y_pred = method_caller(
^^^^^^^^^^^^^^
File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 89](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=88), in _cached_call
result, _ = _get_response_values(
^^^^^^^^^^^^^^^^^^^^^
File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\utils\_response.py", line 204](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/utils/_response.py#line=203), in _get_response_values
raise ValueError(
ValueError: pos_label=1 is not a valid label: It should be one of ['setosa' 'versicolor' 'virginica']
Versions
System:
python: 3.11.5 | packaged by Anaconda, Inc. | (main, Sep 11 2023, 13:26:23) [MSC v.1916 64 bit (AMD64)]
executable: C:\Users\rgt0227\AppData\Local\anaconda3\python.exe
machine: Windows-10-10.0.19045-SP0
Python dependencies:
sklearn: 1.5.0
pip: 23.2.1
setuptools: 68.0.0
numpy: 1.26.2
scipy: 1.11.4
Cython: None
pandas: 2.1.1
matplotlib: 3.8.0
joblib: 1.2.0
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: mkl
num_threads: 8
prefix: mkl_rt
filepath: C:\Users\rgt0227\AppData\Local\anaconda3\Library\bin\mkl_rt.2.dll
version: 2023.1-Product
threading_layer: intel
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: vcomp
filepath: C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\.libs\vcomp140.dll
version: None
Metadata
Metadata
Assignees
Type
Projects
Status
No status