ResNet文章复现

ResNet:Deep Residual Learning for Image Recognition

摘要:越深的神经网络越难训练。本文提出了一个残差学习框架(residual learning framework)去减轻训练深度网络(比前人的网络深的多)的难度。我们重新定义层的学习过程为一个与输入有关的残差函数的学习,而不是一个与输入无关的函数。我们进行了全面的(comprehensive)实验,从而证明了这种残差网络是更容易优化的,并可以获得深度带来的准确率提高(虽然深度越大,准确率越高,但实验中,普通网络并不是这样的)。在ImageNet数据集上,作者评估了残差网络(最深到152层,虽然深度是VGG的8倍,但复杂性仍然很低)。集成多个残差网络将在ImageNet的test集上的错误率降低到了3.57%。这个结果赢得了ILSVRC 2015分类比赛桂冠。作者也在CIFAR数据集上进行了100和1000层的实验。

representation的深度对于很多视觉识别任务性能是至关重要的。仅通过极深的表示,我们就在COCO物体探测任务中取得了28%的性能提高。深度残差网络是我们在ILSVRC和COCO 2015比赛中提交的模型的基础,并取得了ImageNet探测、ImageNet定位、COCO探测、COCO分割比赛的桂冠。

1. 简介

深度网络自然地将低、中、高级别特征和分类器以端到端多层方式集成,并且特征的级别能够通过堆叠层数(深度)来丰富。最近的研究表明网络深度是至关重要的。

增加深度很流行,但是问题来了:堆叠更多的层是否能训练出来更好的网络?回答这个问题的一个阻碍是梯度弥散和爆炸问题,这从训练开始便阻碍收敛。梯度问题可以用
normalized initialization23,8,36,12BN来解决,这将使得10层深的网络可以用SGD(梯度反向传播算法)来训练。

更深的网络已经能够使用梯度反向传播训练(能够收敛),但另一个问题来了:随着网络深度的增加,准确率陷入饱和(可能不意外),然后快速降低。意外的是,这样的下降不是由过拟合引起的,并且给一个深度合适的网络增加更多的层导致更高的训练误差[图1][10]#10,41
图1
训练准确率的退化表明不是所有的系统都很容易去优化。考虑一个浅网络和为其添加更多的层产生的更深网络。有一种构建更深网络的方案:增加的层都是identity mapping,同时其它层和浅层完全一样。这种方案的存在表明,更深的模型不应该产生比浅网络更高的训练误差。但是实验表明,现有解决方案中无法找到比上面提出的解决方案相当或更好的解决方案了(或在有限时间内无法找到)。

在本文中,我们通过引入一个深度残差框架解决了上面提到的退化问题。我们直接让这些层拟合一个残差映射(residual mapping),而不是希望每几个堆叠的层直接拟合一个期望的基础映射(underlying mapping)。形式上,让期望的基础映射为 H ( x ) ,我们让堆叠的非线性层拟合另一个映射: F ( x ) := H ( x ) x 。这个原始的映射可以重写为 F ( x ) + x 。我们假设残差映射比原始无参考映射更容易优化。极端情况下,如果一个恒等映射(identity mapping)是最优的,那么将残差置为零比通过非线性层的叠加去拟合一个恒等映射更容易。

公式 F ( x ) + x 能够通过包含“shortcut connection”的前馈网络实现。(图2)shortcut connection实现了identity mapping,并且它们的输出被添加到堆叠层的输出。identity shortcut connection既没有增加额外的参数,也没有增加计算量。整个网络依然能够使用SGD进行端到端训练,能够能很容易地使用Caffe实现。
图2
作者在ImageNet上系统的实验去说明退化问题并且评估了作者的方法。我们说明了:1.我们的极深残差网络很容易去优化,但是plain网络(只是简单的堆叠)当深度增加时却呈现出更高的训练误差。2.我们的深度残差网络能够从网络深度的增加中获取到可观的准确率提升,从而产生了比以往网络更高的准确率。

CIFAR-10数据集上也呈现出了类似的现象,这表明优化困难这个问题和我们方法效果具有普遍性。在CIFAR-10上,我们成功训练出了超100层的模型,并且对超1000层的模型进行了探索。

