Pytorch Dataloader module source code analysis (2): Sampler / Fetcher components and Dataloader core code

Dataloader component

Sampler class

Before looking at the specific implementation of Sampler, let's see when Dataloader generates Sampler objects:

class DataLoader(object):
    def __init__(self, ...):
        ...
        if sampler is None:  
            ...
             # 如果指定shuffle就使用随机采样,否则使用顺序采样
                if shuffle: 
                    sampler = RandomSampler(dataset, generator=generator)
                else:
                    sampler = SequentialSampler(dataset)

        if batch_size is not None and batch_sampler is None:
            # 如果指定了batch_size又没有指定自定义的batch_sampler,就开启自动批采样
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        ...

We can see that the main responsibility of the Sampler object is to generate an index for accessing the Dataset. The subclasses of Sampler are as follows:

  • SequentialSampler sequential sampling
  • RandomSampler random sampling
  • BatchSampler batch sampling

In fact, there are other sampling methods, but because they are not used much, this article mainly explains the above three Samplers. The sampling classes mentioned above are all subclasses of Sampler. The __iter__ method in Sampler is defined as raising NotImplementedError:

class Sampler(Generic[T_co]):
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

SequentialSampler

SequentialSampler implements:

class SequentialSampler(Sampler[int]):
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
    	# 创建一个迭代器
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

Here we mainly focus on the __Iter__ method. In fact, the returned index is the result of the sequential increment of range(len(self.data_source)): len(data_source) is actually the length of the samples returned by Dataset. After the iterator is created, when the __next__ method is called on the iterator, it will return the index of 0, 1, 2, 3, 4, ... increasing in order.

RandomSampler

RandomSampler implementation:

class RandomSampler(Sampler[int]):
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator    
		...
    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator
		# replacement 表示是否可以生成重复 index
        if self.replacement:
        	# num_samples 表示一次性采样的数据量
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            for _ in range(self.num_samples // n):
                yield from torch.randperm(n, generator=generator).tolist()
            yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

    def __len__(self) -> int:
        return self.num_samples

The implementation of RandomSampler is to call the randperm function to generate a random sequence. At the same time, you need to pay attention to the replacement parameter. Replacement indicates whether the index that has been generated before can be generated. True indicates that it is possible, otherwise it is not.

BatchSampler

BatchSampler is different from the previous SequentialSampler and RandomSampler that generate an index each time. BatchSampler generates a list containing batch_size indexes at a time. At the same time, BatchSampler calls SequentialSampler and RandomSampler internally, that is, the index in the generated list may be sequential or random. The following is the BatchSampler implementation:

class BatchSampler(Sampler[List[int]]):
    
    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
        ...
        # 需要指定内部调用 SequentialSampler 还是 RandomSampler
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[List[int]]:
        
        if self.drop_last:
            sampler_iter = iter(self.sampler)
            while True:
                try:
                    batch = [next(sampler_iter) for _ in range(self.batch_size)]
                    yield batch
                except StopIteration:
                    break
        else:
            batch = [0] * self.batch_size
            idx_in_batch = 0
            for idx in self.sampler:
                batch[idx_in_batch] = idx
                idx_in_batch += 1
                if idx_in_batch == self.batch_size:
                    yield batch
                    idx_in_batch = 0
                    batch = [0] * self.batch_size
            if idx_in_batch > 0:
                yield batch[:idx_in_batch]

    def __len__(self) -> int:
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

It can be seen from the official comments that if drop_last=True is called, the last few indexes that cannot form a complete batch_size will be discarded, otherwise a list will also be formed separately. From the source code, we can see that BatchSampler is just a simple wrapper for Sampler, and finally generates an index list.

All in all, SequentialSampler and RandomSampler generate an index each time, and BatchSampler generates a batch index each time.

Fetcher class

The function of the Fetcher class is very simple. In fact, it is a layer of encapsulation for the Dataset. The Fetcher object receives an index parameter and returns the obtained tensor. Fetcher is called in DataloaderIter:

data = self._dataset_fetcher.fetch(index)

Fetcher is also divided into _IterableDatasetFetcher and _MapDatasetFetcher, only _MapDatasetFetcher is introduced here:

class _BaseDatasetFetcher(object):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        self.dataset = dataset
        self.auto_collation = auto_collation
        self.collate_fn = collate_fn
        self.drop_last = drop_last

    def fetch(self, possibly_batched_index):
        raise NotImplementedError()

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

The auto_collation is defined in DataLoader:

class DataLoader(object):
    ...
    @property
    def _auto_collation(self):
        return self.batch_sampler is not None
    ...

The collate_fn function called in Fetcher actually simply converts a Numpy array to a Pytorch Tensor. And possibly_batched_index is an index list when batch_size is set, and a list in other cases. Therefore, the function implemented by Fetcher is actually:

  1. Pull the data in Dataset by index
  2. Perform the collate_fn function on the data to get the corresponding Pytorch Tensor

Dataloader core code

After introducing the various Dataloader components above, we can finally introduce the source code of Dataloader. The function of Dataloader to realize iteration mainly relies on _SingleProcessDataLoaderIter and _MultiProcessingDataLoaderIter classes. Let's review the method of using Dataloader in the training script:

train_loader = torch.utils.data.Dataloader(dataset, ...)
# 写法1
for input, target in train_loader:
    # 前向计算
    output = model(input)
    # 计算损失
    loss = loss_fn(output, target)
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    # 梯度更新
    optimizer.step()

# 写法2
for i, data in enumerate(train_loader):
    input, target = data
    # 前向计算
    output = model(input)
    # 计算损失
    loss = loss_fn(output, target)
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    # 梯度更新
    optimizer.step()

When a for loop or enumerate is called on a Dataloader object, the Dataloader object returns an iterator defined in __iter__. When using the iterator, the Fetcher class object will be used according to the index generated by the Sampler. The Fetcher object uses the Dataset index interface to access data, that is, the sample and target are given to the Dataloader object.

First look at the __iter__ method in Dataloader:

class DataLoader(Generic[T_co]):
	...
	def _get_iterator(self) -> '_BaseDataLoaderIter':
        # 单线程场景
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        # 多线程场景
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)
            
	def __iter__(self) -> '_BaseDataLoaderIter':
		# 调用 _get_iterator 方法
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

