何凯明的深度残差网络(ResNet)在深度学习的发展中起到了很重要的作用,ResNet一举拿下了当年多个CV比赛的冠军,更重要的是这一结构解决了训练极深网络时的梯度消失问题。
首先来看看ResNet的网络结构,这里选取的是ResNet的一个变种:ResNet34。ResNet的网络结构如下图所示,可见除了最开始的卷积、池化和最后的池化、全连接之外,网络中有很多结构相似的单元,这些重复单元的共同点就是有个跨层直连的shortcut。
ResNet中将一个跨层直连的单元称为Residual block,其结构如下图所示,左边部分是普通的卷积网络结构,右边是直连,但如果输入和输出的通道数不一致,或其步长不为1,那么就需要有一个专门的单元将二者转成一致,使其可以相加。
另外我们可以发现Residual block的大小也是有规律的,在最开始的pool之后有连续的几个一模一样的Residual block单元,这些单元的通道数一样,在这里我们将这几个拥有多个Residual block单元的结构称之为layer,注意和通常说的layer区分开来,这里的layer是几个层的集合。
考虑到Residual block和layer出现了多次,我们可以把它们实现为一个子Module或函数。这里我们将Residual block实现为一个子module,而将layer实现为一个函数。
下面是实现代码
import torch import torch.nn as nn import torch.nn.functional as F class ResidualBlock(nn.Module): ''' 实现子module: Residual Block ''' def __init__(self,inchannel,outchannel,stride=1,shortcut=None): super(ResidualBlock,self).__init__() self.left=nn.Sequential( nn.Conv2d(inchannel,outchannel,3,stride,1,bias=False), nn.BatchNorm2d(outchannel), nn.ReLU(inplace=True), nn.Conv2d(outchannel,outchannel,3,1,1,bias=False), nn.BatchNorm2d(outchannel) ) self.right=shortcut def forward(self,x): out=self.left(x) residual=x if self.right is None else self.right(x) out+=residual return F.relu(out) class ResNet(nn.Module): ''' 实现主module:ResNet34 ResNet34 包含多个layer,每个layer又包含多个residual block 用子module来实现residual block,用_make_layer函数来实现layer ''' def __init__(self,num_classes=1000): super(ResNet,self).__init__() # 前几层图像转换 self.pre=nn.Sequential( nn.Conv2d(3,64,7,2,3,bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(3,2,1) ) # 重复的layer,分别有3,4,6,3个residual block self.layer1=self._make_layer(64,64,3) self.layer2=self._make_layer(64,128,4,stride=2) self.layer3=self._make_layer(128,256,6,stride=2) self.layer4=self._make_layer(256,512,3,stride=2) #分类用的全连接 self.fc=nn.Linear(512,num_classes) def _make_layer(self,inchannel,outchannel,bloch_num,stride=1): ''' 构建layer,包含多个residual block ''' shortcut=nn.Sequential( nn.Conv2d(inchannel,outchannel,1,stride,bias=False), nn.BatchNorm2d(outchannel) ) layers=[] layers.append(ResidualBlock(inchannel,outchannel,stride,shortcut)) for i in range(1,bloch_num): layers.append(ResidualBlock(outchannel,outchannel)) return nn.Sequential(*layers) def forward(self,x): x=self.pre(x) x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) x=self.layer4(x) x=F.avg_pool2d(x,7) x=x.view(x.size(0),-1) return self.fc(x) model=ResNet() print(model)
输出:
ResNet34( (pre): Sequential( (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) ) (layer1): Sequential( (0): ResidualBlock( (left): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (right): Sequential( (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ResidualBlock( (left): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): ResidualBlock( (left): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (layer2): Sequential( (0): ResidualBlock( (left): Sequential( (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (right): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ResidualBlock( (left): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): ResidualBlock( (left): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): ResidualBlock( (left): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (layer3): Sequential( (0): ResidualBlock( (left): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (right): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ResidualBlock( (left): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): ResidualBlock( (left): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): ResidualBlock( (left): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (4): ResidualBlock( (left): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): ResidualBlock( (left): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (layer4): Sequential( (0): ResidualBlock( (left): Sequential( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (right): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ResidualBlock( (left): Sequential( (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): ResidualBlock( (left): Sequential( (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (fc): Linear(in_features=512, out_features=1000, bias=True) ) 规律总结如下:
- 对于模型中的重复部分,实现为子module或用函数生成相应的module
make_layer
- nn.Module和nn.Functional结合使用
- 尽量使用
nn.Seqential
另外,与PyTorch配套的图像工具包torchvision
已经实现了深度学习中大多数经典的模型,其中就包括ResNet34,可以通过下面两行代码使用:
from torchvision import models model = models.resnet34()
本例中ResNet34的实现就是参考了torchvision中的实现并做了简化,感兴趣的读者可以阅读相应的源码,比较这里的实现和torchvision中实现的不同。
参考:https://github.com/chenyuntc/pytorch-book/blob/master/chapter4-%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E5%B7%A5%E5%85%B7%E7%AE%B1nn/chapter4.ipynb