图像分类任务_模型的调用预测
predict.py
import torch
import torchvision.transforms as transforms
from PIL import Image
from LetNet import LetNet
def main():
transform = transforms.Compose(
[transforms.Resize((32, 32)), # 将图片强制转换成32*32的,这个跟网络结构定义有关,必须转换 transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LetNet() # 实例化
net.load_state_dict(torch.load('Lenet.pth')) # 载入权重文件
im = Image.open('my.png')
im = transform(im) # [C, H, W]
im = torch.unsqueeze(im, dim=0) # [N, C, H, W]
with torch.no_grad():
outputs = net(im) # 将图像传到网络中
predict = torch.max(outputs, dim=1)[1].numpy() #dim代表维度
print(classes[int(predict)])
with torch.no_grad():
outputs = net(im) # 将图像传到网络中
predict = torch.softmax(outputs,dim=1)
print(predict)
if __name__ == '__main__':
main()