Python is generated based on its own model data set and weight confusion matrix

Confusion Matrix is ​​a tabular format used to evaluate the performance of classification models. It shows various combinations between the model's predictions and the actual labels in a classification problem.

Confusion matrices are commonly used for binary classification problems, but can also be extended to multi-class classification problems. For binary classification problems, it consists of four important indicators:

True Positive (TP): The number of positive cases predicted by the model and actually positive.
True Negative (TN): The number of negative examples predicted by the model and actually negative.
False Positive (FP): The number of positive examples predicted by the model, but actually negative examples. Also known as a "false positive".
False Negative (FN): The number of negative examples predicted by the model but actually positive. Also known as "false negative".

The general form of a confusion matrix is ​​as follows:
insert image description here

The confusion matrix can be used to calculate multiple indicators to measure the performance of the classifier, such as accuracy (Accuracy), precision (Precision), recall (Recall, also known as sensitivity or true case rate) and F1 value. These metrics can be calculated from the individual elements in the confusion matrix:

Accuracy: The proportion of correct samples predicted by the classifier to the total number of samples, the calculation formula is (TP + TN) / (TP + TN + FP + FN).
Precision (Precision): The proportion of positive predictions that are correct, the calculation formula is TP / (TP + FP).
Recall (Recall): The proportion of positive examples that are correctly predicted as positive examples, the calculation formula is TP / (TP + FN).
F1 value: an indicator that takes into account the precision rate and the recall rate, and the calculation formula is 2 (Precision Recall) / (Precision + Recall).

Confusion matrices provide the ability to assess classification model performance in more detail and comprehensively, helping us understand false positives and false negatives in predictions. By analyzing the confusion matrix, we can gain valuable insights about the performance of the classifier on each class and optimize the classification results.

There are not many nonsense, the above code:

def draw_confusion_matrix(label_true, label_pred, label_name, normlize, title="Confusion Matrix", pdf_save_path=None, dpi=100):
    """

    @param label_true: 真实标签,比如[0,1,2,7,4,5,...]
    @param label_pred: 预测标签,比如[0,5,4,2,1,4,...]
    @param label_name: 标签名字,比如['cat','dog','flower',...]
    @param normlize: 是否设元素为百分比形式
    @param title: 图标题
    @param pdf_save_path: 是否保存,是则为保存路径pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式
    @param dpi: 保存到文件的分辨率,论文一般要求至少300dpi
    @return:

    example:
            draw_confusion_matrix(label_true=y_gt,
                          label_pred=y_pred,
                          label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"],
                          normlize=True,
                          title="Confusion Matrix on Fer2013",
                          pdf_save_path="Confusion_Matrix_on_Fer2013.png",
                          dpi=300)

    """
    cm1=confusion_matrix(label_true, label_pred)
    cm = confusion_matrix(label_true, label_pred)
    if normlize:
        row_sums = np.sum(cm, axis=1)
        cm = cm / row_sums[:, np.newaxis]
    cm=cm.T
    cm1=cm1.T
    plt.imshow(cm, cmap='Blues')
    plt.title(title)
    plt.xlabel("Predict label")
    plt.ylabel("Truth label")
    plt.yticks(range(label_name.__len__()), label_name)
    plt.xticks(range(label_name.__len__()), label_name, rotation=45)

    plt.tight_layout()

    plt.colorbar()

    for i in range(label_name.__len__()):
        for j in range(label_name.__len__()):
            color = (1, 1, 1) if i == j else (0, 0, 0)	# 对角线字体白色,其他黑色
            value = float(format('%.1f' % (cm[i, j]*100)))
            value1=str(value)+'%\n'+str(cm1[i, j])
            plt.text(i, j, value1, verticalalignment='center', horizontalalignment='center', color=color)

    # plt.show()
    if not pdf_save_path is None:
        plt.savefig(pdf_save_path, bbox_inches='tight',dpi=dpi)



labels_name = ['bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

y_gt=[]
y_pred=[]

model_weight_path = "./best_CBAM_model.pth"
models = Xception(num_classes = 4)
models.load_state_dict(torch.load(model_weight_path))




models.eval()
for index, (imgs, labels) in enumerate(test_dl):
    labels_pd = models(imgs)
    predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1).tolist()
    labels_np = labels.numpy().tolist()

    y_pred.extend(predict_np)
    y_gt.extend(labels_np)
print("预测标签为:", y_pred)
print("真实标签为", y_gt)



draw_confusion_matrix(label_true=y_gt,
                      label_pred=y_pred,
                      label_name=labels_name,
                      normlize=True,
                      title="Confusion Matrix",
                      pdf_save_path="Confusion_Matrix.jpg",
                      dpi=300)

The result is as follows:
insert image description here

Guess you like

Origin blog.csdn.net/m0_63007797/article/details/132161666
Recommended