ROC曲线、PR曲线(笔记整理)

简介:

ROC曲线和PR曲线是机器学习中两个常见的评估指标(对于二分器而言),做个笔记…

原理:

在二分类问题中,分类器将一个实例的分类标记为是或否,这可以用一个混淆矩阵来表示。混淆矩阵有四个分类,如下表:
在这里插入图片描述
TP(True Positive):指正确分类的正样本数,即预测为正样本,实际也是正样本。
FP(False Positive):指被错误的标记为正样本的负样本数,即实际为负样本而被预测为正样本,所以是False。
TN(True Negative):指正确分类的负样本数,即预测为负样本,实际也是负样本。
FN(False Negative):指被错误的标记为负样本的正样本数,即实际为正样本而被预测为负样本,所以是False。

Precision=TP/(TP+FP) —— (正确分类的正样本)/ (预测为正的总样本)
Recall=TP/(TP+FN) —— (正确分类的正样本)/ (实际为正的总样本)
TPR=TP/(TP+FN)=Recall # 真正例率
FPR=FP/(TN+FP) #假正例率

  • ROC曲线常用于二分类问题中的模型比较,主要表现为一种真正例率 (TPR) 和假正例率 (FPR) 的权衡。具体方法是在不同的分类阈值 (threshold) 设定下分别以TPR和FPR为纵、横轴作图。曲线越靠近左上角,意味着越多的正例优先于负例,模型的整体表现也就越好。AUC:ROC曲线下面积,越大越好。 在这里插入图片描述
  • PR曲线实则是以precision(精准率)和recall(召回率)这两个为变量而做出的曲线,其中recall为横坐标,precision为纵坐标。一条PR曲线要对应一个阈值。通过选择合适的阈值,比如50%,对样本进行划分,概率大于50%的就认为是正例,小于50%的就是负例,从而计算相应的精准率和召回率。如果一个学习器的P-R曲线被另一个学习器的P-R曲线完全包住,则可断言后者的性能优于前者。我们还可以根据曲线下方的面积大小来进行比较,但更常用的是平衡点或者是F1值。平衡点(BEP)是P=R时的取值,如果这个值较大,则说明学习器的性能较好。而F1=2×P×R/(P+R),同样,F1值越大,我们可以认为该学习器的性能较好。在这里插入图片描述

使用场景:

1、 ROC曲线由于兼顾正例与负例,所以适用于评估分类器的整体性能,相比而言PR曲线完全聚焦于正例。
2、如果有多份数据且存在不同的类别分布,比如信用卡欺诈问题中每个月正例和负例的比例可能都不相同,这时候如果只想单纯地比较分类器的性能且剔除类别分布改变的影响,则ROC曲线比较适合,因为类别分布改变可能使得PR曲线发生变化时好时坏,这种时候难以进行模型比较;反之,如果想测试不同类别分布下对分类器的性能的影响,则PR曲线比较适合。
3、如果想要评估在相同的类别分布下正例的预测情况,则宜选PR曲线。
4、类别不平衡问题中,ROC曲线通常会给出一个乐观的效果估计,所以大部分时候还是PR曲线更好。
5、最后,可以根据具体的应用,在曲线上找到最优的点,得到相对应的precision,recall,f1 score等指标,去调整模型的阈值,从而得到一个符合具体应用的模型

python代码:

matplotlib>=2.0.2
numpy>=1.13.0
opencv-python>=3.3.1+contrib
tqdm>=4.19.4

# -*- coding:utf-8 -*-
import glob
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt


class CollectData:
    def __init__(self):
        self.TP = []
        self.FP = []
        self.FN = []
        self.TN = []

    def reload(self,groundtruth,probgraph):
        """
        :param groundtruth:  list,groundtruth image list
        :param probgraph:    list,prob image list
        :return:  None
        """
        self.groundtruth = groundtruth
        self.probgraph = probgraph
        self.TP = []
        self.FP = []
        self.FN = []
        self.TN = []

    def statistics(self):
        """
        calculate FPR TPR Precision Recall IoU
        :return: (FPR,TPR,AUC),(Precision,Recall,MAP),IoU
        """
        for threshold in tqdm(range(0,255)):
            temp_TP=0.0
            temp_FP=0.0
            temp_FN=0.0
            temp_TN=0.0
            assert(len(self.groundtruth)==len(self.probgraph))

            for index in range(len(self.groundtruth)):
                gt_img=cv2.imread(self.groundtruth[index])[:,:,0]
                prob_img=cv2.imread(self.probgraph[index])[:,:,0]

                gt_img=(gt_img>0)*1
                prob_img=(prob_img>=threshold)*1

                temp_TP = temp_TP + (np.sum(prob_img * gt_img))
                temp_FP = temp_FP + np.sum(prob_img * ((1 - gt_img)))
                temp_FN = temp_FN + np.sum(((1 - prob_img)) * ((gt_img)))
                temp_TN = temp_TN + np.sum(((1 - prob_img)) * (1 - gt_img))

            self.TP.append(temp_TP)
            self.FP.append(temp_FP)
            self.FN.append(temp_FN)
            self.TN.append(temp_TN)

        self.TP = np.asarray(self.TP).astype('float32')
        self.FP = np.asarray(self.FP).astype('float32')
        self.FN = np.asarray(self.FN).astype('float32')
        self.TN = np.asarray(self.TN).astype('float32')

        FPR = (self.FP) / (self.FP + self.TN)
        TPR = (self.TP) / (self.TP + self.FN)
        AUC = np.round(np.sum((TPR[1:] + TPR[:-1]) * (FPR[:-1] - FPR[1:])) / 2., 4)

        Precision = self.TP / (self.TP + self.FP)
        Recall = self.TP / (self.TP + self.FN)
        MAP = np.round(np.sum((Precision[1:] + Precision[:-1]) * (Recall[:-1] - Recall[1:])) / 2.,4)

        iou=self.IOU()

        return (FPR,TPR,AUC),(Precision,Recall,MAP),iou

    def IoU(self,threshold=128):
        """
        to calculate IoU
        :param threshold: numerical,a threshold for gray image to binary image
        :return:  IoU
        """
        intersection=0.0
        union=0.0

        for index in range(len(self.groundtruth)):
            gt_img = cv2.imread(self.groundtruth[index])[:, :, 0]
            prob_img = cv2.imread(self.probgraph[index])[:, :, 0]

            gt_img = (gt_img > 0) * 1
            prob_img = (prob_img >= threshold) * 1

            intersection=intersection+np.sum(gt_img*prob_img)
            union=union+np.sum(gt_img)+np.sum(prob_img)-np.sum(gt_img*prob_img)
        iou=np.round(intersection/union,4)
        return iou

    def debug(self):
        """
        show debug info
        :return: None
        """
        print("Now enter debug mode....\nPlease check the info bellow:")
        print("total groundtruth: %d   total probgraph: %d\n"%(len(self.groundtruth),len(self.probgraph)))
        for index in range(len(self.groundtruth)):
            print(self.groundtruth[index],self.probgraph[index])
        print("Please confirm the groundtruth and probgraph name is opposite")