So what we need to pay attention to is the _SingleProcessDataLoaderIter and _MultiProcessingDataLoaderIter classes. First, analyze the simpler _SingleProcessDataLoaderIter class.

single thread scenario

In this scenario, num_workers is set to 0:

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0
		
		# 调用 create_fetcher 创建 Fetcher 对象,create_fetcher 代码见下面
        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
	
	# _SingleProcessDataLoaderIter 继承自 _BaseDataLoaderIter,_
	# BaseDataLoaderIter 中的__next__方法调用子类的 _next_data 方法
    def _next_data(self):
        index = self._next_index()
        data = self._dataset_fetcher.fetch(index)
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
        return data

create_fetcher method:

class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

Since the data set we use is map-style, a _MapDatasetFetcher object will be created here. The definition of this class has been introduced when we introduced the Fetcher class above, so you can look back. In addition, the __iter__ and __next__ methods are not defined in _SingleProcessDataLoaderIter because these two methods are defined in the base class, and _BaseDataLoaderIter is defined as follows:

class _BaseDataLoaderIter(object):

    def __init__(self, loader: DataLoader) -> None:
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._auto_collation = loader._auto_collation
        self._num_workers = loader.num_workers
        self._sampler_iter = iter(self._index_sampler)
        self._collate_fn = loader.collate_fndtype=torch.int64).random_(generator=loader.generator).item()
        ...
    # 返回自身,自身是一个迭代器
    def __iter__(self) -> '_BaseDataLoaderIter':
        return self
	
	# _next_index 调用 Sampler 对象生成
    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

	# _next_data 方法定义在子类中
    def _next_data(self):
        raise NotImplementedError
        
	# 调用_next_data 方法生成数据
    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            data = self._next_data()
            self._num_yielded += 1
            return data

    def __len__(self) -> int:
        return len(self._index_sampler)

Therefore, the iterator of _SingleProcessDataLoaderIter can return a batch of data or a data when the __iter__ and __next__ methods are called. Note that in this scenario (single thread), all steps are serialized.

multi-threaded scene

When num_workers is set to a value > 0, DataLoaderIter returns a _MultiProcessDataLoaderIter object. Let's first look at the working principle of _MultiProcessDataLoaderIter from a macro perspective. Here is a picture I saw on the Internet, which vividly expresses the data flow direction:
insert image description here
in fact, the work of the iterator in the multi-threaded scenario is the same as that of the iterator in the single-threaded scenario, but it will be performed using separate threads at Fetcher and pin_memory. This is because the operation at Fetcher involves I/O, which can easily become a bottleneck during the training process. Therefore, multiple threads are set for I/O to overlap with the calculation of network training. For I/O issues, please refer to my previous article: Thread pool size selection: for I/O-intensive scenarios and CPU-intensive scenarios . And pin_memory involves memory copying, which is a part that can easily make the CPU a bottleneck, so a thread is also set up to copy Tensor.

