pytorch 多分类问题,计算百分比

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
                    [2,3],
                    [1,0],
                    [3,4]])
cond = torch.Tensor([[1,0],
                     [0,1],
                     [1,0],
                     [1,0]])

persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
       [2, 3],
       [1, 0],
       [3, 4]]
cond = [[1, 0],
        [0, 1],
        [1, 0],
        [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

猜你喜欢

转载自blog.csdn.net/luolinll1212/article/details/83897047