Pytorch进行自定义Dataset 和 Dataloader 原理

目录

1、自定义加载数据

2、重写 Dataset 类

2.1、Pytorch自定义Dataset的步骤:

3、Dataloader

3.1、什么是 pin_memory

3.2、Dataloader 的多进程读数据细节

3.3、Pytorch Dataloader加速


1、自定义加载数据

在pytorch中,数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现相应的方法。

在学习Pytorch的教程时,加载数据许多时候都是直接调用torchvision.datasets里面集成的数据集,直接在线下载,然后使用torch.utils.data.DataLoader进行加载。
那么,我们怎么使用我们自己的数据集,然后用DataLoader进行加载呢?

扫描二维码关注公众号,回复: 15143961 查看本文章

常见的两种形式的导入:

1.1、一种是整个数据集都在一个文件下,内部再另附一个label文件,说明每个文件的状态。这种存放数据的方式可能更时候在非分类问题上得到应用。下面就是我们经常使用的数据存放方式。

1.2、一种则是更适合在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。这样,我们可以从文件夹或文件名得到label。使用torchvision.datasets.imageFolder函数生成数据集。这种方式没有用过,暂时不介绍了

2、重写 Dataset 类

2.1、Pytorch自定义Dataset的步骤:

官方:torch.utils.data.Dataset 是一个抽象类,

def __getitem__(self, index):
	raise NotImplementedError

def __len__(self):
	raise NotImplementedError

用户想要加载自定义的数据只需要继承这个类(torch.util.data.Dataset),并且覆写__len__ 和 __getitem__两个方法, 不覆写这两个方法会直接返回错误因此步骤如下:

  1. 继承torch.util.data.Dataset
  2. __init__:改写__init__函数时,需要添加对父类的初始化,该方法主要就是一些参数初始化工作,定义一些路径或者变量什么的
  3. __getitem__: 该方法是加载数据用的,用于读取每一条数据,他会有一个参数idx,就是对应的索引,可以用来获取一些索引的数据,使dataset [i] 返回数据集中第i个样本。
  4. __len__:实现len(dataset),返回整个数据集的大小

建立的自定义类如下:

# 加载数据集,自己重写DataSet类
class dataset(Dataset):
    # image_dir为数据目录,label_file,为标签文件
    def __init__(self, image_dir, label_file, transform=None):
        super(dataset, self).__init__()    # 添加对父类的初始化
        self.image_dir = image_dir         # 图像文件所在路径
        self.labels = read(label_file)     # 图像对应的标签文件, read label_file之后的结果
        self.transform = transform         # 数据转换操作
        self.images = os.listdir(self.image_dir )#目录里的所有img文件
    
    # 加载每一项数据
    def __getitem__(self, idx):
        image_index = self.images[index]    #根据索引index获取该图片
        img_path = os.path.join(self.image_dir, image_index) #获取索引为index的图片的路径名    
        labels = self.labels[index]   # 对应标签

        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        # 返回一张照片,一个标签
        return image, labels
    
    # 数据集大小
    def __len__(self):
        return (len(self.images))

设置好数据类之后,我们就可以将其用torch.utils.data.DataLoader加载,并访问它。

if __name__=='__main__':
    data = AnimalData(img_dir_path, label_file, transform=None)#初始化类,设置数据集所在路径以及变换
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
    for i_batch,batch_data in enumerate(dataloader):
        print(i_batch)#打印batch编号
        print(batch_data['image'].size())#打印该batch里面图片的大小
        print(batch_data['label'])#打印该batch里面图片的标签

其实Dataset类不局限于这么写,它可以实现多种数据读取方法,只需要把读取数据以及数据处理逻辑写在__getitem__方法中即可,然后将处理好后的数据以及标签返回即可。

3、Dataloader

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

参数解释:

  1. dataset(Dataset): 传入的数据集
  2. batch_size(int, optional): 每个batch有多少个样本
  3. shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
  4. sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
  5. batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
  6. num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
  7. collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
  8. pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
  9. drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
  10. 如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
  11. timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
  12. worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__ 函数获取单个的数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的。

因为dataloader是有batch_size参数的,我们可以通过自定义collate_fn=myfunction来设计数据收集的方式,意思是已经通过上面的Dataset类中的__getitem__函数采样了batch_size数据,以一个包的形式传递给collate_fn所指定的函数。

dataloader 对于数据的读取延迟主要取决于num_workerspin_memory这两个参数。首先,我先介绍一下比较简单的 pin_memory 参数。

3.1、什么是 pin_memory

所谓的 pin_memory 就是锁页内存的意思。

