Loading and training of Pytorch custom data sets

Table of contents

1) Datasets source code

2) Datasets overall framework

3) Custom Datasets framework

4) Use of DataLoader

5) Generate txt file


Datasets is the library of data sets we use. We know that pytorch comes with a variety of data sets, such as the Cifar10 data set, which is in the Datasets library of pytorch.

There is a tool function torch.utils.Data.DataLoader in Pytorch. Through this function, we can use multi-threaded parallel processing when preparing to load the data set using mini-batch, which can speed up the speed of preparing the data set. Datasets is one of the instance parameters used to build this tool function.

The Dataset class is the most important class in the image data set in Pytorch, and it is also the parent class that should be inherited by all data set loading classes in Pytorch. The two private member functions in the parent class must be overloaded, otherwise an error message will be triggered:

1. def getitem(self, index):
2. def len(self):

Among them, __len__ should return the size of the data set, and __getitem__ should write a function that supports the index of the data set.
Here we focus on the getitem function. getitem receives an index and then returns the image data and label. This index usually refers to a list. index, each element of this list contains the path and label information of the image data .

To make a list, the usual method is to store the path and label information of the image in a txt, and then read it from the txt.
Then the basic process of reading your own data is:

1. Make a txt that stores the path and label information of the image;
2. Convert this information into a list, each element of the list corresponds to a sample;
3. Read the data and labels through the getitem function, and return the data and labels;

Define your own dataset class

1) Datasets source code

All datasets are subclasses of torch.utils.data.Dataset i.e,
 they have __getitem__ and __len__ methods implemented. 
 Hence, they can all be passed to a torch.
 utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. 

[Source code address ( https://pytorch.org/docs/stable/torchvision/datasets.html )
From the source code we can see that inheriting Datasets must inherit __init_() and __getitim__()
first inherit the above dataset class. Then get the path of the image in the __init__() method, and then form the image path into an array, so that it can be read directly in __getitim__().

2) Datasets overall framework

class FirstDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. 初始化文件路径或文件名列表。
        #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
        pass
    def __getitem__(self, index):
        # TODO

        #1。从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
         #2。预处理数据(例如torchvision.Transform)。
         #3。返回数据对(例如图像和标签)。
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # 您应该将0更改为数据集的总大小。

3) Custom Datasets framework

# ***************************一些必要的包的调用********************************
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch.optim as optim
import os

# ***************************初始化一些函数********************************
# torch.cuda.set_device(gpu_id)#使用GPU
learning_rate = 0.0001  # 学习率的设置

# *************************************数据集的设置****************************************************************************
root = os.getcwd() + '/data1/'  # 数据集的地址


# 定义读取文件的格式
def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(Dataset):
    # 创建自己的类: MyDataset,这个类是继承的torch.utils.data.Dataset
    # **********************************  #使用__init__()初始化一些需要传入的参数及数据集的调用**********************
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        super(MyDataset, self).__init__()
        # 对继承自父类的属性进行初始化
        fh = open(txt, 'r')
        # 按照传入的路径和txt文本参数,以只读的方式打开这个文本
        for line in fh:  # 迭代该列表#按行循环txt文本中的内
            line = line.strip('\n')
            line = line.rstrip('\n')
            # 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
            words = line.split()
            # 用split将该行分割成列表  split的默认参数是空格,所以不传递任何参数时分割空格
            imgs.append((words[0], int(words[1])))
            # 把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定
        # 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        # *************************** #使用__getitem__()对数据进行预处理并返回想要的信息**********************


    def __getitem__(self, index):  # 这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        fn, label = self.imgs[index]
        # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        img = self.loader(fn)
        # 按照路径读取图片
        if self.transform is not None:
            img = self.transform(img)
            # 数据标签转换为Tensor
        return img, label
        # return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
        # **********************************  #使用__len__()初始化一些需要传入的参数及数据集的调用**********************


    def __len__(self):
        # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)
train_data = MyDataset(txt=root + 'train.txt', transform=transforms.ToTensor())
test_data = MyDataset(txt=root + 'text.txt', transform=transforms.ToTensor())

I won’t describe it in detail here. The comments in the code are very clear. Here are some detailed explanations of some minor issues:

Question 1. Why is RGB returned?

def default_loader(path):
    return Image.open(path).convert('RGB')

Answer: For color images, regardless of whether the image format is PNG, BMP, or JPG, in PIL, after opening it using the open() function of the Image module, the mode of the returned image object is "RGB".
For grayscale images, regardless of whether the image format is PNG, BMP, or JPG, after opening, the mode is "L".

Question 2 : Why word[0] is picture information and word[1] is label information

imgs.append((words[0],int(words[1])))
# 把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定
# 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable

When generating the txt file, I set the default setting so that the first part is the image information and the second part is the label information. Of course, if you don't like this order, you can change it.

4) Use of DataLoader

