argmax对softmax的输出返回概率最大的类别

假设数据如下,每行是softmax输出得到的概率,我需要找到最大的概率返回类别,可以使用argmax函数

(1)注意使用argmax函数时,需要将数据转换为tensor类型,否则报错
argmax(): argument 'input' (position 1) must be Tensor, not numpy.ndarray
(2)torch.argmax函数需要传递dim参数,dim=1就是在行上求
index = torch.argmax(data_pre, dim=1)
import numpy as np
import pandas as pd
import torch


data_pre = np.loadtxt('./pred.txt')
data_pre = torch.tensor(data_pre)
index = torch.argmax(data_pre, dim=1)
index = np.array(index)

np.savetxt('./cluster.txt', (index))
 

猜你喜欢

转载自blog.csdn.net/ziqingnian/article/details/111272798