TP, TN, FN, FP calculation of confusion matrix in multi-classification

Regarding the confusion matrix, you can learn about it here: Detailed understanding of confusion matrix_Summer is Iced Black Tea Blog-CSDN Blog

In the previous article, we learned about the confusion matrix and defined the class. In this section, we will expand it and how to calculate TP, TN, FN, and FP in multi-classification.

Principle derivation

Here we take three categories as an example. Let’s see how TP, TN, FN, and FP are distributed.

Tags for Category 1:

Tags for Category 2:

Tags for Category 3:

In this way we can know that the diagonal of the confusion matrix is ​​TP

TP = torch.diag(h)

 False positives (FP) are the number of negative class samples that the model incorrectly classified as positive class

FP = torch.sum(h, dim=1) - TP

False negatives (FN) are the number of positive category samples that the model incorrectly classified as negative categories.

FN = torch.sum(h, dim=0) - TP

Finally, subtract the sum of the other three elements except TP from the total to get TN

TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)

logical verification

Borrowing the example from the previous article, if our confusion matrix looks like this:

tensor([[2, 0, 0],
            [0, 1, 1],
            [0, 2, 0]])

For the convenience of explanation, here we give them a simple number, namely 0-8:

0 1 2
3 4 5
6 7 8

torch.sum(h, dim=1) can get tensor([2., 2., 2.]), torch.sum(h, dim=0) can get tensor([2., 3., 1.] ).

  •  TP:   tensor([2., 1., 0.]) 
  •  FP:   tensor([0., 1., 2.]) 
  •  TN:   tensor([4., 2., 3.]) 
  •  FN:   tensor([0., 2., 1.])

Let's first look at the composition of TP, which corresponds to the diagonal 2, 1, 0 of the matrix; FP occupies positions 3 and 6 in category 1, occupies positions 1 and 7 in category 2, and occupies positions 1 and 7 in category 3. Positions 2 and 5 add up to 0, 1, 2; TN occupies positions 4, 5, 7, and 8 in category 1, corner positions in category 2, and 0, 1 in category 3. Positions 3 and 4 add up to 4, 2, 3; FN occupies positions 1 and 2 in category 1, positions 3 and 5 in category 2, and positions 6 and 7 in category 3. The sum is 0, 2, 1.

Supplementary class definition

import torch
import numpy as np

class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, t, p):
        n = self.num_classes
        if self.mat is None:
            # 创建混淆矩阵
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=t.device)
        with torch.no_grad():
            # 寻找GT中为目标的像素索引
            k = (t >= 0) & (t < n)
            # 统计像素真实类别t[k]被预测成类别p[k]的个数
            inds = n * t[k].to(torch.int64) + p[k]
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    @property
    def ravel(self):
        """
        计算混淆矩阵的TN, FP, FN, TP
        """
        h = self.mat.float()
        n = self.num_classes
        if n == 2:
            TP, FN, FP, TN = h.flatten()
            return TP, FN, FP, TN
        if n > 2:
            TP = h.diag()
            FN = h.sum(dim=1) - TP
            FP = h.sum(dim=0) - TP
            TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)

            return TP, FN, FP, TN

    def compute(self):
        """
        主要在eval的时候使用,你可以调用ravel获得TN, FP, FN, TP, 进行其他指标的计算
        计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
        计算每个类别的准确率
        计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)
        """
        h = self.mat.float()
        acc_global = torch.diag(h).sum() / h.sum()
        acc = torch.diag(h) / h.sum(1)
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iu

    def __str__(self):
        acc_global, acc, iu = self.compute()
        return (
            'global correct: {:.1f}\n'
            'average row correct: {}\n'
            'IoU: {}\n'
            'mean IoU: {:.1f}').format(
            acc_global.item() * 100,
            ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
            ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
            iu.mean().item() * 100)

I added attribute decorators to the code so that we can call it directly, and also took into account the differences between two classifications and multiple classifications.

Performance

There are many introductions to these indicators on the Internet, so I won’t go into details here.

class ModelIndex():
    def __init__(self,TP, FN, FP, TN, e=1e-5):
        self.TN = TN
        self.FP = FP
        self.FN = FN
        self.TP = TP
        self.e = e

    def Precision(self):
        """精确度衡量了正类别预测的准确性"""
        return self.TP / (self.TP + self.FP + self.e)

    def Recall(self):
        """召回率衡量了模型对正类别样本的识别能力"""
        return self.TP / (self.TP + self.FN + self.e)

    def IOU(self):
        """表示模型预测的区域与真实区域之间的重叠程度"""
        return self.TP / (self.TP + self.FP + self.FN + self.e)

    def F1Score(self):
        """F1分数是精确度和召回率的调和平均数"""
        p = self.Precision()
        r = self.Recall()
        return 2*p*r / (p + r + self.e)

    def Specificity(self):
        """特异性是指模型在负类别样本中的识别能力"""
        return self.TN / (self.TN + self.FP + self.e)

    def Accuracy(self):
        """准确度是模型正确分类的样本数量与总样本数量之比"""
        return self.TP + self.TN / (self.TP + self.TN + self.FP + self.FN + self.e)

    def FP_rate(self):
        """False Positive Rate,假阳率是模型将负类别样本错误分类为正类别的比例"""
        return self.FP / (self.FP + self.TN + self.e)

    def FN_rate(self):
        """False Negative Rate,假阴率是模型将正类别样本错误分类为负类别的比例"""
        return self.FN / (self.FN + self.TP + self.e)

    def Qualityfactor(self):
        """品质因子综合考虑了召回率和特异性"""
        r = self.Recall()
        s = self.Specificity()
        return r+s-1

Reference article: Calculation of TP/TN/FP/FN in multi-classification_Hello_Chan's blog-CSDN blog 

Guess you like

Origin blog.csdn.net/m0_62919535/article/details/132926719