ResNet源码及原理

基于pytorch架构:

class Residual(nn.Module):
  def __init__(self, numIn, numOut):
    super(Residual, self).__init__()
    self.numIn = numIn
    self.numOut = numOut
    self.bn = nn.BatchNorm2d(self.numIn)
    self.relu = nn.ReLU(inplace = True)
    self.conv1 = nn.Conv2d(self.numIn, self.numOut / 2, bias = True, kernel_size = 1)
    self.bn1 = nn.BatchNorm2d(self.numOut / 2)
    self.conv2 = nn.Conv2d(self.numOut / 2, self.numOut / 2, bias = True, kernel_size = 3, stride = 1, padding = 1)
    self.bn2 = nn.BatchNorm2d(self.numOut / 2)
    self.conv3 = nn.Conv2d(self.numOut / 2, self.numOut, bias = True, kernel_size = 1)
    
    if self.numIn != self.numOut:
      self.conv4 = nn.Conv2d(self.numIn, self.numOut, bias = True, kernel_size = 1) 
    
  def forward(self, x):
    residual = x
    out = self.bn(x)
    out = self.relu(out)
    out = self.conv1(out)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)
    out = self.conv3(out)
    
    if self.numIn != self.numOut:
      residual = self.conv4(x)
    
    return out + residual

盗一张图:


设最终输出为H(x),即H(x)=F(x)+x。为什么说ResNet网络可以降低层数加深后梯度下降问题呢,假设输入x已经是最优,那么输出H(x)也为最优,那么F(x)=H(x)-x,就为0,而不是其他值。

猜你喜欢

转载自blog.csdn.net/cz_bykz/article/details/79954491