torch之测试过程test()
import torch
import torchvision
from PIL import Image
from torch import nn
img_path = "./image/dog.webp"
image = Image.open(img_path)
print(image)
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()
])
image = transform(image)
print(image.shape)
# class FirstModel(nn.Module):
# def __init__(self):
# super(FirstModel, self).__init__()
# self.model = nn.Sequential(
# nn.Conv2d(3, 32, 5, 1, 2),
# nn.MaxPool2d(2),
# nn.Conv2d(32, 32, 5, 1, 2),
# nn.MaxPool2d(2),
# nn.Conv2d(32, 64, 5, 1, 2),
# nn.MaxPool2d(2),
# nn.Flatten(),
# nn.Linear(64*4*4, 64),
# nn.Linear(64, 10)
# )
# def forward(self,x):
# x = self.model(x)
# return x
# 加载模型数据(之前保存的)
model = torch.load("firstmodel_0.pth", map_location=torch.device("cpu"))
print(model)
image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
output = model(image)
print(output)
print(output.argmax(1))