From the above Datasets source code, we can know that defining Datasets will also encounter the DataLoader class, which can adjust utils.data.DataLoader and use torch.multiprocessing worker to load multiple samples in parallel .
DataLoader class
The Dataset class mentioned before reads the data set data and indexes the read data. But this function alone is not enough. In the actual process of loading the data set, our data volume is often very large. For this, we also need several functions:

batch_size:  可以分批次读取

shuffle=True  可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序

num_workers=2 可以并行加载数据(利用多核处理器加快载入数据的效率

batch : 可以分批次读取:batch-size

At this time, the Dataloader class is needed. The Dataloader class does not require us to design the code ourselves. We only need to use the DataLoader class to read the ShipDataset we designed:

train_loader= DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4)
test_loader = DataLoader(dataset=test_data,  batch_size=6, shuffle=False,num_workers=4)

The simple operation of reading a data set is generally like this. Of course, actual applications may involve more complex operations, so we will not give a complicated description here. For details, you can take a look at the source code of the official website.

5) Generate txt file

In addition to the above code for reading the data set, how should we place the actual image data? This requires the use of txt files.
Generally speaking, the data set we make ourselves generally contains three parts: training set, verification set and test set .
Because the data set is large, we generally put these three modules under three folders and call them directly using code. It is simple, convenient and not prone to errors.
When calling, we not only need to call the image but also the path and label information of the image, so we use the txt file and add two kinds of information to the txt file. One is the path of the image. We can find the image through the path of the image. , so as to read the picture; the other is the label of the picture, which corresponds the information of each picture to the label one by one.
The following is the code to generate the txt file of the image:

import os

a = 0
while (a < 1024):               # 1024为我们的类别数
    dir = './data/images/'      # 图片文件的地址
    label = a
    # os.listdir的结果就是一个list集,可以使用list的sort方法来排序。如果文件名中有数字,就用数字的排序
    files = os.listdir(dir)     # 列出dirname下的目录和文件
    files.sort()                # 排序
    train = open('./data/train.txt', 'a')
    text = open('./data/text.txt', 'a')
    i = 1
    for file in files:
        if i < 200000:
            fileType = os.path.split(file)  # os.path.split():按照路径将文件名和路径分割开
            if fileType[1] == '.txt':
                continue
            name = str(dir) + file + ' ' + str(int(label)) + '\n'
            train.write(name)
            i = i + 1

        else:
            fileType = os.path.split(file)
            if fileType[1] == '.txt':
                continue
            name = str(dir) + file + ' ' + str(int(label)) + '\n'
            text.write(name)
            i = i + 1
    text.close()
    train.close()
    a = a+1 #######

Generating a txt file refers to the article https://blog.csdn.net/Jack_and_monkey/article/details/86677253.
The entire code for defining the data set in this article is as follows:

import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch.optim as optim
import os

# torch.cuda.set_device(gpu_id)#使用GPU
learning_rate = 0.0001

# 数据集的设置*****************************************************************************************************************
root = os.getcwd() + '/data1/'  # 调用图像


# 定义读取文件的格式
def default_loader(path):
    return Image.open(path).convert('RGB')


# 首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取:
class MyDataset(Dataset):  # 创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):  # 初始化一些需要传入的参数
        super(MyDataset, self).__init__()   # 对继承自父类的属性进行初始化
        fh = open(txt, 'r')                 # 按照传入的路径和txt文本参数,打开这个文本,并读取内容
        imgs = []
        for line in fh:  # 迭代该列表#按行循环txt文本中的内
            line = line.strip('\n')
            line = line.rstrip('\n')    # 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
            words = line.split()        # 用split将该行分割成列表  split的默认参数是空格,所以不传递任何参数时分割空格
            imgs.append((words[0], int(words[1])))  # 把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定
            # 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):       # 这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        fn, label = self.imgs[index]    # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        img = self.loader(fn)           # 按照路径读取图片
        if self.transform is not None:
            img = self.transform(img)   # 数据标签转换为Tensor
        return img, label               # return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容

    def __len__(self):                  # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)


# 根据自己定义的那个MyDataset来创建数据集!注意是数据集!而不是loader迭代器
# *********************************************数据集读取完毕********************************************************************
# 图像的初始化操作
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((227, 227)),
    transforms.ToTensor(),
])
text_transforms = transforms.Compose([
    transforms.RandomResizedCrop((227, 227)),
    transforms.ToTensor(),
])

# 数据集加载方式设置
train_data = MyDataset(txt=root + 'train.txt', transform=transforms.ToTensor())
test_data = MyDataset(txt=root + 'text.txt', transform=transforms.ToTensor())
# 然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False, num_workers=4)
print('num_of_trainData:', len(train_data))
print('num_of_testData:', len(test_data))

Original link: https://blog.csdn.net/sinat_42239797/article/details/90641659

Guess you like

Origin blog.csdn.net/u010192735/article/details/131684089