21个项目玩转深度学习:基于TensorFlow的实践详解03—打造自己的图像识别模型

书籍源码:https://github.com/hzy46/Deep-Learning-21-Examples

CNN的发展已经很多了,ImageNet引发的一系列方法,LeNet,GoogLeNet,VGGNet,ResNet每个方法都有很多版本的衍生,tensorflow中带有封装好各方法和网络的函数,只要喂食自己的训练集就可以完成自己的模型,感觉超方便!!!激动!!!因为虽然原理流程了解了,但是要写出来真的。。。。好难,臣妾做不到啊~~~~~~~~

START~~~~

1.数据准备

首先了解下微调的概念: 以VGG为例

他的结构是卷积+全连接,卷积层分为5个部分共13层,conv1~conv5。还有三层全连接,即fc6,fc7,fc8。总共16层,因此被称为VGG16。

a.如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8,因为fc8原本的输出是1000类的概率。需要改为符合自身训练集的输出类别数。

b.训练的时候,网络的参数的初始值并不是随机化生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。因为已经训练过的VGG16中的参数已经包含了大量有用的卷积过滤器,这样做不仅节约大量训练时间,而且有助于分类器性能的提高。

载入VGG16的参数后,即可开始训练。此时需要指定训练层数的范围。一般而言,可以选择以下几种:

  • 只训练fc8:训练范围一定要包含fc8这一层。这样的选择一般性能都不会太好,但速度很快;因为他只训练fc8,保持其他层的参数不动,相当于把VGG16当成一个“特征提取器”,用fc7层提取的特征做一个softmax的模型分类。
  • 训练所有参数:耗时较慢,但能取得较高性能。
  • 训练部分参数:通常是固定浅层参数不变,训练深层参数。如固定conv1、conv2部分的参数不训练,只训练conv3、conv4、conv5、fc6、fc7、fc8的参数。

这种训练方法就是对神经网络做微调。

1.1 切分train&test

书中提供了卫星图像数据集,有6个类别,分别是森林(wood),水域(water),岩石(rock),农田(wetland),冰川(glacier),城市区域(urban)

保存结构为data_prepare/pic,再下层有两个文件夹train和validation,各文件夹下有6个文件夹,放的是该类别下的图片。

1.2 转换成tfrecord格式

python data_convert.py -t pic/ \
  --train-shards 2 \
  --validation-shards 2 \
  --num-threads 2 \
  --dataset-name satellite

参数解释:

-t pic/ :表示转换pic文件夹下的数据,该文件夹必须与上面的文件结构保持一致

--train-shards 2 :把训练集分成两块,即最后的训练数据就是两个tfrecord格式的文件。若数据集更大,可以分更多数据块

--validation-shards 2 :把验证集分成两块

--num-thread 2 :用两个线程来产生数据。注意线程数必须要能整除train-shards和validation-shards,来保证每个线程处理的数据块是相同的。

--dataset-name :给生成的数据集起个名字,即表示最后生成文件的开头是satellite_train和satellite_validation

运行上述命令后,就可以在 pic 文件夹中找到 5 个新生成的文件 ,分别是:

  • 训练数据 satellite_train_00000-of-00002.tfrecord、satellite_train_00001-of-00002.tfrecord,
  • 验证数据 satellite_validation_00000-of-00002.tfrecord、satellite_validation_00001-of-00002.tfrecord。
  • label.txt 它表示图片的内部标签(数字)到真实类别(字符串)之间的映射顺序 。 如图片在 tfrecord 中的标签为 0 ,那么就对应 label.txt 第一行的类别,在 tfrecord的标签为1,就对应 label.txt 中第二行的类别,依此类推。

2.训练模型

2.1 TensorFlow Slim

Google 公司公布的一个图像分类工具包,它不仅定义了一些方便的接口,还提供了很多 ImageNet 数据集上常用的网络结构和预训练模型

截至2017年7月,Slim 提供包括 VGG16VGG19InceptioV1 ~ V4、ResNet 50、ResNet 101、MobileNet 在内大多数常用模型的结构及预训练模型,更多的模型还会被持续添加进来

源码地址: https://github.com/tensorflow/models/tree/master/research/slim

可以通过执行  git clone https://github.corn/tensorflow/models.git  来获取

2.2 定义新的datasets文件<修改slim源码>

2.3 准备训练文件夹

2.4 开始训练

3.验证准确率

4.导出模型并对单张图片分类

THE END~~~~

猜你喜欢

转载自www.cnblogs.com/helloworld0604/p/9936520.html