Tensorflow 数据预读取--Queue

Tensorflow 数据预读取–Queue

本文大部分转载于:https://blog.csdn.net/wangjian1204/article/details/54728603

Google开源的深度学习框架Tensorflow在数据预取上做了一些特殊的特征来提高模型训练或者推理的效率,避免在IO上耗费过多的时间。本文通过几个简单例子介绍Tensorflow构建queue常用函数的使用方法。

深度学习训练模型通常是建立在大数据基础上,一般情况下可以把数据都加载到内存避免训练时数据读取IO。但是,当数据占用空间较大,如图片集或者视频集,无法全部载入内存;另一种方式是在训练时再读取需要的数据,但是增加的IO耗时会让模型训练过程很漫长很漫长。

Tensorflow提供了Queue这个工具来更好的解决这类问题。Queue构建了一个大小为capacity的缓存区,多线程执行数据的enqueue,神经网络模型从缓存区dequeue数据。如果capacity足够大,数据的加载和读取可以同时执行,没有阻塞,从而IO的时间几乎可以忽略不计。

slice_input_producer

过程描述:图片数据保存在本地,内存中保存所有图片的系统路径,现在构建Queue,从磁盘上读取并缓存数据。整个过程类似于:

这里写图片描述

def slice_input_producer_demo(image_pair_path, summary_path):
    # 重置graph
    tf.reset_default_graph() 
    # 获取<图片一系统路径,图片二系统路径,标签信息>三个list(load_data函数见supplementary)
    image_one_path_list, image_two_path_list, label_list = load_data()
    ## 构造数据queue
    train_input_queue = tf.train.slice_input_producer([image_one_path_list, 
image_two_path_list, label_list], capacity=10 * batch_size)

    ## queue输出数据
    img_one_queue = get_image(train_input_queue[0])
    img_two_queue = get_image(train_input_queue[1])
    label_queue = train_input_queue[2]

    ## shuffle_batch批量从queu批量读取数据
    batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue],batch_size=batch_size,capacity =  10 + 10* batch_size,min_after_dequeue = 10,num_threads=16,shapes=[(image_width, image_height, image_channel),(image_width, image_height, image_channel),()])

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

    ## 启动queue线程
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)  

    for i in range(10):
        batch_img_one_val, batch_img_two_val, label = sess.run([batch_img_one, batch_img_two,batch_label])
        for k in range(batch_size):
            fig = plt.figure()
            fig.add_subplot(1,2,1)
            plt.imshow(batch_img_one_val[k])
            fig.add_subplot(1,2,2)
            plt.imshow(batch_img_two_val[k])
            plt.show()


    coord.request_stop()  
    coord.join(threads)  
    sess.close()
    summary_writer.close()

整个过程很清晰,主要由以下几步组成:
1、图片的路径和标记信息载入内存:image_one_path_list, image_two_path_list, label_list = load_data()
2、构造第一个queue:train_input_queue = tf.train.slice_input_producer( [image_one_path_list, image_two_path_list, label_list], capacity=10 * batch_size)
3、从queue取出图片路径数据加载图片:img_one_queue = get_image(train_input_queue[0])
4、构造第二个queue:shuffle_queue,把图片数据enqueue到缓存区,批量dequeue输出结果。batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue]…)

string_input_producer

string_input_producer从一个pipeline把字符串输出到一个queue。

def string_input_producer_demo(image_pair_path, summary_path):
    tf.reset_default_graph()

    image_one_path_list, image_two_path_list, label_list = load_data()
    ## 构造数据queue
    train_input_queue = tf.train.string_input_producer(image_one_path_list, capacity=10 * batch_size)

    ## queue输出数据
    img_one_queue = get_image(train_input_queue.dequeue())

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

    ## queue线程
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)  

    for i in range(10):
        img_one_val = sess.run([img_one_queue])
        fig = plt.figure()
        plt.imshow(img_one_val[0])
        plt.show()


    coord.request_stop()  
    coord.join(threads)  
    sess.close()
    summary_writer.close()

