Tensorflow多线程输入数据处理框架

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/fengxianghui01/article/details/82191233

一、多线程

         图像预处理方法可以减少无关因素对图像识别模型效果的影响,但会减慢整个训练过程。为了避免图像预处理成为神经网络模型训练效率的瓶颈,tensorflow提供了一套多线程处理输入数据的框架。

处理流程图:

                                                                                   

         队列不仅是一种数据结构,也提供了多线程机制。tensorflow提供了EnqueueEnqueueManyDequeue三种方式修改队列的状态。Tensorflow提供了FIFOQueueRandomShuffleQueue两种队列。

        队列不仅是一种数据结构,还是异步计算张量取值的一个重要机制,比如多个线程可以同时向一个队列中写元素,或者读取队列中的元素。Tensorflow提供了tf.Coordinatortf.QueueRunner两个类来完成多线程协同功能。

        tf.Coordinator主要协同多个线程一起停止,并提供should_stoprequest_stopjoin三个函数。启动的线程需要一直查询tf.Coordinator提供的should_stop函数,当这个函数返回True时,则当前线程需要退出。每一个启动的线程都可以调用request_stop函数来通知其他线程退出。当某一个线程调用request_stop函数后,should_stop返回值将被设置为True,其他线程可以同时终止。

    tf.QueueRunner主要用于启动多个线程来操作同一个队列。在使用tf.train.QueueRunner时,需要明确调用tf.train.start_queue_runners来启动所有线程。

# 部分代码
with tf.Session() as sess:
    print("start")
    i = 0
    # 开始输入队列监控,启动多线程处理数据
    coord = tf.train.Coordinator() # 
    threads = tf.train.start_queue_runners(coord = coord) # 启动入队线程
    
    try:
        while not coord.should_stop() and i<1:
            
            img, label = sess.run([image_batch, label_batch])# 输入list结构
            
            # just test one batch
            # arange返回一个array对象([ ])
            for j in np.arange(BATCH_SIZE):
                print('label: %d'%label[j])
                plt.imshow(img[j,:,:,:]) 
                plt.show()
            i += 1
    except tf.errors.OutOfRangeError:
        print('done!')
    finally:
        print('finished')
        coord.request_stop() # 通知其它线程关闭
    coord.join(threads) # 其他线程关闭之后,这一函数才能返回

        

二、输入文件队列

        Tensorflow提供创建队列的两种方式:tf.train.string_input_producer()tf.train.slice_input_producer()。两种方式的区别和使用参考:https://blog.csdn.net/qq_30666517/article/details/79715045

扫描二维码关注公众号,回复: 3448019 查看本文章

 三、组合训练数据

        Tensorflow提供了tf.train.batchtf.train.shuffle_batch函数来将单个的样例组织成batch的形式输出。这俩个函数都会生成一个队列,队列的入队操作是生成单个样例的方法, 而每次出队得到的是一个batch的样例。唯一的区别是是否会将数据打乱。

# 示例代码
image_batch, label_batch = tf.train.batch([image, label],
                                          batch_size = batch_size,
                                          num_threads = 64,
                                          capacity = capacity) 

两个函数除了可以将单个训练数据整理成输入batch之外,也提供了并行化处理输入数据的方法

综合实例:

    通过一个函数来进行说明,数据的预处理及线程运用

#%%
import tensorflow as tf
import numpy as np
import os

# 
img_width = 208
img_height = 208


#%% 获取图片 及 生成标签
train_dir = 'G:/tensorflow/cats_vs_dogs/data/train/'

