torch之测试过程

torch之测试过程test()

import torch
import torchvision
from PIL import Image
from torch import nn

img_path = "./image/dog.webp"
image = Image.open(img_path)
print(image)

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])
image = transform(image)
print(image.shape)

# class FirstModel(nn.Module):
#     def __init__(self):
#         super(FirstModel, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(3, 32, 5, 1, 2),
#             nn.MaxPool2d(2),
#             nn.Conv2d(32, 32, 5, 1, 2),
#             nn.MaxPool2d(2),
#             nn.Conv2d(32, 64, 5, 1, 2),
#             nn.MaxPool2d(2),
#             nn.Flatten(),
#             nn.Linear(64*4*4, 64),
#             nn.Linear(64, 10)
#         )
#     def forward(self,x):
#         x = self.model(x)
#         return x

# 加载模型数据(之前保存的)
model = torch.load("firstmodel_0.pth", map_location=torch.device("cpu"))
print(model)
image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
    output = model(image)
print(output)
print(output.argmax(1))

测试结果:

执行结果

猜你喜欢

转载自blog.csdn.net/weixin_49005845/article/details/125642830