range_input_producer:生成0到limit-1的queue

def range_input_producer_demo(image_pair_path, summary_path):
    tf.reset_default_graph()

    image_one_path_list, image_two_path_list, label_list = load_data()
    length_data = len(image_one_path_list)

    image_one_path_list = tf.convert_to_tensor(image_one_path_list)
    image_two_path_list = tf.convert_to_tensor(image_two_path_list)
    label_list = tf.convert_to_tensor(label_list)

    ## 构造数据queue
    train_input_queue = tf.train.range_input_producer(length_data, capacity=10 * batch_size)

    ## queue输出数据
    range_index = train_input_queue.dequeue()
    img_one_queue = get_image(tf.gather(image_one_path_list, range_index))
    img_two_queue = get_image(tf.gather(image_two_path_list, range_index))
    label_queue = range_index 

    ## 批量从queu读取数据
    batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue],batch_size=batch_size,capacity =  10 + 10* batch_size,min_after_dequeue = 10,num_threads=16,shapes=[(image_width, image_height, image_channel),(image_width, image_height, image_channel),()])

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

    ## queue线程
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)  

    for i in range(10):
        batch_img_one_val, batch_img_two_val, label = sess.run([batch_img_one, batch_img_two,batch_label])
        for k in range(batch_size):
            fig = plt.figure()
            fig.add_subplot(1,2,1)
            plt.imshow(batch_img_one_val[k])
            fig.add_subplot(1,2,2)
            plt.imshow(batch_img_two_val[k])
            plt.show()


    coord.request_stop()  
    coord.join(threads)  
    sess.close()
    summary_writer.close()

input_producer:input_tensor里的行构成queue

def input_producer_demo(image_pair_path, summary_path):
    tf.reset_default_graph()

    image_one_path_list, image_two_path_list, label_list = load_data()
    length_data = len(image_one_path_list)

    image_one_path_list = tf.convert_to_tensor(image_one_path_list)

    ## 构造数据queue
    train_input_queue = tf.train.input_producer(image_one_path_list, capacity=10 * batch_size)

    ## Expected string, got <tensorflow.python.ops.data_flow_ops.FIFOQueue object of type 'FIFOQueue' instead.
    img_one_queue = get_image(train_input_queue.dequeue())

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

    ## queue线程
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)  

    for i in range(10):
        batch_img_one_val = sess.run([img_one_queue])

#         for k in range(batch_size):
        print(len(batch_img_one_val))
        fig = plt.figure()
        plt.imshow(batch_img_one_val[0])
        plt.show()


    coord.request_stop()  
    coord.join(threads)  
    sess.close()
    summary_writer.close()

supplementary

数据格式:
    /home/Alex/4000.jpg /home/Alex/4001.jpg 0 
    /home/Alex/4000.jpg /home/Alex/4002.jpg 1
# 获取《图片一本地路径,图片二本地路径,标记》数据对
def load_data():
    reader_handler = open(image_pair_path, 'r')

    image_one_path_list = []
    image_two_path_list = []
    label_list = []

    count = 0
    for line in reader_handler:
        count = count + 1
        elems = line.split("\t")
        if len(elems) < 3:
            print("len(elems) < 3:" + line)
            continue
        image_one_path = elems[0].strip()
        image_two_path = elems[1].strip()
        label = int(elems[2].strip())

        image_one_path_list.append(image_one_path)
        image_two_path_list.append(image_two_path)
        label_list.append(label)

    return image_one_path_list, image_two_path_list, label_list


# 根据图片路径读取图片
def get_image(image_path):  
    """Reads the jpg image from image_path. 
    Returns the image as a tf.float32 tensor 
    Args: 
        image_path: tf.string tensor 
    Reuturn: 
        the decoded jpeg image casted to float32 
    """  
    content = tf.read_file(image_path)
    tf_image = tf.image.decode_jpeg(content, channels=3)

    return tf_image

猜你喜欢

转载自blog.csdn.net/chaowang1994/article/details/80281410