Pytorch读取图片

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Mundane_World/article/details/80873074

    由Tensorflow转Pytorch,慢慢开始吧
    Pytorch的torchvision虽然包含了几个数据集的读取API,例如CAFRI10,Imagenet,MNIST等,但这远远不够。实际应用中,我们需要从各种不同的数据集中读取图片。
    Pytorch自定义读取数据的方式,主要用到两个类:
torch.utils.data.Datasettorch.utils.data.DataLoader
    为了自由读取数据集中的数据(图片),必须写一个Dataset的子类,该子类中必须overrider两个Dataset中的方法:__getitem__(self, index)__len__(self)
    前者是一个通过index索引来读取数据的方法,在这个方法中,可以利用torchvision.transforms来对数据做一些预处理。需要注意的是,该方法中产生的数据就是直接传递到DataLoader中的数据,因此,数据格式一定要是Tensor!
后者返回数据集的长度。
    最后调用DataLoader来形成batch,并可以进行shuffle等操作。
具体代码如下:

import cv2
import os
import numpy as np
import torch 
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader


name = 'C:\\Users\\Administrator\\Desktop\\history\\HED-BSDS\\test1.lst'
base_path = 'C:\\Users\\Administrator\\Desktop\\history\\HED-BSDS'

#首先定义一个Dataset的子类->myDataset
class myDataset(Dataset):
    def __init__(self, name, base_path):
        f = open(name)
        self.filenames = f.readlines()
        f.close()
 #override这两个方法
    def __getitem__(self, index):
        path = self.filenames[index]
        print(os.path.join(base_path, path))
        img = cv2.imread(os.path.join(base_path, path).strip())
        img = torch.Tensor(img)
        return img

    def __len__(self):
        return len(self.filenames)

dataset = myDataset

train_loader = DataLoader(dataset(name=name, base_path=base_path), 
    batch_size=4, shuffle=True)
for img in train_loader:
    print(img.size())
    cv2.imshow('we', np.uint8(img.numpy()[0]))
    cv2.waitKey()

最初级的读取方式就是这样。
另外,torchvision.data.ImageFolder也可以方便读取已经分类好的,放在不同文件夹中的图片。

猜你喜欢

转载自blog.csdn.net/Mundane_World/article/details/80873074
今日推荐