TF提供了一种统一的格式来存储数据,这个格式就是TFRecord。TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。tf.train.Example中包括一个从属性名称到取值的字典。其中属性名称为一个字符串,取值为字符串、实数列表或者整数列表。下面为一个具体的样例程序将MNIST输入数据转化为TFRecord格式。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#生成字符串型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
mnist = input_data.read_data_sets('MNIST_data', dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples
filename = 'mnist.tfrecords'
#创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
#将图像矩阵转化为一个字符串
images_raw = images[index].tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(labels[index])),
'image_raw': _bytes_feature(images_raw)}))
writer.write(example.SerializeToString())
writer.close()
以下程序给出了如何读取TFRecord文件中的数据。
import tensorflow as tf
#创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
#创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer('mnist.tfrecords')
#从文件中读出一个样例
_, serialized_example = reader.read(filename_queue)
#解析读入的一个样例
features = tf.parse_single_example(serialized_example, features={
#tf.FixedLenFeature解析结果为tensor
'image_raw': tf.FixedLenFeature([], tf.string),
'pixels': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([],tf.int64)
})
#tf.decode_raw可以将字符串解析成图像对应的像素数组
images = c(features['image_raw'], tf.unit8)
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)
sess = tf.Session()
#多线程。。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(10):
image, label, pixel = sess.run([images, labels, pixels])
多线程输入数据处理框架
为了避免图像预处理成为神经网络模型训练效率的瓶颈,TF提供了一套多线程处理输入数据的框架。
经典的输入数据处理流程为:指定原始数据的文件列表>创建文件列表队列>从文件中读取数据>数据预处理>整理成Batch作为神经网络的输入。队列不仅是一种数据结构,更提供了多线程机制,队列也是TF中多线程输入数据处理框架的基础。比如多个线程可以同时向一个队列中写元素,或者同时读取一个队列中的元素。
队列与多线程
队列和变量都是计算图上有状态的节点。通过赋值修改变量的取值;通过Enqueue、EnqueueMany、Dequeue函数来修改队列状态。以下程序展示了如何使用这些函数来操作一个队列。
import tensorflow as tf
#创建一个先进先出队列,指定队列中最多可以保存两个元素,并指定类型为整数。
q = tf.FIFOQueue(2, 'int32')
#使用enqueue_many函数来初始化队列元素,在使用队列之前要明确调用这个初始化过程
init = q.enqueue_many(([0, 10],))
#使用Dequeue函数将队列中第一个元素出队列,这个元素将被存在变量x
x = q.dequeue()
y = x + 1
#重新加入队列
q_inc = q.enqueue([y])
with tf.Session() as sess:
#运行初始化队列的操作
sess.run(init)
for _ in range(5):
v, _ = sess.run([x, q_inc])
print(v)
在TF中提供了FIFOQueue和RandomShuffleQueue两种队列。在上面的程序中展示了FIFOQueue队列。而RandomShuffleQueue会将队列中的元素打乱,每次enqueue_many操作得到的是从当前队列中随机选择的一个元素。
TF提供了tf.Coordinator和tf.QueueRunner两个类来完成多线程协同的功能。
tf.Coordinator主要用于协同多个线程一起停止,提供了should_stop,request_stop和join三个函数。启动的进程只有当should_stop函数为True时则退出。每一个启动的进程通过调用request_stop函数来通知其他线程退出。
import tensorflow as tf
import numpy as np
import threading
import time
#在线程中运行的程序,这个程序每隔1s判断是否需要停止打印自己的id
def MyLoop(coord, worker_id):
while not coord.should_stop():
if np.random.rand() < 0.1:
print('Stoping from id:%d' % worker_id)
coord.request_stop()
else:
print('Working on id:%d'% worker_id)
time.sleep(1)
#声明一个tf.train.Coordinator()类
coord = tf.train.Coordinator()
#创建五个线程
threads = [threading.Thread(target=MyLoop, args=(coord,i,))for i in range(5)]
#启动所有的线程
for t in threads: t.start()
#等待所有线程退出
coord.join(threads)
tf.QueueRunner主要用于启动多个线程来操作同一个队列,这些线程可以通过tf.Coordinator来进行统一管理。比如,
import tensorflow as tf
#声明队列,100个元素,类型实数
queue = tf.FIFOQueue(100, 'float')
#定义队列入队操作
enqueue_op = queue.enqueue([tf.random_normal([1])])
#使用tf.train.QueueRunner创建多个线程的入队操作
#第一个参数为被操作的队列,第二个表示需要启动五个线程,每个线程都是enqueue_op操作
qr = tf.train.QueueRunner(queue, [enqueue_op]*5)
#将定义过的qr加入tf计算图指定的集合
#若没有指定集合,则加入默认的集合tf.GraphKeys,QUEUE_RUNNERS
tf.train.add_queue_runner(qr)
#定义出队列操作
out_tensor = queue.dequeue()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for _ in range(3): print(sess.run(out_tensor)[0])
coord.request_stop()
coord.join(threads)