版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ssmixi/article/details/75807714
Faster Rcnn源码阅读分析
———源码地址:https://github.com/CharlesShang/TFFRCNN
1)直接运行已经训练好的模型
The demo performs detection using a VGG16 network trained for detection on PASCAL VOC 2007.
python demo.py --model /models/VGGnet_fast_rcnn_iter_150000.ckpt (后面为本人自己的模型下载存放路径)
测试的数据集存放在/data/demo下。
2)自己训练模型
### Training on Pascal VOC 2007
1. Download the training, validation, test data and VOCdevkit
```Shell
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar
```
2. Extract all of these tars into one directory named `VOCdevkit`
```Shell
tar xvf VOCtrainval_06-Nov-2007.tar
tar xvf VOCtest_06-Nov-2007.tar
tar xvf VOCdevkit_08-Jun-2007.tar
```
3. It should have this basic structure
```Shell
$VOCdevkit/ # development kit
$VOCdevkit/VOCcode/ # VOC utility code
$VOCdevkit/VOC2007 # image sets, annotations, etc.
# ... and several other directories ...
```
4. Create symlinks for the PASCAL VOC dataset
```Shell
cd $TFFRCNN/data
ln -s $VOCdevkit VOCdevkit2007
```
5. Download pre-trained model [VGG16](https://drive.google.com/open?id=0ByuDEGFYmWsbNVF5eExySUtMZmM) and put it in the path `./data/pretrain_model/VGG_imagenet.npy`
6. Run training scripts
```Shell
cd $TFFRCNN
python ./faster_rcnn/train_net.py --gpu 0 --weights ./data/pretrain_model/VGG_imagenet.npy --imdb voc_2007_trainval --iters 70000 --cfg ./experiments/cfgs/faster_rcnn_end2end.yml --network VGGnet_train --set EXP_DIR exp_dir
```
7. Run a profiling
```Shell
cd $TFFRCNN
# install a visualization tool
sudo apt-get install graphviz
./experiments/profiling/run_profiling.sh
# generate an image ./experiments/profiling/profile.png
该部分训练代码主要针对VOC2007数据集,若想训练自己的数据集可将其转换为VOC格式。
3)源码阅读笔记
(1)TFFRCNN 下的train_net.py
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""Train a Fast R-CNN network on a region of interest database."""
import argparse #argparse是python用于解析命令行参数和选项的标准模块;
import pprint #用于打印python数据结构类和方法;
import numpy as np
import pdb #pdb模块让在用文本编辑器写脚本的情况下进行debug;
import sys
import os.path
this_dir = os.path.dirname(__file__)
sys.path.insert(0, this_dir + '/..')
# for p in sys.path: print p
# print (this_dir)
from lib.fast_rcnn.train import get_training_roidb, train_net
from lib.fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_log_dir
from lib.datasets.factory import get_imdb
from lib.networks.factory import get_network
from lib.fast_rcnn.config import cfg
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=70000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='kitti_train', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--network', dest='network_name',
help='name of the network',
default=None, type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--restore', dest='restore',
help='restore or not',
default=1, type=int)
if len(sys.argv) == 1:
parser.print_help()
# sys.exit(1)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print('Using config:')
pprint.pprint(cfg)
if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
imdb = get_imdb(args.imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
roidb = get_training_roidb(imdb) #得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性
output_dir = get_output_dir(imdb, None)
log_dir = get_log_dir(imdb)
print 'Output will be saved to `{:s}`'.format(output_dir)
print 'Logs will be saved to `{:s}`'.format(log_dir)
device_name = '/gpu:{:d}'.format(args.gpu_id)
print device_name
network = get_network(args.network_name)
print 'Use network `{:s}` in training'.format(args.network_name)
train_net(network, imdb, roidb,
output_dir=output_dir,
log_dir=log_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters,
restore=bool(int(args.restore)))
(1) get_imdb,get_roidb函数:
http://blog.csdn.net/sloanqin/article/details/51537713
http://www.cnblogs.com/alanma/p/6802835.html
http://www.cnblogs.com/alanma/p/6803713.html
(2)cfg(config.py)模块解读:
http://www.cnblogs.com/alanma/p/6800944.html