前言
计算机视觉系列之学习笔记主要是本人进行学习人工智能(计算机视觉方向)的代码整理。本系列所有代码是用python3编写,在平台Anaconda中运行实现,在使用代码时,默认你已经安装相关的python库,这方面不做多余的说明。本系列所涉及的所有代码和资料可在我的github上下载到,gitbub地址:https://github.com/mcyJacky/DeepLearning-CV,如有问题,欢迎指出。
一、图像数据准备
①、在/data文件夹下,准备5种不同类型图片的文件夹,每个文件夹下大约有500张该类型的图片用于训练,5种文件夹如下图1.1所示:
②、与上一篇一样,准备inception_model文件,其中包含下载的inception-v3模型训练的.pb文件和标签的映射文件。
③、准备一个空的bottleneck文件夹,用于存放图片的训练数据。
④、准备测试图片,放于/images文件夹下。
准备文件的目录结构如下图1.2所示:
二、图像重新训练(迁移学习)
下面我们为我们准备的图像类型进行重新训练,我们在训练的过程中会执行retrain.py文件(具体程序的内容,这边不做解析,具体内容去我开头介绍的github地址上去下载),为了训练方便,我们建立批处理文件retrain.bat,其中retrain.bat中内容为retrain.py要传入的参数,它的具体格式如下图2.1所示(自己可以进行相应的更改):
如图2.1所示,bottleneck_dir参数为bottleneck文件夹路径,how_many_training_steps为训练的迭代步长,model_dir为inception-v3模型路径,output_graph为训练好后文件输出路径,output_labels为训练好后标签输出路径,image_dir为训练图片的路径。
配置后后,我们就可以双击执行retrain.bat批处理文件,执行完成后会生成相应的pb文件和标签文件,此时迁移学习就训练完了。
三、新训练好的图像识别模型做预测
通过以上训练完成的.pb文件和标签文件,我们就可以对相应的新的图像做预测,具体实现如下:
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
# 读取标签文件
lines = tf.gfile.GFile('output_labels.txt').readlines()
uid_to_human = {} #{0: 'animal', 1: 'house', 2: 'plane', 3: 'flower', 4: 'guitar'}
# 一行一行读取数据
for uid, line in enumerate(lines) :
#去掉换行符
line=line.strip('\n')
uid_to_human[uid] = line
# 分类编号变成描述
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
# 创建一个图来存放训练好的模型
with tf.gfile.GFile('output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
# final_result为输出tensor的名字,具体在retrain.py中定义
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
# 遍历目录
for root,dirs,files in os.walk('images/'):
for file in files:
# 载入图片
image_data = tf.gfile.GFile(os.path.join(root,file), 'rb').read()
# 把图像数据传入模型获得模型输出结果
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
# 把结果转为1维数据
predictions = np.squeeze(predictions)
# 打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
# 显示图片
img=Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
# 排序
top_k = predictions.argsort()[::-1]
for node_id in top_k:
# 获取分类名称
human_string = id_to_string(node_id)
# 获取该分类的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()
部分图像识别结果如下图3.1所示:
由图3.1识别结果可知,用该方法进行迁移学习后预测不同类型图片的有效,对不同种图片的预测准确率也会比较高,因为不同类型图片的区分比较明显。
【参考】:
1. 城市数据团课程《AI工程师》计算机视觉方向
2. deeplearning.ai 吴恩达《深度学习工程师》
3. 《机器学习》作者:周志华
4. 《深度学习》作者:Ian Goodfellow
转载声明:
版权声明:非商用自由转载-保持署名-注明出处
署名 :mcyJacky
文章出处:https://blog.csdn.net/mcyJacky