深度之眼Pytorch打卡(七):Pytorch数据读取机制,DataLoader()和Dataset

前言


   无论是模型训练还是实际测试,数据读取都是第一步,因为深度学习说到底是由数据驱动的。如果有能够准确的读入数据的程序,后面再结合预处理、Pytorch中预训练的模型和优化器等就可以构建一个简易的用于预测的模型了。Pytorch的数据读取的核心是DataLoader方法和Dataset类。本笔记的框架主要来源于深度之眼,并作了一些相关的拓展,拓展内容主要源自对torch文档的翻译理解。

   数据切分:深度之眼Pytorch打卡(六):将数据集切分成训练集、验证集和测试集的方法


Dataset类


   Dataset是一个代表数据的抽象类,定义数据从哪里读取以及如何读取。数据从哪里读取?当然是从硬盘中读取,通过给Dataset传入一个路径参数来实现的。数据如何读取?读取方式需要我们自定义,不同的数据集划分方式我们有不同的读取方法。

class DataSet(Dataset):

   pass

   我们需要自定义的代表数据集的类都需要继承Dataset这个抽象类。然后实例化我们创建的Dataset子类就可以用来代表训练集,验证集和测试集数据。每个Dataset的子类,都必须要复写__getitem__()方法,常常还选着性的复写__len__()方法和__init__()方法。

   基本框架如代码所示:

class Dataset(Dataset):

	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...
  • 方法实现与实例化

  __init__(self)用于添加类自身的一些属性,如标签、数据信息和是否数据增强等。__len__(self)用于返回数据集的大小。__getitem__(self, index)用于接收一个索引index,并返数据集中对应的数据与标签,是读取数据的核心,index由DataLoader()中的sampler类产生。

  上一篇笔记中第一种数据集划分方式数据读取实现:

import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

label_name = {
    
    'ants': 0, 'bees': 1}


class DataSet(Dataset):
    def __init__(self, data_path):  # 除了这两个参数之外,还可以设置其它参数
        self.label_name = {
    
    'ants': 0, 'bees': 1}
        self.data_info = get_info(data_path)

    def __getitem__(self, index):
        label, img_path = self.data_info[index]
        pil_img = Image.open(img_path).convert('RGB')  # 读数据
        re_img = transforms.Resize((32, 32))(pil_img)
        img = transforms.ToTensor()(re_img)  # PIL转张量
        return img, label

    def __len__(self):
        return len(self.data_info)


def get_info(data_path):
    data_info = list()
    for root_dir, sub_dirs, _ in os.walk(data_path):
        for sub_dir in sub_dirs:
            file_names = os.listdir(os.path.join(root_dir, sub_dir))
            img_names = list(filter(lambda x: x.endswith('.jpg'), file_names))
            for i in range(len(img_names)):
                img_path = os.path.join(root_dir, sub_dir, img_names[i])
                img_label = label_name[sub_dir]
                data_info.append((img_label, img_path))

    return data_info


if __name__ == '__main__':

    train_set_path = os.path.join('data', 'train_set')
    val_set_path = os.path.join('data', 'val_set')
    test_set_path = os.path.join('data', 'test_set')
    train_set = DataSet(data_path=train_set_path)
    val_set = DataSet(data_path=val_set_path)
    test_set = DataSet(data_path=test_set_path)

图1.第一种方式调试结果

  断点调试结果:
在这里插入图片描述
  上一篇笔记中第二种数据集划分方式数据读取实现:主要是get_info()函数的区别。

import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

label_name = {
    
    'ants': 0, 'bees': 1}


class DataSet(Dataset):
    def __init__(self, data_path):  # 除了这两个参数之外,还可以设置其它参数
        self.label_name = {
    
    'ants': 0, 'bees': 1}
        self.data_info = get_info_list(data_path)

    def __getitem__(self, index):
        label, img_path = self.data_info[index]
        pil_img = Image.open(img_path).convert('RGB')  # 读数据
        re_img = transforms.Resize((32, 32))(pil_img)
        img = transforms.ToTensor()(re_img)  # PIL转张量
        return img, label

    def __len__(self):
        return len(self.data_info)

def get_info_list(list_path):
    data_info = list()
    with open(list_path, mode='r') as f:
        lines = f.readlines()
        for i in range(len(lines)):
            img_label = int(lines[i].split(' ')[0])
            img_path = lines[i].split(' ')[1]
            data_info.append((img_label, img_path))
    return data_info


if __name__ == '__main__':

    train_list_path = os.path.join('old_data', 'train_set.txt')
    val_list_path = os.path.join('old_data', 'val_set.txt')
    test_list_path = os.path.join('old_data', 'test_set.txt')
    train_set = DataSet(data_path=train_list_path)
    val_set = DataSet(data_path=val_list_path)
    test_set = DataSet(data_path=test_list_path)

  断点调试结果:
