Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f9e93bf

Browse files
author
lucas de oliveira carvalho
committed
confusion matrix python script added
1 parent a1dde3b commit f9e93bf

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

scripts/Confusion_Matrix/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# CODE BY ZeroToMastery TensorFlow course.
2+
# The function makes a labelled confusion matrix comparing predictions and ground truth labels.
3+
4+
# If classes is passed, confusion matrix will be labelled, if not, integer class values
5+
# will be used.
6+
7+
## Args:
8+
9+
`y_true`: Array of truth labels (must be same shape as y_pred).
10+
`y_pred`: Array of predicted labels (must be same shape as y_true).
11+
`classes`: Array of class labels (e.g. string form). If `None`, integer labels are used.
12+
`figsize`: Size of output figure (default=(10, 10)).
13+
`text_size`: Size of output figure text (default=15).
14+
`norm`: normalize values or not (default=False).
15+
`savefig`: save confusion matrix to file (default=False).
16+
17+
## Returns: A labelled confusion matrix plot comparing y_true and y_pred.
18+
### Example usage:
19+
20+
"""make_confusion_matrix(y_true=test_labels, # ground truth test labels
21+
y_pred=y_preds, # predicted labels
22+
classes=class_names, # array of class label names
23+
figsize=(15, 15),
24+
text_size=10)"""
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import itertools
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
from sklearn.metrics import confusion_matrix
5+
6+
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False):
7+
# Create the confustion matrix
8+
cm = confusion_matrix(y_true, y_pred)
9+
cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
10+
n_classes = cm.shape[0] # find the number of classes we're dealing with
11+
12+
# Plot the figure and make it pretty
13+
fig, ax = plt.subplots(figsize=figsize)
14+
cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
15+
fig.colorbar(cax)
16+
17+
# Are there a list of classes?
18+
if classes:
19+
labels = classes
20+
else:
21+
labels = np.arange(cm.shape[0])
22+
23+
# Label the axes
24+
ax.set(title="Confusion Matrix",
25+
xlabel="Predicted label",
26+
ylabel="True label",
27+
xticks=np.arange(n_classes), # create enough axis slots for each class
28+
yticks=np.arange(n_classes),
29+
xticklabels=labels, # axes will labeled with class names (if they exist) or ints
30+
yticklabels=labels)
31+
32+
# Make x-axis labels appear on bottom
33+
ax.xaxis.set_label_position("bottom")
34+
ax.xaxis.tick_bottom()
35+
36+
# Set the threshold for different colors
37+
threshold = (cm.max() + cm.min()) / 2.
38+
39+
# Plot the text on each cell
40+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
41+
if norm:
42+
plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
43+
horizontalalignment="center",
44+
color="white" if cm[i, j] > threshold else "black",
45+
size=text_size)
46+
else:
47+
plt.text(j, i, f"{cm[i, j]}",
48+
horizontalalignment="center",
49+
color="white" if cm[i, j] > threshold else "black",
50+
size=text_size)
51+
52+
# Save the figure to the current working directory
53+
if savefig:
54+
fig.savefig("confusion_matrix.png")
55+
56+
y_true = [0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0]
57+
y_pred = [0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0]
58+
59+
make_confusion_matrix(y_true, y_pred)

0 commit comments

Comments
 (0)