整体步骤:
步骤一:数据准备,准备自己要分类的图片训练样本。
可以去 http://www.robots.ox.ac.uk/~vgg/data/ 下载数据集,
下载之后保存到指定目录(我的文件夹路径及结构如下):
E:\TensorFlow\retrain\data\train\
-->>animal
-->>flower
-->>car
步骤二:下载 retrain.py 程序:
去 https://github.com/tensorflow/hub 下载,
找到 retrain.py 文件放到 E:\TensorFlow\retrain\ 下
步骤三:下载inception-v3 模型
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
把压缩包存放到 E:\TensorFlow\inception_model 文件夹下,会自动寻找该压缩包。
步骤四:批处理命令文件 retrain.bat ,内容如下,把文件路径按自己的路径填写:
python E:/TensorFlow/retrain/retrain.py ^ --bottleneck_dir bottleneck ^ --how_many_training_steps 200 ^ --model_dir E:/Tensorflow/inception_model/ ^ --output_graph output_graph.pb ^ --output_labels output_labels.txt ^ --image_dir E:\TensorFlow\retrain\data\train pause
注释:
python E:/TensorFlow/retrain/retrain.py ^ #retrain.py 文件的路径 --bottleneck_dir bottleneck ^ #bottleneck 文件夹的路径 ,默认和 retrain.py 同一个文件夹 --how_many_training_steps 200 ^ #迭代 200 次 --model_dir E:/Tensorflow/inception_model/ ^ #inception-v3 模型的压缩包路径 --output_graph output_graph.pb ^ #输出的模型文件名
--output_labels output_labels.txt ^ #输出的标签
--image_dir E:\TensorFlow\retrain\data\train #自己的训练数据集存放路径
pause
在 E:\TensorFlow\retrain\ 下新建一个名叫 bottleneck 的文件夹,用于存放批处理之后各个图片的.txt文件。
也就是图片输入到 inception-v3 模型之后,经过倒数第二层的输出值(运行retrain.bat 文件之后产生)。
步骤五:预测 prediction.py 程序,用于调用新生成的模型预测新数据的结果。
# coding: utf-8 import tensorflow as tf import os import numpy as np import re from PIL import Image import matplotlib.pyplot as plt lines = tf.gfile.GFile('retrain/output_labels.txt').readlines() uid_to_human = {} #一行一行读取数据 for uid,line in enumerate(lines) : #去掉换行符 line=line.strip('\n') uid_to_human[uid] = line def id_to_string(node_id): if node_id not in uid_to_human: return '' return uid_to_human[node_id] #创建一个图来存放google训练好的模型 with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') with tf.Session() as sess: softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') #遍历目录 for root,dirs,files in os.walk('retrain/images/'): #测试图片存放位置 for file in files: #载入图片 image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read() predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式 predictions = np.squeeze(predictions)#把结果转为1维数据 #打印图片路径及名称 image_path = os.path.join(root,file) print(image_path) #显示图片 img=Image.open(image_path) plt.imshow(img) plt.axis('off') plt.show() #排序 top_k = predictions.argsort()[::-1] print(top_k) for node_id in top_k: #获取分类名称 human_string = id_to_string(node_id) #获取该分类的置信度 score = predictions[node_id] print('%s (score = %.5f)' % (human_string, score)) print()