(2)caffe学习之调用预训练的网络训练自己的数据集--补充

参考:https://blog.csdn.net/wangbo792450190/article/details/52583749

上一篇文章中介绍了训练网络的过程,在测试时需要调用均值文件,因此需要将我们生成的二进制均值文件mean.binaryproto转换成mean.npy:建立一个create_npy.py如下:

import sys,os
caffe_root ='/home/qf/git/caffe/' #
sys.path.insert(0,caffe_root + 'python')
import caffe
os.chdir(caffe_root)
import numpy as np


print "Usage: python convert_protomean.py proto.mean out.npy"
blob = caffe.proto.caffe_pb2.BlobProto()
BINARY_PROTO_FILE_NAME  ='/home/qf/git/myfile/build_lmdb/imagenet_mean.binaryproto' 
#
BINARY_PROTO_FILE_PATH  = os.path.join(os.getcwd(),BINARY_PROTO_FILE_NAME)
NPY_FILE_NAME  ='/home/qf/git/myfile/build_lmdb/mean.npy'#
NPY_FILE_PATH  = os.path.join(os.getcwd(),NPY_FILE_NAME)

data = open( BINARY_PROTO_FILE_PATH, 'rb' ).read()
blob.ParseFromString(data)
arr = np.array( caffe.io.blobproto_to_array(blob) )
out = arr[0]
np.save( NPY_FILE_PATH , out )

在终端下运行即可在build_lmdb文件夹下得到mean.npy文件

在dataset文件夹下新建test文件,放入待预测的图片,同时准备labels.txt

测试程序:

import os
import caffe 
import numpy as np 
root='/home/qf/git/'   #
deploy=root + 'myfile/bvlc_reference_caffenet/deploy.prototxt'    #deploy
caffe_model=root + 'myfile/bvlc_reference_caffenet/caffenet_train_iter_1500.caffemodel'  # caffemodel 
 
dir = root+'myfile/dataset/test/'
filelist=[]
filenames = os.listdir(dir)
for fn in filenames:
   fullfilename = os.path.join(dir,fn)
   filelist.append(fullfilename)
 
 
# img=root+'data/DRIVE/test/60337.jpg'   #
 
def Test(img):
      
    net = caffe.Net(deploy,caffe_model,caffe.TEST)   #
       
    #
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})  #
    transformer.set_transpose('data', (2,0,1))    # 
    transformer.set_mean('data', np.load('/home/qf/git/myfile/build_lmdb/mean.npy').mean(1).mean(1))    
    transformer.set_raw_scale('data', 255)    # 
    transformer.set_channel_swap('data', (2,1,0))   #
       
    im=caffe.io.load_image(img)  
    print "im",im                 # 
    net.blobs['data'].data[...] = transformer.preprocess('data',im)      #
    out = net.forward()
    print "out",out
    

    print out['prob'].argmax()
    #
    labels = np.loadtxt("/home/qf/git/myfile/dataset/label.txt", str, delimiter='\t')   #
    prob= net.blobs['prob'].data[0].flatten() #
    print prob 
    order=prob.argsort()[14]  #
     
    print 'the class is:',labels[order]   #
    f=file("/home/qf/git/myfile/dataset/pre_label.txt","a+")
    f.writelines(img+' '+labels[order]+'\n')
 
labels_filename = root +'/home/qf/git/myfile/dataset/label.txt'    #
 
for i in range(0, len(filelist)):
    img= filelist[i]
    Test(img)


猜你喜欢

转载自blog.csdn.net/qq_38096703/article/details/79695667
今日推荐