用tensorflow中slim下的分类网络训练自己的数据集以及fine-tuning(可以直接实战使用)

目录

前期准备

训练flower数据集(包括fine-tuning)

训练自己的数据集(包括fine-tuning)

前期准备

前期了解

tensorflow models

在tensorflow models中有官方维护和非官方维护的models,official models就是官方维护的models,里面使用的接口都是一些官方的接口,比如tf.layers.conv2d之类。而research models是tensorflow的研究人员自己实现的一些流行网络,不受官方支持,里面会用到一些slim之类的非官方接口。但是因为research models实现的网络非常多,而且提供了完整的训练和评价方案,所以我们现在基于research models中的实现来部署网络。

环境配置

首先要保证tf.contrib.slim在你的tensorflow环境中是存在的,运行下面的脚本保证没有错误发生。

python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"

base代码准备

TF的库里面没有TF-slim的内容,所以我们需要将代码clone到本地

cd $HOME/workspace
git clone https://github.com/tensorflow/models/

运行以下脚本确定是否可用

cd $HOME/workspace/models/research/slim
python -c "from nets import cifarnet; mynet = cifarnet.cifarnet"

其实我们只需要使用research中的slim的代码,所以我是直接拷贝了slim的代码到本地,基于slim代码进行修改。

训练flower数据集

下载数据并创建tfrecord

官网提供了下载并且转换数据集的方法,运行如下脚本即可,脚本会直接下载flower数据集并且存储为TFRecord的格式。

$ python download_and_convert_data.py \
    --dataset_name=flowers \
    --dataset_dir=./tmp/data/flowers

为何官网要使用TFRecord呢?因为TFRecord和tensorflow内部有一个加速机制。实际读取tfrecord数据时,先以相应的tfrecord文件为参数,创建一个输入队列,这个队列有一定的容量,在一部分数据出队列时,tfrecord中的其他数据就可以通过预取进入队列,这个过程和网络的计算是独立进行的。也就是说,网络每一个iteration的训练不必等待数据队列准备好再开始,队列中的数据始终是充足的,而往队列中填充数据时,也可以使用多线程加速。

下载pre-trained checkpoint

每个网络对应的checkpoint可以从官网上找到,官网也提供了下载inception v3的checkpoint的例子

$ mkdir ./tmp/checkpoints
$ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
$ tar -xvf inception_v3_2016_08_28.tar.gz
$ mv inception_v3.ckpt ./tmp/checkpoints
$ rm inception_v3_2016_08_28.tar.gz

从头开始训练

官网上提供了从头开始训练的例子,我根据我本地训练flowers数据集的存储位置而对脚本稍做修改

python train_image_classifier.py --train_dir=./tmp/train_logs \
--dataset_name=flowers --dataset_split_name=train \
--dataset_dir=./tmp/flowers --model_name=inception_v3

训练过程会打印出loss值

......
INFO:tensorflow:global step 10: loss = 3.3827 (0.384 sec/step)
INFO:tensorflow:global step 20: loss = 2.9981 (0.389 sec/step)
INFO:tensorflow:global step 30: loss = 3.8143 (0.392 sec/step)
INFO:tensorflow:global step 40: loss = 3.3529 (0.385 sec/step)
INFO:tensorflow:global step 50: loss = 3.1890 (0.388 sec/step)
INFO:tensorflow:global step 60: loss = 2.2893 (0.389 sec/step)
INFO:tensorflow:global step 70: loss = 2.5434 (0.386 sec/step)
INFO:tensorflow:global step 80: loss = 3.1224 (0.386 sec/step)
INFO:tensorflow:global step 90: loss = 3.4845 (0.387 sec/step)
INFO:tensorflow:global step 100: loss = 2.2984 (0.391 sec/step)
INFO:tensorflow:global step 110: loss = 2.5087 (0.387 sec/step)
INFO:tensorflow:global step 120: loss = 2.8148 (0.391 sec/step)
INFO:tensorflow:global step 130: loss = 2.4258 (0.390 sec/step)
INFO:tensorflow:global step 140: loss = 2.9289 (0.391 sec/step)
INFO:tensorflow:global step 150: loss = 2.5775 (0.391 sec/step)
INFO:tensorflow:global step 160: loss = 2.5603 (0.390 sec/step)
INFO:tensorflow:global step 170: loss = 2.8423 (0.392 sec/step)
INFO:tensorflow:global step 180: loss = 2.3163 (0.388 sec/step)
......

tensorboard

打开tensorboard,tensorboard --logdir=./tmp/train_logs

训练1000多次查看tensorboard的效果

Fine-tuning

--checkpoint_path:指定checkpoint文件的路径。

--checkpoint_exclude_scopes:当pre-trained checkpoint对应的网络最后一层分类的类别数量和现在数据集的类别数量不匹配时使用,可以指定checkpoint restore时哪些层的参数不恢复。

--trainable_scopes:如果只希望某些层参与训练,其他层的参数固定时,就使用这个flag,在这个flag中指定需要训练的参数。

python train_image_classifier.py \
    --train_dir=./tmp/train_logs \
    --dataset_dir=./tmp/flowers \
    --dataset_name=flowers \
    --dataset_split_name=train \
    --model_name=inception_v3 \
    --checkpoint_path=./tmp/inception_v3_checkpoints/inception_v3.ckpt \
    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
    --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits

使用pre-trained checkpoint,loss会很快下降到一个比较小的值。

评价

