Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ae943bd

Browse files
leonardbinetalk-lbinetglemaitrebtelIlia Ivanov
authored
FIX properly support sparse matrices in type_of_target (#14862)
Co-authored-by: Leonard Binet <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Bartosz Telenczuk <[email protected]> Co-authored-by: Ilia Ivanov <[email protected]> Co-authored-by: jeremie du boisberranger <[email protected]>
1 parent 2710a9e commit ae943bd

File tree

3 files changed

+80
-34
lines changed

3 files changed

+80
-34
lines changed

doc/whats_new/v1.2.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,9 @@ Changelog
458458
deterministic SVD used by the randomized SVD algorithm.
459459
:pr:`20617` by :user:`Srinath Kailasa <skailasa>`
460460

461+
- |FIX| :func:`utils.multiclass.type_of_target` now properly handles sparse matrices.
462+
:pr:`14862` by :user:`Léonard Binet <leonardbinet>`.
463+
461464
Code and Documentation Contributors
462465
-----------------------------------
463466

sklearn/utils/multiclass.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,11 @@ def is_multilabel(y):
170170
if issparse(y):
171171
if isinstance(y, (dok_matrix, lil_matrix)):
172172
y = y.tocsr()
173+
labels = xp.unique_values(y.data)
173174
return (
174175
len(y.data) == 0
175-
or xp.unique_values(y.data).size == 1
176-
and (
177-
y.dtype.kind in "biu"
178-
or _is_integral_float(xp.unique_values(y.data)) # bool, int, uint
179-
)
176+
or (labels.size == 1 or (labels.size == 2) and (0 in labels))
177+
and (y.dtype.kind in "biu" or _is_integral_float(labels)) # bool, int, uint
180178
)
181179
else:
182180
labels = xp.unique_values(y)
@@ -223,8 +221,9 @@ def type_of_target(y, input_name=""):
223221
224222
Parameters
225223
----------
226-
y : array-like
227-
Target values.
224+
y : {array-like, sparse matrix}
225+
Target values. If a sparse matrix, `y` is expected to be a
226+
CSR/CSC matrix.
228227
229228
input_name : str, default=""
230229
The data name used to construct the error message.
@@ -303,12 +302,13 @@ def type_of_target(y, input_name=""):
303302
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
304303
with warnings.catch_warnings():
305304
warnings.simplefilter("error", np.VisibleDeprecationWarning)
306-
try:
307-
y = xp.asarray(y)
308-
except (np.VisibleDeprecationWarning, ValueError):
309-
# dtype=object should be provided explicitly for ragged arrays,
310-
# see NEP 34
311-
y = xp.asarray(y, dtype=object)
305+
if not issparse(y):
306+
try:
307+
y = xp.asarray(y)
308+
except np.VisibleDeprecationWarning:
309+
# dtype=object should be provided explicitly for ragged arrays,
310+
# see NEP 34
311+
y = xp.asarray(y, dtype=object)
312312

313313
# The old sequence of sequences format
314314
try:
@@ -328,25 +328,39 @@ def type_of_target(y, input_name=""):
328328
pass
329329

330330
# Invalid inputs
331-
if y.ndim > 2 or (y.dtype == object and len(y) and not isinstance(y.flat[0], str)):
332-
return "unknown" # [[[1, 2]]] or [obj_1] and not ["label_1"]
333-
334-
if y.ndim == 2 and y.shape[1] == 0:
335-
return "unknown" # [[]]
336-
331+
if y.ndim not in (1, 2):
332+
# Number of dimension greater than 2: [[[1, 2]]]
333+
return "unknown"
334+
if not min(y.shape):
335+
# Empty ndarray: []/[[]]
336+
if y.ndim == 1:
337+
# 1-D empty array: []
338+
return "binary" # []
339+
# 2-D empty array: [[]]
340+
return "unknown"
341+
if not issparse(y) and y.dtype == object and not isinstance(y.flat[0], str):
342+
# [obj_1] and not ["label_1"]
343+
return "unknown"
344+
345+
# Check if multioutput
337346
if y.ndim == 2 and y.shape[1] > 1:
338347
suffix = "-multioutput" # [[1, 2], [1, 2]]
339348
else:
340349
suffix = "" # [1, 2, 3] or [[1], [2], [3]]
341350

342-
# check float and contains non-integer float values
343-
if y.dtype.kind == "f" and xp.any(y != y.astype(int)):
351+
# Check float and contains non-integer float values
352+
if y.dtype.kind == "f":
344353
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
345-
_assert_all_finite(y, input_name=input_name)
346-
return "continuous" + suffix
347-
348-
if (xp.unique_values(y).shape[0] > 2) or (y.ndim >= 2 and len(y[0]) > 1):
349-
return "multiclass" + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
354+
data = y.data if issparse(y) else y
355+
if xp.any(data != data.astype(int)):
356+
_assert_all_finite(data, input_name=input_name)
357+
return "continuous" + suffix
358+
359+
# Check multiclass
360+
first_row = y[0] if not issparse(y) else y.getrow(0).data
361+
if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
362+
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
363+
return "multiclass" + suffix
350364
else:
351365
return "binary" # [1, 2] or [["a"], ["b"]]
352366

sklearn/utils/tests/test_multiclass.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@
2727
from sklearn.svm import SVC
2828
from sklearn import datasets
2929

30+
sparse_multilable_explicit_zero = csc_matrix(np.array([[0, 1], [1, 0]]))
31+
sparse_multilable_explicit_zero[:, 0] = 0
32+
33+
34+
def _generate_sparse(
35+
matrix,
36+
matrix_types=(csr_matrix, csc_matrix, coo_matrix, dok_matrix, lil_matrix),
37+
dtypes=(bool, int, np.int8, np.uint8, float, np.float32),
38+
):
39+
return [
40+
matrix_type(matrix, dtype=dtype)
41+
for matrix_type in matrix_types
42+
for dtype in dtypes
43+
]
44+
3045

3146
EXAMPLES = {
3247
"multilabel-indicator": [
@@ -35,14 +50,10 @@
3550
csr_matrix(np.random.RandomState(42).randint(2, size=(10, 10))),
3651
[[0, 1], [1, 0]],
3752
[[0, 1]],
38-
csr_matrix(np.array([[0, 1], [1, 0]])),
39-
csr_matrix(np.array([[0, 1], [1, 0]], dtype=bool)),
40-
csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.int8)),
41-
csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.uint8)),
42-
csr_matrix(np.array([[0, 1], [1, 0]], dtype=float)),
43-
csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.float32)),
44-
csr_matrix(np.array([[0, 0], [0, 0]])),
45-
csr_matrix(np.array([[0, 1]])),
53+
sparse_multilable_explicit_zero,
54+
*_generate_sparse([[0, 1], [1, 0]]),
55+
*_generate_sparse([[0, 0], [0, 0]]),
56+
*_generate_sparse([[0, 1]]),
4657
# Only valid when data is dense
4758
[[-1, 1], [1, -1]],
4859
np.array([[-1, 1], [1, -1]]),
@@ -72,6 +83,11 @@
7283
np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.uint8),
7384
np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=float),
7485
np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.float32),
86+
*_generate_sparse(
87+
[[1, 0, 2, 2], [1, 4, 2, 4]],
88+
matrix_types=(csr_matrix, csc_matrix),
89+
dtypes=(int, np.int8, np.uint8, float, np.float32),
90+
),
7591
np.array([["a", "b"], ["c", "d"]]),
7692
np.array([["a", "b"], ["c", "d"]]),
7793
np.array([["a", "b"], ["c", "d"]], dtype=object),
@@ -110,9 +126,20 @@
110126
np.array([[0, 0.5], [0.5, 0]]),
111127
np.array([[0, 0.5], [0.5, 0]], dtype=np.float32),
112128
np.array([[0, 0.5]]),
129+
*_generate_sparse(
130+
[[0, 0.5], [0.5, 0]],
131+
matrix_types=(csr_matrix, csc_matrix),
132+
dtypes=(float, np.float32),
133+
),
134+
*_generate_sparse(
135+
[[0, 0.5]],
136+
matrix_types=(csr_matrix, csc_matrix),
137+
dtypes=(float, np.float32),
138+
),
113139
],
114140
"unknown": [
115141
[[]],
142+
np.array([[]], dtype=object),
116143
[()],
117144
# sequence of sequences that weren't supported even before deprecation
118145
np.array([np.array([]), np.array([1, 2, 3])], dtype=object),
@@ -121,6 +148,8 @@
121148
[frozenset([1, 2, 3]), frozenset([1, 2])],
122149
# and also confusable as sequences of sequences
123150
[{0: "a", 1: "b"}, {0: "a"}],
151+
# ndim 0
152+
np.array(0),
124153
# empty second dimension
125154
np.array([[], []]),
126155
# 3d

0 commit comments

Comments
 (0)