Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f212400

Browse files
committed
fix type_of_target for csr_matrices
1 parent 9d62e2b commit f212400

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

sklearn/utils/multiclass.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818

19-
from .validation import check_array, _assert_all_finite
19+
from .validation import check_array, _assert_all_finite, assert_all_finite
2020

2121

2222
def _unique_multiclass(y):
@@ -247,11 +247,14 @@ def type_of_target(y):
247247
if is_multilabel(y):
248248
return 'multilabel-indicator'
249249

250-
try:
251-
y = np.asarray(y)
252-
except ValueError:
253-
# Known to fail in numpy 1.3 for array of arrays
254-
return 'unknown'
250+
if not issparse(y):
251+
# calling np.asarray on sparse matrix has unexpected behavior
252+
# https://github.com/numpy/numpy/issues/14221
253+
try:
254+
y = np.asarray(y)
255+
except ValueError:
256+
# Known to fail in numpy 1.3 for array of arrays
257+
return 'unknown'
255258

256259
# The old sequence of sequences format
257260
try:
@@ -266,9 +269,13 @@ def type_of_target(y):
266269
pass
267270

268271
# Invalid inputs
269-
if y.ndim > 2 or (y.dtype == object and len(y) and
270-
not isinstance(y.flat[0], str)):
272+
if y.ndim > 2:
273+
return 'unknown'
274+
if not issparse(y) and y.dtype == object \
275+
and not isinstance(y.flat[0], str):
271276
return 'unknown' # [[[1, 2]]] or [obj_1] and not ["label_1"]
277+
if issparse(y) and y.dtype == object and not isinstance(y.data[0], str):
278+
return 'unknown' # [[[1, 2]]] or [obj_1] and not ["label_1"] (sparse)
272279

273280
if y.ndim == 2 and y.shape[1] == 0:
274281
return 'unknown' # [[]]
@@ -279,13 +286,22 @@ def type_of_target(y):
279286
suffix = "" # [1, 2, 3] or [[1], [2], [3]]
280287

281288
# check float and contains non-integer float values
282-
if y.dtype.kind == 'f' and np.any(y != y.astype(int)):
283-
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
284-
_assert_all_finite(y)
285-
return 'continuous' + suffix
286-
287-
if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1):
288-
return 'multiclass' + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
289+
if y.dtype.kind == 'f':
290+
if not issparse(y) and np.any(y != y.astype(int)):
291+
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
292+
_assert_all_finite(y)
293+
return 'continuous' + suffix
294+
if issparse(y) and np.any(y.data != y.data.astype(int)):
295+
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
296+
assert_all_finite(y)
297+
return 'continuous' + suffix
298+
299+
if len(np.unique(y)) > 2:
300+
return 'multiclass' + suffix # [1, 2, 3] or [[1., 2., 3]]
301+
if not issparse(y) and y.ndim >= 2 and len(y[0]) > 1:
302+
return 'multiclass' + suffix # [[1, 2]] or [[0],[1]]
303+
if issparse(y) and y.ndim >= 2:
304+
return 'multiclass' + suffix # [[1, 2]]
289305
else:
290306
return 'binary' # [1, 2] or [["a"], ["b"]]
291307

sklearn/utils/tests/test_multiclass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def __array__(self, dtype=None):
8181
np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.uint8),
8282
np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.float),
8383
np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.float32),
84+
csr_matrix(np.array([[1, 0, 2, 2], [1, 4, 2, 4]])),
85+
csr_matrix(np.array([[1, 0, 2, 2], [1, 4, 2, 4]]), dtype=np.int8),
86+
csr_matrix(np.array([[1, 0, 2, 2], [1, 4, 2, 4]]), dtype=np.uint8),
87+
csr_matrix(np.array([[1, 0, 2, 2], [1, 4, 2, 4]]), dtype=np.float),
88+
csr_matrix(np.array([[1, 0, 2, 2], [1, 4, 2, 4]]), dtype=np.float32),
8489
np.array([['a', 'b'], ['c', 'd']]),
8590
np.array([['a', 'b'], ['c', 'd']]),
8691
np.array([['a', 'b'], ['c', 'd']], dtype=object),
@@ -119,6 +124,9 @@ def __array__(self, dtype=None):
119124
np.array([[0, .5], [.5, 0]]),
120125
np.array([[0, .5], [.5, 0]], dtype=np.float32),
121126
np.array([[0, .5]]),
127+
csr_matrix(np.array([[0, .5], [.5, 0]])),
128+
csr_matrix(np.array([[0, .5], [.5, 0]]), dtype=np.float32),
129+
csr_matrix(np.array([[0, .5]])),
122130
],
123131
'unknown': [
124132
[[]],

0 commit comments

Comments
 (0)