在ImageNet分类数据集上,我们用极深残差网络获得了最好的结果。我们的152层残差网络是ImageNet上到目前为止最深的网络,同时有比VGG低的复杂性。残差网络的集成版本在ImageNet上取得了top-5: 3.57%的准确率,并且赢得了ILSVRC 2015的分类比赛。极深的表示同时在其它识别任务上有很好的泛化性能,并且帮助我们进一步赢得了ImageNet探测、ImageNet定位、COCO探测、COCO分割赛的第一。这强有力地说明了残差学习原则是通用的,并且我们期待它在其它视觉和非视觉问题上的应用。

2.相关工作

Residual Representation。在图像识别中,[VLAD][18]是一种通过字典的残差向量进行编码的表示形式,Fisher向量可以看做VLAD的一个概率版本。VLAD和Fisher都是图像检索和分类任务中强大的浅层表示。从向量量化来看,编码后的残差向量被证明比原始向量更有效。【这段话啥意思】

在低层视觉和计算机图形学中,为了解决偏微分方程(Partial Differential Equations (PDEs)),广泛使用的多尺度方法(Multigrid method)将这个系统重构为多个尺度上的子问题,这里每一个子问题负责粗和细粒度之间的残差解。Multigrid的一个替代方法是层次化的基础预处理,它依赖于两个尺度之间的残差向量的变量。已经证明这些求解器比不知道解的残差性质的标准求解器收敛的更快。这些方法表明,一个好的重构或者预处理可以简化优化。

Shortcut Connections。Shortcut connections的实践和理论已经被研究可很长时间。早期shortcut connection在MLPs中的实践都是从网络输入和输出之间添加一个线性层。在[43, 24]中,一些内部层被直接连接到辅助分类器for解决梯度弥散和爆炸问题。文章[38,37,31,46]提出了通过shortcut connections实现层间信息、梯度、反向传播误差的交流。在[43]中,一个Inception层由一个shortcut分支和一个不深的分支组成。

当前和我们类似的工作,highway network用门(gate)函数实现了shortcut connection。这些门是与数据相关的并且是有参数的。当一个gated shortcut是关闭的(接近0),highway网络中的门连接层表示一个non-residual函数。相反,我们的公式总是学习残差函数,我们的identity shortcut是从不关闭的,并且所有的信息总是通过,还有额外的残差函数要学习。此外,highway网络没有说明极度增加的深度带来的准确度的提高。

3.深度残差学习(Deep Residual Learning)

3.1残差学习(Residual Learning)

让我们把 H ( x ) 当做用很多堆叠层(不必是整个网络)拟合的基础映射,这里 x 是整个堆叠层的输入。如果我们假设多个非线性层能够渐进(近似)一个复杂函数,然后它等价于假设它们能够渐进地近似残差函数,例如 H ( x ) x (假设输入和输出的维度相同)。所以,我们显式地让这些层近似一个残差函数 F ( x ) := H ( x ) x ,而不是期待堆叠的层去近似 H ( x ) 。原始的函数因此变为 F ( x ) + x 。尽管两个公式都应该能够去渐进地近似希望的函数(如假设一样),但学习的难度可能是不同的。

退化问题的反直觉现象(图1左)激发了这个重构。正如我们在第1章的讨论,如果增加的层能够用identity mapping构建,一个更深的模型应该有更低的训练误差。退化问题说明求解器可能用多个非线性层近似identity mapping。有了残差学习的重构,如果identity mapping是最优的,那么求解器可能简单地多个线性层的权重推向0从而去近似identity mappings。

在实践中,identity不太可能是最优的,但是我们的重构可能有助于去预处理这个问题。如果与zero mapping相比,identity mapping更接近于最优的函数,则求解器应该更容易找到关于恒等映射的扰动,而不是将该函数作为新函数来学习。我们通过实验(图7)表明学习到的残差函数一般有小的响应(responses),说明identity mapping提供了合理的预处理。
图7

3.2 Shortcut实现的恒等映射(Identity Mapping by Shortcut)

我们每隔几个堆叠层应用残差学习。图2中展示了一个构件块。本文中我们考虑一个构建块的的数学公式如下:

y = F ( x , W i ) + x                      (1)

这里 x y 是考虑的层的输入和输出向量。函数 F ( x , W i ) 表示将要学习的残差映射。例如图2中的两层, F = W 2 σ ( W 1 x ) ,其中, σ 表示ReLU(这里将偏差省略掉了)。操作 F + x 通过有一个快捷链接(shortcut connection)和元素级别的加法来实现。我们在加法后,我们也使用了激活函数(可以从图2中看到)。

