PyTorch实现简单的残差网络

一、实现过程

残差网络(Residual Network)的特点是容易优化,并且能够通过增加相当的深度来提高准确率。其内部的残差块使用了跳跃连接,缓解了在深度神经网络中增加深度带来的梯度消失问题。
本文实现如图1所示的两层残差模块用于识别MNIST数据集,其中每一层均是卷积层。
在这里插入图片描述

图1 残差构建模块

残差构建模块封装成类,代码如下:

class ResidualBlock(torch.nn.Module):
    def __init__(self,channels):
        super(ResidualBlock,self).__init__()
        self.channels = channels
        
        self.conv1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
    
    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = self.conv2(x)
        return F.relu(x+y)

嵌入残差模块的网络模型代码如下:

class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = torch.nn.Conv2d(1,16,kernel_size=5)
        self.conv2 = torch.nn.Conv2d(16,32,kernel_size=5)
        self.mp = torch.nn.MaxPool2d(2)
        
        self.rblock1 = ResidualBlock(16)
        self.rblock2 = ResidualBlock(32)
        
        self.fc = torch.nn.Linear(512,10)
        
    def forward(self,x):
        # Flatten data from (n,1,28,28) to (n,784)
        in_size = x.size(0)
        x = self.mp(F.relu(self.conv1(x)))
        x = self.rblock1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)
        x = x.view(in_size,-1)  # flatten
#         print(x.size(1))
        return self.fc(x)
model = Net()

运行结果如下:

[1,300] loss: 0.486
[1,600] loss: 0.143
[1,900] loss: 0.103
Accuracy on test set: 97.34 % [9734/10000]
[2,300] loss: 0.082
[2,600] loss: 0.074
[2,900] loss: 0.066
Accuracy on test set: 98.37 % [9837/10000]
[3,300] loss: 0.058
[3,600] loss: 0.052
[3,900] loss: 0.051
Accuracy on test set: 98.68 % [9868/10000]
[4,300] loss: 0.044
[4,600] loss: 0.047
[4,900] loss: 0.038
Accuracy on test set: 98.81 % [9881/10000]
[5,300] loss: 0.037
[5,600] loss: 0.035
[5,900] loss: 0.038
Accuracy on test set: 98.8 % [9880/10000]
[6,300] loss: 0.030
[6,600] loss: 0.034
[6,900] loss: 0.032
Accuracy on test set: 98.89 % [9889/10000]
[7,300] loss: 0.029
[7,600] loss: 0.030
[7,900] loss: 0.026
Accuracy on test set: 98.83 % [9883/10000]
[8,300] loss: 0.026
[8,600] loss: 0.028
[8,900] loss: 0.021
Accuracy on test set: 99.04 % [9904/10000]
[9,300] loss: 0.021
[9,600] loss: 0.023
[9,900] loss: 0.022
Accuracy on test set: 99.05 % [9905/10000]
[10,300] loss: 0.019
[10,600] loss: 0.019
[10,900] loss: 0.022
Accuracy on test set: 99.05 % [9905/10000]

在这里插入图片描述
可以看出:带残差的深度网络比普通的深度网络的学习效果更好。

二、参考文献

[1] K. He, X. Zhang, S. Ren and J. Sun. Deep Residual Learning for Image Recognition[C]. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778.
[2] https://www.bilibili.com/video/BV1Y7411d7Ys?p=11

猜你喜欢

转载自blog.csdn.net/weixin_43821559/article/details/123384077