在这里插入图片描述

图2.第二种方式调试结果

DataLoader方法


   DataLoader()方法,在给定数据集上提供可迭代的数据加载,即模型每进行一次迭代,就从DataLoader()中获取一个batch_size的数据。其函数形式与参数如下述代码所示。

torch.utils.data.DataLoader(
                                  dataset,
                                  batch_size=1, 
                                  shuffle=False, 
                                  sampler=None, 
                                  batch_sampler=None, 
                                  num_workers=0, 
                                  collate_fn=None, 
                                  pin_memory=False, 
                                  drop_last=False, 
                                  timeout=0, 
                                  worker_init_fn=None, 
                                  multiprocessing_context=None)

   dataset: 继承于上述抽象类Dataset的子类的实例,比如训练数据类,验证集类等。
   batch_size: 批大小,即进行一次迭代的数据大小,默认为1
   shuffle: 设置每个epoch中,样本的顺序是否乱序,默认为False。所有训练样本输入到模型中一次,称为一个epoch。根据训练数据大小和batch_size大小,就可以计算出一个epoch需要进行多少次迭代。
   sampler: 定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序。上述在Dataset实例中,复写的__getitem__(self, index)中的index就是由这个sampler类产生的。详细解释见这篇博文
   batch_sampler: 一次返回一个batch数据的index,即将sampler生成的indices打包分组,得到一个又一个batch的index。
  num_workers: 读取数据是否采用多进程,默认0,即在主进程中读取数据。
  collate_fn: 将一个batch的数据和标签进行合并操作。
  drop_last: 设置为True时,如果数据集大小不能被batch_size整除,那么删除最后一个不完整的batche。设置为False,且数据集的大小不能被batch_size整除,那么最后一个batch将更小一些。

  • DataLoader使用

  将如下代码加在Dataset实例化后即可实现DataLoader的简单使用。读取训练集,批大小为10,随机读取。

    train_loader = DataLoader(dataset=train_set, batch_size=2, shuffle=True)

    for i, data in enumerate(train_loader):
        inputs, labels = data

  然后在此处设断点,debug,单步调试,观察DataLoader的运行情况。
在这里插入图片描述

图3.断点位置

  首先执行的是DataLoader的初始化工作,比如依次进入DataLoader(object)类的__iter__(self)方法,选择单进程还是多进程的DataLoader迭代器,然后默认进入单进程的DataLoader迭代器类_SingleProcessDataLoaderIter(_BaseDataLoaderIter)初始化方法__init__(self, loader),然后进入_BaseDataLoaderIter(object)类的初始化方法 init(self, loader)等。

class DataLoader(object):
# ......
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
class _BaseDataLoaderIter(object):
    def __init__(self, loader):
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
        self._num_yielded = 0

  DataLoader的初始化完成之后,便开始读数据。首先进入的是_BaseDataLoaderIter(object)def __next__(self)方法。然后跳转到 _SingleProcessDataLoaderIter类的 _next_data(self)方法,然后进入BatchSampler(Sampler)采样器类的__iter__(self)方法,就是在这里生成一个batch数据的所有索引值,并放到一个列表中,如图4所示。再次进入_SingleProcessDataLoaderIter类的 _next_data(self)方法,把生成的index给_dataset_fetcher.fetch(index) 。最后跳到我们自定义的DataSet(Dataset):类的__getitem__(self, index),一次给它一个索引,循环batch_size次。循环完成之后默认将进入default_collate(batch)方法,整合这个batch的数据。

class BaseDataLoaderIter(object):

    def __next__(self):
        data = self._next_data()
        # ......
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
     # ......
     
    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data
class BatchSampler(Sampler):
     # ......
     
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

在这里插入图片描述

图4.生成一个batch的index
def default_collate(batch):

在这里插入图片描述
  小结:由以上的分析,我们可以知道,其实数据读取完全依赖于我们自定义的DataSet类。DataLoader()方法,则主要是用于迭代的,我们每进行一次迭代,会调用BatchSampler,产生一个batch数据的索引。返回索引,再调用DataSet的__getitem__(self, index)获取索引在数据集中对应的图片和标签,获取了一个batch数据的图片和标签之后,打包返回。这样就实现了一个batch数据的读取操作。


参考


  https://blog.csdn.net/qq_31622015/article/details/90573874
  https://www.cnblogs.com/marsggbo/p/11308889.html
  https://www.cnblogs.com/jiaxin359/p/7324077.html
  https://ww1··w.cnblogs.com/jiaxin359/p/7324077.html

猜你喜欢

转载自blog.csdn.net/sinat_35907936/article/details/105636697