Tensorflow细节-P194-组合训练数据

import tensorflow as tf


files = tf.train.match_filenames_once("data.tfrecords-*")
filename = tf.train.string_input_producer(files, shuffle=False, num_epochs=3)  # 创建输入队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename)

features = tf.parse_single_example(  # 解析serialized_example
    serialized_example,
    features={
        'i': tf.FixedLenFeature([], tf.int64),
        'j': tf.FixedLenFeature([], tf.int64),
    }
)

example, label = features['i'], features['j']
batch_size = 3
capacity = 1000 + 3 * batch_size
# example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
# example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size,
#                                                     capacity=capacity, min_after_dequeue=30)
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size,
                                                    capacity=capacity, min_after_dequeue=30, num_threads=2)
# capacity与队列有关,当队列容量小于capacity时,Tensorflow将重新启动入队操作,当长度等于容量时,暂停入队操作
# 以上这玩意tf.train.batch是会新创建一个队列的
with tf.Session() as sess:
    tf.local_variables_initializer().run()  # 处理files = tf.train.match_filenames_once("data.tfrecords-*")
    print(sess.run(files))
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(4):
        cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
        print(cur_example_batch, cur_label_batch)

    coord.request_stop()
    coord.join(threads)

猜你喜欢

转载自www.cnblogs.com/liuboblog/p/11651021.html