@@ -294,7 +294,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
294
294
295
295
296
296
def multilabel_confusion_matrix (y_true , y_pred , sample_weight = None ,
297
- labels = None , samplewise = False ):
297
+ labels = None ):
298
298
"""Returns a confusion matrix for each output of a multilabel problem
299
299
300
300
Multiclass tasks will be treated as if binarised under a one-vs-rest
@@ -309,8 +309,6 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
309
309
labels : array-like
310
310
A list of classes or column indices to select some (or to force
311
311
inclusion of classes absent from the data)
312
- samplewise : bool, default=False
313
- In the multilabel case, this calculates a confusion matrix per sample
314
312
315
313
Returns
316
314
-------
@@ -322,11 +320,14 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
322
320
y_pred = check_array (y_pred , ensure_2d = False , dtype = None ,
323
321
accept_sparse = ['csr' , 'csc' ])
324
322
check_consistent_length (y_true , y_pred , sample_weight )
323
+ if sample_weight is not None and sample_weight .ndim > 1 :
324
+ raise ValueError ('sample_weight should be 1-d array. ' )
325
+
326
+ y_type , _ , _ = _check_targets (y_true , y_pred )
327
+ if y_type not in ("binary" , "multiclass" , "multilabel-indicator" ):
328
+ raise ValueError ("%s is not supported" % y_type )
325
329
326
330
if y_true .ndim == 1 :
327
- if samplewise :
328
- raise ValueError ("Samplewise confusion is not useful outside of "
329
- "multilabel classification." )
330
331
present_labels = unique_labels (y_true , y_pred )
331
332
C = confusion_matrix (y_true , y_pred , sample_weight = sample_weight ,
332
333
labels = present_labels )
@@ -345,11 +346,11 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
345
346
fp [not_present_mask ] = 0
346
347
fn [not_present_mask ] = 0
347
348
349
+ tn = y_true .shape [0 ] - tp - fp - fn
350
+
348
351
else :
349
352
# check labels
350
- if labels is None :
351
- labels = slice (None )
352
- else :
353
+ if labels is not None :
353
354
max_label = y_true .shape [1 ] - 1
354
355
if np .max (labels ) > max_label :
355
356
raise ValueError ('All labels must be in [0, n labels). '
@@ -359,6 +360,10 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
359
360
raise ValueError ('All labels must be in [0, n labels). '
360
361
'Got %d < 0' % np .min (labels ))
361
362
363
+ n_labels = len (labels )
364
+ y_true = y_true [:, labels [:n_labels ]]
365
+ y_pred = y_pred [:, labels [:n_labels ]]
366
+
362
367
# make sure values are in (0, 1) (but avoid unnecessary copy)
363
368
if y_true .max () != 1 or y_true .min () != 0 :
364
369
if issparse (y_true ):
@@ -372,50 +377,35 @@ def multilabel_confusion_matrix(y_true, y_pred, sample_weight=None,
372
377
y_pred = (y_pred == 1 )
373
378
374
379
# account for sample weight
375
-
376
- def _sparse_row_multiply (A , w ):
377
- # for scipy <= 0.16 (and maybe later), A.multiply(w) will densify A
378
- if A .format == 'csr' :
379
- return A ._with_data (A .data * np .repeat (w , np .diff (A .indptr )))
380
- elif A .format == 'csc' :
381
- return A ._with_data (A .data * np .take (w , A .indices ,
382
- mode = 'clip' ))
380
+ def _normal_or_sparse_row_multiply (A , w ):
381
+ if issparse (A ):
382
+ return A .multiply (w )
383
+ elif issparse (w ):
384
+ return w .multiply (A )
383
385
else :
384
- raise ValueError
386
+ return np .multiply (A , w )
387
+
388
+ true_and_pred = _normal_or_sparse_row_multiply (y_true , y_pred )
385
389
386
390
if sample_weight is not None :
387
- if issparse (y_true ):
388
- y_true = _sparse_row_multiply (y_true , sample_weight ).tocsc ()
389
- else :
390
- y_true = np .multiply (sample_weight , y_true )
391
- if issparse (y_pred ):
392
- y_pred = _sparse_row_multiply (y_pred , sample_weight ).tocsc ()
393
- else :
394
- y_pred = np .multiply (sample_weight , y_pred )
395
-
396
- if samplewise :
397
- y_true = y_true [:, labels ].T
398
- y_pred = y_pred [:, labels ].T
399
- labels = slice (None )
400
-
401
- n_outputs = y_true .shape [1 ]
402
- y_pred_rows , y_pred_cols = y_pred .nonzero ()
403
- # ravel is needed for sparse matrices
404
- if not len (y_pred_rows ):
405
- # the below doesn't work in some older scipy for empty y_pred_cols
406
- tp = np .zeros (y_pred .shape [1 ], dtype = int )[labels ]
407
- else :
408
- tp = np .bincount (y_pred_cols ,
409
- weights = np .ravel (y_true [y_pred_rows ,
410
- y_pred_cols ]),
411
- minlength = n_outputs )[labels ]
391
+ sample_weight = sample_weight .reshape ((len (sample_weight ), 1 ))
392
+ true_and_pred = _normal_or_sparse_row_multiply (true_and_pred ,
393
+ sample_weight )
394
+ y_true = _normal_or_sparse_row_multiply (y_true , sample_weight )
395
+ y_pred = _normal_or_sparse_row_multiply (y_pred , sample_weight )
396
+
397
+ tp = np .sum (true_and_pred , axis = 0 )
398
+
412
399
if y_true .dtype .kind in {'i' , 'u' , 'b' }:
413
- # bincount returns floats if weights is provided
414
400
tp = tp .astype (np .int64 )
415
- fp = np .ravel (y_pred .sum (axis = 0 ))[labels ] - tp
416
- fn = np .ravel (y_true .sum (axis = 0 ))[labels ] - tp
401
+ fp = y_pred .sum (axis = 0 ) - tp
402
+ fn = y_true .sum (axis = 0 ) - tp
403
+
404
+ if sample_weight is not None :
405
+ tn = sample_weight .sum () - tp - fp - fn
406
+ else :
407
+ tn = y_true .shape [0 ] - tp - fp - fn
417
408
418
- tn = y_true .shape [0 ] - tp - fp - fn
419
409
return np .array ([tn , fp , fn , tp ]).T .reshape (- 1 , 2 , 2 )
420
410
421
411
@@ -1302,6 +1292,23 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
1302
1292
return precision , recall , f_score , true_sum
1303
1293
1304
1294
1295
+ if __name__ == '__main__' :
1296
+ from scipy .stats import bernoulli
1297
+
1298
+ n_samples = 30000
1299
+ n_labels = 2000
1300
+
1301
+ y_true = bernoulli .rvs (np .ones ((n_samples , n_labels )) / 2 ,
1302
+ size = (n_samples , n_labels ))
1303
+ y_pred = bernoulli .rvs (np .ones ((n_samples , n_labels )) / 2 ,
1304
+ size = (n_samples , n_labels ))
1305
+
1306
+ precision_recall_fscore_support_with_multilabel_confusion_matrix (y_true ,
1307
+ y_pred )
1308
+ precision_recall_fscore_support (y_true ,
1309
+ y_pred )
1310
+ multilabel_confusion_matrix (y_true , y_pred )
1311
+
1305
1312
def precision_recall_fscore_support_with_multilabel_confusion_matrix (
1306
1313
y_true , y_pred ,
1307
1314
beta = 1.0 , labels = None ,
0 commit comments