这里将要解析的是Faster-RCNN Tensorflow版本,fork自githubFaster-RCNN_TF。
1. 背景交代
Faster-RCNN_TF中,网络的训练文件是 Faster-RCNN_TF/tools/train_net.py。我们已经在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中对该文件进行了源码解析,现在来解析一下该文件中用到的函数。
Faster-RCNN_TF/tools/train_net.py中用到的函数有以下几个:
def parse_args():解析输入参数。 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数。该函数的定义就在Faster-RCNN_TF/tools/train_net.py中。
get_imdb():加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。
get_training_roidb():将训练数据变成minibatch的形式。函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
get_output_dir():设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。
get_network():按照args.network_name获取网络。选择train网络或者test网络。为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。
train_net():启动Faster-RCNN网络训练。函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义
2. 下面来分析一下上面每个函数的源码
2.1. def parse_args():
该函数在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中已经进行了解析。
2.2. get_imdb():
作用:加载训练数据。
定义位置:该函数在Faster-RCNN/lib/datasetes/factory.py中被定义。
文件factory.py是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用。
factory.py 源码如下:
# coding=utf-8 #有中文注释的时候,记得加上这个
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""Factory method for easily getting imdbs by name."""
__sets = {}
import datasets.pascal_voc
import datasets.imagenet3d
import datasets.kitti
import datasets.kitti_tracking
import numpy as np
def _selective_search_IJCV_top_k(split, year, top_k):
"""Return an imdb that uses the top k proposals from the selective search
IJCV code.
"""
imdb = datasets.pascal_voc(split, year)
imdb.roidb_handler = imdb.selective_search_IJCV_roidb
imdb.config['top_k'] = top_k
return imdb
# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year:
datasets.pascal_voc(split, year))
"""
# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode
# but only returning the first k boxes
for top_k in np.arange(1000, 11000, 1000):
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k)
__sets[name] = (lambda split=split, year=year, top_k=top_k:
_selective_search_IJCV_top_k(split, year, top_k))
"""
# Set up voc_<year>_<split> using selective search "fast" mode
'''
主要解析一下这部分,其他类似。
该部分用到的数据库是pascal_voc 2007数据库,
该数据库由几个部分组成,名称name分别是voc_2007_train、voc_2007_val、voc_2007_trainval、voc_2007_test,
看你的任务是训练还是测试,选择相对应的数据库名称。
这个数据库名称对应的就是(网络训练文件Faster-RCNN_TF/tools/train_net.py)中的参数--imdb的值,
'''
for year in ['2007']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
print name
__sets[name] = (lambda split=split, year=year:
datasets.pascal_voc(split, year)) #这是一个lambda函数。所用的函数是datasets.pascal_voc。
#pascal_voc是一个类,在Faster-RCNN_TF/lib/datasets/pascal_voc.py中被定义
#(文件pascal_voc.py就是数据库voc_2007_train的数据读写接口)。
#datasets.pascal_voc的作用就是加载voc_2007_train数据库
#lambda函数也叫匿名函数,即,函数没有具体的名称,而用def创建的方法是有名称的。
#lambda允许用户快速定义单行函数,当然用户也可以按照典型的函数定义完成函数。
#lambda的目的就是简化用户定义使用函数的过程。
# KITTI dataset
for split in ['train', 'val', 'trainval', 'test']:
name = 'kitti_{}'.format(split)
print name
__sets[name] = (lambda split=split:
datasets.kitti(split))
# Set up coco_2014_<split>
for year in ['2014']:
for split in ['train', 'val', 'minival', 'valminusminival']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# Set up coco_2015_<split>
for year in ['2015']:
for split in ['test', 'test-dev']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# NTHU dataset
for split in ['71', '370']:
name = 'nthu_{}'.format(split)
print name
__sets[name] = (lambda split=split:
datasets.nthu(split))
def get_imdb(name): #加载训练数据。
"""Get an imdb (image database) by name."""
'''
在Faster-RCNN_TF/tools/train_net.py中被用到。
传进来的形参name的值就是train_net.py中的args.imdb_name,
也就是train_net.py中的参数--imdb的值
参数--imdb的值,代表的是训练数据库的名字
'''
if not __sets.has_key(name): #如果没有该训练数据库的名字
raise KeyError('Unknown dataset: {}'.format(name)) #报错
return __sets[name]() #如果有该训练数据库的名字,执行__sets[name](),该函数是在本文件中(在上面)定义的
def list_imdbs():
"""List all registered imdbs."""
return __sets.keys()
2.3. get_training_roidb():
作用:将训练数据变成minibatch的形式。
定义位置:该函数在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
imdb.append_flipped_images()
print 'done'
print 'Preparing training data...'
if cfg.TRAIN.HAS_RPN:#如果使用RPN(参数cfg.TRAIN.HAS_RPN在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义)
if cfg.IS_MULTISCALE:
gdl_roidb.prepare_roidb(imdb)
else:
rdl_roidb.prepare_roidb(imdb) #rdl_roidb.prepare_roidb()在Faster-RCNN_TF/lib/roi_data_layer/roidb.py中
else:
rdl_roidb.prepare_roidb(imdb)
print 'done'
return imdb.roidb
2.4. get_output_dir():
作用:设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。
定义位置:函数get_output_dir()在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。