公式(1)中的shortcut connection既没有引入额外的参数也没有引入额外的计算量。这不仅在实践中很有吸引力,而且在plain和残差网络的对比中很重要。我们可以在相同参数量、深度、宽度和计算量(不包含图2中最后的加法)的情况下,公平地比较plain和残差网络。

在公式(1)中, x F 的维度必须相等。如果两者维度不同(例如,输入输出通道数不一致时),我们可以通过一个线性投射 W s shortcut connection来匹配维度:

y = F ( x , { W i } ) + W s x                             ( 2 )

我们也可以使用公式(1)中的方形矩阵 W s 。但是我们将通过实验表明identity mapping足以解决退化问题,并且它是合算的;因此,只在维度匹配时使用 W s

残差函数 F 是灵活的。本文的实验中提到了一个有两或三层的残差函数,当然更多层数的情况也是可行的。但是如果 F 仅仅只有一层,公式(1)和线性层是相似的: y = W 1 x + x ,单层F在试验中,并没有观察到优势。

我们也注意到:尽管上面的辩证是关于FC层的(为了简单),但它也适用于conv层。函数 F ( x , { W i } ) 能够表示多个卷积层,元素加法在两个feature maps上逐通道进行。

3.3 网络架构

我们已经在测试了不同的plain / residual网络,并且观察到了一致的现象。为了提供讨论的实例,我们描述了ImageNet比赛中用的两个模型。
图3
Plain Network:我们的plain baselines(图3中间的模型)主要受VGG网络(图3左)的启发。卷积层主要有3x3滤波器,并遵守两个简单的设计原则:1. 输入输出有相同的feature map size,层有相同数量的滤波器(通道数一致)。2. 如果feature map size减半,则滤波器的数量增加一倍以保持每层的时间复杂度。我们通过stride 为2的卷积层来实现下采样。网络以全局平均池化层+激活函数为softmax的1000为FC层结束。网络中重要层的数量为34(图3中间)。

值得注意的是:我们的模型和VGG网络相比,有更少的滤波器和更低的复杂性。我们的34层baseline有36亿次FLOPs(multiply-adds),计算量仅仅是VGG-19(196亿次FLOPs)的18%。

Residual Network。基于上面的plain网络,我们插入shortcut connections,从而得到了Residual网络(图3右)。当输入和输出维度相同时,公式(1)中的shortcuts能够直接使用。当维度增加时(图3中的虚线shortcut),我们考虑两种选项:A. shortcut仍然执行恒等映射,用零来填充增加的维度。这个选项不会引入额外的参数。B. 使用公式(2)中的投射shortcut来匹配维度(done by 1x1 conv)。对于两个选项,当shortcut穿过两种尺寸的feature maps时,卷积都选用stride 2。

3.4 实现

ImageNet中我们的实现遵循了[21,41]两文的实践。The image is resized with its shorter side randomly sampled in [256, 480] for scale augmentation [40]。A 224×224 crop is randomly sampled from an image or its horizontal flip, with the per-pixel mean subtracted [21]. The standard color augmentation in [21] is used. We adopt batch normalization (BN) [16] right after each convolution and before activation, following [16]. We initialize the weights as in [12] and train all plain/residual nets from scratch. We use SGD with a mini-batch size of 256. The learning rate starts from 0.1 and is divided by 10 when the error plateaus, and the models are trained for up to 60×104 iterations. We use a weight decay of 0.0001 and a momentum of 0.9. We do not use dropout [13], following the practice in [16].

In testing, for comparison studies we adopt the standard 10-crop testing [21]. For best results, we adopt the fully-convolutional form as in [40, 12], and average the scores at multiple scales (images are resized such that the shorter side is in {224, 256, 384, 480, 640}).

4. 实验(Experiments)

4.1 ImageNet分类(ImageNet Classification)

作者在ImageNet分类数据集(1000类)上评估了residual模型。模型在280万张训练图片上进行了训练,在5万张验证图片上进行评估。在10万张测试图片上获得最终的结果(reported by the test server)。我们评估了top-1和top-5错误率。

Plain Network
作者首先评估了18层和34层plain网络。34层网络采用的是图3中间的模型。18层是其的简化版本。表1详细阐述了架构细节。
表1
表2
表2表明34层plain网络比18层plain网络有更高的验证错误。为了揭示原因,在图4左,我们比较了训练过程中的训练和验证误差。我们注意到退化问题(34层网络具有比18层网络更高的训练误差)贯穿整个训练过程,即使18层网络是34层网络的一个子空间。
图4
作者认为这里的优化困难不太可能是由梯度弥散引起的。这里的plain网络采用了BN,它能确保正向传播过程中的信号有非0的方差,同时确保了反向传播过程中梯度的有效传递。事实上,34层深是能够取得相当高的准确率的(表3),这表明在某种程度上来说求解器仍工作。我们推测深度plain网络可能有exponentially low收敛速度,这影响了训练误差的降低。这种优化困难的原因未来会进一步研究。

