TensorFlow多线程读取机制

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/MOU_IT/article/details/82797610

1、TensorFlow读取机制图解

    我们必须要把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率。解决这个问题方法就是将读入数据和计算分别放在两个线程中,将数据读入内存的一个队列,如下图所示:

    读取线程源源不断地将文件系统中的图片读入到一个内存的队列中,而负责计算的是另一个线程,计算需要数据时,直接从内存队列中取就可以了。这样就可以解决GPU因为IO而空闲的问题!在tensorflow中,为了方便管理,在内存队列前又添加了一层所谓的“文件名队列”。tensorflow使用文件名队列+内存队列双队列的形式读入文件,可以很好地管理epoch。下面我们用图片的形式来说明这个机制的运行方式:

2、相关函数简介

(1)三个机制:

      Queue是TF队列和缓存机制的实现,它本质上是一个队列;
      QueueRunner是TF中对操作Queue的线程的封装,它本质上是一个线程;
      Coordinator是TF中用来协调线程运行的工具,保存线程组的运行状态;

(2)读取步骤

    步骤:
     1)获取文件名列表list
    2)创建文件名队列,调用tf.train.string_input_producer(),参数包含:文件名列表,num_epochs【定义重复次数】,shuffle【定义是否打乱文件的顺序】
   3)定义对应文件的阅读器:tf.ReaderBase、tf.TFRecordReader 、tf.TextLineReader 、tf.WholeFileReader 、tf.IdentityReader 、tf.FixedLengthRecordReader。
     4)解析器 : tf.decode_csv 、tf.decode_raw 、 tf.image.decode_image 。
     5)预处理,对原始数据进行处理,以适应network输入所需
     6)生成batch,调用tf.train.batch() 或者 tf.train.shuffle_batch()
     7)prefetch【可选】使用预加载队列slim.prefetch_queue.prefetch_queue()
     8)启动填充队列的线程,调用tf.train.start_queue_runners

(3)tf.train.string_input_producer():生成文件名队列

    这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。tf.train.string_input_producer还有两个重要的参数,一个是num_epochs,表示epoch数。另外一个就是shuffle是指在一个epoch内文件的顺序是否被打乱。在tensorflow中,内存队列不需要我们自己建立,我们只需要使用reader对象从文件名队列中读取数据就可以了。
    在我们使用tf.train.string_input_producer创建文件名队列后,整个系统其实还是处于“停滞状态”的,也就是说,我们文件名并没有真正被加入到队列中,此时如果我们开始计算,因为内存队列中什么也没有,计算单元就会一直等待,导致整个系统被阻塞。使用tf.train.start_queue_runners之后,才会启动填充队列的线程,这时系统就不再“停滞”。此后计算单元就可以拿到数据并进行计算,整个程序也就跑起来了。

import tensorflow as tf 

with tf.Session() as sess:
    filename = ['A.jpg', 'B.jpg', 'C.jpg']
    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    tf.local_variables_initializer().run()
    threads = tf.train.start_queue_runners(sess=sess)
    i = 0
    while True:
        i += 1
        # 获取图片数据并保存
        image_data = sess.run(value)
        with open('read/test_%d.jpg' % i, 'wb') as f:
            f.write(image_data)

(4) Queue    

    tf.FIFOQueue: 按入列顺序出列的队列
    tf.RandomShuffleQueue: 随机顺序出列的队列
    tf.PaddingFIFOQueue: 以固定长度批量出列的队列
    tf.PriorityQueue: 带优先级出列的队列
  函数原型:tf.FIFOQueue(capacity, dtypes, shapes=None, names=None ...)
    Queue主要包含入列(enqueue)和出列(dequeue)两个操作。enqueue操作返回计算图中的一个Operation节点,dequeue操作返回一个Tensor值。Tensor在创建时同样只是一个定义(或称为“声明”),需要放在Session中运行才能获得真正的数值。

import tensorflow as tf
tf.InteractiveSession()

q = tf.FIFOQueue(2, "float")
init = q.enqueue_many(([0,0],))
x = q.dequeue()
y = x+1
q_inc = q.enqueue([y])
init.run()
q_inc.run()
q_inc.run()
q_inc.run()
x.eval()  # 返回1
x.eval()  # 返回2
x.eval()  # 卡住

(5) QueueRunner

    Tensorflow的计算主要在使用CPU/GPU和内存,而数据读取涉及磁盘操作,速度远低于前者操作。因此通常会使用多个线程读取数据,然后使用一个线程消费数据,QueueRunner就是来管理这些读写队列的线程。

