TensorFlow组合训练数据

TensorFlow提供了tf.train.batch和tf.train.shuffle_batch函数来讲单个样例组织成batch的形式输出:

import tensorflow as tf
example, label = features['i'], features['j']
#一个batch中的样例的个数
batch_size = 3
#组合样例的队列最多可以储存的样例个数。队列太大会很占内存,太小,出队操作可能会因为没有数据而被阻碍(block),导致训练效率降低。
#一般来说这个队列大小回合每个batch的大小相关
capacity = 1000 + 3 * batch_size
#使用tf.train.batch函数组合样例0。[example, label]参数给出了需要组合的参数,一般example和label分别代表训练样本和这个样本对应的正确标签。
#batch_size参数给出了每个batch中样例的个数。capacity给出了队列的最大容量。当队列长度等于容量,TensorFlow将暂停入队,而是等待出队;当小于容量时,
#TensorFlow将自动重启入队操作
example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(3):
        cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
        #获取打印组合之后的样例。真是问题中一般作为神经网络输入
        print(cur_example_batch, cur_label_batch)
    coord.request_stop()
    coord.join(threads)

其结果如下:

[0 0 1] [1 1 1]
[0 1 0] [0 0 0]
[0 0 1] [1 0 1]
其结果是3个一组的batch

而 tf.train.shuffle_batch函数使用示例代码如下:

example, label = features['i'], features['j']
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size = batch_size, capacity=capacity, min_after_dequeue = 30)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(3):
        cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
        #获取打印组合之后的样例。真是问题中一般作为神经网络输入
        print(cur_example_batch, cur_label_batch)
    coord.request_stop()
    coord.join(threads)
其结果如下:
[1 1 0] [0 1 0]
[0 1 1] [1 0 1]
[0 0 1] [1 0 0]
可以发现输出的样例顺序已经打乱了。

tf.train.batch和tf.train.shuffle_batch函数除了可以将单个训练数据整理成输入batch,也提供并行化处理输入数据的方法。通过设置tf.train.shuffle_batch函数的num_threads参数,可以指定多个线程同时执行入队操作,其入队操作就是数据读取以及预处理的过程。如果需要多个线程处理不同文件中的样例,可以使用tf.train.shuffle_batch_join函数。

猜你喜欢

转载自blog.csdn.net/dz4543/article/details/79653075