Tensorflow 入门学习7.TensorFlow的队列

版权声明:(谢厂节的博客)博主文章绝大部分非原创,转载望留链接。 https://blog.csdn.net/xundh/article/details/82765006

本节学习资源来自《TensorFlow深度学习应用实践》

TensorFlow 队列

队列(Queue)是一种最为常用的数据输入输出方式,其通过先进先出的线性数据结构,一端只负责增加队列中的数据元素,而数据的输出和删除在队列的另一端实现。

队列的创建

队列的使用和Python中队列的函数类似。

操作 描述
class tf.QueueBase 基本的队列应用类,队列(queue)是一种数据结构,该结构通过多个步骤存储tensors,并且对tensors进行入列(enqueue)与出列(dequeue)操作。
tf.enqueue(vals,name=None) 将一个元素编入该队列中。如果在执行该操作时队列已满,那么将会阻塞起到元素输入队列之中。
tf.enqueue_many(vals,name=None) 将零个或多个元素编入该队列中
tf.dequeue(name=None) 将元素从队列中移出,如果在执行该操作时队列已空,那么将会阻塞元素出殡,返回出殡的tensors的tuple
tf.size(name=None) 计算队列中的元素个数
tf.close 关闭该队列
f.dequeue_up_to(n,name=None) 从该队列中移出n个元素并将之连接
tf.dtypes 列出组成元素的数据类型
tf.from_list(index,queues) 根据queues[index]的参考队列创建一个队列
tf.name 返回队列最下面元素的名称
tf.names 返回队列每一个组成部分的名称
class tf.FIFOQueue 在出殡时依照先入先出顺序
class tf.PaddingFIFOQueue 一个FIFOQueue,同时根据padding支持batching变长的tensor
class tf.RandomShuffleQueue 该队列将随机元素出列

一般而言,创建一个队列首先要选定数据出入类型,例如是使用FIFOQueue函数设定数据为先入先出,还是RandomShuffleQueue这种随机元素出列的方式。

q = tf.FIFOQueue(3,"float")  # 第一个参数为队列中数据的个数,第二个参数是队列中元素的类型

之后要对队列中元素进行初始化和进行操作。

示例:

import tensorflow as tf

with tf.Session() as sess:
    q = tf.FIFOQueue(3, "float")              # 设定一个先进先出队列
    init = q.enqueue_many(([0.1, 0.2, 0.3],))  # 填进数据
    init2 = q.dequeue()                        # 弹出数据   [0.2, 0.3]
    init3 = q.enqueue(1.)                      # [0.2, 0.3, 1]

    sess.run(init)
    sess.run(init2)
    sess.run(init3)

    quelen = sess.run(q.size())                # 获取队列的数据个数
    for i in range(quelen):                   # 循环弹出数据
        print(sess.run(q.dequeue()))

另外,TensorFlow 提供了QueueRunner函数用以解决异步操作问题。其可创建一系列的线程同时进入主线程内进行操作,数据的读取与操作是同步,即主线程在进行训练模型的工作的同时将数据从硬盘读入。

import tensorflow as tf

q = tf.FIFOQueue(1000, "float32")
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter,tf.constant(1.0))
enqueueData_op = q.enqueue(counter)

sess = tf.Session()
qr = tf.train.QueueRunner(q,enqueue_ops=[add_op, enqueueData_op] * 2)
sess.run(tf.initialize_all_variables())
enqueue_threads = qr.create_threads(sess, start=True)    # 启动入队线程

for i in range(10):
    print(sess.run(q.dequeue()))
    

运行会话是正确的,但程序也没有结束,而是被挂起。造成这种情况的原因是add操作和入队操作没有同步,即TensorFlow在队列设计时为了优化IO系统,队列的操作一般使用批处理,这样入队线程没有发送结束的信息而程序主线程期望将程序结束,因此造成线程堵塞程序被挂起。

线程同步与停止

TensorFlow中的会话是支持多线程的,多个线程可以很方便地在一个会话下共同工作,并行地相互执行。但通过上面程序看到,这种同步会造成某个线程想要关闭对话时,对话被强行关闭而未完成工作的线程也被强行关闭。

TensorFlow为了解决多线程的同步和处理问题,提供了Coordinator和QueueRunner函数来对线程进行控制和协调。在使用上,这2个类必须被同时工作,共同协作来停止会话中所有线程,并向等待所有工作线程终止的程序报告。

import tensorflow as tf

q = tf.FIFOQueue(1000, "float32")
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter,tf.constant(1.0))
enqueueData_op = q.enqueue(counter)

sess = tf.Session()
qr = tf.train.QueueRunner(q, enqueue_ops=[add_op, enqueueData_op] * 2)
sess.run(tf.initialize_all_variables())
enqueue_threads = qr.create_threads(sess, start=True)    # 启动入队线程

coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord = coord, start=True)
for i in range(0, 10):
    print(sess.run(q.dequeue()))

coord.request_stop()
coord.join(enqueue_threads)


这里create_threads 函数被添加了一个新的参数:线程协调器,用于协调线程之间的关系,之后启动线程以后,线程协调器在最后负责所有线程的接受和处理,即当一个线程结束时,线程协调器会对所有的线程发出通知,协调其完毕。

从有磁盘读取数据步骤

  1. 从磁盘读取数据的名称与路径
  2. 将文件名堆入列队尾部
  3. 从队列头部读取文件名并读取数据
  4. Decoder将读取的数据解码
  5. 将数据输入样本队列,供后续使用。

猜你喜欢

转载自blog.csdn.net/xundh/article/details/82765006