计算机为了运行进程会先将进程和数据读到内存里。一般来说,计算机的内存都是比较小的,很难存的下太多的数据。但是,某个进程在某个时间段所需的进程和数据往往是比较少的,也就是说在某个时间点我们不需要将一个进程所需要的所有资源都放在内存里。我们可以将这些暂时用不到的数据或进程存放在硬盘一个被称为虚拟内存的地方。在进程运行的时候,我们可以不断交换内存和虚拟内存的数据以减少内存所需存储的数据。而且这些交换往往是通过某些规律预测下个时刻进程会用到的数据和代码并提前交换至内存的,这些规律的使用以及预测的准确性将会影响到进程的速度。

所谓的锁页内存就是说,我们不允许系统将某些内存里的数据交换至虚拟内存,毋庸置疑这将会提升进程的运行速度。但是也会是内存的存储占用消耗很多。

pin_memory 为 true 的时候速度的提升会有多大

3.2、Dataloader 的多进程读数据细节

Dataloader 多进程读取数据的参数是通过num_workers指定的,num_workers 为 0 的话就用主进程去读取数据,num_workers 为 N 的话就会多开 N 个进程去读取数据。这里的多进程是通过 python 的 multiprocessing module 实现的(其实 pytorch 在 multiprocessing 又加了一个 wraper 以实现shared memory)。

关于 num_workers的工作原理:

  1. 开启num_workers个子进程(worker)。
  2. 每个worker通过主进程获得自己需要采集的ids。
    ids的顺序由采样器(sampler)或shuffle得到。然后每个worker开始采集一个batch的数据。(因此增大num_workers的数量,内存占用也会增加。因为每个worker都需要缓存一个batch的数据)
  3. 在第一个worker数据采集完成后,会卡在这里,等着主进程把该batch取走,然后采集下一个batch。
  4. 主进程运算完成,从第二个worker里采集第二个batch,以此类推。
  5. 主进程采集完最后一个worker的batch。此时需要回去采集第一个worker产生的第二个batch。如果第一个worker此时没有采集完,主线程会卡在这里等。(这也是为什么在数据加载比较耗时的情况下,每隔num_workers个batch,主进程都会在这里卡一下。)

 所以:

  • 如果内存有限,过大的num_workers会很容易导致内存溢出。
  • 可以通过观察是否每隔num_workers个batch后出现长时间等待来判断是否需要继续增大num_workers。如果没有明显延时,说明读取速度已经饱和,不需要继续增大。反之,可以通过增大num_workers来缓解。
  • 如果性能瓶颈是在io上,那么num_workers超过(cpu核数*2)是有加速作用的。但如果性能瓶颈在cpu计算上,继续增大num_workers反而会降低性能。(因为现在cpu大多数是每个核可以硬件级别支持2个线程。超过后,每个进程都是操作系统调度的,所用时间更长)

Dataloader 读数据的整个流程:

  1. 首先每个 worker 的进程会拥有一个 index_queue,dataloader 初始化的时候,每个 worker 的 index_queue 会放入两个batch 的 index。index 的放入是根据 worker 的 id顺序放入的。
  2. 每个 worker 的进程会不断检查自己的 index_queue 里有没有值,没有的话就继续检查。有的话,就去读一个 batch(这个读的过程是通过调用 dataset 的get_item()实现的,并通过函数将数据合并为一个 batch)。放入所有 worker 共享的 data_queue(如果指定了 pin_memory,这个新加的 batch 是会被放入 pin_memory 的)
  3. Dataloader 会返回一个迭代器,每迭代一次,首先进程会检查这次要 load 的 idx 数据是不是之前已经 load 过了(已经从共享的 data_queue 里取出来了),并事先放在一个字典里存起来了(为什么会 load 过,下面会解释),如果是的话,就直接拿来用。 如果没有 load 过,就从 data_queue 获取下一个 batch 和相应的 idx,但是这里从 data_queue 获得的 batch 可能不是按顺序的,因为有的 worker 可能比较快提前将它的数据读好放到 data_queue 里了。这时候我们将这个提前来的 batch 先保存到 self.reorder_dict 这个字典里面,这就解释了上面为什么会出现 load 过的问题。如果一直等不到我们就会一直将提前来的 batch 放入 self.reorder_dict 暂存,直至我们等到那个按顺序来的 batch。
  4. 在每次迭代成功的时候,dataloader 会放入一个新的 batch_index 到特定 worker 的 index_queue 里面

可以看出,dataloader 只会在每次迭代成功的时候才会放入新的 index 到 index_queue 里面。因为上面写了在初始化 dataloader 的时候,我们一共放了 2 x self.num_workers 个 batch 的 index 到 index_queue。读了一个 batch 才会放新的 batch,所以这所有的 worker 进程最多缓存的 batch 数量就是 2 x self.num_workers 个。

 以上流程的如果想看代码可以参考:​​​​​​Pytorch Dataloader 学习笔记 · 大专栏

3.3、Pytorch Dataloader加速

Pytorch Dataloader加速_cwpeng.cn的博客-CSDN博客_dataloader 加速

猜你喜欢

转载自blog.csdn.net/ytusdc/article/details/128517308
今日推荐