使用 Inception-v3 模型训练自己的图片分类

整体步骤:

步骤一:数据准备,准备自己要分类的图片训练样本。

                可以去 http://www.robots.ox.ac.uk/~vgg/data/ 下载数据集,

                下载之后保存到指定目录(我的文件夹路径及结构如下):

                 E:\TensorFlow\retrain\data\train\

                                                    -->>animal

                                                    -->>flower

                                                    -->>car

步骤二:下载 retrain.py 程序:

             去  https://github.com/tensorflow/hub 下载,

            找到 retrain.py 文件放到 E:\TensorFlow\retrain\

步骤三:下载inception-v3 模型

               http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz

                把压缩包存放到  E:\TensorFlow\inception_model  文件夹下,会自动寻找该压缩包。

步骤四:批处理命令文件 retrain.bat ,内容如下,把文件路径按自己的路径填写:

python E:/TensorFlow/retrain/retrain.py ^
--bottleneck_dir bottleneck ^
--how_many_training_steps 200 ^
--model_dir E:/Tensorflow/inception_model/ ^
--output_graph output_graph.pb ^
--output_labels output_labels.txt ^
--image_dir E:\TensorFlow\retrain\data\train
pause
    

    注释:

python E:/TensorFlow/retrain/retrain.py ^       #retrain.py 文件的路径 
--bottleneck_dir bottleneck ^                   #bottleneck 文件夹的路径 ,默认和 retrain.py 同一个文件夹
--how_many_training_steps 200 ^                 #迭代 200 次
--model_dir E:/Tensorflow/inception_model/ ^    #inception-v3 模型的压缩包路径
--output_graph output_graph.pb ^                #输出的模型文件名
--output_labels output_labels.txt ^             #输出的标签
--image_dir E:\TensorFlow\retrain\data\train    #自己的训练数据集存放路径
pause

    在  E:\TensorFlow\retrain\  下新建一个名叫 bottleneck 的文件夹,用于存放批处理之后各个图片的.txt文件。

    也就是图片输入到 inception-v3 模型之后,经过倒数第二层的输出值(运行retrain.bat 文件之后产生)。

步骤五:预测 prediction.py 程序,用于调用新生成的模型预测新数据的结果。

# coding: utf-8
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt


lines = tf.gfile.GFile('retrain/output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
    #去掉换行符
    line=line.strip('\n')
    uid_to_human[uid] = line

def id_to_string(node_id):
    if node_id not in uid_to_human:
        return ''
    return uid_to_human[node_id]


#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    #遍历目录
    for root,dirs,files in os.walk('retrain/images/'):  #测试图片存放位置
        for file in files:
            #载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
            predictions = np.squeeze(predictions)#把结果转为1维数据

            #打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)
            #显示图片
            img=Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            #排序
            top_k = predictions.argsort()[::-1]
            print(top_k)
            for node_id in top_k:     
                #获取分类名称
                human_string = id_to_string(node_id)
                #获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()




猜你喜欢

转载自blog.csdn.net/weixin_38663832/article/details/80555341