Faster-RCNN Tensorflow版本源码解析(二)train_net.py所用到的函数

这里将要解析的是Faster-RCNN Tensorflow版本,fork自githubFaster-RCNN_TF

1. 背景交代

Faster-RCNN_TF中,网络的训练文件是 Faster-RCNN_TF/tools/train_net.py。我们已经在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中对该文件进行了源码解析,现在来解析一下该文件中用到的函数。

Faster-RCNN_TF/tools/train_net.py中用到的函数有以下几个:

def parse_args():解析输入参数。 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数。该函数的定义就在Faster-RCNN_TF/tools/train_net.py中。
get_imdb():加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。
get_training_roidb():将训练数据变成minibatch的形式。函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
get_output_dir():设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。
get_network():按照args.network_name获取网络。选择train网络或者test网络。为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。
train_net():启动Faster-RCNN网络训练。函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义

2. 下面来分析一下上面每个函数的源码

2.1. def parse_args():

该函数在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中已经进行了解析。

2.2. get_imdb():

作用:加载训练数据。
定义位置:该函数在Faster-RCNN/lib/datasetes/factory.py中被定义。
文件factory.py是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用。
factory.py 源码如下:

# coding=utf-8 #有中文注释的时候,记得加上这个
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Factory method for easily getting imdbs by name."""

__sets = {}

import datasets.pascal_voc
import datasets.imagenet3d
import datasets.kitti
import datasets.kitti_tracking
import numpy as np

def _selective_search_IJCV_top_k(split, year, top_k):
    """Return an imdb that uses the top k proposals from the selective search
    IJCV code.
    """
    imdb = datasets.pascal_voc(split, year)
    imdb.roidb_handler = imdb.selective_search_IJCV_roidb
    imdb.config['top_k'] = top_k
    return imdb

# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year:
                datasets.pascal_voc(split, year))
"""
# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode
# but only returning the first k boxes
for top_k in np.arange(1000, 11000, 1000):
    for year in ['2007', '2012']:
        for split in ['train', 'val', 'trainval', 'test']:
            name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k)
            __sets[name] = (lambda split=split, year=year, top_k=top_k:
                    _selective_search_IJCV_top_k(split, year, top_k))
"""

# Set up voc_<year>_<split> using selective search "fast" mode
'''
主要解析一下这部分,其他类似。
该部分用到的数据库是pascal_voc 2007数据库,
该数据库由几个部分组成,名称name分别是voc_2007_train、voc_2007_val、voc_2007_trainval、voc_2007_test,
看你的任务是训练还是测试,选择相对应的数据库名称。
这个数据库名称对应的就是(网络训练文件Faster-RCNN_TF/tools/train_net.py)中的参数--imdb的值,
'''
for year in ['2007']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        print name
        __sets[name] = (lambda split=split, year=year:
                datasets.pascal_voc(split, year)) #这是一个lambda函数。所用的函数是datasets.pascal_voc。
        #pascal_voc是一个类,在Faster-RCNN_TF/lib/datasets/pascal_voc.py中被定义
        #(文件pascal_voc.py就是数据库voc_2007_train的数据读写接口)。
        #datasets.pascal_voc的作用就是加载voc_2007_train数据库

    #lambda函数也叫匿名函数,即,函数没有具体的名称,而用def创建的方法是有名称的。
    #lambda允许用户快速定义单行函数,当然用户也可以按照典型的函数定义完成函数。
    #lambda的目的就是简化用户定义使用函数的过程。

# KITTI dataset
for split in ['train', 'val', 'trainval', 'test']:
    name = 'kitti_{}'.format(split)
    print name
    __sets[name] = (lambda split=split:
            datasets.kitti(split))

# Set up coco_2014_<split>
for year in ['2014']:
    for split in ['train', 'val', 'minival', 'valminusminival']:
        name = 'coco_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: coco(split, year))

# Set up coco_2015_<split>
for year in ['2015']:
    for split in ['test', 'test-dev']:
        name = 'coco_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: coco(split, year))

# NTHU dataset
for split in ['71', '370']:
    name = 'nthu_{}'.format(split)
    print name
    __sets[name] = (lambda split=split:
            datasets.nthu(split))


def get_imdb(name):  #加载训练数据。
    """Get an imdb (image database) by name."""
    '''
    在Faster-RCNN_TF/tools/train_net.py中被用到。
    传进来的形参name的值就是train_net.py中的args.imdb_name,
    也就是train_net.py中的参数--imdb的值
    参数--imdb的值,代表的是训练数据库的名字
    '''
    if not __sets.has_key(name): #如果没有该训练数据库的名字
        raise KeyError('Unknown dataset: {}'.format(name)) #报错
    return __sets[name]() #如果有该训练数据库的名字,执行__sets[name](),该函数是在本文件中(在上面)定义的

def list_imdbs():
    """List all registered imdbs."""
    return __sets.keys()

2.3. get_training_roidb():

作用:将训练数据变成minibatch的形式。
定义位置:该函数在Faster-RCNN/lib/fast_rcnn/train.py中被定义。

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images()
        print 'done'

    print 'Preparing training data...'
    if cfg.TRAIN.HAS_RPN:#如果使用RPN(参数cfg.TRAIN.HAS_RPN在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义)
        if cfg.IS_MULTISCALE:
            gdl_roidb.prepare_roidb(imdb)
        else:
            rdl_roidb.prepare_roidb(imdb) #rdl_roidb.prepare_roidb()在Faster-RCNN_TF/lib/roi_data_layer/roidb.py中
    else:
        rdl_roidb.prepare_roidb(imdb)
    print 'done'

    return imdb.roidb

2.4. get_output_dir():

作用:设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。
定义位置:函数get_output_dir()在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。

猜你喜欢

转载自blog.csdn.net/lanyuelvyun/article/details/78194722