【Pytorch】使用EfficientNet进行图像分类预测

import json
from PIL import Image
import torch
from torchvision import transforms
model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
img = Image.open('panda.jpg')
labels_map = json.load(open('labels_map.txt'))
labels_map = [labels_map[str(i)] for i in range(1000)]
tfms = transforms.Compose([transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
              ])
img = tfms(img).unsqueeze(0)
model.eval()
with torch.no_grad():
  logits = model(img)
preds = torch.topk(logits, k=5).indices.squeeze(0).tolist()
for idx in preds:
  label = labels_map[idx]
  prob = torch.softmax(logits, dim=1)[0, idx].item()
  print('{:<75} ({:.2f}%)'.format(label, prob*100))

发布了437 篇原创文章 · 获赞 590 · 访问量 61万+

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/103009126