PytorchCNN项目搭建 7--- 深度学习模型评估指标 Accuracy,Precision,Recall,ROC曲线

以二分类为例,进行说明:

在这里插入图片描述

注:

  • 判别是否为正例只需要设一个概率阈值T,预测概率大于阈值T的为正类,小于阈值T的为负类,默认就是0.5。
  • 如果减小阀值T,更多的样本会被识别为正类,这样可以提高正类的召回率,但同时也会带来更多的负类被错分为正类;
  • 如果增加阈值T,则正类的召回率降低,精度增加。如果是多类,比如ImageNet1000分类比赛中的1000类,预测类别就是预测概率最大的那一类。

常用的几种评估指标:

1. 准确度: Accuracy = (TP + TN) / (TP + FN + FP + TN)

注:Top_1 Accuracy和Top_5 Accuracy,Top_1 Accuracy就是计算的Accuracy。而Top_5 Accuracy是给出概率最大的5个预测类别,只要包含了真实的类别,则判定预测正确。

2. 精确度:Precision = TP / (TP + FP)

3. 召回率:Recall = TP / (TP + FN)


多分类情况

4. 混淆矩阵

如果对于每一类,若想知道类别之间相互误分的情况,查看是否有特定的类别之间相互混淆,就可以用混淆矩阵画出分类的详细预测结果。对于包含多个类别的任务,混淆矩阵很清晰的反映出各类别之间的错分概率,如下图:

注: 横坐标表示预测分类,纵坐标表示标签分类,其中(i,j)表示第i类目标被分为第j类的概率,对角线的值越大越好。

在这里插入图片描述


代码实现:

实际图片的标签值labels, 预测的分类值predicted,二者转换为具体的标签值(而不是onehot值),矩阵具体的大小为:[number_total_pictures, 1], 然后将二者拼接为[number_total_pictures*2, 1], 得到的这个矩阵的每一行都代表一个(i,j)值,然后进行统计即可。代码如下:

def Confusion_mxtrix(labels, predicted, num_classes):
    """
    混淆矩阵的函数定义
    Args:
        labels: [number_total_pictures,1]
        predicted: [number_total_pictures,1] 
        num_classes: 分类数目

    Returns: Confusion_matrix
    """
    Cmatrixs = torch.zeros((num_classes,num_classes))
    stacked = torch.stack((labels, predicted), dim=1)
    for s in stacked:
        a, b = s.tolist()
        Cmatrixs[a, b] = Cmatrixs[a, b] + 1
    return Cmatrixs
def plot_confusion_matrix(cm, savename, title='Confusion Matrix'):
    classes = ('airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    plt.figure(figsize=(12, 8), dpi=100)
    np.set_printoptions(precision=2)

    # 在混淆矩阵中每格的概率值
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        if c > 0.001:
            plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=15, va='center', ha='center')

    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes, rotation=90)
    plt.yticks(xlocations, classes)
    plt.ylabel('Actual label')
    plt.xlabel('Predict label')

    # offset the tick
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)

    # show confusion matrix
    plt.savefig(savename, format='png')
    plt.show()

根据混淆矩阵Cm计算多分类的Accuracy,Precision & Recall

混淆矩阵Cm的对角线的值就是每一类的TP值,而FN=sum_row(Cm)-TP, FP=sum_col(Cm)-TP, TN=sum(Cm)-TP-FN-FP

