resnet_v1_50源码的理解与分析

源代码链接:
https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_utils.py
https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py

1、TensorFlow中resnet_v1_50的用法

首先总结一下用法,源码中resnet_v1_50的参数如下:

def resnet_v1_50(inputs,
                 num_classes=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 spatial_squeeze=True,
                 store_non_strided_activations=False,
                 min_base_depth=8,
                 depth_multiplier=1,
                 reuse=None,
                 scope='resnet_v1_50'):

其中:

  • input:训练集,其格式为[batch, height_in, width_in, channels]
  • num_classes:样本的种类别数,用于定义出上层的节点个数。如果为“None”的话,其最终输出的应该是[batch,1,1,2048],若“spatial_stride=True”,则其最终输出为[batch,2048]
  • is_training:是否在训练模型中加入“Batch_Norm”层
  • global_pool:该层位于整个网络结构之后,位于“num_classes”之前。为“True”则表示对于网络最后一个“net”层的输出结果做一个全局的average pooling。所谓全局池化就是池化的stride等于输入的size,得到一个标量。
  • spatial_squeeze:将列表中维度等于1的维度去掉,如spatial_squeeze([B,1,1,C])=[B,C]
  • store_non_strided_activations:在多尺度图片处理中有用,可以将不同size的输出存储下来

简单来讲,就是我们将上述模块import以后,就构建好了 ResNet50  的网络结构,主要输入训练集“input”和类别数“num_classes”就可以了。如果“num_classes=None”则我们构建的是一个特征提取器的网络架构,只能提取图片的特征,可能是很高的维度,比如2048,;如果“num_classes=10”则表示,我们将“input”中的数据分为了10类,网络架构的最后一层输出的就是10维的一个向量,可以表示该图片属于某一类的概率。

2、resnet_v1_50源码构建的框架

“resnet_utils.py”和“resnet_v1.py”为源码中构建resnet_v1_50的两个模块,里面重要的函数已经用黑框圈出。Block定义了一个类,‘scope’是命名空间属性,‘unit_fn’是一个函数,用于处理网络架构中的unit块,args是其参数。“stack_blocks_dense”则是处理ResNet_Block块。"bottleneck"处理的的网络 架构中的bottleneck部分,包括论文中的“shortcut”部分。“resnet_v1”则是“ResNet50”的主架构。“resnet_v1_block”里将Block.unit_fn赋值为bottleneck。

我们接下来看看源码中,Block、bottleneck和unit表示的含义,如下图3:

图3 ResNet_Block, bottleneck和unit

一个“ResNet_Block”代表的是原论文Table 1中的conv2_x,不包括max pool层,为图3中蓝色虚线框内的部分;“bottleneck”表示的是绿色虚线框内包括黑色曲线的内容;而“unit”则表示的是红色虚线框内那个3*1的表格。对照Table1可以更容易理解。

源码中“resnet_v1”的内容如下:

def resnet_v1(inputs,
              blocks,
              num_classes=None,
              is_training=True,
              global_pool=True,
              output_stride=None,
              include_root_block=True,
              spatial_squeeze=True,
              store_non_strided_activations=False,
              reuse=None,
              scope=None):
  with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    with slim.arg_scope([slim.conv2d, bottleneck,
                         resnet_utils.stack_blocks_dense],
                        outputs_collections=end_points_collection):
      with (slim.arg_scope([slim.batch_norm], is_training=is_training)
            if is_training is not None else NoOpScope()):
        net = inputs
        if include_root_block:
          if output_stride is not None:
            if output_stride % 4 != 0:
              raise ValueError('The output_stride needs to be a multiple of 4.')
            output_stride /= 4
          net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
          net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
        net = resnet_utils.stack_blocks_dense(net, blocks, output_stride,
                                              store_non_strided_activations)
        # Convert end_points_collection into a dictionary of end_points.
        end_points = slim.utils.convert_collection_to_dict(
            end_points_collection)

        if global_pool:
          # Global average pooling.
          net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
          end_points['global_pool'] = net
        if num_classes:
          net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                            normalizer_fn=None, scope='logits')
          end_points[sc.name + '/logits'] = net
          if spatial_squeeze:
            net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
            end_points[sc.name + '/spatial_squeeze'] = net
          end_points['predictions'] = slim.softmax(net, scope='predictions')
        return net, end_points
resnet_v1.default_image_size = 224

源码中语句“ net = resnet_utils.stack_blocks_dense(net, blocks, output_stride, store_non_strided_activations) ”之前的内容处理的是ResNet50主体架构之间的内容,包括Table1中的“conv1”卷积层和“conv2_1”的池化层;而该语句处理的是网络的主体架构——"conv2_2"到"conv4_x"的所有部分,由参数"blocks"定义。该语句之后是“global_pool”、“num_classes”和“spatial_squeeze”,只有在“num_classes”不为“None”的时候,“spatial_squeeze”才会生效 。

Table2表示的函数的调用过程,以及部分参数在ResNet50调用过程中的实际值。白色部分的黑色字体表示的是函数的形参,而蓝色部分则表示参数的实际值。其中“resnet_v1_50()”和“resnet_v1()”的形式参数并没有全部列举出来,但是未列出的部分,并不影响对于源码的理解。

Table2 函数的调用过程

“resnet_v1_block()”是主体结构,定义了ResNet50的4个“Block”块,包括block块的名称和每个块里面包含的unit个数,以及对应的stride,这里面的depth或者说base_depth实际上就是网络架构中kernel的channel数。“resnet_v1_block()”在执行过程中又调用“resnet_utils.Block()”定义了每个block块中的unit_fn()函数,其等于bottleneck(),同时给其中的部分参数赋了值,在“conv2_x”中就是“[{256,64,1},{256,64,1},{256,64,2}]”每个花括号中的参数分别对应一个unit。Table2中最后一行实际上是倒数的第二行的函数的调用,其处理的每个“bottleneck()”里具体的东西。处理完后输出,继续处理下一个unit,最后返回到“resnet_v1()”,在处理下一个block块。一直到主体架构部分处理完成,也就是得到“ net = resnet_utils.stack_blocks_dense(net, blocks, output_stride, store_non_strided_activations) ”的返回值,最后处理后续部分,返回整个程序的输出。

猜你喜欢

转载自blog.csdn.net/Huang_Fj/article/details/100575180
今日推荐