Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cdd619d

Browse files
author
shangwuyao
committed
improved speed, correctly handle sample_weight, dealed with sparse matrix corner cases
1 parent 6d15cfd commit cdd619d

File tree

1 file changed

+54
-47
lines changed

1 file changed

+54
-47
lines changed

sklearn/metrics/classification.py

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
294294

295295

296296
def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
297-
labels=None, samplewise=False):
297+
labels=None):
298298
"""Returns a confusion matrix for each output of a multilabel problem
299299
300300
Multiclass tasks will be treated as if binarised under a one-vs-rest
@@ -309,8 +309,6 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
309309
labels : array-like
310310
A list of classes or column indices to select some (or to force
311311
inclusion of classes absent from the data)
312-
samplewise : bool, default=False
313-
In the multilabel case, this calculates a confusion matrix per sample
314312
315313
Returns
316314
-------
@@ -322,11 +320,14 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
322320
y_pred = check_array(y_pred, ensure_2d=False, dtype=None,
323321
accept_sparse=['csr', 'csc'])
324322
check_consistent_length(y_true, y_pred, sample_weight)
323+
if sample_weight is not None and sample_weight.ndim > 1:
324+
raise ValueError('sample_weight should be 1-d array. ')
325+
326+
y_type, _, _= _check_targets(y_true, y_pred)
327+
if y_type not in ("binary", "multiclass", "multilabel-indicator"):
328+
raise ValueError("%s is not supported" % y_type)
325329

326330
if y_true.ndim == 1:
327-
if samplewise:
328-
raise ValueError("Samplewise confusion is not useful outside of "
329-
"multilabel classification.")
330331
present_labels = unique_labels(y_true, y_pred)
331332
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
332333
labels=present_labels)
@@ -345,11 +346,11 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
345346
fp[not_present_mask] = 0
346347
fn[not_present_mask] = 0
347348

349+
tn = y_true.shape[0] - tp - fp - fn
350+
348351
else:
349352
# check labels
350-
if labels is None:
351-
labels = slice(None)
352-
else:
353+
if labels is not None:
353354
max_label = y_true.shape[1] - 1
354355
if np.max(labels) > max_label:
355356
raise ValueError('All labels must be in [0, n labels). '
@@ -359,6 +360,10 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
359360
raise ValueError('All labels must be in [0, n labels). '
360361
'Got %d < 0' % np.min(labels))
361362

363+
n_labels = len(labels)
364+
y_true = y_true[:, labels[:n_labels]]
365+
y_pred = y_pred[:, labels[:n_labels]]
366+
362367
# make sure values are in (0, 1) (but avoid unnecessary copy)
363368
if y_true.max() != 1 or y_true.min() != 0:
364369
if issparse(y_true):
@@ -372,50 +377,35 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
372377
y_pred = (y_pred == 1)
373378

374379
# account for sample weight
375-
376-
def _sparse_row_multiply(A, w):
377-
# for scipy <= 0.16 (and maybe later), A.multiply(w) will densify A
378-
if A.format == 'csr':
379-
return A._with_data(A.data * np.repeat(w, np.diff(A.indptr)))
380-
elif A.format == 'csc':
381-
return A._with_data(A.data * np.take(w, A.indices,
382-
mode='clip'))
380+
def _normal_or_sparse_row_multiply(A, w):
381+
if issparse(A):
382+
return A.multiply(w)
383+
elif issparse(w):
384+
return w.multiply(A)
383385
else:
384-
raise ValueError
386+
return np.multiply(A, w)
387+
388+
true_and_pred = _normal_or_sparse_row_multiply(y_true, y_pred)
385389

386390
if sample_weight is not None:
387-
if issparse(y_true):
388-
y_true = _sparse_row_multiply(y_true, sample_weight).tocsc()
389-
else:
390-
y_true = np.multiply(sample_weight, y_true)
391-
if issparse(y_pred):
392-
y_pred = _sparse_row_multiply(y_pred, sample_weight).tocsc()
393-
else:
394-
y_pred = np.multiply(sample_weight, y_pred)
395-
396-
if samplewise:
397-
y_true = y_true[:, labels].T
398-
y_pred = y_pred[:, labels].T
399-
labels = slice(None)
400-
401-
n_outputs = y_true.shape[1]
402-
y_pred_rows, y_pred_cols = y_pred.nonzero()
403-
# ravel is needed for sparse matrices
404-
if not len(y_pred_rows):
405-
# the below doesn't work in some older scipy for empty y_pred_cols
406-
tp = np.zeros(y_pred.shape[1], dtype=int)[labels]
407-
else:
408-
tp = np.bincount(y_pred_cols,
409-
weights=np.ravel(y_true[y_pred_rows,
410-
y_pred_cols]),
411-
minlength=n_outputs)[labels]
391+
sample_weight = sample_weight.reshape((len(sample_weight), 1))
392+
true_and_pred = _normal_or_sparse_row_multiply(true_and_pred,
393+
sample_weight)
394+
y_true = _normal_or_sparse_row_multiply(y_true, sample_weight)
395+
y_pred = _normal_or_sparse_row_multiply(y_pred, sample_weight)
396+
397+
tp = np.sum(true_and_pred, axis=0)
398+
412399
if y_true.dtype.kind in {'i', 'u', 'b'}:
413-
# bincount returns floats if weights is provided
414400
tp = tp.astype(np.int64)
415-
fp = np.ravel(y_pred.sum(axis=0))[labels] - tp
416-
fn = np.ravel(y_true.sum(axis=0))[labels] - tp
401+
fp = y_pred.sum(axis=0) - tp
402+
fn = y_true.sum(axis=0) - tp
403+
404+
if sample_weight is not None:
405+
tn = sample_weight.sum() - tp - fp - fn
406+
else:
407+
tn = y_true.shape[0] - tp - fp - fn
417408

418-
tn = y_true.shape[0] - tp - fp - fn
419409
return np.array([tn, fp, fn, tp]).T.reshape(-1, 2, 2)
420410

421411

@@ -1302,6 +1292,23 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
13021292
return precision, recall, f_score, true_sum
13031293

13041294

1295+
if __name__ == '__main__':
1296+
from scipy.stats import bernoulli
1297+
1298+
n_samples = 30000
1299+
n_labels = 2000
1300+
1301+
y_true = bernoulli.rvs(np.ones((n_samples, n_labels)) / 2,
1302+
size=(n_samples, n_labels))
1303+
y_pred = bernoulli.rvs(np.ones((n_samples, n_labels)) / 2,
1304+
size=(n_samples, n_labels))
1305+
1306+
precision_recall_fscore_support_with_multilabel_confusion_matrix(y_true,
1307+
y_pred)
1308+
precision_recall_fscore_support(y_true,
1309+
y_pred)
1310+
multilabel_confusion_matrix(y_true, y_pred)
1311+
13051312
def precision_recall_fscore_support_with_multilabel_confusion_matrix(
13061313
y_true, y_pred,
13071314
beta=1.0, labels=None,

0 commit comments

Comments
 (0)