深度学习中基于tensorflow_slim进行复杂模型训练二之tensorflow_slim的使用

版权声明: https://blog.csdn.net/hh_2018/article/details/83511470

上篇博客主要介绍了tensorflow_slim的基本模块,本篇主要介绍一下如何使用该模块训练自己的模型。主要分为数据转化,数据读取,数据预处理,模型选择,训练参数设定,构建pb文件,固化pb文件中的参数几部分。

一、数据转化:

主要目的是将图片转化为TFrecords文件,该部分属于数据的预处理阶段,可以参考datasets中的download_and_conver_flower中的run函数实现。具体关于如何使用将会在后续介绍。

二、数据读取

该部分主要是在datasets中新建一个文件并将其命名为自己的名字,例如命名为emotion.py,然后将flowers.py中的内容复制到新建的文件中,并对以下部分进行修改:

1. _FILE_PATTERN = 'emotion_%s_*.tfrecord' 表示tfrecord文件名的格式

2. SPLITS_TO_SIZES = {'train': 18534, 'validation': 8331}表示用于训练和测试的数据个数

3. _NUM_CLASSES = 5,训练数据的类数,涉及到网络模型最后一层的输出。

最后需要在dataset_factory中增加自己新建的数据映射。

datasets_map = {
    'emotion': emotion,
}

三、数据增强

该过程主要是对读取的数据进行数据增强,可以有两种方式:1. 采用现有的增强模式(因为数据增强的大部分操作都是一样的),2. 构建自己的增强方式(可以使模型训练的时候传入的参数较统一)。

对于第二种方式依然需要构建新的文件夹,然后复制一个内容进行修改或者完全自己书写。本次采用的是复制cifarnet_preprocessing.py的内容进行修改得到的。具体修改的地方如下:将

distorted_image = tf.random_crop(image, [output_height, output_width, 3])    改为

distorted_image = tf.image.resize_images(image, [output_height, output_width], method=1)

主要是为了避免需要的图片比还未裁剪的小导致无法进行裁剪的错误

然后在preprocessing.py中增加新的映射:
  preprocessing_fn_map = {
      'emotion': emotion_preprocessing,
  }
  
四、模型选择

在nets中选择出自己需要使用的模型,并下载对应训练好的模型.ckpt文件,具体的下载地址可以参考README.md文件(以inception_v3模型为例)[inception_v3_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) 

五、训练参数设定

准备工作完成后就是使用train_image_classifier.py对自己的数据进行训练,该部分涉及到较多的参数,具体设定如下:

1. tf.app.flags.DEFINE_string( 'dataset_name', 'emotion', 'The name of the dataset to load.')
    
表示数据的名字,在读取数据和数据增强的时候该值相当于是map中的key,根据该值找到对应的读取和增强的脚本。

tf.app.flags.DEFINE_string('dataset_split_name', 'train', 'The name of the train/test split.')

表示数据的作用,用来train还是validation,该值主要是产生emotion_train*.tfrecord的形式存于data_sources中,便于在后面的读取数据时使用

  if '*' in data_sources or '?' in data_sources or '[' in data_sources:  data_files = gfile.Glob(data_sources)

的方式对数据进行读取。
      

tf.app.flags.DEFINE_string( 'dataset_dir', "data_to_tfrecord", 'The directory where the dataset files are stored.')

表示tfrecord数据的存储路径,构建data结构读取数据时使用
      
tf.app.flags.DEFINE_integer( 'labels_offset',  0,  'An offset for the labels in the dataset. This flag is primarily used to  evaluate the VGG and ResNet architectures which do not use a background  class for the ImageNet dataset.') 表示标签的偏移量,即默认标签是从0开始,假如偏移2,那么标签就会从2开始,一般情况下选择默认的值即可。


tf.app.flags.DEFINE_string('model_name', 'inception_v3', 'The name of the architecture to train.')用于进行二次训练的模型名字,主要依赖于net_factory中的map是如何写的。

tf.app.flags.DEFINE_string( 'preprocessing_name', 'emotion', 'The name of the preprocessing to use. If left  as `None`, then the model_name flag is used.')  表示采用的预处理方式,主要依赖于preprocessing_factory中的map
                                     
tf.app.flags.DEFINE_integer('batch_size', 32, 'The number of samples in each batch.') 在将数据进行批处理时每批数据的多少
    
tf.app.flags.DEFINE_integer( 'train_image_size', 299, 'Train image size') 模型输入的图片大小

tf.app.flags.DEFINE_integer('max_number_of_steps', 50000, 'The maximum number of training steps.')   表示训练的步数
                            
