tensorflow函数笔记

1.tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=20, allow_smaller_final_batch=False)  # 获取一个batch的数据

import numpy as np
import tensorflow as tf


def next_batch():
    datasets = np.asarray(range(0, 20))
    input_queue = tf.train.slice_input_producer([datasets], shuffle=False, num_epochs=1)
    data_batchs = tf.train.batch(input_queue, batch_size=5, num_threads=1,
                                 capacity=20, allow_smaller_final_batch=False) # capacity读取20个数据
    return data_batchs


if __name__ == "__main__":
    data_batchs = next_batch()
    sess = tf.Session()
    sess.run(tf.initialize_local_variables())
    coord = tf.train.Coordinator() # 构造管道
    threads = tf.train.start_queue_runners(sess, coord) # 构造线程
    try:
        while not coord.should_stop():
            data = sess.run([data_batchs])
            print(data)
    except tf.errors.OutOfRangeError:
        print("complete")
    finally:
        coord.request_stop()
    coord.join(threads)  # 将线程继续添加
    sess.close()

猜你喜欢

转载自www.cnblogs.com/my-love-is-python/p/10918171.html