Plot the confusion matrix



1. Related concepts of confusion matrix

Therefore, from the perspective of two-class classification, in the two-class model, if all the results of the prediction situation and the actual situation are combined, there will be true positive, false positive, true negative and true negative. False negative (false negative) four situations, respectively represented by TP, FP, TN, FN (T stands for correct prediction, F stands for incorrect prediction), these four scenarios constitute a confusion matrix.

Insert picture description here

In fact, it can be done from the above figure. Only the prediction results that appear on the diagonal are correct, and the others are wrong.
The sum of the four cases is the total number of samples.


For different scenarios, we have different requirements for the model

  • For the diagnosis of the disease model, the model should be more inclined to find all the anti- samples (sick patients);
  • For spam detection model, which should be more inclined to elect all positive samples (normal mail).

2. Code implementation:

#confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
# classes = ['A','B','C','D','E']
# confusion_matrix = np.array([(9,1,3,4,0),(2,13,1,3,4),(1,4,10,0,13),(3,1,1,17,0),(0,0,0,1,14)],dtype=np.float64)


# 标签
classes=['angry','disgust','scared','happy','sad','surprised','neutral']

# 标签的个数
classNamber=7 #表情的数量

# 在标签中的矩阵
confusion_matrix = np.array([
    (0.70,0   ,0.07, 0.04, 0.09, 0.01, 0.09),
    (0.18,0.75,0,    0,    0.03, 0.02, 0.02),
    (0.09,0,   0.51, 0.04, 0.17, 0.09, 0.10),
    (0.02,0,   0.01, 0.91, 0.02, 0.01, 0.03),
    (0.10,0,   0.11, 0.03, 0.57, 0.01, 0.17),
    (0.02,0,   0.07, 0.04, 0,    0.84,0),
    (0.04,0,   0.03, 0.07, 0.12, 0.02,  0.72)
    ],dtype=np.float64)

plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Oranges)  #按照像素显示出矩阵
plt.title('confusion_matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=-45)
plt.yticks(tick_marks, classes)

thresh = confusion_matrix.max() / 2.
#iters = [[i,j] for i in range(len(classes)) for j in range((classes))]
#ij配对,遍历矩阵迭代器
iters = np.reshape([[[i,j] for j in range(classNamber)] for i in range(classNamber)],(confusion_matrix.size,2))
for i, j in iters:
    plt.text(j, i, format(confusion_matrix[i, j]),va='center',ha='center')   #显示对应的数字

plt.ylabel('Real label')
plt.xlabel('Prediction')
plt.tight_layout()
plt.show()


Effect picture:

Insert picture description here

references

[1]https://zhuanlan.zhihu.com/p/68473880
[2]https://blog.csdn.net/u014636245/article/details/85628083

Guess you like

Origin blog.csdn.net/zhaozhao236/article/details/109600252