Top-K准确率的概念与实现(源码讲解)

一. 概念

分类任务中常见的四种指标:准确率、精确率、召回率和F值。不过那什么又是Top-K准确率呢?简单一句话概括:Top-K准确率就是用来计算预测结果中概率最大的前K个结果包含正确标签的占比。换句话说,平常我们所说的准确率其实就是Top-1准确率。下面我们通过一个例子来进行说明:

假如现在有一个用于手写体识别的分类器(10分类),现将一张正确标签为3的图片输入到分类器中且得到了如下所示的一个概率分布:

p=[0.1,0.05,0.1,0.2,0.35,0.01,0.03,0.05,0.01,0.1]

显然,根据预测的结果来看,其最大概率0.35所对应的标签为4,这也就代表着如果按照以往的标准(Top-1准确率)来看,分类器对于这张图片的预测结果就是错误的。但如果我们以Top-2的标准来看的话,分类器对于这个图片的预测结果就是正确的,因为p中概率值最大的前两个中包含有真实的标签。也就是说,虽然0.35对应的标签是错的,但是排名第二的概率值0.2所对应的标签是正确的,所以我们在计算Top-2准确率的时候也将上述结果当作是预测正确的。

因此我们可以看出,Top-K准确率考虑的是预测结果中最有可能的K个结果是否包含有真实标签,如果包含则算预测正确,如果不包含则算预测错误。所以在这里我们能够知道,K值取得越大计算得到的Top-K准确率就会越高,极端情况下如果取K值为分类数,那么得到的准确率就肯定是1。但通常情况下我们只会看模型的Top-1、Top-3和Top-5准确率。

二. 实现

下面我们来看实现:

函数的输入:

output:模型的输出,即模型对不同类别的评分。shape: [batch_size, num_classes]

target:真实的类别标签。shape: [batch_size, ]

topk:需要计算top_k准确率中的k值,元组类型。默认为(1, 5),即函数返回top1和top5的分类准确率

下面我们还是先举一个例子:

import torch
output=torch.Tensor([[0.1,0.05,0.1,0.2,0.35,0.01,0.03,0.05,0.01,0.1],
        [0.2,0.05,0.1,0.35,0.2,0.01,0.02,0.04,0.01,0.0],
        [0.1,0.05,0.1,0.15,0.05,0.01,0.03,0.4,0.01,0.1],
        [0.1,0.05,0.1,0.15,0.05,0.01,0.08,0.1,0.01,0.35]])# 模型预测的概率分布
target=torch.Tensor([[4],[3],[7],[3]]) # 实际的类别索引
# output预测的值:每一行最大值对应的索引,为torch.Tensor([[4],[3],[7],[9]]) 
topk=(1,3)# 这里预测top-1和top-3
maxk = max(topk) # 按topk最大值构建张量
batch_size = target.size(0) # 这里批量数等于样本数4
_, pred = output.topk(maxk, 1, True, True) # topk返回两个张量:values和indices,分别对应前k大值的数值和索引
print(_)
print(pred) # size:batch_size*maxk=4*3

topk的输出:

tensor([[0.3500, 0.2000, 0.1000],
        [0.3500, 0.2000, 0.2000],
        [0.4000, 0.1500, 0.1000],
        [0.3500, 0.1500, 0.1000]])
tensor([[4, 3, 0],
        [3, 0, 4],
        [7, 3, 0],
        [9, 3, 0]])

pred存储了每个样本预测概率前三位的索引值。下面把target的维度改变一下进行比较:

pred = pred.t() # 转置,size:maxk*batch_size=3*4
correct = pred.eq(target.view(1, -1).expand_as(pred))
# eq输出元素相等的布尔值
# expand_as将张量扩展为pred的大小
# view()的作用相当于numpy中的reshape,重新定义矩阵的形状
print(pred) # size:maxk*batch_size=3*4
print(target.view(1, -1).expand_as(pred)) # 扩展维度和pred一样
print(correct)

correct的输出:size:max(topk)*batch_size,行数代表第几大概率,所以correct前n行就代表了前n大概率的预测情况。再看列:列中的前n行有True就表示topn预测正确。比如第4列代表的第4个样本,真实的标签是第二行,但是模型预测的标签最大是第一行,次大是第二行。那么top-1正确率就是False,预测失败,不计入top-1准确率。但是top-3正确率就是True,预测成功,计入top-3准确率。

# pred转置
tensor([[4, 3, 7, 9],
        [3, 0, 3, 3],
        [0, 4, 0, 0]])
# target转置并改变维度
tensor([[4., 3., 7., 3.],
        [4., 3., 7., 3.],
        [4., 3., 7., 3.]])
# correct输出:比较pred和target
tensor([[ True,  True,  True, False],
        [False, False, False,  True],
        [False, False, False, False]])

然后我们输出topk准确率:

res = []
for k in topk:
    correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
    res.append(correct_k.mul_(100.0 / batch_size)) # 以百分比形式输出
print(res)

输出:

[tensor([75.]), tensor([100.])]

最后打包成函数,以后使用直接复制即可:

def accuracy(output, target, topk=(1,5)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    # 根据指定值k,计算top-k准确度
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True) 
        # topk取一个tensor的topk元素(降序后的前k个大小的元素值及索引)
        # 返回两个张量:values和indices,分别对应前k大值的数值和索引
        pred = pred.t() # 转置
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        # eq输出元素相等的布尔值,expand_as将张量扩展为参数tensor的大小,view()的作用相当于numpy中的reshape,重新定义矩阵的形状。

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


 

猜你喜欢

转载自blog.csdn.net/qq_54708219/article/details/129428423