Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b149b0e

Browse files
ShehanATglemaitre
authored andcommitted
FIX Enables label_ranking_average_precision_score to support sparse y_true (scikit-learn#23442)
Allow y_true to be in CSR format.
1 parent 50ee3fa commit b149b0e

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

doc/whats_new/v1.2.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ Changelog
265265
of a binary classification problem. :pr:`22518` by
266266
:user:`Arturo Amor <ArturoAmorQ>`.
267267

268+
- |Fix| Allows `csr_matrix` as input for parameter: `y_true` of
269+
the :func:`metrics.label_ranking_average_precision_score` metric.
270+
:pr:`23442` by :user:`Sean Atukorala <ShehanAT>`
271+
268272
- |Fix| :func:`metrics.ndcg_score` will now trigger a warning when the `y_true`
269273
value contains a negative value. Users may still use negative values, but the
270274
result may not be between 0 and 1. Starting in v1.4, passing in negative

sklearn/metrics/_ranking.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from functools import partial
2424

2525
import numpy as np
26-
from scipy.sparse import csr_matrix
26+
from scipy.sparse import csr_matrix, issparse
2727
from scipy.stats import rankdata
2828

2929
from ..utils import assert_all_finite
@@ -1070,7 +1070,7 @@ def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None
10701070
0.416...
10711071
"""
10721072
check_consistent_length(y_true, y_score, sample_weight)
1073-
y_true = check_array(y_true, ensure_2d=False)
1073+
y_true = check_array(y_true, ensure_2d=False, accept_sparse="csr")
10741074
y_score = check_array(y_score, ensure_2d=False)
10751075

10761076
if y_true.shape != y_score.shape:
@@ -1083,7 +1083,9 @@ def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None
10831083
):
10841084
raise ValueError("{0} format is not supported".format(y_type))
10851085

1086-
y_true = csr_matrix(y_true)
1086+
if not issparse(y_true):
1087+
y_true = csr_matrix(y_true)
1088+
10871089
y_score = -y_score
10881090

10891091
n_samples, n_labels = y_true.shape

sklearn/metrics/tests/test_ranking.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,3 +2042,12 @@ def test_top_k_accuracy_score_warning(y_true, k):
20422042
def test_top_k_accuracy_score_error(y_true, y_score, labels, msg):
20432043
with pytest.raises(ValueError, match=msg):
20442044
top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
2045+
2046+
2047+
def test_label_ranking_avg_precision_score_should_allow_csr_matrix_for_y_true_input():
2048+
# Test that label_ranking_avg_precision_score accept sparse y_true.
2049+
# Non-regression test for #22575
2050+
y_true = csr_matrix([[1, 0, 0], [0, 0, 1]])
2051+
y_score = np.array([[0.5, 0.9, 0.6], [0, 0, 1]])
2052+
result = label_ranking_average_precision_score(y_true, y_score)
2053+
assert result == pytest.approx(2 / 3)

0 commit comments

Comments
 (0)