import tensorflow as tf  
import sys  

q = tf.FIFOQueue(10, "float")  
counter = tf.Variable(0.0)  #计数器
# 给计数器加一
increment_op = tf.assign_add(counter, 1.0)
# 将计数器加入队列
enqueue_op = q.enqueue(counter)

# 创建QueueRunner,用多个线程向队列添加数据
# 这里实际创建了4个线程,两个增加计数,两个执行入队
qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)

# 主线程
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# 启动入队线程
qr.create_threads(sess, start=True)
for i in range(20):
    print (sess.run(q.dequeue()))

    增加计数的进程会不停的后台运行,执行入队的进程会先执行10次(因为队列长度只有10),然后主线程开始消费数据,当一部分数据消费被后,入队的进程又会开始执行。最终主线程消费完20个数据后停止,但其他线程继续运行,程序不会结束。

(6)Coordinator

    用来保存线程组运行状态的协调器对象

import tensorflow as tf
import threading, time

# 子线程函数
def loop(coord, id):
    t = 0
    while not coord.should_stop():
        print(id)
        time.sleep(1)
        t += 1
        # 只有1号线程调用request_stop方法
        if (t >= 2 and id == 1):
            coord.request_stop()

# 主线程
coord = tf.train.Coordinator()
# 使用Python API创建10个线程
threads = [threading.Thread(target=loop, args=(coord, i)) for i in range(10)]

# 启动所有线程,并等待线程结束
for t in threads: t.start()
#join操作经常用在线程当中,其作用是等待某线程结束,其他所有线程关闭之后,这一函数才能返回
coord.join(threads)

   所有的子线程执行完两个周期后都会停止,主线程会等待所有子线程都停止后结束,从而使整个程序结束。由此可见,只要有任何一个线程调用了Coordinator的request_stop()方法,所有的线程都可以通过should_stop()方法感知并停止当前线程。

3、两种使用多线程的方式

(1)第一种,显式的创建QueueRunner,然后调用它的create_threads方法启动线程。例如下面这段代码:

import tensorflow as tf

# 1000个4维输入向量,每个数取值为1-10之间的随机数
data = 10 * np.random.randn(1000, 4) + 1
# 1000个随机的目标值,值为0或1
target = np.random.randint(0, 2, size=1000)

# 创建Queue,队列中每一项包含一个输入数据和相应的目标值
queue = tf.FIFOQueue(capacity=50, dtypes=[tf.float32, tf.int32], shapes=[[4], []])

# 批量入列数据(这是一个Operation)
enqueue_op = queue.enqueue_many([data, target])
# 出列数据(这是一个Tensor定义)
data_sample, label_sample = queue.dequeue()

# 创建包含4个线程的QueueRunner
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

with tf.Session() as sess:
    # 创建Coordinator
    coord = tf.train.Coordinator()
    # 启动QueueRunner管理的线程
    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
    # 主线程,消费100个数据
    for step in range(100):
        if coord.should_stop():
            break
        data_batch, label_batch = sess.run([data_sample, label_sample])
    # 主线程计算完成,停止所有采集数据的进程
    coord.request_stop()
    coord.join(enqueue_threads)

(2)第二种,使用全局的start_queue_runners方法启动线程。
    在这个例子中,tf.train.string_input_produecer()将一个隐含的QueueRunner添加到全局图中,类似的操作还有tf.train.shuffle_batch()等)。由于没有显式地返回QueueRunner来用create_threads启动线程,这里用tf.train.start_queue_runners()方法直接启动tf.GraphKeys.QUEUE_RUNNERS集合中的所有队列线程。

import tensorflow as tf

# 同时打开多个文件,显示创建Queue,同时隐含了QueueRunner的创建
filename_queue = tf.train.string_input_producer(["data1.csv","data2.csv"])
reader = tf.TextLineReader(skip_header_lines=1)
# Tensorflow的Reader对象可以直接接受一个Queue作为输入
key, value = reader.read(filename_queue)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    # 启动计算图中所有的队列线程
    threads = tf.train.start_queue_runners(coord=coord)
    # 主线程,消费100个数据
    for _ in range(100):
        features, labels = sess.run([data_batch, label_batch])
    # 主线程计算完成,停止所有采集数据的进程
    coord.request_stop()
    coord.join(threads)

参考:https://www.cnblogs.com/demian/p/8005407.html

猜你喜欢

转载自blog.csdn.net/MOU_IT/article/details/82797610