TensorFlow提供了tf.train.batch和tf.train.shuffle_batch函数来讲单个样例组织成batch的形式输出:
import tensorflow as tf example, label = features['i'], features['j'] #一个batch中的样例的个数 batch_size = 3 #组合样例的队列最多可以储存的样例个数。队列太大会很占内存,太小,出队操作可能会因为没有数据而被阻碍(block),导致训练效率降低。 #一般来说这个队列大小回合每个batch的大小相关 capacity = 1000 + 3 * batch_size #使用tf.train.batch函数组合样例0。[example, label]参数给出了需要组合的参数,一般example和label分别代表训练样本和这个样本对应的正确标签。 #batch_size参数给出了每个batch中样例的个数。capacity给出了队列的最大容量。当队列长度等于容量,TensorFlow将暂停入队,而是等待出队;当小于容量时, #TensorFlow将自动重启入队操作 example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size, capacity=capacity) with tf.Session() as sess: tf.global_variables_initializer().run() tf.local_variables_initializer().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(3): cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch]) #获取打印组合之后的样例。真是问题中一般作为神经网络输入 print(cur_example_batch, cur_label_batch) coord.request_stop() coord.join(threads)
其结果如下:
[0 0 1] [1 1 1] [0 1 0] [0 0 0] [0 0 1] [1 0 1]其结果是3个一组的batch
而 tf.train.shuffle_batch函数使用示例代码如下:
example, label = features['i'], features['j'] example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size = batch_size, capacity=capacity, min_after_dequeue = 30) with tf.Session() as sess: tf.global_variables_initializer().run() tf.local_variables_initializer().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(3): cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch]) #获取打印组合之后的样例。真是问题中一般作为神经网络输入 print(cur_example_batch, cur_label_batch) coord.request_stop() coord.join(threads)其结果如下:
[1 1 0] [0 1 0] [0 1 1] [1 0 1] [0 0 1] [1 0 0]可以发现输出的样例顺序已经打乱了。
tf.train.batch和tf.train.shuffle_batch函数除了可以将单个训练数据整理成输入batch,也提供并行化处理输入数据的方法。通过设置tf.train.shuffle_batch函数的num_threads参数,可以指定多个线程同时执行入队操作,其入队操作就是数据读取以及预处理的过程。如果需要多个线程处理不同文件中的样例,可以使用tf.train.shuffle_batch_join函数。