class DrawCurve:
    """
    draw ROC/PR curve
    """
    def __init__(self,savepath):
        self.savepath=savepath
        self.colorbar=['red','green','blue','black']
        self.linestyle=['-','-.','--',':','-*']

    def reload(self,xdata,ydata,auc,dataName,modelName):
        """
        this function is to update data for Function roc/pr to draw
        :param xdata:  list,x-coord of roc(pr)
        :param ydata:  list,y-coord of roc(pr)
        :param auc:    numerical,area under curve
        :param dataName: string,name of dataset
        :param modelName: string,name of test model
        :return:  None
        """
        self.xdata.append(xdata)
        self.ydata.append(ydata)
        self.modelName.append(modelName)
        self.auc.append(auc)
        self.dataName=dataName

    def newly(self,modelnum):
        """
        renew all the data
        :param modelnum:  numerical,number of models to draw
        :return:  None
        """
        self.modelnum = modelnum
        self.xdata = []
        self.ydata = []
        self.modelName = []
        self.auc = []

    def roc(self):
        """
        draw ROC curve,save the curve graph to  savepath
        :return: None
        """
        plt.figure(1)
        plt.title('ROC Curve of %s'%self.dataName, fontsize=15)
        plt.xlabel("False Positive Rate", fontsize=15)
        plt.ylabel("True Positive Rate", fontsize=15)
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        for i in range(self.modelnum):
            plt.plot(self.xdata[i], self.ydata[i], color=self.colorbar[i%len(self.colorbar)], linewidth=2.0, linestyle=self.linestyle[i%len(self.linestyle)], label=self.modelName[i]+',AUC:' + str(self.auc[i]))
        plt.legend()
        plt.savefig(self.savepath+'%s_ROC.png'%self.dataName, dpi=800)
        #plt.show()


    def pr(self):
        """
        draw PR curve,save the curve to  savepath
        :return: None
        """
        plt.figure(2)
        plt.title('PR Curve of %s'%self.dataName, fontsize=15)
        plt.xlabel("Recall", fontsize=15)
        plt.ylabel("Precision", fontsize=15)
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        for i in range(self.modelnum):
            plt.plot(self.xdata[i], self.ydata[i], color=self.colorbar[i%len(self.colorbar)], linewidth=2.0, linestyle=self.linestyle[i%len(self.linestyle)],label=self.modelName[i]+',MAP:' + str(self.auc[i]))
        plt.legend()
        plt.savefig(self.savepath+'%s_PR.png'%self.dataName, dpi=800)
        #plt.show()


def fileList(imgpath,filetype):
    return glob.glob(imgpath+filetype)


def drawCurve(gtlist,problist,modelName,dataset,savepath='./'):
    """
    draw ROC PR curve,calculate AUC MAP IoU
    :param gtlist:  list,groundtruth list
    :param problist: list,list of probgraph list
    :param modelName:  list,name of test,model
    :param dataset: string,name of dataset
    :param savepath: string,path to save curve
    :return:
    """
    assert(len(problist)==len(modelName))

    process = CollectData()
    painter_roc = DrawCurve(savepath)
    painter_pr = DrawCurve(savepath)
    modelNum=len(problist)
    painter_roc.newly(modelNum)
    painter_pr.newly(modelNum)

    # calculate param
    for index in range(modelNum):
        process.reload(gtlist,problist[index])
        (FPR, TPR, AUC), (Precision, Recall, MAP),IoU = process.statistics()
        painter_roc.reload(FPR, TPR, AUC,dataset, modelName[index])
        painter_pr.reload(Precision, Recall, MAP, dataset, modelName[index])

    # draw curve and save
    painter_roc.roc()
    painter_pr.pr()

if __name__=="__main__":
    gtlist = fileList('./gt/', '*.png')
    problist1 = fileList('./pre1/', '*.png')
    problist2 = fileList('./pre2/', '*.png')
    modelName=["fcn","unet"]

    drawCurve(gtlist,[problist1,problist2],modelName,'kaggle')
print('--------------------------------------')

猜你喜欢

转载自blog.csdn.net/qq_42823043/article/details/109853058
今日推荐