深度学习中基于tensorflow_slim进行复杂模型训练一之tensorflow_slim基本介绍

版权声明: https://blog.csdn.net/hh_2018/article/details/83474893

最近在进行微表情识别,但是目前没有查到比较有效的模型方式,考虑使用inception_v3的模型进行开发,但是该模的构造过程比较复杂,训练更是麻烦,因此考虑基于tensorflow_slim的模块进行二次训练,首先介绍一下关于tensorflow_slim的基本模块。

tensorflow_slim的模块主要包括以下几个部分deployment ,nets,dataset, preprocessing, scripts。其中scripts中主要介绍了如何使用各模型,相当于tensorflow_slim的使用字典。下面分别介绍剩下几个文件夹的作用。

1. dataset:

该文件夹主要存储了数据的读取方式,定义了数据读取的文件类型是tfrecord,文件名的格式是‘flower_%s_*.tfrecord’,文件的train部分数据的多少,文件的validation部分数据的大小(train:3200,validation:350),以及读取tfrecord格式数据的操作方式,下面表示tfrecord格式读取的结构。

keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
    'image/class/label': tf.FixedLenFeature(
        [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}

接着定义了一个decord用来解码TFrecords形式的数据,并且将decord, redaer(tf.TFRecordReader),num, label一起放入到dataset中供后续使用。

dataset_factory:用来统一管理dataset中的数据,根据传入的数据名调用对应的dataset中的脚本。

2. preprocessing: 对数据进行预处理,该处理主要包括一些旋转,裁剪,随机增减光强等操作,不一样的数据处理不一样的数据集,最后通过preprocessing_factory控制调用的数据预处理类型。

 在数据处理时要先给出数据的大小,因为使用tfrecord对数据进行读取时没有读取数据的大小,因此需要对读取出的数据采用rand_crop或者tf.image.resize_image的方式指定读取出来的数据大小,(个人建议使用tf.image.resize_image方式,因为采用rand_crop的方式有可能会出现裁剪后的照片比需要的小)

3. nets : 该文件夹里面包含了比较经典的网络结构,有inception, Alexnet, cifarnet,  mobilenet, vgg 等一些提前训练好的结构。其中每一个模型的结构以一个脚本的形式存在,在每个脚本中定义了默认的输入大小,并且返回了模型最终的输出和各个节点的名字对应的值。

同样在该文件夹下也存在一个net_factory, 该脚本的作用是根据传入的模型名字找到对应的模型脚本,通过传入的数据结构(data目录下的脚本生成的)读取对应的num_label确定最后输出的大小。最终返回一个调用已写好的网络结构的接口。

4. deployment 中主要是一些基于单机单GPU和单机多GPU的一些配置参数,由于本人目前使用的是单机单GPU所以没有进行过深入的了解。

5 .  train_image_classifier。 该脚本的主要作用是为了对整个模型进行训练,其中涉及的参数较多,在本篇博客中不一一介绍其作用,当介绍如何使用时会详细介绍各参数的作用。在该脚本中主要有以下几个模块:

一、各种预处理

数据读取:

dataset = dataset_factory.get_dataset(
    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

网络读取

network_fn = nets_factory.get_network_fn(
    FLAGS.model_name,
    num_classes=(dataset.num_classes - FLAGS.labels_offset),
    weight_decay=FLAGS.weight_decay,
    is_training=True)

读取数据的处理:

with tf.device(deploy_config.inputs_device()):
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset
    train_image_size = FLAGS.train_image_size or network_fn.default_image_size

在DatasetDateProvider内部调用了parallel_read方法,该方法主要通过string_input_producer,RandomShuffleQueue,FIFOQueue的方法采用队列的方式对tfrecorf类型的数据进行读取。

数据的预处理
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    preprocessing_name,
    is_training=True)  
image = image_preprocessing_fn(image, train_image_size, train_image_size)

最后对处理好的数据调用batch的方式分批处理,并通过one_hot_encoding的形式对标签进行编码,并调用slim.prefetch_queue.prefetch_queue的方法将数据处理成批队列。

二、利用网络和数据构建交叉熵(clone_fn): 该函数主要是通过处理好的数据调用预训练的网络得出最后的结果并构交叉熵。

三、 通过create_clones函数构建多个clones . clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]),在每个GPU上都进行训练。

四、 定义优化器

learning_rate根据参数自行调整。

五、 计算中的损失和梯度:

total_loss, clones_gradients = model_deploy.optimize_clones(),该方法中回传入需要计算的参数列表,同时
在该函数中调用了total_loss = tf.add_n(clones_losses, name='total_loss')表明其计算的是总损失函数。

六、 进行梯度更新:grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step), 该方法对梯度进行了更新。compute_graddients()和apply_gradients()的方法一起使用的效果相当于minimize()

七、 构建train_op并进行train:  通过train_tensor = tf.identity(total_loss, name='train_op')构建了需要的训练张量,通过  slim.learning.train进行了训练。

八、 在train方法中通过以下几个关键性的语句分别完成对数据的训练,存储和队列的开启

total_loss, should_stop = train_step_fn(sess, train_op, global_step, train_step_kwargs)
sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
threads = sv.start_queue_runners(sess)

九、 需要计算的参数列表:首先通过variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)得到模型中所有的参数,再去掉传入的不需要计算的参数,得到需要计算的参数。

十、 模型恢复部分:主要是 _get_init_fn()函数,调用ckpt文件然后通过saver.restore完成。

十一、模型转化,在slim文件下面存在一个export_interference_graph,该脚本的作用是将网络结构存储为.pb的形式,然后在通过free_graph的方法结合存储的模型参数将其保存为可以直接使用的.pb文件

下图为自己理解的流程图,供参考。

猜你喜欢

转载自blog.csdn.net/hh_2018/article/details/83474893