使用inception_model模型分类图片

1.下载源码包,并准备图片

将TF1.5源码包和inception_model 模型下载下来
在码云上将TF1.5的源码包下载下来
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
图片样本参考网址–https://www.robots.ox.ac.uk/~vgg/data/
图片至少两个类别且每个种类最少200张,放在目录data/train/

2.执行retrain.bat 训练模型

现将所在盘符下的tmp文件夹中的内容删除,防止训练报错

python D:/AI/tensorflow-v1.5.0/tensorflow-v1.5.0/tensorflow/examples/image_retraining/retrain.py ^
--bottleneck_dir bottleneck ^# 图片转换为向量地址
--how_many_training_steps 2 ^ #训练的周期
--model_dir C:/Users/Administrator/Tensorflowcx/inception_model/ ^ #模型加载路径
--output_graph output_graph.pb ^ #输出训练好的模型
--output_labels output_labels.txt ^ #将训练好的标签参数输出
--image_dir data/train/ # 你要分类的图片
pause

保存为retrain.bat

3.测试训练好的模型

在jupyter Notebook里新建python文件并执行

#测试训练好的模型
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt


# In[2]:

lines = tf.gfile.GFile('retrain/output_labels.txt').readlines()
uid_to_human = {
    
    }
#一行一行读取数据
for uid,line in enumerate(lines) :
    #去掉换行符
    line=line.strip('\n')
    uid_to_human[uid] = line

def id_to_string(node_id):
    if node_id not in uid_to_human:
        return ''
    return uid_to_human[node_id]


#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    #遍历目录
    for root,dirs,files in os.walk('retrain/data/train/1/'):
        for file in files:
            #载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
            predictions = sess.run(softmax_tensor,{
    
    'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
            predictions = np.squeeze(predictions)#把结果转为1维数据

            #打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)
            #显示图片
            img=Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            #排序
            top_k = predictions.argsort()[::-1]
            print(top_k)
            for node_id in top_k:     
                #获取分类名称
                human_string = id_to_string(node_id)
                #获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()

Guess you like

Origin blog.csdn.net/chushudu/article/details/119955913