Faster R-CNN源码阅读之二:Faster R-CNN/lib/networks/factory.py

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DaVinciL/article/details/81812472
  1. Faster R-CNN源码阅读之零:写在前面
  2. Faster R-CNN源码阅读之一:Faster R-CNN/lib/networks/network.py
  3. Faster R-CNN源码阅读之二:Faster R-CNN/lib/networks/factory.py
  4. Faster R-CNN源码阅读之三:Faster R-CNN/lib/networks/VGGnet_test.py
  5. Faster R-CNN源码阅读之四:Faster R-CNN/lib/rpn_msr/generate_anchors.py
  6. Faster R-CNN源码阅读之五:Faster R-CNN/lib/rpn_msr/proposal_layer_tf.py
  7. Faster R-CNN源码阅读之六:Faster R-CNN/lib/fast_rcnn/bbox_transform.py
  8. Faster R-CNN源码阅读之七:Faster R-CNN/lib/rpn_msr/anchor_target_layer_tf.py
  9. Faster R-CNN源码阅读之八:Faster R-CNN/lib/rpn_msr/proposal_target_layer_tf.py
  10. Faster R-CNN源码阅读之九:Faster R-CNN/tools/train_net.py
  11. Faster R-CNN源码阅读之十:Faster R-CNN/lib/fast_rcnn/train.py
  12. Faster R-CNN源码阅读之十一:Faster R-CNN预测demo代码补完
  13. Faster R-CNN源码阅读之十二:写在最后

一、介绍
   本demo由Faster R-CNN官方提供,我只是在官方的代码上增加了注释,一方面方便我自己学习,另一方面贴出来和大家一起交流。
   该文件中的函数的主要目的是根据所传入的参数选择特定的test网络结构或者train网络结构。

二、代码以及注释

# -*- coding:utf-8 -*-
# --------------------------------------------------------
# SubCNN_TF
# Copyright (c) 2016 CVGL Stanford
# Licensed under The MIT License [see LICENSE for details]
# Written by Yu Xiang
# --------------------------------------------------------

"""Factory method for easily getting imdbs by name."""

__sets = {}

import networks.VGGnet_train
import networks.VGGnet_test
import pdb
import tensorflow as tf

#__sets['VGGnet_train'] = networks.VGGnet_train()

#__sets['VGGnet_test'] = networks.VGGnet_test()


# 根据名称选择使用test的网络结构还是train的网络结构
def get_network(name):
    """Get a network by name.
    :param name: name的一般选择是'VGGnet_test'或者'VGGnet_train'
    :return: 所选择的网络结构
    """
    #if not __sets.has_key(name):
    #    raise KeyError('Unknown dataset: {}'.format(name))
    #return __sets[name]
    if name.split('_')[1] == 'test':
       return networks.VGGnet_test()
    elif name.split('_')[1] == 'train':
       return networks.VGGnet_train()
    else:
       raise KeyError('Unknown dataset: {}'.format(name))


def list_networks():
    """List all registered imdbs."""
    return __sets.keys()

猜你喜欢

转载自blog.csdn.net/DaVinciL/article/details/81812472