1
+ import itertools
2
+ import matplotlib .pyplot as plt
3
+ import numpy as np
4
+ from sklearn .metrics import confusion_matrix
5
+
6
+ def make_confusion_matrix (y_true , y_pred , classes = None , figsize = (10 , 10 ), text_size = 15 , norm = False , savefig = False ):
7
+ # Create the confustion matrix
8
+ cm = confusion_matrix (y_true , y_pred )
9
+ cm_norm = cm .astype ("float" ) / cm .sum (axis = 1 )[:, np .newaxis ] # normalize it
10
+ n_classes = cm .shape [0 ] # find the number of classes we're dealing with
11
+
12
+ # Plot the figure and make it pretty
13
+ fig , ax = plt .subplots (figsize = figsize )
14
+ cax = ax .matshow (cm , cmap = plt .cm .Blues ) # colors will represent how 'correct' a class is, darker == better
15
+ fig .colorbar (cax )
16
+
17
+ # Are there a list of classes?
18
+ if classes :
19
+ labels = classes
20
+ else :
21
+ labels = np .arange (cm .shape [0 ])
22
+
23
+ # Label the axes
24
+ ax .set (title = "Confusion Matrix" ,
25
+ xlabel = "Predicted label" ,
26
+ ylabel = "True label" ,
27
+ xticks = np .arange (n_classes ), # create enough axis slots for each class
28
+ yticks = np .arange (n_classes ),
29
+ xticklabels = labels , # axes will labeled with class names (if they exist) or ints
30
+ yticklabels = labels )
31
+
32
+ # Make x-axis labels appear on bottom
33
+ ax .xaxis .set_label_position ("bottom" )
34
+ ax .xaxis .tick_bottom ()
35
+
36
+ # Set the threshold for different colors
37
+ threshold = (cm .max () + cm .min ()) / 2.
38
+
39
+ # Plot the text on each cell
40
+ for i , j in itertools .product (range (cm .shape [0 ]), range (cm .shape [1 ])):
41
+ if norm :
42
+ plt .text (j , i , f"{ cm [i , j ]} ({ cm_norm [i , j ]* 100 :.1f} %)" ,
43
+ horizontalalignment = "center" ,
44
+ color = "white" if cm [i , j ] > threshold else "black" ,
45
+ size = text_size )
46
+ else :
47
+ plt .text (j , i , f"{ cm [i , j ]} " ,
48
+ horizontalalignment = "center" ,
49
+ color = "white" if cm [i , j ] > threshold else "black" ,
50
+ size = text_size )
51
+
52
+ # Save the figure to the current working directory
53
+ if savefig :
54
+ fig .savefig ("confusion_matrix.png" )
55
+
56
+ y_true = [0 , 0 , 0 , 0 , 1 , 1 , 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , 1 , 1 , 1 , 0 ]
57
+ y_pred = [0 , 0 , 1 , 0 , 1 , 1 , 0 , 0 , 1 , 1 , 0 , 0 , 1 , 0 , 1 , 1 , 1 , 0 , 1 , 0 ]
58
+
59
+ make_confusion_matrix (y_true , y_pred )
0 commit comments