【多线程管理】TensorFlow线程管理器——Coordinator

TensorFlow线程管理器——Coordinator

  简介

       再现是过程中,有效的运行多个线程可能会更加复杂。线程应该能够正常停止(避免“僵尸”线程,或当一个线程出故障时,要一起关闭所有线程),停止后需要关闭线程队列,还有很多的问题需要解决。
       
        TensorFlow 配备了工具帮我们完成这个过程。 CoordinatorTensorFlow 的一个多线程管理器,用来管理在 Session 中的多个线程,可以用来同时停止多个工作线程。除此之外,它还能向等待所有工作线程终止的程序报告异常,该线程捕获到异常后就会终止所有线程。
       

  代码示例

       可以通过下面的代码学习 Coordinator 的使用方法:

import tensorflow as tf
import threading
import time

gen_random_normal = tf.random_normal(shape=())
queue = tf.FIFOQueue(capacity=100, dtypes=[tf.float32], shapes=())
enque = queue.enqueue(gen_random_normal)

with tf.Session() as sess:
    def add(coord, i):
        try:
            while not coord.should_stop():
                sess.run(enque)
                if i == 9:
                    # 抛出异常做测试
                    raise Exception("Invalid id: !", i)
        except Exception:
            print("Invalid id: !", i)
        finally:
            # 请求关闭所有线程
            coord.request_stop()


    coord = tf.train.Coordinator()
    threads = [threading.Thread(target=add, args=(coord, i)) for i in range(10)]
    coord.join(threads)

    for t in threads:
        t.start()

    print(sess.run(queue.size()))
    time.sleep(0.003)
    print(sess.run(queue.size()))
    time.sleep(0.01)
    print(sess.run(queue.size()))

    #把开启的线程加入主线程,等待threads结束
    coord.join(threads)
    print('所有进程终止!')

  测试结果

	0
	Invalid id: ! 9
	30
	33
	所有进程终止!

       

  使用QueueRunner,RandomShuffleQueue改进

       虽然我们可以创建多个线程重复运行入队操作,但最好使用内置的 tf.train.QueueRunner ,这会完成我们需要偶的动作,且遇到异常会关闭队列
       
       改进代码如下:

import tensorflow as tf

gen_random_normal = tf.random_normal(shape=())
queue = tf.RandomShuffleQueue(capacity=100, dtypes=[tf.float32], min_after_dequeue=1)
enqueue = queue.enqueue(gen_random_normal)

qr = tf.train.QueueRunner(queue, [enqueue] * 4)
coord = tf.train.Coordinator()

with tf.Session() as sess:

    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)

    for i in range(5):
        my_dequeue = sess.run(queue.dequeue())
        print(my_dequeue)

    coord.request_stop()
    coord.join(enqueue_threads)

       在这个例子中,我们使用 tf.RandomShuffleQuene 而不是FIFO队列。RandomShuffleQuene 就是一个带有出队操作的队列,它一随机的顺序弹出项。当使用随机梯度下降优化来训练深度神经网络使,这是非常有用的,当然我们需要对数据进行随机排序。
       
       min_after_dequeue参数指定在调用出队操作之后将保留在队列中的最小数目——更大的数字需要更好的混合(随机采样),也需要更多的内存。
       
       输出结果:

	-0.5699217
	0.23737773
	0.018851127
	0.8081313
	1.6825532
	

       

本文示例参考《TensorFlow学习指南——深度学习系统构建详解》第八章第三节

       
       欢迎各位大佬交流讨论!

猜你喜欢

转载自blog.csdn.net/weixin_42721167/article/details/112795491