Pytorch实现Top1准确率和Top5准确率

之前一直不清楚Top1和Top5是什么,其实搞清楚了很简单,就是两种衡量指标,其中,Top1就是普通的Accuracy,Top5比Top1衡量标准更“严格”,

具体来讲,比如一共需要分10类,每次分类器的输出结果都是10个相加为1的概率值,Top1就是这十个值中最大的那个概率值对应的分类恰好正确的频率,而Top5则是在十个概率值中从大到小排序出前五个,然后看看这前五个分类中是否存在那个正确分类,再计算频率。Pytorch实现如下:

def evaluteTop1(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        #correct += torch.eq(pred, y).sum().item()
    return correct / total

def evaluteTop5(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1,5))
            _, pred = logits.topk(maxk, 1, True, True)
            correct += torch.eq(pred, y).sum().float().item()
    return correct / total

topk函数的具体用法参见https://blog.csdn.net/u014264373/article/details/86525621

猜你喜欢

转载自www.cnblogs.com/yqpy/p/11391972.html