Pytorch——DataSet与DataLoader

When using pytorch to build deep learning-related projects, it usually needs to go through [model structure]-[loss function definition]-[data setting]-[training code]-[log, verification, visualization and checkpoints]. Among them, [Data Settings] often need to customize the appropriate DataLoader (data loader) because of different projects/tasks.

This article will introduce the basic usage of Dataset and Dataloader in torch.utils.data, take the loading of unpaired image data of Unpaired Image-to-Image Translation task as an example, and explain how pytorch customizes the data loader.

The following codes are in the file dataset.py.

(1) Import necessary packages

# -*- coding:utf-8 -*-

import torch.utils.data as data
import torchvision.transforms as transforms
import os
from PIL import Image
import random
import torch
import numpy as np

(2) Custom Dataset Dataset

#### 01. Create a dataset
## BaseDataset
class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return "BaseDataset"

    def initialize(self, opt):
        self.opt = opt
'''
定义一些公用的属性/函数;一般的,torch.utils.data.Dataset 本身已经包含了很多属性,如 __len__, __getitem__ 等。

一般我们会新增一个成员函数 name 和 initialize,分别用于:
1)name:没有任何意义,纯属装 B
2)在 pytorch 中,我们经常会使用到 parser,即一个能够从命令行赋予超参数值的辅助类,我们在代码中实例化它的一个对象为 "opt" ,而且,诸如 opt.img_size, opt.batch_size 这样的参数是与 data 相关的,所以我们通常会在这个函数引入 opt,并将它作为自己一个属性 self.opt,如此,我们就可以随时访问所有的超参数了。
'''

Next we want to customize the data set UnAlignedDataset.

First look at what our data set looks like:

, This is the data set structure of a typical UIT model, you can know that it involves dual training. Under each subfolder is a series of images, and they are not aligned.

Let us explain some opt parameters:


opt.dataroot = '__data__/horse2zebra'
opt.mode = 'train'            # 训练的时候是 train,测试的时候是 test,用来辅助分情况

opt.trainA = 'trainA'
opt.trainB = 'trainB'
opt.testA  = 'testA'
opt.testB  = 'testB'

opt.load_size = 288           # 读入图像大小
opt.crop_size = 256           # 将读入后的图像随机裁剪出的 patch 的大小
opt.input_nc  = 3             # 图像输入的通道数:RGB-3,灰度图-1,CMYK-4等等,一般是前两种情况

Our thinking is: (1) Get the path of all images in initialize to ensure that we can access them; (2) Define the basic processing pipeline of image data in initialize; (3) Define how to return in __getitem__ data.

 

## SelfDataset
class UnAlignedDataset(BaseDataset):
    ## 重写 name,返回数据集的名字,一般用不到
    def name(self):
        return "UnAlignedDataset"

    ## 重写 initialize
    '''
    这里我们会根据传入的 opt,获取数据集的基本信息
    '''
    def initialize(self, opt):
        self.opt = opt                                     #-> 获取 opt

        ## get dir 
        self.dataroot = opt.dataroot                       #-> 根据 opt 里的 dataroot 得知数据集的位置

        ## get images                                      #-> 构建图像子文件夹的路径
        if opt.mode == 'train':
            dir_A = os.path.join(opt.dataroot, opt.trainA)
            dir_B = os.path.join(opt.dataroot, opt.trainB)
        elif opt.mode == 'test':
            dir_A = os.path.join(opt.dataroot, opt.testA)
            dir_B = os.path.join(opt.dataroot, opt.testB)

        A_paths = os.listdir(dir_A)
        B_paths = os.listdir(dir_B)
        self.length = min(len(A_paths), len(B_paths))      #-> 获取图像域 A 和图像域 B 的所有文件的文件名;并定义数据集大小为两个域的大小的较小一个,构建新的属性 self.length 存储它

        ## get full path
        for i in range(len(A_paths)):
            A_paths[i] = os.path.join(dir_A, A_paths[i])
        for i in range(len(B_paths)):
            B_paths[i] = os.path.join(dir_B, B_paths[i])   #-> 为了方便调用,先构建每张图像的完整路径(这里用相对路径)
        self.A_paths = A_paths 
        self.B_paths = B_paths                             #-> 最后,为了在其他成员函数中可以直接访问,我们构建新的属性来存储它们

        self.input_nc = self.opt.input_nc                  #-> 当然,对于一些重要的属性,我们可以从 opt. 中单独取出,下次用的时候就不需要经过 self.opt.xxx 调用,当然你也可以这么做,只不过不优雅

        ## define transform
        transforms_list = [transforms.ToTensor(),                  #-> 从numpy到torch.tensor
                           transforms.Normalize((0.5, 0.5, 0.5),   
                                                (0.5, 0.5, 0.5))]  #-> 归一化到 -1.0~+1.0
        self.transform = transforms.Compose(transforms_list)
        #-> 定义数据处理的过程,注意,经过 torch.utils.data.Dataset 读入的图像就已经将像素值转换为浮点数,范围在 0~1.0 之间了,类型是 numpy 数组

    ## Dataset 类的核函数,用 len(dataset_object) 调用,返回数据集的大小
    #-> Dataset 的大小与 DataLoader 的 batch_size 共同决定了一个 epoch 中 迭代次数的多少。即:length_of_dataset // batch_size
    def __len__(self):
        return self.length

    ## 这个核函数是 dataset 被调用时自己内部调用的,每次 dataset 用 next 获取下一个 batch 的数据的时候,内部会用连续的 batch_size 个索引来取值,并将最后的 batch_size 个结果在第〇个维度拼接在一起。
    '''
    举个栗子,在图像中,网络的输入一般是:(B, C, H, W);在视频中,输入一般是:(B, C, T, H, W)
    而在 __getitem__ 中,我们通过定义它,让数据返回的数据是:(C, H, W)或者(C, T, H, W)的形式
    '''
    def __getitem__(self, index):
        #-> 首先我们获取图像路径,注意由于我们的任务需要两个图像域的图像
        #-> 我们根据索引对应数据大小的模来定位
        A_pth = self.A_paths[index % self.length]
        B_pth = self.B_paths[index % self.length]    

        #-> 读入图像
        x_img = Image.open(A_pth).convert('RGB')                                          #-> 读入图像
        x_img = x_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)     #-> 双线性插值放缩到我们指定的大小(256x256)
        x = self.transform(x_img)                                                         #-> 数据预处理  

        y_img = Image.open(B_pth).convert('RGB')
        y_img = y_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
        y = self.transform(y_img)

        ## random crop 随机裁剪
        h, w = x.size(1), x.size(2)
        h_offset = random.randint(0, max(0, h - self.opt.crop_size - 1))
        w_offset = random.randint(0, max(0, w - self.opt.crop_size - 1))
        x = x[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]
        y = y[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]

        ## expand to 4-dim tensor
        if self.opt.input_nc == 1:
            # RGB to gray
            tmp_x = x[0, ...] * 0.299 + x[1, ...] * 0.587 + x[2, ...] * 0.114
            x = tmp_x.unsqueeze(0)  # (H,W) -> (C=1,H,W)
            tmp_y = y[0, ...] * 0.299 + y[1, ...] * 0.587 + y[2, ...] * 0.114
            x = tmp_y.unsqueeze(0)  # (H,W) -> (C=1,H,W)

        return {'A': x, 'B':y, 'A_pth': A_pth, 'B_pth': B_pth} 
        '''
        返回什么样的数据是我们自定义的,后面我们会看到,我们怎么使用它:

        for i, data in enumerate(dataset):
            real_x = data['A']
            real_y = data['B']
            ...
        
        可以发现,DataLoader 只负责返回 batch 的数据(数据分不同部分时,各个部分单独作为 batch),数据的具体内容自定义的

        '''

