caffe下py-faster-rcnn利用vgg16训练及预测

工作中经常用到py-faster-rcnn做图片的检测与识别,训练过程有必要记录一下,下面是参照网上的一些资料整理实践后的总结:
py-faster-rcnn的github地址:https://github.com/rbgirshick/py-faster-rcnn

数据采用VOC 2007格式。
这里写图片描述

一、制作数据集
程序/工具:VOC2007文件夹、labelImg
处理流程:图像重命名为6位数字,使用labelImg工具标定,根据xml生成四个txt(train.txt、val.txt、test.txt、trainval.txt),将jpg、xml、txt等文件按照逻辑图所示位置存放
数据生成工具类可参考:imagenet数据标注文件的read和write

二、修改网络文件
train.txt:
models/pascal_voc/VGG16/faster_rcnn_end2end/train.prototxt VGG16的train.prototxt
Line 11:’num_classes’: 2 修改成 损伤类型数目+1(背景算一类)
Line 530:’num_classes’: 2 修改成 损伤类型数目+1(背景算一类)
Line 620:num_output: 2 修改成 损伤类型数目+1(背景算一类)
Line 643:num_output: 8 此处数字应为 (损伤类别数+1)*4 “4”是指bbox的四个角

test.prototxt:
models/pascal_voc/VGG16/faster_rcnn_end2end/test.prototxt VGG16的test.prototxt
Line 567:num_output: 2 修改成 损伤类型数目+1(背景算一类)
Line 592:num_output: 8 此处数字应为 (损伤类别数+1)*4 “4”是指bbox的四个角

pascal_voc.py
lib/datasets/pascal_voc.py 修改line 31 修改为自定义类型
这里写图片描述

三、 运行程序
每次改动数据记得清空缓存 rm -rf data/cache
终端访问py-faster-rcnn目录,输入以下命令:
./experiments/scripts/faster_rcnn_end2end.sh 0 VGG16 pascal_voc
0表示使用GPU 0运行程序,可修改;VGG16表示使用的网络

四、 预测阶段
直接上代码:

#!/usr/bin/env python
# Copyright (c) 2016 Yuwen Xiong
# Licensed under The MIT License [see LICENSE for details]
# Written by Yuwen Xiong
# --------------------------------------------------------
"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
CLASSES = ('__background__',
           'hand')
def get_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return None
    bboxs = []
    for i in inds:
        bbox = dets[i, :4]
        bboxs.append([int(bbox[0]),int(bbox[1]),int(bbox[2]),int(bbox[3])])
    return bboxs
def frcn_predict(net,img_im):
    # Detect all object classes and regress object bounds
    scores, boxes = im_detect(net, img_im)
    # Visualize detections for each class
    CONF_THRESH = 0.65
    NMS_THRESH = 0.15
    res_dict = {}
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1  # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        boxs = get_detections(img_im, cls, dets, thresh=CONF_THRESH)
        if boxs is not None:
           res_dict[cls] = boxs
    return res_dict
def get_init_net():
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    prototxt = r'models/online/models/pascal_voc/VGG16/faster_rcnn_end2end/hand_test.prototxt'
    caffemodel = r'models/online/faster_rcnn_models/vgg16_faster_rcnn_hand_iter_500000.caffemodel'
    if not os.path.isfile(caffemodel):
        raise IOError(('{:s} not found.\n').format(caffemodel))
    caffe.set_mode_gpu()
    caffe.set_device(0)
    cfg.GPU_ID = 0
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)
    print '\n\nLoaded network {:s}'.format(caffemodel)
    # Warmup on a dummy image
    im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
    for i in xrange(2):
        _, _= im_detect(net, im)
    return net
if __name__ == '__main__':
    net = get_init_net()
    img_path = r'data/VOCdevkit/VOC2007_lisa/JPEGImages/5500.jpg'
    im = cv2.imread(img_path)
    res=frcn_predict(net, im)
    print(res)

返回结果res结果示例,数据格式:{label:[pic1_point,pic2_point,…]}:

{'hand': [[482, 347, 570, 438], [52, 289, 147, 362], [104, 261, 273, 375]]}

参考来源:
https://www.zhihu.com/question/57091642/answer/165134753
http://blog.csdn.net/otengyue/article/details/79243559

猜你喜欢

转载自blog.csdn.net/otengyue/article/details/79278486