def Evaluate(Cmatrixs):
    """for Precision & Recall"""
    classes = ('airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    n_classes = Cmatrixs.size(0)
    Prec, Rec = torch.zeros(n_classes+1), torch.zeros(n_classes+1)

    sum_cmt_row = torch.sum(Cmatrixs,dim=1)#行的和
    sum_cmt_col = torch.sum(Cmatrixs,dim=0)#列的和
    print("----------------------------------------")
    for i in range(n_classes):
        TP = Cmatrixs[i,i]
        FN = sum_cmt_row[i] - TP
        FP = sum_cmt_col[i] - TP
        # TN = torch.sum(Cmatrixs) - sum_cmt_row[i] - FP
        Prec[i] = TP / (TP + FP)
        Rec[i]  = TP / (TP + FN)
        print("%s"%(classes[i]).ljust(10," "),"Presion=%.3f%%,     Recall=%.3f%%"%(Prec[i],Rec[i]))

    Prec[-1] = torch.mean(Prec[0:-1])
    Rec[-1] = torch.mean(Rec[0:-1])
    print("ALL".ljust(10," "),"Presion=%.3f%%,     Recall=%.3f%%" % (Prec[i], Rec[i]))
    print("----------------------------------------")
    # return Prec,Rec

5. ROC曲线

  • Receiver Operating Characteristic (ROC)曲线,评价一个分类器在不同阈值T下的表现情况。曲线横坐标False Positive rate(FPR), 纵坐标 True positive rate(TPR) ,描述True positive和False Positive之间的平衡。所以,绘制ROC曲线的一个重要的事情是要自己定义阈值T。

  • TPR = TP / (TP + FN) 分类器预测正类中实际正实例占所有正实例的比例

  • FPR = FP / (FP + TN) 分类器预测的正类中实际负实类占所有负实例的比例。

  • ROC曲线有4个关键的点:

    • 点(0,0):FPR=TPR=0,分类器预测所有的样本都为负样本;
    • 点(1,1):FPR=TPR=1,分类器预测所有的样本都为正样本;
    • 点(0,1):FPR=0, TPR=1,此时FN=0且FP=0,所有的样本都正确分类;
    • 点(1,0):FPR=1,TPR=0,此时TP=0且TN=0,最差分类器,避开了所有正确答案
def MyROC_i(outputs, labels, n=20):
    '''
    ROC曲线计算 绘制每一类的
    Args:
        outputs: [num_labels,num_classes]
        labels: 标签值
        n: 得到 n 个点之后绘图
    Returns:plot_roc
    '''

    n_total, n_classes = outputs.size()
    labels = labels.reshape(-1,1) # 行向量转为列向量
    T = torch.linspace(0, 1, n)
    TPR, FPR = torch.zeros(n, n_classes+1), torch.zeros(n, n_classes+1)

    for i in range(n_classes):
        for j in range(n):
            mask_1 = outputs[:, i] > T[j]
            TP_FP = torch.sum(mask_1)
            mask_2 = (labels[:, -1] == i)
            TP = torch.sum(mask_1 & mask_2)
            FN = n_total / n_classes - TP
            FP = TP_FP - TP
            TN = n_total - n_total / n_classes - FP

            TPR[j,i] = TP / (TP + FN)
            FPR[j,i] = FP / (FP + TN)

    TPR[:,-1] = torch.mean(TPR[:,0:-1],dim=1)
    FPR[:, -1] = torch.mean(FPR[:, 0:-1], dim=1)

    return TPR,FPR

def Plot_ROC_i(TPR, FPR, args, cfg):
    for i in range(10+1):
        if i==10: width=2
        else: width=1
        plt.plot(FPR[:,i],TPR[:,i],linewidth=width,label='classes_%d'%i)
    plt.legend()
    plt.title("ROC")
    plt.grid(True)
    plt.xlim(0,1)
    plt.ylim(0,1)
    plt.savefig(cfg.PARA.utils_paths.visual_path + args.net + '_ROC_i.png')

在这里插入图片描述


6. 主要的测试函数,输出所有的评估值

def test(net, epoch, test_loader, log, args, cfg):
    with torch.no_grad():
        labels_value, predicted_value, outputs_value = [],[],[]
        correct = 0
        total = 0
        net.eval()
        for i, data in enumerate(test_loader, 0):
            images, labels = data
            images = images.cuda()
            labels_onehot = labels.cuda()
            _, labels = torch.max(labels_onehot, 1)
            outputs = net(images) #outputs:[100,10]

            _, predicted = torch.max(outputs.data, 1)
            # predicted = ToOnehots(predicted,cfg.PARA.train.num_classes)
            total += labels.size(0)
            correct += (predicted == labels).sum()#.item()

            # Ready for matrixs
            if i==0:
                labels_value = labels
                predicted_value = predicted
                outputs_value = F.softmax(outputs.data,dim=1)
            else:
                labels_value = torch.cat((labels_value,labels),0)
                predicted_value = torch.cat((predicted_value,predicted),0)
                outputs_value = torch.cat((outputs_value,F.softmax(outputs.data,dim=1)),0)

        correct = correct.cpu().numpy()

        log.logger.info('epoch=%d,acc=%.5f%%' % (epoch, 100 * correct / total))
        f = open("./cache/visual/"+args.net+"_test.txt", "a")
        f.write("epoch=%d,acc=%.5f%%" % (epoch, 100 * correct / total))
        f.write('\n')

        log.logger.info("==> Get Confusion_Matrixs <==")
        Cmatrixs = Confusion_mxtrix(labels_value,predicted_value,cfg.PARA.train.num_classes)
        # print(Cmatrixs)

        log.logger.info("==> Precision & Recall <==")
        Evaluate(Cmatrixs) #get_Precision & Recall

        log.logger.info("==> Plot_ROC <==")
        TPR_i, FPR_i = MyROC_i(outputs_value, labels_value)
        Plot_ROC_i(TPR_i, FPR_i,args,cfg)

    f.close()

def main():
    args = parser()
    cfg = Config.fromfile(args.config)
    log = Logger('./cache/log/' + args.net + '_testlog.txt', level='info')
    log.logger.info('==> Preparing data <==')
    test_loader = dataLoad(cfg)
    log.logger.info('==> Loading model <==')
    net = get_network(args,cfg).cuda()
    # net = torch.nn.DataParallel(net, device_ids=cfg.PARA.train.device_ids)
    log.logger.info("==> Waiting Test <==")
    for epoch in range(100, 101):
        # log.logger.info("==> Epoch:%d <=="%epoch)
        checkpoint = torch.load('./cache/checkpoint/'+args.net+'/'+ str(epoch) +'ckpt.pth')
        # checkpoint = torch.load('./cache/checkpoint/' + args.net + '/' + str(60) + 'ckpt.pth')
        net.load_state_dict(checkpoint['net'])
        test(net, epoch, test_loader, log, args, cfg)

    log.logger.info('*'*25)

if __name__ == '__main__':
    main()

参考文献

  1. 深度学习模型评估指标

猜你喜欢

转载自blog.csdn.net/qq_44783177/article/details/113895012