PyTorch数据读取

torch.utils.data.DataLoader

torch.utils.data.DataLoader(torch.utils.data.dataset,batch_size,shuffle,num_workers,pin_memory)

关键是这两个类:
torch.utils.data.DataLoader
torch.utils.data.dataset

import torchvision.transforms as transforms

train_loader = torch.utils.data.DataLoader(
ImageList(root=opt.root_path, fileList=opt.train_list, 
transform=transforms.Compose([ 
transforms.ToTensor(),              #将读取的图片变为Tensor类型,很重要
])),
batch_size=opt.batch_size, shuffle=True,
num_workers=opt.workers, pin_memory=True)

写一个类作为数据读取器,继承torch.utils.data.dataset

#load_imglist.py
import torch.utils.data 

from PIL import Image
import os



def default_list_reader(fileList):
    imgList = []
    with open(fileList, 'r') as file:
        for line in file.readlines():
            imgPath, label = line.strip().split(' ')
            imgList.append((imgPath, int(label)))
    return imgList


class ImageList(torch.utils.data.Dataset):
    def __init__(self, root, fileList, transform=None):
        self.root      = root
        self.imgList   = default_list_reader(fileList)
        self.transform = transform


    def __getitem__(self, index):

        imgPath, target = self.imgList[index]

        print(imgPath)

        img_loc=os.path.join(self.root, imgPath)
        img = Image.open(img_loc).convert('L')  #默认读取彩色图象,这儿转化为RGB图像

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.imgList)

使用

for i,(input,target) in enumerate(train_loader):
    print(i,target)
    print(input.shape)

输出的最后一个结果为

(1093, 
 928
[torch.LongTensor of size 1]
)
(1L, 1L, 64L, 64L)

输出的Tensor是4维,将图像自动加了一维。

猜你喜欢

转载自blog.csdn.net/yskyskyer123/article/details/80669894