tensorflow--开启线程操作

线程异步

在训练过程中,可以通过异步读取数据,加快程序执行速度

1. 操作过程

1.1. 定义要做的事情

例如读取文件,此处示例就变量自加好了:

var = tf.Variable(0.0)
data = tf.assign_add(var,tf.constant(1.0))

1.2. 定义队列

Q = tf.FIFOQueue(1000, tf.float32)

1.3. 定义入队操作

en_q = Q.enqueue(data)

1.4. 定义入队管理器

qr = tf.train.QueueRunner(Q, enqueue_ops=[en_q] * 5)

最后*5是指设置5个线程干en_q这件事

1.5. 定义线程管理器

coord = tf.train.Coordinator()

1.6. 开启子线程

注意,接下来操作都在session()中执行

threads = qr.create_threads(sess, coord=coord, start=True)

1.7. 读取数据

for i in range(1000):
    print(sess.run(Q.dequeue()))

1.8. 关闭与回收

coord.request_stop()
coord.join(threads)

2. 完整代码

import tensorflow as tf

# 定义操作0
var = tf.Variable(0.)
data = tf.assign_add(var, tf.constant(1.0))
# 定义队列
Q = tf.FIFOQueue(1000, tf.float32)

# 定义入队操作
en_q = Q.enqueue(data)
# 定义入队管理器
qr = tf.train.QueueRunner(Q, enqueue_ops=[en_q] * 5)
# 定义线程管理器
coord = tf.train.Coordinator()

with tf.Session() as sess:
    # 初始化变量
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    # 开启子线程
    threads = qr.create_threads(sess, coord=coord, start=True)
    # 读取数据
    for i in range(1000):
        print(sess.run(Q.dequeue()))
    # 回收
    coord.request_stop()
    coord.join(threads)

猜你喜欢

转载自blog.csdn.net/weixin_43003274/article/details/83418589
今日推荐