ImageNet打造自己的图像识别

一、原理
      在自己的数据集上训练一个新的深度学习模型时,一般采取在预训练ImageNet上进行微调的方法。什么是微调?这里以VGG16为例进行讲解。
      VGG16网络结构:http://ethereon.github.io/netscope/#/preset/vgg-16
如下图:
在这里插入图片描述

在这里插入图片描述

      VGG16的结构为卷积+全连接层。卷积层分为5个部分共13层,即conv1~conv5。还有三层全连接层,即fc6、fc7、fc8。卷积层加上全连接层合起来一共为16层。如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8这一层。原因是fc8层的输入是fc7层的特征,输出是1000类的概率,这1000类正好对应了ImageNet模型中的1000个类别。在自己的数据中,类别数一般不是1000类,因此fc8层的结构在此时是不适用的。必须将fc8层去掉,重新采用符合数据集类别数的全连接层,作为新的fc8.比如数据集为5类,那么新的fc8的输出也应当是5类。

      此外,在训练的时候,网络的参数的初始值并不是随机化生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。这样做的原因在于,在ImageNet数据集上训练过的VGG16的参数已经包含了大量有用的卷积过滤器,与其从零开始初始化VGG16的所有参数,不如使用已经训练好的参数当作训练的起点。这样做不仅可以节约大量训练时间,而且有助于分类起性能的提高。

      载入VGG16的参数后,就可以开始训练了。此时需要指定训练层数的范围。一般来说,可以选择以下几种范围进行训练:

  • 只训练fc8.训练范围一定要包含fc8这一层。之前说过,fc8的结构被调整过,因此它的参数不能直接从ImageNet预训练模型中取得。可以只训练fc8,保持其他层的参数不动。这就相当于将VGG16当作一个特征提取器,用fc7层提取的特征做一个softmax模型分类。这样做的好处是训练速度块,但往往性能不会太好。
  • 训练所有参数。还可以对网络中的所有参数进行训练,这种方法的训练速度可能比较慢,但是能取得较高的性能,可以充分发挥深度模型的威力。
  • 训练部分参数。通常是固定浅层参数不变,训练深层参数。如固定conv1、conv2部分的参数不训练,只训练conv3、conv4、conv5、fc6、fc7、fc8的参数。

二、使用Tensorflow Slim微调模型
      slim是google公司公布的一个图像分类工具包,不仅定义了一些方便的接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型。包括VGG16\VGG19、Inception v1~v4、ResNet 50、ResNet101、MobileNet在内大多数常用模型的结构以及预训练模型,更多的模型会被持续添加进来。

  1. 下载Tensorflow Slim的源代码

      git clone https://github.com/tensorflow/models.git

      找到models/research/slim文件夹。

  1. 数据准备
    将jpg格式样本集合转化为tfrecord格式。
    首先做数据准备的工作,一是将数据集切分为训练集和验证集,二是转换为tfrecord格式。建立data_prepare目录,建立目录结构如下:
    在这里插入图片描述

      在data_prepare目录下运行脚本:

python  data_convert.py -t pic/ --train-shards 2 --validation-shards 2 --num-threads 2 --dataset-name satellite

      其中dataset-name为给数据集起的名字。

      运行上述命令后,pic目录生成如下5个文件
在这里插入图片描述
      tfrecord文件就是对应的训练集和验证集,另外还有label.txt,为类别映射关系。

  1. 定义新的dataset
    在slim/dataset中,定义所有可用的数据库,前面定义的新的satellite数据集,在这里也要定义对应的dataset。
    新建 satellite.py文件,把 flowers.py复制到其中。如下
_FILE_PATTERN='satellite_%s_*.tfrecord'//改成自己的图片的命名

SPLITS_TO_SIZES={‘train’:4800,'validation':1200}//训练集和测试集的总数目

_NUM_CLASSES=6  //类别数目

修改图片的默认格式

keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

修改完satellite.py文件后,还要在dataset_factory.py注册satellite数据库

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
	'satellite':satellite,
}
  1. 准备训练文件夹
    slim下新建satellite目录,完成以下准备工作:
  • data目录,把之前生成好的5个文件复制进来
  • 新建一个空的train_dir目录,用来保存训练过程中的日志和模型。
  • 新建一个pretrained目录,在slim的GitHubi页面找到Inception-V3模型的下载地址http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz,下载并解压后,得到 inception_v3.ckpt文件,将该文件复制到pretrained目录下。

在这里插入图片描述

  1. 开始训练
python train_image_classifier.py --train_dir=satellite\train_dir --dataset_name=satellite --dataset_split_name=train --dataset_dir=satellite\data --model_name=inception_v3 --checkpoint_path=satellite\pretrained\inception_v3.ckpt --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --max_number_of_steps=100000 --batch_size=32 --learning_rate=0.001 --learning_rate_decay_type=fixed --save_interval_secs=120 --save_summaries_secs=20 --log_every_n_steps=10 --optimizer=rmsprop --weight_decay=0.00004   

      trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits。trainable_scopes规定了在模型中微调变量的范围。这里的设定表示只对 InceptionV3/Logits, InceptionV3/AuxLogits两个变量进行微调,其他变量都保持不动。
      InceptionV3/Logits, InceptionV3/AuxLogits是inception V3的末端层。只对最后一层分类层进行训练,比如原来是1000类,现在训练的只是2类。如果不设定trainable_scopes,就只会对模型中所有的参数进行训练。

  1. 验证模型
python eval_image_classifier.py --checkpoint_path=satellite/train_dir --eval_dir=satellite/eval_dir --dataset_name=satellite --dataset_split_name=validation --dataset_dir=satellite/data --model_name=inception_v3

最后显示:eval/Recall_5[0.979166687]eval/Accuracy[0.561666667]
其中Accuracy为分类准确率,Recall_5表示Top 5的准确率,即在输出的类别概率中,正确的类别只有落在前5就算对。

  1. Tensorboard可视化
tensorboard –logdir satellite/train_dir

猜你喜欢

转载自blog.csdn.net/mozf881/article/details/83267768