Python se genera en base a su propio conjunto de datos modelo y matriz de confusión de peso

Confusion Matrix es un formato tabular utilizado para evaluar el rendimiento de los modelos de clasificación. Muestra varias combinaciones entre las predicciones del modelo y las etiquetas reales en un problema de clasificación.

Las matrices de confusión se usan comúnmente para problemas de clasificación binaria, pero también se pueden extender a problemas de clasificación de clases múltiples. Para problemas de clasificación binaria, consta de cuatro indicadores importantes:

Verdadero Positivo (TP): El número de casos positivos predichos por el modelo y realmente positivos.
True Negative (TN): el número de ejemplos negativos predichos por el modelo y realmente negativos.
Falso positivo (FP): el número de ejemplos positivos predichos por el modelo, pero en realidad ejemplos negativos. También conocido como "falso positivo".
Falso Negativo (FN): El número de ejemplos negativos predichos por el modelo pero realmente positivos. También conocido como "falso negativo".

La forma general de una matriz de confusión es la siguiente:
inserte la descripción de la imagen aquí

La matriz de confusión se puede utilizar para calcular múltiples indicadores para medir el rendimiento del clasificador, como exactitud (Exactitud), precisión (Precisión), recuperación (Recuperación, también conocida como sensibilidad o tasa de casos reales) y valor F1. Estas métricas se pueden calcular a partir de los elementos individuales en la matriz de confusión:

Precisión (Precisión): La proporción de muestras correctas predichas por el clasificador al número total de muestras, la fórmula de cálculo es (TP + TN) / (TP + TN + FP + FN).
Precisión (Precisión): La proporción de predicciones positivas que son correctas, la fórmula de cálculo es TP / (TP + FP).
Recall (Recall): La proporción de ejemplos positivos que se predicen correctamente como ejemplos positivos, la fórmula de cálculo es TP/(TP+FN).
Valor F1: un indicador que tiene en cuenta la tasa de precisión y la tasa de recuperación, y la fórmula de cálculo es 2 (Precisión de recuperación) / (Precisión + recuperación).

Las matrices de confusión brindan la capacidad de evaluar el rendimiento del modelo de clasificación de manera más detallada y completa, lo que nos ayuda a comprender los falsos positivos y los falsos negativos en las predicciones. Al analizar la matriz de confusión, podemos obtener información valiosa sobre el rendimiento del clasificador en cada clase y optimizar los resultados de la clasificación.

No hay muchas tonterías, el código anterior:

def draw_confusion_matrix(label_true, label_pred, label_name, normlize, title="Confusion Matrix", pdf_save_path=None, dpi=100):
    """

    @param label_true: 真实标签,比如[0,1,2,7,4,5,...]
    @param label_pred: 预测标签,比如[0,5,4,2,1,4,...]
    @param label_name: 标签名字,比如['cat','dog','flower',...]
    @param normlize: 是否设元素为百分比形式
    @param title: 图标题
    @param pdf_save_path: 是否保存,是则为保存路径pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式
    @param dpi: 保存到文件的分辨率,论文一般要求至少300dpi
    @return:

    example:
            draw_confusion_matrix(label_true=y_gt,
                          label_pred=y_pred,
                          label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"],
                          normlize=True,
                          title="Confusion Matrix on Fer2013",
                          pdf_save_path="Confusion_Matrix_on_Fer2013.png",
                          dpi=300)

    """
    cm1=confusion_matrix(label_true, label_pred)
    cm = confusion_matrix(label_true, label_pred)
    if normlize:
        row_sums = np.sum(cm, axis=1)
        cm = cm / row_sums[:, np.newaxis]
    cm=cm.T
    cm1=cm1.T
    plt.imshow(cm, cmap='Blues')
    plt.title(title)
    plt.xlabel("Predict label")
    plt.ylabel("Truth label")
    plt.yticks(range(label_name.__len__()), label_name)
    plt.xticks(range(label_name.__len__()), label_name, rotation=45)

    plt.tight_layout()

    plt.colorbar()

    for i in range(label_name.__len__()):
        for j in range(label_name.__len__()):
            color = (1, 1, 1) if i == j else (0, 0, 0)	# 对角线字体白色,其他黑色
            value = float(format('%.1f' % (cm[i, j]*100)))
            value1=str(value)+'%\n'+str(cm1[i, j])
            plt.text(i, j, value1, verticalalignment='center', horizontalalignment='center', color=color)

    # plt.show()
    if not pdf_save_path is None:
        plt.savefig(pdf_save_path, bbox_inches='tight',dpi=dpi)



labels_name = ['bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

y_gt=[]
y_pred=[]

model_weight_path = "./best_CBAM_model.pth"
models = Xception(num_classes = 4)
models.load_state_dict(torch.load(model_weight_path))




models.eval()
for index, (imgs, labels) in enumerate(test_dl):
    labels_pd = models(imgs)
    predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1).tolist()
    labels_np = labels.numpy().tolist()

    y_pred.extend(predict_np)
    y_gt.extend(labels_np)
print("预测标签为:", y_pred)
print("真实标签为", y_gt)



draw_confusion_matrix(label_true=y_gt,
                      label_pred=y_pred,
                      label_name=labels_name,
                      normlize=True,
                      title="Confusion Matrix",
                      pdf_save_path="Confusion_Matrix.jpg",
                      dpi=300)

El resultado es el siguiente:
inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/m0_63007797/article/details/132161666
Recomendado
Clasificación