Deep Eye Pytorch punch card (7): Pytorch data reading mechanism, DataLoader() and Dataset

Preface


   Whether it is model training or actual testing, data reading is the first step, because deep learning is ultimately driven by data. If there is a program that can accurately read in the data, then combined with preprocessing, pre-trained models and optimizers in Pytorch, you can build a simple model for prediction. The core of Pytorch's data reading is the DataLoader method and the Dataset class. The framework of this note is mainly derived from Deep Eyes , and some related expansions have been made. The expanded content mainly comes from the translation and understanding of torch documents .

   Data segmentation: Pytorch clocking in the Eye of Depth (6): Method of segmenting the data set into training set, validation set and test set


Dataset class


   Dataset is an abstract class representing data, defining where and how the data is read. Where is the data read from? Of course, it is read from the hard disk, which is achieved by passing a path parameter to the Dataset. How to read the data? The reading method needs to be customized by us, and we have different reading methods for different data set division methods.

class DataSet(Dataset):

   pass

   We need to customize the class representing the data set to inherit the abstract class of Dataset. Then instantiate the Dataset subclass we created to represent the training set, validation set and test set data. Each subclass of Dataset must have replication __getitem__()methods, and often optional replication __len__()methods and __init__()methods.

   The basic framework is shown in the code:

class Dataset(Dataset):

	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...
  • Method implementation and instantiation

  __init__(self)Used to add some attributes of the class itself, such as tags, data information, and whether to enhance data. __len__(self)Used to return the size of the data set. __getitem__(self, index)It is used to receive an index indexand return the corresponding data and label in the data set. It is the core of reading data and is generated indexby the samplerclass in DataLoader() .

  The first data set division method in the previous note is realized by data reading:

import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

label_name = {
    
    'ants': 0, 'bees': 1}


class DataSet(Dataset):
    def __init__(self, data_path):  # 除了这两个参数之外,还可以设置其它参数
        self.label_name = {
    
    'ants': 0, 'bees': 1}
        self.data_info = get_info(data_path)

    def __getitem__(self, index):
        label, img_path = self.data_info[index]
        pil_img = Image.open(img_path).convert('RGB')  # 读数据
        re_img = transforms.Resize((32, 32))(pil_img)
        img = transforms.ToTensor()(re_img)  # PIL转张量
        return img, label

    def __len__(self):
        return len(self.data_info)


def get_info(data_path):
    data_info = list()
    for root_dir, sub_dirs, _ in os.walk(data_path):
        for sub_dir in sub_dirs:
            file_names = os.listdir(os.path.join(root_dir, sub_dir))
            img_names = list(filter(lambda x: x.endswith('.jpg'), file_names))
            for i in range(len(img_names)):
                img_path = os.path.join(root_dir, sub_dir, img_names[i])
                img_label = label_name[sub_dir]
                data_info.append((img_label, img_path))

    return data_info


if __name__ == '__main__':

    train_set_path = os.path.join('data', 'train_set')
    val_set_path = os.path.join('data', 'val_set')
    test_set_path = os.path.join('data', 'test_set')
    train_set = DataSet(data_path=train_set_path)
    val_set = DataSet(data_path=val_set_path)
    test_set = DataSet(data_path=test_set_path)

Figure 1. Debugging results of the first method

  Breakpoint debugging results:
Insert picture description here
  The second data set division method in the previous note is realized by data reading: mainly get_info()the difference of functions.

import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

label_name = {
    
    'ants': 0, 'bees': 1}


class DataSet(Dataset):
    def __init__(self, data_path):  # 除了这两个参数之外,还可以设置其它参数
        self.label_name = {
    
    'ants': 0, 'bees': 1}
        self.data_info = get_info_list(data_path)

    def __getitem__(self, index):
        label, img_path = self.data_info[index]
        pil_img = Image.open(img_path).convert('RGB')  # 读数据
        re_img = transforms.Resize((32, 32))(pil_img)
        img = transforms.ToTensor()(re_img)  # PIL转张量
        return img, label

    def __len__(self):
        return len(self.data_info)

def get_info_list(list_path):
    data_info = list()
    with open(list_path, mode='r') as f:
        lines = f.readlines()
        for i in range(len(lines)):
            img_label = int(lines[i].split(' ')[0])
            img_path = lines[i].split(' ')[1]
            data_info.append((img_label, img_path))
    return data_info


if __name__ == '__main__':

    train_list_path = os.path.join('old_data', 'train_set.txt')
    val_list_path = os.path.join('old_data', 'val_set.txt')
    test_list_path = os.path.join('old_data', 'test_set.txt')
    train_set = DataSet(data_path=train_list_path)
    val_set = DataSet(data_path=val_list_path)
    test_set = DataSet(data_path=test_list_path)

  Breakpoint debugging results:
Insert picture description here

Figure 2. The second way of debugging results

DataLoader method


   The DataLoader() method is to 可迭代load the data provided on a given data set , that is, every time the model is iterated, a batch_size data is obtained from the DataLoader(). The function form and parameters are shown in the following code.

torch.utils.data.DataLoader(
                                  dataset,
                                  batch_size=1, 
                                  shuffle=False, 
                                  sampler=None, 
                                  batch_sampler=None, 
                                  num_workers=0, 
                                  collate_fn=None, 
                                  pin_memory=False, 
                                  drop_last=False, 
                                  timeout=0, 
                                  worker_init_fn=None, 
                                  multiprocessing_context=None)

   dataset: An instance of a subclass inherited from the above abstract class Dataset, such as training data class, validation set class, etc.
   batch_size: batch size, that is, the data size for one iteration, the default is 1.
   shuffle: Set epochwhether the order of the samples in each is out of order, the default is False. All training samples are input into the model once, called one epoch. According to the size and batch_sizesize of the training data , you can calculate epochhow many iterations one needs to perform.
   sampler: Define the strategy for extracting samples from the data set, that is, the way to generate the index, which can be in order or out of order. In the above example Dataset, reproducible __getitem__(self, index)in indexthat the the samplergenerated class. See this blog post for detailed explanation .
   batch_sampler: returns one batch of data at a time index, that is, the indices generated by the sampler are packaged and grouped to obtain the index of one batch after another.
  num_workers: Whether to use multiple processes to read data, the default 0is to read data in the main process.
  collate_fn: Combine the data and tags of a batch.
  drop_last: set Truetime, if the data set size is not evenly divisible batch_size, then delete the last incomplete batche. Set asFalse, And the size of the data set cannot be divisible by batch_size, so the last batch will be smaller.

  • DataLoader use

  The simple use of DataLoader can be realized by adding the following code to the Dataset instantiation. Read the training set, the batch size is 10, read randomly.

    train_loader = DataLoader(dataset=train_set, batch_size=2, shuffle=True)

    for i, data in enumerate(train_loader):
        inputs, labels = data

  Then set a breakpoint here, debug, single-step debugging, and observe the operation of DataLoader.
Insert picture description here

Figure 3. Breakpoint location

  First, the initialization of the DataLoader is performed, such as DataLoader(object)the __iter__(self)method of entering the class in turn , selecting the single-process or multi-process DataLoader iterator, and then entering the single-process DataLoader iterator class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)initialization method by default __init__(self, loader), and then entering _BaseDataLoaderIter(objectthe initialization method init (self ) of the class , loader) etc.

class DataLoader(object):
# ......
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
class _BaseDataLoaderIter(object):
    def __init__(self, loader):
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
        self._num_yielded = 0

  After the initialization of DataLoader is completed, it starts to read data. The first entry is the _BaseDataLoaderIter(object)class def __next__(self)method. Then jump to the _SingleProcessDataLoaderIterclass _next_data(self)methods, and then enters the BatchSampler(Sampler)sampler class __iter__(self)method, here is the index value to generate a batch of all data, and placed in a list, as shown in FIG. Re-entering _SingleProcessDataLoaderIterthe class _next_data(self)method, the generated index to _dataset_fetcher.fetch (index). Finally jump to our custom DataSet(Dataset):class __getitem__(self, index), give it one index at a time, loop batch_size times. After the loop is completed, the default_collate(batch) method will be entered by default to integrate the data of this batch.

class BaseDataLoaderIter(object):

    def __next__(self):
        data = self._next_data()
        # ......
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
     # ......
     
    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data
class BatchSampler(Sampler):
     # ......
     
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

Insert picture description here

Figure 4. Generate a batch index
def default_collate(batch):

Insert picture description here
  Summary: From the above analysis, we can know that in fact, data reading completely depends on our custom DataSet class. The DataLoader() method is mainly used for iteration. Each time we perform an iteration, BatchSampler will be called to generate an index of batch data. Return to the index, and then call the DataSet to __getitem__(self, index)obtain the corresponding pictures and labels of the index in the data set. After the pictures and labels of a batch data are obtained, the package returns. In this way, a batch data read operation is realized.


reference


  https://blog.csdn.net/qq_31622015/article/details/90573874
  https://www.cnblogs.com/marsggbo/p/11308889.html
  https://www.cnblogs.com/jiaxin359/p/7324077.html
  https://ww1··w.cnblogs.com/jiaxin359/p/7324077.html

Guess you like

Origin blog.csdn.net/sinat_35907936/article/details/105636697