The _MultiProcessDataLoaderIter object mainly consists of multiple threads and multiple queues. These queues act as buffers to store the data of the producer process/thread and provide it to the consumer process/thread for consumption.

  • When the main thread calls the __next__ method on the Dataloader, it takes one/a group (batch) of data from the data_queue, and then obtains one/a group of (batch) indexes through the sampler and puts them into the index_queue provided for consumption by the sub-threads.
  • The I/O sub-thread is responsible for taking a subscript from the index_queue each time, loading it from the disk into the memory, and then performing the transform operation specified by the user (preprocessing, or not specified), and then sending the data corresponding to the subscript to the worker_result_queue.
  • The pin_memory sub-thread is responsible for taking the data out of the worker_result_queue and converting it from pageable_tensor to pinned_tensor, that is, copying the data to the page-locked memory, so that during training, the GPU can copy the Tensor in the memory to the video memory of the GPU (the GPU can only access the page-locked memory).

Multiple queues as buffers for multiple producer and consumer threads:

  • index_queue: store (send_idx, index), where index is the subscript of Fetcher to get data from dataset, send_idx is used to ensure the order of data, which will be mentioned later.
  • worker_result_queue:存放(send_idx, pageable_tensor)
  • data_queue: storage (send_idx, pinned_tensor)

First look at the main functions in _MultiProcessDataLoaderIter:

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        # 调用时机:用户初始化 DataLoader 对象时,若num_worker > 0,便会构造_MultiProcessDataLoaderIter 对象,进入该__init__方法。
        # 职责:从 DataLoader 对象中获得用户参数,初始化 numworker 个子进程、pin_memory线程以及多个队列queue,
        #      并下发 2*num_worker 数量的任务(即index)。

    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # 调用时机:由 _get_data 方法调用。
        # 职责:从 data_queue 中取数据,并对各种异常进行处理。

    def _get_data(self):
        # 调用时机:由 _next_data 方法调用。
        # 职责:调用 _try_get_data 方法获取数据,并检查数据是否获取成功。
        
    def _next_data(self):
        # 调用时机:用户每次对 DataLoader 对象进行for循环迭代时,都会进入该方法。
        # 职责:作为迭代器的入口,该方法负责返回用户需要的数据,每次的工作流程如下:
        #         1、检查本次需要获取的数据是否已在缓存中,若在则直接从缓存取。
        #         2、若不在缓存中,则调用 _get_data 获取数据。
        #         3、若该数据不是本次应该等待的数据(即该数据的 idx 不等于 recvd_idx),则存到缓存中,返回第一步,否则进入下一步。
        #         4、获取数据后,调用 _process_data 做进一步处理并返回数据。

    def _try_put_index(self):
        # 调用时机:由 _process_data 方法调用。
        # 职责:1、从 sampler 对象中获得 index (调用父类的 _next_index 方法)
        #      2、将 (send_idx, index) 送入对应的 index_queue 中
        #      3、send_idx 加一

    def _process_data(self, data):
        # 调用时机:由 _next_data 方法调用。
        # 职责:先对 rcvd_idx 加一,再调用_try_put_index 方法,然后返回之前从 _get_data 中获取的数据。

Next, let’s take a closer look at the specific implementation in the source code. A lot of code that ensures robustness is still omitted. We mainly look at the core part.

__init__ method

def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)
        ...
        # 创建多进程/线程间用于维护数据顺序的数据结构
        self._send_idx = 0  # idx of the next task to be sent to workers
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
        self._task_info = {
    
    }

        # 根据用户参数将 num_worker 个子进程和 pin_memory 线程创建并初始化
        self._index_queues = []
        self._workers = []
        for i in range(self._num_workers):
            index_queue = multiprocessing_context.Queue()
            # index_queue.cancel_join_thread()
            w = multiprocessing_context.Process(... ...)
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)

        if self._pin_memory:
            self._data_queue = queue.Queue()
            pin_memory_thread = threading.Thread(... ...)
            pin_memory_thread.daemon = True
            pin_memory_thread.start()
        else:
            self._data_queue = self._worker_result_queue

        # 3、发送 2*num_worker 个 index,让多进程/线程工作起来
        for _ in range(2 * self._num_workers):
            self._try_put_index()

