pytorch十一:计算机视觉工具包:torchvision

计算机视觉是深度学习中最重要的一类应用,为了方便研究者应用,pytorch专门开发了一个视觉工具包torchvision。

可通过pip install torchvision安装。

torchvision主要包含以下三部分:

模型加载 

  • models:提供深度学习中各种经典网络结构及与训练好的模型,包括Alex-Net、VGG系列、ResNet系列、Inception系列等。
  • datasets:提供常用的数据集下载,设计上都是继承torch.utils.data.Dataset,主要包括MNIST、CIFAR10/100、ImageNet、COCO等。
  • transform:提供常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作
from torchvision import models
from torch import nn

#加载预训练模型,如果不存在会下载
#预训练的模型保存在~/.torch/models/下面
resnet34 = models.resnet34(pretrained=True,num_classes=1000)

#修改最后的全连接层为10分类问题(默认是ImageNet上的1000分类)
resnet34.fc = nn.Linear(512,10)
import torch as t
from torchvision import transforms as T

to_pil = T.ToPILImage()
to_pil(t.randn(3,128,128))
>>输出如下图所示

数据加载

from torchvision import transforms as T
transform = T.Compose(
[
    T.ToTensor(),
    T.Normalize(mean=[0.5],std=[0.5])
])

from torchvision import datasets
#指定数据集路径为data,如果数据集不存在则进行下载
#通过train = False获取测试集
dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)

len(dataset)
>>10000

torchvision中还提供了两个常用的函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是save_img,它能将Tensor保存成图片。

from torch.utils.data import DataLoader
from torchvision.utils import make_grid,save_image
from torchvision import transforms as T
to_img = T.ToPILImage()

dataloader = DataLoader(dataset,batch_size=16,shuffle=True)

dataiter = iter(dataloader)
imgs,label = (next(dataiter))
print(label)
img = make_grid(imgs,4)#拼成4*4网格图片
to_img(img)

>>tensor([9, 2, 8, 3, 5, 6, 0, 5, 6, 2, 2, 5, 6, 6, 7, 6])

to_img(imgs[4])

from PIL import Image
save_image(img,'a.png')
Image.open('a.png')

猜你喜欢

转载自blog.csdn.net/qq_24946843/article/details/89452118
今日推荐