《实战Google深度学习框架》学习笔记之输入数据处理框架

  虽然使用上篇文章中介绍的数据预处理方法可以减少无关因素对图像识别模型效果的影响,但这些复杂的数据预处理过程也会减慢整个训练过程,因此,TensorFlow提供了一套多线程处理输入数据的框架,下图为一个经典的输入数据处理流程:
在这里插入图片描述
完整流程的代码及注释如下:

import tensorflow as tf

#创建文件列表,并通过文件列表创建输入文件队列。在调用输入数据处理流程前,需要统一所有原始数据的格式并将它们存储到
#TFRecord文件中。下面给出的文件列表应包含所有提供训练数据的TFRecord文件。shuffle参数可以选择是否将文件顺序打乱。
files = tf.train.match_filenames_once("/path/to/file_pattern-*")
filename_queue = tf.train.string_input_producer(files, shuffle=False)

#创建一个reader来读取TFRecord文件中的样例。这里假设image中存储的是图像的原始数据,label为样例对应的标签。
#height、width、和channels给出了图片的维度。
reader = tf.TFRecordReader()
#从文件中读取一个样例。也可以使用read_up_to一次性读取多个样例。
_,serialized_example = reader.read(filename_queue)
#解析读入的一个样例,如果需要解析多个样例,可以用parse_example函数。
features = tf.parse_single_example(serialized_example,features={
        #TensorFlow提供两种不同的属性解析方法,一种方法是tf.FixedLenFeature,这种方法的
        #解析结果为一个Tensor。另一种方法是tf.VarLenFeature,这种方法得到的解析结果为
        #SparseTensor,用于处理稀疏数据。。
        'image': tf.FixedLenFeature([],tf.string),
        'label': tf.FixedLenFeature([],tf.int64),
        'height': tf.FixedLenFeature([],tf.int64),
        'width': tf.FixedLenFeature([],tf.int64),
        'channels': tf.FixedLenFeature([],tf.int64),
        })

image, label = features['image'], features['label']
height, width = features['height'], features['width']
channels = features['channels']

#tf.decode_raw可以将字符串解析成图像对应的像素矩阵,并根据图像尺寸还原图像。
decoded_image = tf.decode_raw(image, tf.unit8)
decoded_image.set_shape([height,width,channels])

#定义神将网络输入层图片大小。
image_size = 299
#preprocess_for_train为上篇文章中介绍的图像预处理程序。
distorted_image = preprocess_for_train(decoded_image, image_size, image_size, None)

#将处理后的图像和标签数据通过tf.train.shuffle_batch整理成神将网络训练时需要的batch。
#min_after_dequeue参数是tf.train.shuffle_batch函数特有的,限制了出队时队列中元素的最少个数,保证随机打乱顺序的作用。
min_after_dequeue = 10000
#一个batch中样例的个数。
batch_size = 100

#组合样例的的队列中最多可以存储的样例个数。这个队列如果太大,那么需要占用很多内存资源;如果太小,那么出队操作可能会因为
#没有数据而被阻碍,从而导致训练效率降低。一般来说这个队列的大小会和每个batch的大小相关,下面一行代码给出一种设置队列大小的方式。
capacity = min_after_dequeue + 3*batch_size

#tf.train.shuffle_batch来组合样例,[distorted_image,label]参数给出了需要组合的元素,一般分别代表训练样本和样本
#对应的正确标签。batch_size给出了每个batch中的样例个数,capacity给出了队列的最大容量。当队列长度等于容量时,TensorFlow
#将暂停入队操作,而只是等待元素出队,当元素个数等于容量时,TensorFlow将自动重新启动入队操作。
image_batch, label_batch = tf.train.shuffle_batch(
        [distorted_image,label], batch_size=batch_size,
        capacity=capacity, min_after_dequeue=min_after_dequeue)

#定义神将网络的结构以及优化过程。image_batch可以作为输入提供给神将网络的输入层,label_batch提供了输入batch中样例的正确答案。
#inference为前文定义的神将网络的前向传播过程。
logit = inference(image_batch)
#loss为损失函数,计算当前时刻输出的损失。
loss = calc_loss(logit, label_batch)
#使用tf.train.GradientDescentOptimizer来优化损失函数。
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)


#声明会话并运行神将网络的优化过程。
with tf.Session() as sess:
    #神将网络训练准备工作,包括变量初始化,线程启动。
    tf.initialize_all_variables().run()
    #tf.train.Coordinator()来协同启动的线程。
    coord = tf.train.Coordinator()
    
    #使用tf.train.QueueRunner时,需要明确调用tf.train.start_queue_runners来启动所有线程。否则因为没有线程运行入队操作
    #当调用出队操作时,程序会一直等待入队操作被运行。tf.train.start_queue_runners函数会默认启动tf.GraphKeys.QUEEN_RUNNERS
    #集合中所有的QueenRnner,所以一般来说tf.train.add_quene_runner函数和tf.train.start_queue_runners函数会指定同一个集合。
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    
    #获取队列中的取值,神将网络训练过程。
    for i in range(TRAINING_ROUDNS):
        sess.run(train_step)
    
    #使用tf.train.Coordinator来停止所有线程。
    coord.request_stop()
    coord.join(threads)

猜你喜欢

转载自blog.csdn.net/qq_40739970/article/details/86775830
今日推荐