Faster Rcnn源码阅读分析(TF+python版)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ssmixi/article/details/75807714

Faster Rcnn源码阅读分析

———源码地址:https://github.com/CharlesShang/TFFRCNN

1)直接运行已经训练好的模型

The demo performs detection using a VGG16 network trained for detection on PASCAL VOC 2007.

python demo.py --model /models/VGGnet_fast_rcnn_iter_150000.ckpt  (后面为本人自己的模型下载存放路径)

测试的数据集存放在/data/demo下。


2)自己训练模型

### Training on Pascal VOC 2007

1. Download the training, validation, test data and VOCdevkit

    ```Shell
    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar
    ```

2. Extract all of these tars into one directory named `VOCdevkit`

    ```Shell
    tar xvf VOCtrainval_06-Nov-2007.tar
    tar xvf VOCtest_06-Nov-2007.tar
    tar xvf VOCdevkit_08-Jun-2007.tar
    ```

3. It should have this basic structure

    ```Shell
    $VOCdevkit/                           # development kit
    $VOCdevkit/VOCcode/                   # VOC utility code
    $VOCdevkit/VOC2007                    # image sets, annotations, etc.
    # ... and several other directories ...
    ```

4. Create symlinks for the PASCAL VOC dataset

    ```Shell
    cd $TFFRCNN/data
    ln -s $VOCdevkit VOCdevkit2007
    ```

5. Download pre-trained model [VGG16](https://drive.google.com/open?id=0ByuDEGFYmWsbNVF5eExySUtMZmM) and put it in the path `./data/pretrain_model/VGG_imagenet.npy`

6. Run training scripts 

    ```Shell
    cd $TFFRCNN
    python ./faster_rcnn/train_net.py --gpu 0 --weights ./data/pretrain_model/VGG_imagenet.npy --imdb voc_2007_trainval --iters 70000 --cfg  ./experiments/cfgs/faster_rcnn_end2end.yml --network VGGnet_train --set EXP_DIR exp_dir
    ```

7. Run a profiling

    ```Shell
    cd $TFFRCNN
    # install a visualization tool
    sudo apt-get install graphviz  
    ./experiments/profiling/run_profiling.sh 
    # generate an image ./experiments/profiling/profile.png

该部分训练代码主要针对VOC2007数据集,若想训练自己的数据集可将其转换为VOC格式。

3)源码阅读笔记

(1)TFFRCNN 下的train_net.py


# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Train a Fast R-CNN network on a region of interest database."""

import argparse #argparse是python用于解析命令行参数和选项的标准模块;
import pprint   #用于打印python数据结构类和方法;
import numpy as np
import pdb      #pdb模块让在用文本编辑器写脚本的情况下进行debug;
import sys
import os.path

this_dir = os.path.dirname(__file__)
sys.path.insert(0, this_dir + '/..')
# for p in sys.path: print p
# print (this_dir)

from lib.fast_rcnn.train import get_training_roidb, train_net
from lib.fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_log_dir
from lib.datasets.factory import get_imdb
from lib.networks.factory import get_network
from lib.fast_rcnn.config import cfg

def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
    parser.add_argument('--gpu', dest='gpu_id',
                        help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--solver', dest='solver',
                        help='solver prototxt',
                        default=None, type=str)
    parser.add_argument('--iters', dest='max_iters',
                        help='number of iterations to train',
                        default=70000, type=int)
    parser.add_argument('--weights', dest='pretrained_model',
                        help='initialize with pretrained model weights',
                        default=None, type=str)
    parser.add_argument('--cfg', dest='cfg_file',
                        help='optional config file',
                        default=None, type=str)
    parser.add_argument('--imdb', dest='imdb_name',
                        help='dataset to train on',
                        default='kitti_train', type=str)
    parser.add_argument('--rand', dest='randomize',
                        help='randomize (do not use a fixed seed)',
                        action='store_true')
    parser.add_argument('--network', dest='network_name',
                        help='name of the network',
                        default=None, type=str)
    parser.add_argument('--set', dest='set_cfgs',
                        help='set config keys', default=None,
                        nargs=argparse.REMAINDER)
    parser.add_argument('--restore', dest='restore',
                        help='restore or not',
                        default=1, type=int)

    if len(sys.argv) == 1:
        parser.print_help()
        # sys.exit(1)

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()

    print('Called with args:')
    print(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    print('Using config:')
    pprint.pprint(cfg)

    if not args.randomize:
        # fix the random seeds (numpy and caffe) for reproducibility
        np.random.seed(cfg.RNG_SEED)
    imdb = get_imdb(args.imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    roidb = get_training_roidb(imdb) #得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性

    output_dir = get_output_dir(imdb, None)
    log_dir = get_log_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)
    print 'Logs will be saved to `{:s}`'.format(log_dir)

    device_name = '/gpu:{:d}'.format(args.gpu_id)
    print device_name

    network = get_network(args.network_name)
    print 'Use network `{:s}` in training'.format(args.network_name)

    train_net(network, imdb, roidb,
              output_dir=output_dir,
              log_dir=log_dir,
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters,
              restore=bool(int(args.restore)))

(1) get_imdb,get_roidb函数:
http://blog.csdn.net/sloanqin/article/details/51537713
http://www.cnblogs.com/alanma/p/6802835.html
http://www.cnblogs.com/alanma/p/6803713.html

(2)cfg(config.py)模块解读:
http://www.cnblogs.com/alanma/p/6800944.html

猜你喜欢

转载自blog.csdn.net/ssmixi/article/details/75807714