tf.app.flags.DEFINE_integer( 'log_every_n_steps', 10,  'The frequency with which logs are print.') log的输出频率,即每运行多少步输出一个log

tf.app.flags.DEFINE_integer( 'save_summaries_secs', 100,  'The frequency with which summaries are saved, in seconds.')    
表示存储summaries的频率

tf.app.flags.DEFINE_integer('save_interval_secs', 600, 'The frequency with which the model is saved, in seconds.')  表示存储模型的频率

tf.app.flags.DEFINE_float( 'weight_decay', 0.00004, 'The weight decay on the model weights.') 表示为了避免过拟合采用正则化的系数


tf.app.flags.DEFINE_string( 'train_dir', 'train_result', 'Directory where checkpoints and event logs are written to.')表示训练参数存储的地方    
    

tf.app.flags.DEFINE_string(  'checkpoint_path', "pre_trained_check/inception_v3_2016_08_28/inception_v3.ckpt", 'The path to a checkpoint from which to fine-tune.')  表示提前处理好的模型参数存储的地方
    

tf.app.flags.DEFINE_string('checkpoint_exclude_scopes', "InceptionV3/Logits,InceptionV3/AuxLogits",'Comma-separated list of scopes of variables to exclude when restoring  from a checkpoint.')模型中不用恢复的节点,一般均为模型的输出层,因为输出层需要结合自己实际的类进行训练确定数据的输出大小,当该值为空时,则表示所有的变量均恢复。


tf.app.flags.DEFINE_string('trainable_scopes', None, 'Comma-separated list of scopes to filter the set of variables to train. By default, None would train all the variables.')  表示再次训练的节点,None表示所有的都参与训练。
    
tf.app.flags.DEFINE_string( 'learning_rate_decay_type',  'exponential',  'Specifies how the learning rate is decayed. One of "fixed", "exponential", or "polynomial"') 表示学习率衰减的方式。


对于该模块在使用中涉及到的其他参数均使用默认的即可。

在使用脚本时有时候会报出部分操作无法在GPU上运行的错误,此时train的上面增加config = tf.ConfigProto(allow_soft_placement=True)表示当无法采用GPU计算时使用cpu进行。并将该参数传递给train。

五、构建pb文件

此时直接使用export_interence_graph.py可以将模型结构变成.pb的,涉及的参数如下:

tf.app.flags.DEFINE_string( 'model_name', 'inception_v3', 'The name of the architecture to save.') 表示要调用的模型结构

tf.app.flags.DEFINE_boolean( 'is_training', False, 'Whether to save out a training-focused version of the model.') 表示在模型中的参数是否用来进行训练

tf.app.flags.DEFINE_integer( 'image_size', 299, 'The image size to use, otherwise use the model default_image_size.')定义一个输入占位符的二三维大小

tf.app.flags.DEFINE_string('dataset_name', 'emotion',  'The name of the dataset to use with the model.') 主要根据传入的名字确定一个对应的数据集,确定其num_class的值用于构建模型结构
    

tf.app.flags.DEFINE_integer( 'labels_offset', 0, 'An offset for the labels in the dataset. This flag is primarily used to  evaluate the VGG and ResNet architectures which do not use a background  class for the ImageNet dataset.')偏移量,用来构建模型的时候会用到

tf.app.flags.DEFINE_string(  'output_file', 'train_pb/motion_inception_v3_graph.pb', 'Where to save the resulting file to.')输出的.pb文件名字和存储地方

tf.app.flags.DEFINE_integer( 'batch_size', None,'Batch size for the exported model. Defaulted to "None" so batch size can ')定义输入占位符的第一维度大小。


六、在模型结构中放入自己训练的结果并固化

该过程实现的原理是先读入一个结构图,然后在使用saver.restore()恢复图中对应
参数的值,最后再存储。

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.training import saver as saver_lib

# 定义一些参数
input_graph = 'train_pb\\test_1.pb'
output_graph = 'train_pb\\test_2.pb'
input_checkpoint = 'train_result\\model.ckpt-20'
output_node_names = 'InceptionV3/Predictions/Reshape_1'

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(input_graph, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        var_list = {}
        reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            try:
                tensor = sess.graph.get_tensor_by_name(key + ":0")
            except KeyError:
                continue
            var_list[key] = tensor
        saver = tf.train.Saver(var_list=var_list)
        saver.restore(sess, input_checkpoint)

        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                   output_node_names.split(","))
        with tf.gfile.FastGFile(output_graph, mode='wb') as f:
            f.write(constant_graph.SerializeToString())


至此,就完成了模型训练和固化,然后可以根据具体需要自行进行使用。

猜你喜欢

转载自blog.csdn.net/hh_2018/article/details/83511470