Residual Network。接下来,我们评估了18层和34层残差网络(ResNet)。baseline架构是和上述的plain网络一样,并每两层增加一个shortcut连接。在第一次对比中,我们在维度发生变化时,采用了用0填充增加的维度的方法,所以两者的参数量完全一样。对比结果见表2和图4。

从表2和图4中,我们有三个主要发现:1. 通过残差学习,34层(ResNet-34)超过了18层(ResNet-18)的准确率(超过了2.8%)。更重要的是,ResNet-34在训练集上的错误率很低的同时,在验证集上有很好的泛化性。这表明退化问题被很好的解决了。

2,相比于plain网络,ResNet网络将top-1准确率减少了将近3.5%(表2)。这验证了残差机制在非常深的网络上的效果。

3我们同时注意到:18层plain / residual网络的准确率是差不多的,但是ResNet-18收敛的更快。当网络不是过于深时,当前的SGD求解器仍然能够很好的优化plain网络。在这种情况下,残差机制可以加快网络早期的收敛速度。

Identity v s . Projection Shortcut。我们已经证明无参数,identity shortcut有助于训练。接下来我们研究projection shortcut(公式2)。在表3中我们比较了三个选项:(A) 使用zero-padding shortcut来增加维度,所有的shortcuts是没有参数的;(B)用Projection shortcut来增加维度,其它的shortcut是identity;(C)所有的shortcut都是projections。
表3
表3说明三个选项都比对应的plain网络好很多。B比A略好。我们认为这是因为A中的零填充确实没有残差学习。C比B稍好,我们把这归因于projection shortcut引入了额外参数。但A/B/C之间的细微差异表明,projection shortcut对于解决退化问题不是至关重要的。所以我们在本文的剩余部分不再使用C,以减少内存/时间复杂性和模型大小。identity shortcut对于下面介绍的瓶颈结构尤为重要。

Deeper Bottleneck Architecture。接下来我们描述几个更深的网络(for ImageNet)。由于对训练时间的关注(我们要能承受),我们将构建块修改为瓶颈设计。对于每个残差函数F,我们使用堆叠3层而不是2层(图5)。三层分别是1×1,3×3和1×1卷积,其中的两个1×1卷积层分别负责降低维度和增加(恢复)维度,从而在3×3卷积层这里产生一个瓶颈。图5是具体的架构图,图中两个设计有相似的时间复杂性。
图5
无参数identity shortcut对于瓶颈架构尤为重要。如果图5(右)中的恒等快捷连接被投影替换,我们将看到时间复杂度和模型大小加倍,因为identity shortcut是连接到两个高维端的。因此,恒等快捷连接可以为瓶颈设计得到更高效的模型。

50-layer ResNet。作者将2层块的模块更新为图5右的3层瓶颈块,从而得到一个50层深的网络。作者使用B来增加维度。该模型有38亿FLOPs。

101-layer and 152-layer ResNet。作者使用更多的3层瓶颈块来构建101层和152层网络(细节见表1)。值的注意的是,ResNet-152深度增加了许多,但其复杂性(113亿FLOPs)仍比VGG-16/19(153/196亿FLOPs)低。

ResNet-50/101/152比ResNet-34的准确率要高得多(表3和4)。我们没有观察到退化问题,因此可以从显著增加的深度中获得显著的准确性提高。所有评估指标都能证明深度的提高(表3和表4)。
表4
表5

4.2 CIFAR-10上的实现和分析(CIFAR-10 and Analysis)

我们在CIFAR-10(5万张训练图片和1万张测试图片)数据集上进行了更多的研究。实验中在训练集上训练,测试集上评估。我们的关注点是极深网络的行为,而不是去获得更好的结果,所以我们故意(intentionally)使用简单的架构(如下)。