使用上一个步骤训练出来的checkpoint进行评估。注意一个差别,fine-tuning时--checkpoint_path需要指定到具体文件,但是评估的时候--checkpoint_path只需要指定到文件夹路径即可,代码会根据文件夹下的内容自动选定以最新的checkpoint来进行评估。

python eval_image_classifier.py \
    --alsologtostderr \
    --checkpoint_path=./tmp/train_logs \
    --dataset_dir=./tmp/flowers \
    --dataset_name=flowers \
    --dataset_split_name=validation \
    --model_name=inception_v3

结果如下:

INFO:tensorflow:Restoring parameters from ./tmp/train_logs/model.ckpt-0
INFO:tensorflow:Evaluation [1/4]
INFO:tensorflow:Evaluation [2/4]
INFO:tensorflow:Evaluation [3/4]
INFO:tensorflow:Evaluation [4/4]
2018-08-23 18:07:13.349935: I tensorflow/core/kernels/logging_ops.cc:79] eval/Recall_5[1]
2018-08-23 18:07:13.350030: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.1675]
INFO:tensorflow:Finished evaluation at 2018-08-23-10:07:13

保存模型

python export_inference_graph.py \
  --alsologtostderr \
  --model_name=inception_v3 \
  --output_file=./tmp/inception_v3_inf_graph.pb

可以将模型导出,后续可以直接load这个模型来使用

小结

现有数据的训练方式就介绍完了,基本上脚本都可以解决,所以要训练自己的数据集就需要模仿这些代码的实现。

训练自己的数据

创建自己的数据集

首先要准备自己的数据集,保证相同类别的图片放在同一个文件夹下,文件夹的名字就是这个类别的名称。注意,图片数据最好备份一份,因为执行完后图片数据会全部被删除,只保留生成的tfrecord文件,除非修改代码删除这个步骤

接着需要仿照download_and_convert_flowers.py中对flowers数据转tfrecord的处理,来实现对自己的数据转tfrecord的处理。

主要以下几个改动点:

1.创建convert_mydata.py文件,等同于download_and_convert_flowers.py,因为我们自己的数据不用下载,所以文件命名为convert_flowers.py

2.在download_and_convert_data.py中添加处理,这样运行download_and_convert_data.py时,传入mydata数据集就可以走到convert_mydata.py里的run函数。

  # add by stesha
  elif FLAGS.dataset_name == 'mydata':
    convert_mydata.run(FLAGS.dataset_dir)

3.convert_mydata.py的实现基本和download_and_convert_flowers.py类似,只是去掉里面关于download部分的代码,比如run函数中去掉dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)

4._NUM_SHARDS表示有多少类别,_NUM_VALIDATION表示用多少张图片作为validation,根据实际情况填写即可。

5.如果不希望自己的图片数据在执行完后被删掉,可以去掉run中_clean_up_temporary_files(dataset_dir)代码。

参考代码:convert_mydata.py

代码实现后运行下面的脚本就可以将数据转换成tfrecord格式了。

python download_and_convert_data.py \
    --dataset_name=mydata \
    --dataset_dir=./data/mydata

下载pre-trained checkpoint

训练自己的数据我打算用准确率相对比较高的inception v4,所以我们需要下载inception v4的checkpoint。

下载地址:http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz

下载完成后解压放入一个文件夹中,比如我放入了./data/checkpoint中

从头开始训练

训练时数据需要从tfrecord中读取出来,所以代码需要稍作改动

1.在dataset_factory.py中增加mydata数据集

2.创建mydata.py,参考flowers.py的实现,基本需要改动的只有SPLITS_TO_SIZES和_NUM_CLASSES。前者只需要将测试集和训练集的大小写入,后者分类的数量。参考:mydata.py

准备好后,只需要运行下面的脚本就可以开始训练了,新的checkpoint文件会存放在指定的train_dir中。

python train_image_classifier.py --train_dir=./data/train_logs \
--dataset_name=mydata --dataset_split_name=train \
--dataset_dir=./data/mydata --model_name=inception_v4

fine-tuning

python train_image_classifier.py \
    --train_dir=./data/train_logs \
    --dataset_dir=./data/mydata \
    --dataset_name=mydata \
    --dataset_split_name=train \
    --model_name=inception_v4 \
    --checkpoint_path=./data/checkpoint/inception_v4.ckpt \
    --checkpoint_exclude_scopes=InceptionV4/Logits,InceptionV4/AuxLogits \
    --trainable_scopes=InceptionV4/Logits,InceptionV4/AuxLogits

评估

python eval_image_classifier.py \
    --alsologtostderr \
    --checkpoint_path=./data/train_logs \
    --dataset_dir=./data/mydata \
    --dataset_name=mydata \
    --dataset_split_name=validation \
    --model_name=inception_v4

预测

tf-slim中并没有提供predict某张图片的脚本,我这边简单实现了一下,可以作为参考。predict.py

python predict.py --model_name=inception_v4 \
		  --predict_file=./backup/mydata/km1_back/km1_back.jpg \
                  --checkpoint_path=./data/train_logs

结语

使用tensorflow的slim model来训练自己的数据集还是很简单的,基本上要改动的代码不多,这样能够方便我们很快的实施自己的想法,而且基于已经训练好的checkpoint来fine-tuning很快也能得到不错的精确度,使神经网络的部署更加方便快捷。

猜你喜欢

转载自blog.csdn.net/stesha_chen/article/details/81976415