自己造轮子:深度学习dataloader自己实现

自己造轮子:深度学习dataloader自己实现

**摘要:**因为计算机性能的限制,所有的深度学习框架都是采用批量随机梯度下降,所以每次计算都要读取batch_size的数据。这里以自己实现的方式介绍深度学习框架实现批量读取数据的原理,不涉及具体细节和一些逻辑,只注重大体流程和原理。

总体流程:

  • 采用yield写一个生成器函数实现批量图片/标注信息的读取
  • 采用multiprocessing/threading加速文件读取
  • 时间对比

深度学习大体流程

for i in range(epoch):
    data, lable = dataloader.next(batch_size=16)         # 读取batch_size的数据
    output = model(data)            # 前向传播
    loss = crition(output, label)   # 求损失函数
    loss.backward()                 # 反向传播

在dataloader的时候,一般会采用多个进程(num_workers
)加快文件I/O的速度,避免网络反向传播过了,还没有数据。

1. 用yield写一个生成器函数

# coding:utf-8
# 自己造轮子,实现深度学习批量数据的读取
import os
import glob
import numpy as np 
import cv2  


def get_images(path):
    files = []
    for ext in ['jpg', 'png', 'jpeg', 'JPG']:
        files.extend(glob.glob(
            os.path.join(path, '*.{}'.format(ext))))
    return files


def dataset(batch_size=2, path='/media/chenjun/data/1_deeplearning/7_ammeter_data/test'):
    """
        写一个读取图片的生成器
        batch_size:批量大小
        path:图片路径
    """
    # 1. 读取所有图片名字
    image_list = get_images(path)
    index = np.arange(0, len(image_list))
    while True:
        np.random.shuffle(index)
        images = []
        image_names = []
        for i in index:
            try:
                im_name = image_list[i]
                im = cv2.imread(im_name)    # 读取图片
                # 读取相应图片的标注信息
                # text_polys = fun1()
                images.append(im[:,:, ::-1].astype(np.float32))     # cv2读取图片的顺序为BGR,转换成RGB格式
                image_names.append(im_name)

                if len(images) == batch_size:
                    yield images, image_names        # 采用函数生成器,生成一个可迭代对象
                    images = []
                    image_names = []
            
            except Exception as e:
                import traceback
                traceback.print_exc()
                continue                # 所有图片已经读完一遍,跳出for循环,再打乱图片的顺序进行第二次读取

2. 使用muitlprocessing加速文件读取速度

<!-- 采用正常模式进行图片读取,读取100个batch -->
import time
mydataset = dataset()
start = time.time()
for _ in range(100):
    im, im_name = next(mydataset)
#     print(im_name)
print('use time:{}'.format(time.time() - start))
>>>  use time:0.16786599159240723


<!-- 采用muitlprocessing模式进行图片读取,读取100个batch -->
import multiprocessing
def data_generator(data, q):
    for _ in range(100):                # 循环多少次
        generator_output = next(data)
        q.put(generator_output)

q = multiprocessing.Queue()
start2 = time.time()
thread = multiprocessing.Process(target=data_generator, args=(dataset(), q))
thread.start()              # 多进程开始读取图片
print('mulprocess time is:{}'.format(time.time() - start2))
>>>  mulprocess time is:0.002292633056640625

可以看到读取100个batch,时间提高了80倍。
同时,一般的深度学习框架都会使用几个多进程处理上面的功能。
eg:

for _ in range(workers):
                if self._use_multiprocessing:
                    # Reset random seed else all children processes
                    # share the same seed
                    np.random.seed(self.random_seed)
                    thread = multiprocessing.Process(target=data_generator_task)
                    

网上的资料显示threading的效率没有muitlprocessing高,这里就不测试了。

reference

[1] 莫烦python
[2] argman/EAST

猜你喜欢

转载自blog.csdn.net/u011622208/article/details/84717983
今日推荐