Plain和Residual架构使用(follow)图3的形式(中、右)。网络输入是32x32图像(per-pixel mean subtracted)。第一层是3x3卷积。然后我们使用 一个堆叠起来的6n层的3x3卷积,feature maps的size分别为32,16,8,size每减半,channel就变为2倍,channel分别为16,32,64。下采样过程通过stride为2的卷积来实现。网络最后接一个全局平均池化层,一个10-way的FC层和softmax。总共有6n+2个有权重的层。下表总结了架构
表6
当使用shortcut连接时,每两层有一个shortcut(connected to the pairs of 3x3 layers)总共有3n个shortcut。在这个数据集,我们在所有的模型中使用identity shortcut(选项A)。所以我们的残差模型有着和plain模型相同深度、宽度和参数量。
模型中使用了weight decay(0.001),momentum(0.9),【13】权重初始化方法和BN,但没有使用Dropout。模型训练过程在两块GPU上以128的batch_size完成。开始的学习速率为0.1,在第32k和48k次迭代,学习速率除以10,迭代次数达到64k时,终止了训练(训练集合划分为45k的训练集,5k的验证集)。我们使用了【24】中类似的数据增强:图像每边pad 4个像素,然后随机裁剪出一个32x32的图像,并且对图像进行随机翻转。测试过程中,只使用原始的32x32图像。

我们比较了n={3,5,7,9}(分别对应20,32,44,56层)时的网络。图6左说明了plain网络的行为,深度plain网络的性能随深度的增加而降低,并且当深度增加时,训练误差更高。在ImageNet(图4左)和MNIST(42)上也有类似的现象。表明这个优化困难是一个基础性的问题。
图6
图6中表明了ResNet的行为。ImageNet上行为也类似(图4右)。我们的ResNet成功克服了优化困难并且能从深度的增加得到性能提升。

我们进一步探索了n=18时(110层)的ResNet。在这个模型中,我们发现0.1的初始学习速率稍微有点过于大以至于不收敛。所以我们使用0.01的初始学习速率,直到训练误差低于80%(大概400次迭代),然后学习速率恢复为0.1并且继续训练。剩下的学习过程和前面一样。这个100层网络收敛的很好(图6中)。它有着比其它deep和thin网络(FitNet和Highway)更少的参数,同时性能基本接近state of art。

Analysis of Layer Responses:图7表明了层相应的标准差(standard deviations (std))的变化。响应是3x3卷积的输出(BN之后,激活之前)。对于ResNet,这个分析揭示了残差函数的相应特点。图7表明ResNet一般有着比plain网络更小的响应。这些结果支持了我们的基础动机(3.1节):残差函数可能与非残差函数相比一般更接近于0。我们也注意到更深的ResNet有着更小的响应幅度(通过图7中ResNet-20,56,110的对比可以证明)。当网络加深后,ResNet中的每一层趋向于去更少地修改信号(modify signal less)。
图7
Exploring Over 1000 layers:我们探索一个非常非常深的模型(1000层)。我们设置n=200从而产生一个1202层网络,训练方法采用ResNet-110类似的训练方法。我们的方法表明没有优化困难并且1000层网络能够去取得小于0.1%的训练误差(图6右)。它的测试误差仍然是相当好的(7.93%,表6)。

但是在这样非常非常深的模型中,仍然有一个显而易见的问题。1202层网络的测试结果比110层网络要差,尽管有着相同的训练误差。我们认为这是因为过拟合。1202层网络可能是不必要的(太大了,19.4M)对于CIFAR-10。例如Maxout或dropout的强正则被应用去取得最好的结果(在CIFAR-10上)。在本文中,我们没有使用maxout和dropout并且仅仅采用了深窄这样的网络进行正则。但是结合强正则可能提高结果(未来要研究)。

4.3 在PASCAL和MS COCO数据集上的物体探测

我们的方法在其他识别任务上有很好的泛化性能。表7和表8展示了PASCAL 2007、2012和COCO上的物体探测基线。我们采用了Faster R-CNN作为探测方法。然后用ResNet替换其中的部件,从而提高性能。
表78

TensorFlow实现ResNet-34:

#coding:utf-8
'''
ResNet

网络配置来自文章Deep Residual Learning for Image Recognition
'''

import tensorflow as tf
keras = tf.keras
from tensorflow.python.keras.layers import Conv2D,MaxPool2D,Dropout,Dense


