Python 应用caffe模型进行分类(caffe接口)

遍历一个文件下的所有图片,进行单张预测,并复制到相应的文件夹


import caffe
#import lmdb
import numpy as np
import cv2
from caffe.proto import caffe_pb2
import os
import sys

caffe.set_mode_gpu()


def dirlist(path, allfile):
    filelist = os.listdir(path)

    for filename in filelist:
        filepath = os.path.join(path, filename)
        if os.path.isdir(filepath):
            dirlist(filepath, allfile)
        else:
            allfile.append(filepath)
    return allfile

# sys.setrecursionlimit(1000000)


def is_bgr_img(img):
    bools = True
    try:
        a, b, c = img.shape
    except AttributeError:
        bools = False
    return bools


# load caffe
root = 'D:/stomach_raw_data/deepid/'  # 根目录
deploy = root + 'deploy_all.prototxt'  # deploy文件
caffe_model = root + 'id_128_net_iter_1695000.caffemodel'  # 训练好的 caffemodel
labels_filename = root + 'labels.txt'  # 类别名称文件,将数字标签转换回类别名称
  # 加载model和network
net = caffe.Net(deploy, caffe_model, caffe.TEST)
 # 设定图片的shape格式(1,3,28,28)依次为数量,通道,高,宽
transformer = caffe.io.Transformer({'data': net.blobs['data_1'].data.shape}) 
 # 改变颜色通道,由RGB转成BGR
transformer.set_transpose('data', (2, 0, 1)) 
 #减去均值,前面训练模型时没有减均值,这儿就不用
# transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))   
# transformer.set_raw_scale('data', 255)    # 缩放到【0,255】之间
# transformer.set_channel_swap('data', (2,1,0))   #交换通道,将图片由RGB变为BGR
labels = np.loadtxt(labels_filename, str, delimiter='\t')
dirs = ['0_CA', '1_FV', '2_GB', '3_GA', '4_SV', '5_PY', '6_OT','7_IV']


imgnames = dirlist('D:\\2D', [])
path ='D:/sto_img_1695000/'

temp = imgnames[0]
print(temp.split('\\')[-2].split('_')[0])
print(temp)
t = 0
all = 0
acc = 0
a_pro = 1
for imgname in imgnames:
    image = cv2.imread(imgname)
    temp = imgname     
    try:
            image.shape
    except AttributeError:
            print(imgname)
            os.remove(imgname)
            continue

    # imgx = image/255
    net.blobs['data_1'].data[...] = transformer.preprocess('data', image)
    t1 = cv2.getTickCount()
    for i in range(1):
        out = net.forward()
    t += (cv2.getTickCount() - t1) * 1000 / cv2.getTickFrequency()

    prob = net.blobs['softmax'].data[0].flatten()
    #print(prob)
    order = prob.argsort()[-1]
    prob_max = prob[order]
    print('max = %f,class = %d,all = %d\n'%(prob_max,order,all))
    
    if prob_max > 0.70:
        imgname = temp.split('\\')[-1]        
        imgpath = path + dirs[order]
        if not os.path.exists(imgpath):
            os.mkdir(imgpath)
        cv2.imwrite(imgpath+'/'+imgname, image)
    else:
        imgname = temp.split('\\')[-1]
        imgpath = path + 'unkown'
        if not os.path.exists(imgpath):
            os.mkdir(imgpath)
        cv2.imwrite(imgpath+'/'+imgname, image)

    cv2.imshow('cv2', image)
    k = cv2.waitKey(1)
    if k == 27:
        break
    if k == 32:
        cv2.waitKey()

cv2.destroyAllWindows()

猜你喜欢

转载自blog.csdn.net/penghejuan2012/article/details/83027132