零基础学Pytorch03_模型的调用预测

图像分类任务_模型的调用预测

predict.py

import torch
import torchvision.transforms as transforms
from PIL import Image

from LetNet import LetNet


def main():
    transform = transforms.Compose(
        [transforms.Resize((32, 32)), # 将图片强制转换成32*32的,这个跟网络结构定义有关,必须转换         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LetNet() # 实例化
    net.load_state_dict(torch.load('Lenet.pth')) # 载入权重文件

    im = Image.open('my.png')
    im = transform(im)  # [C, H, W]
    im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]

    with torch.no_grad():
        outputs = net(im)  # 将图像传到网络中
        predict = torch.max(outputs, dim=1)[1].numpy() #dim代表维度


    print(classes[int(predict)])

    with torch.no_grad():
        outputs = net(im)  # 将图像传到网络中
        predict = torch.softmax(outputs,dim=1)
    print(predict)


if __name__ == '__main__':
    main()

猜你喜欢

转载自blog.csdn.net/weixin_49321128/article/details/125453935