Among them, send_idx, rcvd_idx and task_info are used. These three members are actually used to ensure the order of the data. In a single-threaded scenario, since each step is executed serially, the iterator obtains the index through the sampler object, and fetches data from the disk, so the order of the data fetched each time is consistent with the order in which the index is generated. But in a multi-threaded scenario, this sequential consistency is difficult to guarantee. Because the index may be put into any one of multiple index_queues, and the child thread may be interrupted due to the time slice running out during the process of using the CPU (transform), therefore, we cannot guarantee that the data sequence in the worker_result_queue is consistent with the sequence when the index is generated. Then, when DataLoader fetches data, the order may be out of order, and we need to solve this problem.

Hence the introduction of send_idx, rcvd_idx and task_info. send_idx identifies the serial number of the index generated in the main thread, rcvd_idx identifies the serial number of the data successfully obtained by the main thread, and task_info is a buffer for saving the out-of-order data obtained by the main thread from the data_queue. So __init__ mainly initializes these members, and the specific usage logic will be analyzed in the following functions.

_next_data method

The _next_data method is what the iterator actually calls when it calls __next__:

def _next_data(self):
        while True:
            ...
            # 检查本次要拿的数据是否已经在缓存中
            if len(self._task_info[self._rcvd_idx]) == 2:
                data = self._task_info.pop(self._rcvd_idx)[1]
                return self._process_data(data)

            # 数据不在缓存中,调用_get_data从queue中拿数据
            idx, data = self._get_data()

            # 检查刚拿的数据是否顺序一致
            if idx != self._rcvd_idx:
                # 不一致则放到缓存中
                self._task_info[idx] += (data,)
            else:
                del self._task_info[idx]
                # 一致则交给_process_data处理
                return self._process_data(data)           

The above code shows the logic of ensuring the consistency of subscript order, where task_info is implemented as a hash table, which stores the mapping between send_idx and data (Tensor). So when the iterator needs to return a data, first check whether there is this data in task_info, and then fetch it from data_queue (call the _get_data method). After taking it out, if send_idx is not equal to rcvd_idx, it is still necessary to add this data to the task_info cache, otherwise, proceed to the next step (_process_data).

_get_data method

def _get_data(self):
        if self._timeout > 0:
            success, data = self._try_get_data(self._timeout)
            if success:
                return data
            else:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
        elif self._pin_memory:
            while self._pin_memory_thread.is_alive():
                success, data = self._try_get_data()
                if success:
                    return data
            else:
                raise RuntimeError('Pin memory thread exited unexpectedly')
        else:
            while True:
                success, data = self._try_get_data()
                if success:
                    return data

The _get_data method mainly calls the _try_get_data function to obtain data from the data_queue, and the others are mainly codes used to ensure robustness.

_try_get_data method

def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        try:
            data = self._data_queue.get(timeout=timeout)
            return (True, data)
        except Exception as e:
            ...
            if isinstance(e, queue.Empty):
                return (False, None)

The logic of this part is very simple, if there is data in the queue, return (True, data), otherwise return (False, None).

_process_data method

def _process_data(self, data):
        self._rcvd_idx += 1
        self._try_put_index()
        if isinstance(data, ExceptionWrapper):
            data.reraise()
        return data

This part simply calls _try_put_index and returns the data.

_try_put_index method

def _try_put_index(self):
        try:
        # 调用sampler获取index
            index = self._next_index()
        ...
        # 将 index 放入空闲活跃的线程对应的队列中
        for _ in range(self._num_workers):  
            worker_queue_idx = next(self._worker_queue_idx_cycle)
            if self._workers_status[worker_queue_idx]:
                break

        # 将获得和 index 和 send_idx 打包送到对应的_index_queue 中
        self._index_queues[worker_queue_idx].put((self._send_idx, index))
        
        # 更新用于保证数据顺序一致性的成员
        # 因为已经生产并发送了 index,就把 send_idx 加 1 并在 task_info 中创建对象,此时len == 1
        self._task_info[self._send_idx] = (worker_queue_idx,)
        self._send_idx += 1

The logic involved is like a comment. The _try_put_index method is called by _process_data. This part is responsible for sending the index into the queue, and in order to ensure sequential consistency, task_info and send_idx will be updated.

Summarize

I finally finished writing the DataLoader part. I summarized the source code of this part mainly because the company recently used related businesses and needed to modify Dataset and DataLoader, so I carefully read the source code of this part.

All in all, the Dataset and DataLoader modules provide a common data loading and preprocessing interface for the entire Pytorch, and the overall code is highly robust. If there is any room for improvement in this module, it mainly lies in the I/O part. When Dataset implements the shuffle operation, it loads data using random I/O, which will greatly reduce I/O.

Guess you like

Origin blog.csdn.net/weixin_41670608/article/details/126588396