def inference(inputs,
              num_classes=1000,
              is_training=True,
              dropout_keep_prob=0.5):
  '''
  inputs: a tensor of images
  num_classes: num of category
  is_training: key of dropout
  dropout_keep_prob: dropout rate
  '''

  x = inputs
  with tf.variable_scope('conv1'):
  # conv1/pool1
    x = Conv2D(64, [7,7], 2, activation='relu', padding='same', name='conv1')(x)
    x = MaxPool2D([3,3], 2, name='pool1')(x)
  with tf.variable_scope('conv2_x'):
    # conv2-7
    for i in range(0,3):
      orig_x = x
      x = Conv2D(64, [3,3], activation='relu', padding='same', name='conv'+str(2*i+2))(x)
      x = Conv2D(64, [3,3], activation='relu', padding='same', name='conv'+str(2*i+3))(x)
      x = tf.add(x,orig_x)
  with tf.variable_scope('conv3_x'):
    # conv8-9
    orig_x = x
    x = Conv2D(128, [3,3], 2, activation='relu', padding='same', name='conv8')(x)
    x = Conv2D(128, [3,3], activation='relu', padding='same', name='conv9')(x)
    orig_x = tf.nn.avg_pool(orig_x, [1,2,2,1], [1,2,2,1], 'SAME')
    orig_x = tf.pad(orig_x, [[0, 0], [0, 0], [0, 0], [(128-64)//2, (128-64)//2]])
    x = tf.add(x,orig_x)
    # conv10-15
    for i in range(0,3):
      orig_x = x
      x = Conv2D(128, [3,3], activation='relu', padding='same', name='conv'+str(2*i+10))(x)
      x = Conv2D(128, [3,3], activation='relu', padding='same', name='conv'+str(2*i+11))(x)
      x = tf.add(x,orig_x)
  with tf.variable_scope('conv4_x'):
    # conv16-17
    orig_x = x
    x = Conv2D(256, [3,3], 2, activation='relu', padding='same', name='conv16')(x)
    x = Conv2D(256, [3,3], activation='relu', padding='same', name='conv17')(x)
    orig_x = tf.nn.avg_pool(orig_x, [1,2,2,1], [1,2,2,1], 'SAME')
    orig_x = tf.pad(orig_x, [[0, 0], [0, 0], [0, 0], [(256-128)//2, (256-128)//2]])
    x = tf.add(x,orig_x)
    # conv18-27
    for i in range(0,5):
      orig_x = x
      x = Conv2D(256, [3,3], activation='relu', padding='same', name='conv'+str(2*i+18))(x)
      x = Conv2D(256, [3,3], activation='relu', padding='same', name='conv'+str(2*i+19))(x)
      x = tf.add(x,orig_x)
  with tf.variable_scope('conv5_x'):
    # conv28-29
    orig_x = x
    x = Conv2D(512, [3,3], 2, activation='relu', padding='same', name='conv28')(x)
    x = Conv2D(512, [3,3], activation='relu', padding='same', name='conv29')(x)
    orig_x = tf.nn.avg_pool(orig_x, [1,2,2,1], [1,2,2,1], 'SAME')
    orig_x = tf.pad(orig_x, [[0, 0], [0, 0], [0, 0], [(512-256)//2, (512-256)//2]])
    x = tf.add(x,orig_x)
    # conv30-33
    for i in range(0,2):
      orig_x = x
      x = Conv2D(512, [3,3], activation='relu', padding='same', name='conv'+str(2*i+30))(x)
      x = Conv2D(512, [3,3], activation='relu', padding='same', name='conv'+str(2*i+31))(x)
      x = tf.add(x,orig_x)
  # fc34
  x = keras.layers.GlobalAveragePooling2D(name='global_avg_pool')(x)
  logits = Dense(num_classes, activation='relu', name='logits')(x)
  return logits


def build_cost():
  pass


def build_train_op():
  pass


if __name__ == '__main__':
  images = tf.placeholder(tf.float32, [None, 224, 224, 3])
  labels = tf.placeholder(tf.float32, [None, 1000])
  logits = inference(inputs=images,
                     num_classes=1000)
  print('inference: good job')

#  sess = tf.Session()
#  sess.run(tf.global_variables_initializer())
#  Writer = tf.summary.FileWriter('./tmp',sess.graph)

上面的代码其实是比较烂的
读者有兴趣的可以将上面的残差模块编写成一个module,然后在编写ResNet时,将大大降低编写的难度。
改进后,可以极大的简化代码的编写,并且便于代码的维护。

另外,不一定要使用keras,其实TensorFlow里的layers和losses模块也是很好用的。

注:任何对于本文代码的使用都必须注明引用

猜你喜欢

转载自blog.csdn.net/u014061630/article/details/80408056