AI model image is used to identify the table, cats, dogs or other

AI model image is used to identify the table, cats, dogs or other

# -*- coding: utf-8 -*-
"""
Spyder Editor

This is a temporary script file.
"""

import sys
nets_path=r'slim'
if nets_path not in sys.path:
    sys.path.insert(0,nets_path)
else:
    print('already add slim')

import tensorflow as tf
form PIL import Image
import matplotlib.pyplot as plt
from nets.nasnet import pnasnet
import numpy as np
from datasets import imagenet
slim=tf.contrib.slim

tf.reset_default_graph()
'''获得图片的尺寸'''
image_size=pnasnet.build_pnasnet_large.default_iamge_size
labels=imagenet.create_readable_name_for_imagenet_labels()
print(len(labels),labels)

def getone(onestr):
    return onestr.replace(',','')
with open('中文标签.csv‘,'r+') as f:
    labels=list(map(getone,list(f)))
    print(len(labels),type(labels),labels(:5))
    
'''用A1模型识别图像'''
sampe_images=['hy.jpg','ps.jpg','72.jpg']
input_imgs=tf.placeholder(tf.float32,[None,image_size,image_size,3])

x1=2*(input_imgs/255.0)-1.0
arg_scope=pnasnet.pnasnet_large_arg_scope()
with slim.arg_scope(arg_scope):
    logits,end_points=pnasnet.build_pnasnet_large(x1,num_classes=1001,is_training=False)
    prob=end_points['Predictions]
    y=tf.argmax(prob,axis=1)


checkpoint_file=r'pnasnet-5_large_2017_12_13/model.ckpt'

saver=tf.train.Saver  #定义saver,用于加载模型
with tf.Session() as sess:
    saver.restore(sess,checkpoint_file)
    
    def preimg(img):
        ch=1
        if img.mode=='RGB':
            ch=4
        
        imgnp=np.asarray(img.resize((image_size,image_size),dtype=np.float32).reshape(image_size,image_size,ch)
        return imgnp[:,:,3]
        
    batchImg=[preing(Image.open(imgfilename)) for imgfilename in sampe_images]
    
    orgImg=[Image.open(imgfilename) for imgfilename in sampe_images]
    
    yv,img_norm=sess.run([y,x1],feed_dict={input_imgs:batchimg})
    
    print(yv,np.shape(yv))
    
    def showresult(yv,img_norm,img_org):
        plt.figure()
        p1=plt.subplot(121)
        p2=plt.subplot(122)
        p1.imshow(img_org)
        p1.axis('off')
        
        p1.set_title('organization image')
        p2.imshow((img_norm*255).astype(np.uint8)
        p2.axis('off')
        p2.set_title("input image")
        
        plt.show()
        print(yy,labels(yy))
    
    for yy,img1,img2 in zip(yv,batchImg,orgImg):
        showresult(yy,img1,img2)
Published 41 original articles · won praise 0 · Views 779

Guess you like

Origin blog.csdn.net/qestion_yz_10086/article/details/104834044