【Pytorch】Detailed explanation of data reading

I'll add it after studying the code

Database DataBase + DataSet + Sampler = Loader

from torch.utils.data import *

IMDB + Dataset + Sampler || BatchSampler = DataLoader

1DataBase

Image DataBase is abbreviated as IMDB, which refers to the data information stored in the file.

The file format can be varied. Such as xml, yaml, json, sql.

VOC is in xml format, COCO is in JSON format.

The process of constructing IMDB is the process of parsing these files and establishing data indexes.

Generally, it will be parsed as a Python list for easy reading in subsequent iterations.

2 Data Set DataSet

Data Set: Based on the database IMDB, it provides a singleton or slice access method to the data.

In other words, it is to define the indexing mechanism of objects in the database, and how to implement singleton index or slice index.

In short, DataSet, by __getitem__defining the data set, DataSet is an indexable object, An Indexerable Object.

That is, after passing in a given index Index, how to perform singleton or slice access according to this index, singleton or slice depends on whether the Index is a single value or a list.

The source code of Pytorch is as follows:

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    # 定义单例/切片访问方法,即 dataItem = Dataset[index]
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

There are two ways to customize the data set based on the aforementioned Dataset base class and IMDB base class.

# 方法一: 单继承
class XxDataset(Dataset)
    # 将IMDB作为参数传入,进行二次封装
    imdb = IMDB()
    pass
# 方法二: 双继承
class XxDataset(IMDB, Dataset):
    pass

3 Sampler & BatchSampler

In practical applications, data is not necessarily accessed in a regular order, but needs to be accessed randomly in a random order, or random weighted access is required.

Therefore, to read data according to a specific rule is a sampling operation, and a sampler needs to be defined: Sampler .

In addition, the data may not be read one by one, but need to be read in batches, that is, batch sampling operations are required. Define a batch sampler: BatchSampler .

Therefore, only the singleton access method of Dataset is not enough. On this basis, further definition of batch access methods is needed.

In short, the sampler defines the index generation rules, and generates the index according to the specified rules, thereby controlling the data reading mechanism

BatchSampler is constructed based on Sampler:  BatchSampler = Sampler + BatchSize

The source code of Pytorch is as follows,

 
class Sampler(object):
    """Base class for all Samplers.
    采样器基类,可以基于此自定义采样器。
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """
    def __init__(self, data_source):
        pass
    def __iter__(self):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
# 序惯采样
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(range(len(self.data_source)))
    def __len__(self):
        return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(torch.randperm(len(self.data_source)).long())
    def __len__(self):
        return len(self.data_source)
# 随机子采样
class SubsetRandomSampler(Sampler):
    pass
# 加权随机采样
class WeightedRandomSampler(Sampler):
    pass

 

 
class BatchSampler(object):
    """Wraps another sampler to yield a mini-batch of indices.
    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    Example:
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler  # ******
        self.batch_size = batch_size
        self.drop_last = drop_last
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

 

It can be seen from the above that Sampler is essentially an iterable object with specific rules, but it can only be iterated in a single instance.

For example  [x for x in range(10)], range(10) is the most basic Sampler, and only one value can be retrieved in each loop.

[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]

 

BatchSampler encapsulates the Sampler twice and introduces the batchSize parameter to realize batch iteration.

from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]

4 Loader DataLoader

In actual calculations, if the amount of data is large, considering the limited memory and the slow IO speed,

Therefore, it cannot be loaded into memory all at once, nor can it be loaded with only one thread.

Therefore, multi-threading and iterative loading are required, so the loader is specifically defined: DataLoader .

DataLoader is an iterable object, An Iterable Object, which has a magic function configured inside ——iter——, and calling it will return an iterator.

This function can iterbe called directly by the built-in function , ie  DataIteror = iter(DataLoader).

dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)

__init__The parameter contains two parts, the first half is used to specify 数据集 + 采样器, the second half is 多线程参数.

class DataLoader(object):
    """
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.
    """
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        if timeout < 0:
            raise ValueError('timeout option should be non-negative')
        # 检测是否存在参数冲突: 默认batchSampler vs 自定义BatchSampler
        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')
        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')
        if self.num_workers < 0:
            raise ValueError('num_workers cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')
        # 在此处会强行指定一个 BatchSampler
        if batch_sampler is None:
            # 在此处会强行指定一个 Sampler
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        # 使用自定义的采样器和批采样器
        self.sampler = sampler
        self.batch_sampler = batch_sampler
    def __iter__(self):
        # 调用Pytorch的多线程迭代器加载数据
        return DataLoaderIter(self)
    def __len__(self):
        return len(self.batch_sampler)

 

5 data iterator DataLoaderIter

There is a difference between iterators and iterable objects.

An iterable object means that when a Iterfunction is used on it, it can return an iterator so that it can be accessed continuously iteratively.

The iterator object has an additional magic function inside __next__. If you use the built-in function to nextact on it, the next data can be generated continuously. The generation rule is determined by this function.

The iterable object describes that the object is iterable, but the specific iterative rules are described by the iterator. The advantage of this decoupling is that the same iterable object can be configured with multiple iterators with different rules.

 

6 Generalized process of data set/container traversal: NILIS

NILIS规则: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))

  1. The sampler  defines the generation rules of the index, returns an index list, and controls the subsequent index access process.
  2. indexer  based __item__on the rules defined by index container and the container to be indexable objects available [] operation.
  3. Loader is  based __iter__on defining iterability on the container, describing the loading rules, including returning an iterator to make the container an iterable object , which can be operated by iter().
  4. Next is  based __next__on defining an iterator on the container, describing specific iteration rules, and making the container an iterator object , which can be operated by next().
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler)            # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable()        # __iter__()
dataIterator = DataLoaderIter(dataLoader)     #__next__()
data_iter = iter(dataLoader)
## 遍历方法1
for _ in range(len(data_iter))
    data = next(data_iter)
## 遍历方法2
for i, data in enumerate(dataLoader):
    data = data

 

Guess you like

Origin blog.csdn.net/u013066730/article/details/114288514