Well, we can find that the DataSet defines how to process the single data to be returned (pixel value normalization, image cropping, color space, etc., that is, all the image processing we have learned in "digital image processing" Technology can be applied); we found that some of transform can be used directly; if not, you can customize the processing function of transform, or you can write it directly in __getitem__ like RGB to Gray above !

(Three) custom data loader

As mentioned earlier, the definition of DataSet is to return a single data, so the task of forming batch, the task of fast loading (threading), the task of shuffle (shuffle) the data set after each epoch, etc., are all completed by DataLoader of.

First, we define a basic DataLoader, mainly to introduce opt so the new member function initialize is added.

#### 0.2 Create a Dataloader
## BaseDataLoader
class BaseDataLoader():
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt

    def load_data(self):
        return None

Below we newly define the data loader of UnAlignedDataLoader.

## Dataloader for self data
class UnAlignedDataLoader(BaseDataLoader):
    def name(self):
        return "UnAlignedDataLoader"

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)  # get the copy of opt->self.opt

        # add dataset and nitialize it
        self.dataset = UnAlignedDataset()
        self.dataset.initialize(opt)          # 因为 initialize 不是 torch.utils.data.Dataset 的核函数,所以我们需要手动调用它,才算完整初始化

        # define a data loader
        self.dataloader = data.DataLoader(    # 调用 torch.utils.data.DataLoader,
            self.dataset,
            batch_size=opt.batch_size,        # batch 的大小
            shuffle=True,                     # 每个 epoch 后是否洗牌
            num_workers=int(opt.n_threads)    # 使用多少个进程加载数据
        )

    def load_data(self):                      # 返回整个数据加载器本身!!!非常重要
        return self

    def __len__(self):                        # 返回数据集的大小
        return len(self.dataset)

    def __iter__(self):
        for _, d in enumerate(self.dataloader):
            yield d                           # 核函数,用于每次以 batch 遍历整个数据集,即一个epoch

Now we can find out, in fact, many of them are routines! The main thing we need to customize is in UnAlignedDataset, the path to get all the data in initialize; read in the data in __getitem__, and do custom processing (scaling, cropping, pixel value normalization, etc.), These processes can be existing in transform, or they can be customized.

In addition, the structure and content of the other three classes basically do not need to be changed.

(4) Test

Finally is the test~

#### Test data loader
from config import parser
opt = parser.parse_args() ##-> 这是我自定义的,大家需要自己定义,结构大致如下:
'''
# config.py

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
...

'''

data_loader = UnAlignedDataLoader()
data_loader.initialize(opt)

data_set = data_loader.load_data()

for i, data in enumerate(data_set):
    print(i, data['A'].size(), data['B'].size())

The output is shown on the left. 

So far, the introduction of pytorch's custom simple data loader to traverse the data set is over. If there are omissions/errors, please point out!

Guess you like

Origin blog.csdn.net/WinerChopin/article/details/100898131