Summary
When it comes to classification problems, we often need to analyze the experimental results by visualizing the confusion matrix to get the idea of parameter adjustment. This article introduces how to use python to draw the confusion matrix (confusion_matrix). This article only provides the code and gives the necessary comments.
Code
# -*-coding:utf-8-*-
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
#labels表示你不同类别的代号,比如这里的demo中有13个类别
labels = ['A', 'B', 'C', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']
#y_true代表真实的label值 y_pred代表预测得到的lavel值
y_true = np.loadtxt('../Data/re_label.txt')
y_pred = np.loadtxt('../Data/pr_label.txt')
tick_marks = np.array(range(len(labels))) + 0.5
def plot_confusion_matrix(cm, title='Confusion Matrix', cmap=plt.cm.binary):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
xlocations = np.array(range(len(labels)))
plt.xticks(xlocations, labels, rotation=90)
plt.yticks(xlocations, labels)
plt.ylabel('True label')
plt.xlabel('Predicted label')
cm = confusion_matrix(y_true, y_pred)
np.set_printoptions(precision=2)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print cm_normalized
plt.figure(figsize=(12, 8), dpi=120)
ind_array = np.arange(len(labels))
x, y = np.meshgrid(ind_array, ind_array)
for x_val, y_val in zip(x.flatten(), y.flatten()):
c = cm_normalized[y_val][x_val]
if c > 0.01:
plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=7, va='center', ha='center')
# offset the tick
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
# show confusion matrix
plt.savefig('../Data/confusion_matrix.png', format='png')
plt.show()
Result
Instructions
Replace labels, y_true and y_pred with your own data according to the comments in the code.