Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Default argument pos_label=1 is not ignored in f1_score metric for multiclass classification #29734

Open
@slimebob1975

Description

@slimebob1975

Describe the bug

I get a ValueError for pos_label=1 default argument value to f1_score metric with argument average='micro' for the iris flower classification problem:

ValueError: pos_label=1 is not a valid label: It should be one of ['setosa' 'versicolor' 'virginica']

According to the documentation, the pos_label argument should be ignored for the multiclass problem:

https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#f1-score

The class to report if average='binary' and the data is binary, otherwise this parameter is ignored.

Setting pos_label explicitly to None solves the problem and produces the expected output, see below.

Steps/Code to Reproduce

# Import necessary libraries
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import make_scorer, f1_score

# Load the Iris dataset
data = load_iris()
X = data.data  # Features
y = data.target  # Labels

# Convert labels to string type
y = np.array([data.target_names[label] for label in data.target])

# Initialize the Linear Discriminant Analysis classifier
classifier = LinearDiscriminantAnalysis()

# Define a custom scorer using F1 score with average='micro'
f1_scorer = make_scorer(f1_score, average='micro', pos_label=1)

# Perform cross-validation with cross_val_score
try:
    scores = cross_val_score(classifier, X, y, cv=5, scoring=f1_scorer)
    print(f"Cross-validated F1 Scores (micro average): {scores}")
    print(f"Mean F1 Score: {np.mean(scores)}")
except ValueError as e:
    print(f"Error: {e}")

Expected Results

Cross-validated F1 Scores (micro average): [1.         1.         0.96666667 0.93333333 1.        ]
Mean F1 Score: 0.9800000000000001

Actual Results

Cross-validated F1 Scores (micro average): [nan nan nan nan nan]
Mean F1 Score: nan
[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection\_validation.py:1000](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/model_selection/_validation.py#line=999): UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 139](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=138), in __call__
    score = scorer._score(
            ^^^^^^^^^^^^^^
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 371](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=370), in _score
    y_pred = method_caller(
             ^^^^^^^^^^^^^^
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 89](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=88), in _cached_call
    result, _ = _get_response_values(
                ^^^^^^^^^^^^^^^^^^^^^
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\utils\_response.py", line 204](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/utils/_response.py#line=203), in _get_response_values
    raise ValueError(
ValueError: pos_label=1 is not a valid label: It should be one of ['setosa' 'versicolor' 'virginica']

Versions

System:
    python: 3.11.5 | packaged by Anaconda, Inc. | (main, Sep 11 2023, 13:26:23) [MSC v.1916 64 bit (AMD64)]
executable: C:\Users\rgt0227\AppData\Local\anaconda3\python.exe
   machine: Windows-10-10.0.19045-SP0

Python dependencies:
      sklearn: 1.5.0
          pip: 23.2.1
   setuptools: 68.0.0
        numpy: 1.26.2
        scipy: 1.11.4
       Cython: None
       pandas: 2.1.1
   matplotlib: 3.8.0
       joblib: 1.2.0
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: mkl
    num_threads: 8
         prefix: mkl_rt
       filepath: C:\Users\rgt0227\AppData\Local\anaconda3\Library\bin\mkl_rt.2.dll
        version: 2023.1-Product
threading_layer: intel

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: vcomp
       filepath: C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\.libs\vcomp140.dll
        version: None

Metadata

Metadata

Assignees

Labels

Buggood first issueEasy with clear instructions to resolve

Type

No type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions