版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Mundane_World/article/details/80873074
由Tensorflow转Pytorch,慢慢开始吧
Pytorch的torchvision虽然包含了几个数据集的读取API,例如CAFRI10,Imagenet,MNIST等,但这远远不够。实际应用中,我们需要从各种不同的数据集中读取图片。
Pytorch自定义读取数据的方式,主要用到两个类:
torch.utils.data.Dataset
和torch.utils.data.DataLoader
为了自由读取数据集中的数据(图片),必须写一个Dataset的子类,该子类中必须overrider两个Dataset中的方法:__getitem__(self, index)
和__len__(self)
。
前者是一个通过index索引来读取数据的方法,在这个方法中,可以利用torchvision.transforms
来对数据做一些预处理。需要注意的是,该方法中产生的数据就是直接传递到DataLoader中的数据,因此,数据格式一定要是Tensor!
后者返回数据集的长度。
最后调用DataLoader来形成batch,并可以进行shuffle等操作。
具体代码如下:
import cv2
import os
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
name = 'C:\\Users\\Administrator\\Desktop\\history\\HED-BSDS\\test1.lst'
base_path = 'C:\\Users\\Administrator\\Desktop\\history\\HED-BSDS'
#首先定义一个Dataset的子类->myDataset
class myDataset(Dataset):
def __init__(self, name, base_path):
f = open(name)
self.filenames = f.readlines()
f.close()
#override这两个方法
def __getitem__(self, index):
path = self.filenames[index]
print(os.path.join(base_path, path))
img = cv2.imread(os.path.join(base_path, path).strip())
img = torch.Tensor(img)
return img
def __len__(self):
return len(self.filenames)
dataset = myDataset
train_loader = DataLoader(dataset(name=name, base_path=base_path),
batch_size=4, shuffle=True)
for img in train_loader:
print(img.size())
cv2.imshow('we', np.uint8(img.numpy()[0]))
cv2.waitKey()
最初级的读取方式就是这样。
另外,torchvision.data.ImageFolder
也可以方便读取已经分类好的,放在不同文件夹中的图片。