TensorFlow学习笔记(9) TFRecord 输入数据格式

TF提供了一种统一的格式来存储数据,这个格式就是TFRecord。TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。tf.train.Example中包括一个从属性名称到取值的字典。其中属性名称为一个字符串,取值为字符串、实数列表或者整数列表。下面为一个具体的样例程序将MNIST输入数据转化为TFRecord格式。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

#生成整数型的属性
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets('MNIST_data', dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples

filename = 'mnist.tfrecords'
#创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
    #将图像矩阵转化为一个字符串
    images_raw = images[index].tobytes()
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels': _int64_feature(pixels),
        'label': _int64_feature(np.argmax(labels[index])),
        'image_raw': _bytes_feature(images_raw)}))
    writer.write(example.SerializeToString())
writer.close()

以下程序给出了如何读取TFRecord文件中的数据。

import tensorflow as tf

#创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
#创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer('mnist.tfrecords')
#从文件中读出一个样例
_, serialized_example = reader.read(filename_queue)
#解析读入的一个样例
features = tf.parse_single_example(serialized_example, features={
    #tf.FixedLenFeature解析结果为tensor
    'image_raw': tf.FixedLenFeature([], tf.string),
    'pixels': tf.FixedLenFeature([], tf.int64),
    'label': tf.FixedLenFeature([],tf.int64)
})

#tf.decode_raw可以将字符串解析成图像对应的像素数组
images = c(features['image_raw'], tf.unit8)
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()
#多线程。。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(10):
    image, label, pixel = sess.run([images, labels, pixels])

多线程输入数据处理框架

为了避免图像预处理成为神经网络模型训练效率的瓶颈,TF提供了一套多线程处理输入数据的框架。

经典的输入数据处理流程为:指定原始数据的文件列表>创建文件列表队列>从文件中读取数据>数据预处理>整理成Batch作为神经网络的输入。队列不仅是一种数据结构,更提供了多线程机制,队列也是TF中多线程输入数据处理框架的基础。比如多个线程可以同时向一个队列中写元素,或者同时读取一个队列中的元素。

队列与多线程

队列和变量都是计算图上有状态的节点。通过赋值修改变量的取值;通过Enqueue、EnqueueMany、Dequeue函数来修改队列状态。以下程序展示了如何使用这些函数来操作一个队列。

import tensorflow as tf

#创建一个先进先出队列,指定队列中最多可以保存两个元素,并指定类型为整数。
q = tf.FIFOQueue(2, 'int32')
#使用enqueue_many函数来初始化队列元素,在使用队列之前要明确调用这个初始化过程
init = q.enqueue_many(([0, 10],))
#使用Dequeue函数将队列中第一个元素出队列,这个元素将被存在变量x
x = q.dequeue()
y = x + 1
#重新加入队列
q_inc = q.enqueue([y])

with tf.Session() as sess:
    #运行初始化队列的操作
    sess.run(init)
    for _ in range(5):
        v, _ = sess.run([x, q_inc])
        print(v)

在TF中提供了FIFOQueue和RandomShuffleQueue两种队列。在上面的程序中展示了FIFOQueue队列。而RandomShuffleQueue会将队列中的元素打乱,每次enqueue_many操作得到的是从当前队列中随机选择的一个元素。

TF提供了tf.Coordinator和tf.QueueRunner两个类来完成多线程协同的功能。

tf.Coordinator主要用于协同多个线程一起停止,提供了should_stop,request_stop和join三个函数。启动的进程只有当should_stop函数为True时则退出。每一个启动的进程通过调用request_stop函数来通知其他线程退出。

import tensorflow as tf
import numpy as np
import threading
import time

#在线程中运行的程序,这个程序每隔1s判断是否需要停止打印自己的id
def MyLoop(coord, worker_id):
    while not coord.should_stop():
        if np.random.rand() < 0.1:
            print('Stoping from id:%d' % worker_id)
            coord.request_stop()
        else:
            print('Working on id:%d'% worker_id)
        time.sleep(1)

#声明一个tf.train.Coordinator()类
coord = tf.train.Coordinator()
#创建五个线程
threads = [threading.Thread(target=MyLoop, args=(coord,i,))for i in range(5)]
#启动所有的线程
for t in threads: t.start()
#等待所有线程退出
coord.join(threads)

tf.QueueRunner主要用于启动多个线程来操作同一个队列,这些线程可以通过tf.Coordinator来进行统一管理。比如,

import tensorflow as tf

#声明队列,100个元素,类型实数
queue = tf.FIFOQueue(100, 'float')
#定义队列入队操作
enqueue_op = queue.enqueue([tf.random_normal([1])])
#使用tf.train.QueueRunner创建多个线程的入队操作
#第一个参数为被操作的队列,第二个表示需要启动五个线程,每个线程都是enqueue_op操作
qr = tf.train.QueueRunner(queue, [enqueue_op]*5)

#将定义过的qr加入tf计算图指定的集合
#若没有指定集合,则加入默认的集合tf.GraphKeys,QUEUE_RUNNERS
tf.train.add_queue_runner(qr)
#定义出队列操作
out_tensor = queue.dequeue()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for _ in range(3): print(sess.run(out_tensor)[0])

    coord.request_stop()
    coord.join(threads)




猜你喜欢

转载自blog.csdn.net/qyf394613530/article/details/79321872