caffe Python API 之Inference

#以SSD的检测测试为例
def detetion(image_dir,weight,deploy,resolution=300):
    caffe.set_mode_gpu()
    net = caffe.Net(weight,deploy,caffe.TEST)
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data',(2,0,1))
    transformer.set_mean('data', np.array([104, 117, 123]))  # mean pixel

    images = os.listdir(image_dir)
    target_dir = "det_results"
    if not os.path.exists(target_dir):
        os.mkdir(target_dir)
    for image in images:
        image_path = os.path.join(image_dir,image)
        target_path = os.path.join(target_dir,image)
        croped = cut(image_path,resolution)
        net.blobs['data'].reshape(1, 3, resolution, resolution)
        transformed_image = transformer.preprocess('data',croped)
        net.blobs['data'].data[...]=transformed_image
        start = time.time()
        net.forward()
        end  = time.time()
        print "Forward time is {} s.".format(int(end-start))
        out_put = net.blobs["detection_out"].data

        out_put = np.squeeze(out_put)
        # label,conf,xmin,ymin,xmax,ymax
        for box in out_put:
            conf = box[2]
            # if conf < 0.1:
            #     continue
            xmin = int(box[3]*resolution) if box[3] > 0 else 0
            ymin = int(box[4]*resolution) if box[4] > 0 else 0
            xmax = int(box[5]*resolution) if box[5] > 0 else 0
            ymax = int(box[6]*resolution) if box[6] > 0 else 0
            cv2.rectangle(croped,(xmin,ymin),(xmax,ymax),(0,255,0),1)
        cv2.imwrite(target_path,croped)
        print target_path

猜你喜欢

转载自www.cnblogs.com/houjun/p/9912498.html
今日推荐