def get_files(file_dir):
    '''
    args:
        file_dir: file directory
    Returns:
        ist of images and labels
    '''
    cats = []
    label_cats = []
    dogs = []
    label_dogs = []
    for file in os.listdir(file_dir): # 获取当前目录下的所有文件和目录名
        name = file.split('.') #分割字符段,返回name为一个列表
        if name[0] == 'cat':
            cats.append(file_dir + file)
            label_cats.append(0)
        else:
            dogs.append(file_dir + file)
            label_dogs.append(1)
    print('There are %d cats \nThere are %d dogs' %(len(cats), len(dogs)))
    
    image_list = np.hstack((cats, dogs)) ## 将图像堆叠在一起
    label_list = np.hstack((label_cats, label_dogs)) ## 将图像标签堆叠在一起
    
    temp = np.array([image_list, label_list]) # 将文件名和标签对应起来
    temp = temp.transpose() #矩阵转置
    np.random.shuffle(temp) # 打乱存放的顺序
    
    # 先集合起来打乱在分开的目的是为了获取打乱后的图形及其对应的标签
    image_list = list(temp[:, 0]) # 获取图片
    label_list = list(temp[:, 1]) # 获取标签
    label_list = [float(i) for i in label_list]
    
    return image_list, label_list

#%%
# 对图片进行裁剪
def get_batch(image, label, image_W, image_H, batch_size, capacity):
    '''
    args:
        image: list type
        label: list type
        image_W: image_width
        image_H: image_Height
        batch_size:batch size #每批次的图像量
        capacity: the maxmum elements in queue
    Returns:
        image_batch: 4D tensor [batch_size, width, height, 3],dtype=tf.float32
        label_batch: 1D tensor [batch_size], dtype = tf.float32
    '''
    # 类型转换函数,返回张量
    image = tf.cast(image, tf.string) # 数据类型转换 image->string
    label = tf.cast(label, tf.int32)  # 数据类型转换 label->int32
    
    # make an input queue 生成输入对列
    input_queue = tf.train.slice_input_producer([image, label])
    
    label = input_queue[1] # 读取标签
    image_contents = tf.read_file(input_queue[0]) # 读取图像 string类型
    image = tf.image.decode_jpeg(image_contents, channels = 3) #解码

    ########################################
    # data argumentatioan should go to here
    ########################################
    # 对图片进行裁剪或扩充【在图像中心处裁剪】,统一大小
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
    # 数据标准化 训练前需要对数据进行标准化
    image = tf.image.per_image_standardization(image) 
    # 生成批次 在输入的tensor中创建一些tensor数据的batch
    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size = batch_size,
                                              num_threads = 64,
                                              capacity = capacity) 
    # 重新生成大小,即将label_batch变换成[batch_size]行的形式
    label_batch = tf.reshape(label_batch, [batch_size]) 
    
    return image_batch, label_batch
    
#%% test :  matplotlib.pyplot绘图 绘制直线、条形/矩形区域

import matplotlib.pyplot as plt

BATCH_SIZE = 5 # 批次中的图像数量
CAPACITY = 256 # 队列中最多容纳元素的个数
IMG_W = 208
IMG_H = 208

train_dir = 'data/train/'

image_list, label_list = get_files(train_dir)
image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H,
                                    BATCH_SIZE, CAPACITY)

with tf.Session() as sess:
    print("start")
    i = 0
    # 开始输入队列监控,启动多线程处理数据
    coord = tf.train.Coordinator() # 
    threads = tf.train.start_queue_runners(coord = coord) # 启动入队线程
    
    try:
        while not coord.should_stop() and i<1:
            
            img, label = sess.run([image_batch, label_batch])# 输入list结构
            
            # just test one batch
            # arange返回一个array对象([ ])
            for j in np.arange(BATCH_SIZE):
                print('label: %d'%label[j])
                plt.imshow(img[j,:,:,:]) 
                plt.show()
            i += 1
    except tf.errors.OutOfRangeError:
        print('done!')
    finally:
        print('finished')
        coord.request_stop() # 通知其它线程关闭
    coord.join(threads) # 其他线程关闭之后,这一函数才能返回

#%%

猜你喜欢

转载自blog.csdn.net/fengxianghui01/article/details/82191233
今日推荐