diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 3186ed60faa08..d60ecf3cadb71 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -458,6 +458,9 @@ Changelog deterministic SVD used by the randomized SVD algorithm. :pr:`20617` by :user:`Srinath Kailasa ` +- |FIX| :func:`utils.multiclass.type_of_target` now properly handles sparse matrices. + :pr:`14862` by :user:`Léonard Binet `. + Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index d3bb22da4fb02..4792ed32661df 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -170,13 +170,11 @@ def is_multilabel(y): if issparse(y): if isinstance(y, (dok_matrix, lil_matrix)): y = y.tocsr() + labels = xp.unique_values(y.data) return ( len(y.data) == 0 - or xp.unique_values(y.data).size == 1 - and ( - y.dtype.kind in "biu" - or _is_integral_float(xp.unique_values(y.data)) # bool, int, uint - ) + or (labels.size == 1 or (labels.size == 2) and (0 in labels)) + and (y.dtype.kind in "biu" or _is_integral_float(labels)) # bool, int, uint ) else: labels = xp.unique_values(y) @@ -223,8 +221,9 @@ def type_of_target(y, input_name=""): Parameters ---------- - y : array-like - Target values. + y : {array-like, sparse matrix} + Target values. If a sparse matrix, `y` is expected to be a + CSR/CSC matrix. input_name : str, default="" The data name used to construct the error message. @@ -303,12 +302,13 @@ def type_of_target(y, input_name=""): # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html with warnings.catch_warnings(): warnings.simplefilter("error", np.VisibleDeprecationWarning) - try: - y = xp.asarray(y) - except (np.VisibleDeprecationWarning, ValueError): - # dtype=object should be provided explicitly for ragged arrays, - # see NEP 34 - y = xp.asarray(y, dtype=object) + if not issparse(y): + try: + y = xp.asarray(y) + except np.VisibleDeprecationWarning: + # dtype=object should be provided explicitly for ragged arrays, + # see NEP 34 + y = xp.asarray(y, dtype=object) # The old sequence of sequences format try: @@ -328,25 +328,39 @@ def type_of_target(y, input_name=""): pass # Invalid inputs - if y.ndim > 2 or (y.dtype == object and len(y) and not isinstance(y.flat[0], str)): - return "unknown" # [[[1, 2]]] or [obj_1] and not ["label_1"] - - if y.ndim == 2 and y.shape[1] == 0: - return "unknown" # [[]] - + if y.ndim not in (1, 2): + # Number of dimension greater than 2: [[[1, 2]]] + return "unknown" + if not min(y.shape): + # Empty ndarray: []/[[]] + if y.ndim == 1: + # 1-D empty array: [] + return "binary" # [] + # 2-D empty array: [[]] + return "unknown" + if not issparse(y) and y.dtype == object and not isinstance(y.flat[0], str): + # [obj_1] and not ["label_1"] + return "unknown" + + # Check if multioutput if y.ndim == 2 and y.shape[1] > 1: suffix = "-multioutput" # [[1, 2], [1, 2]] else: suffix = "" # [1, 2, 3] or [[1], [2], [3]] - # check float and contains non-integer float values - if y.dtype.kind == "f" and xp.any(y != y.astype(int)): + # Check float and contains non-integer float values + if y.dtype.kind == "f": # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] - _assert_all_finite(y, input_name=input_name) - return "continuous" + suffix - - if (xp.unique_values(y).shape[0] > 2) or (y.ndim >= 2 and len(y[0]) > 1): - return "multiclass" + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] + data = y.data if issparse(y) else y + if xp.any(data != data.astype(int)): + _assert_all_finite(data, input_name=input_name) + return "continuous" + suffix + + # Check multiclass + first_row = y[0] if not issparse(y) else y.getrow(0).data + if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1): + # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] + return "multiclass" + suffix else: return "binary" # [1, 2] or [["a"], ["b"]] diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py index 996a25bc3a42b..cf5858d0f52f9 100644 --- a/sklearn/utils/tests/test_multiclass.py +++ b/sklearn/utils/tests/test_multiclass.py @@ -27,6 +27,21 @@ from sklearn.svm import SVC from sklearn import datasets +sparse_multilable_explicit_zero = csc_matrix(np.array([[0, 1], [1, 0]])) +sparse_multilable_explicit_zero[:, 0] = 0 + + +def _generate_sparse( + matrix, + matrix_types=(csr_matrix, csc_matrix, coo_matrix, dok_matrix, lil_matrix), + dtypes=(bool, int, np.int8, np.uint8, float, np.float32), +): + return [ + matrix_type(matrix, dtype=dtype) + for matrix_type in matrix_types + for dtype in dtypes + ] + EXAMPLES = { "multilabel-indicator": [ @@ -35,14 +50,10 @@ csr_matrix(np.random.RandomState(42).randint(2, size=(10, 10))), [[0, 1], [1, 0]], [[0, 1]], - csr_matrix(np.array([[0, 1], [1, 0]])), - csr_matrix(np.array([[0, 1], [1, 0]], dtype=bool)), - csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.int8)), - csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.uint8)), - csr_matrix(np.array([[0, 1], [1, 0]], dtype=float)), - csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.float32)), - csr_matrix(np.array([[0, 0], [0, 0]])), - csr_matrix(np.array([[0, 1]])), + sparse_multilable_explicit_zero, + *_generate_sparse([[0, 1], [1, 0]]), + *_generate_sparse([[0, 0], [0, 0]]), + *_generate_sparse([[0, 1]]), # Only valid when data is dense [[-1, 1], [1, -1]], np.array([[-1, 1], [1, -1]]), @@ -72,6 +83,11 @@ np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.uint8), np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=float), np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.float32), + *_generate_sparse( + [[1, 0, 2, 2], [1, 4, 2, 4]], + matrix_types=(csr_matrix, csc_matrix), + dtypes=(int, np.int8, np.uint8, float, np.float32), + ), np.array([["a", "b"], ["c", "d"]]), np.array([["a", "b"], ["c", "d"]]), np.array([["a", "b"], ["c", "d"]], dtype=object), @@ -110,9 +126,20 @@ np.array([[0, 0.5], [0.5, 0]]), np.array([[0, 0.5], [0.5, 0]], dtype=np.float32), np.array([[0, 0.5]]), + *_generate_sparse( + [[0, 0.5], [0.5, 0]], + matrix_types=(csr_matrix, csc_matrix), + dtypes=(float, np.float32), + ), + *_generate_sparse( + [[0, 0.5]], + matrix_types=(csr_matrix, csc_matrix), + dtypes=(float, np.float32), + ), ], "unknown": [ [[]], + np.array([[]], dtype=object), [()], # sequence of sequences that weren't supported even before deprecation np.array([np.array([]), np.array([1, 2, 3])], dtype=object), @@ -121,6 +148,8 @@ [frozenset([1, 2, 3]), frozenset([1, 2])], # and also confusable as sequences of sequences [{0: "a", 1: "b"}, {0: "a"}], + # ndim 0 + np.array(0), # empty second dimension np.array([[], []]), # 3d