如何使用PyTorch简易搭建残差网络

作者:Ta-ying Chen,牛津大学机器学习博士研究生,Medium知名技术博主

译者:颂贤

图源:Unsplash 时兴的自动驾驶和人脸检测等众多计算机视觉应用之所以能够实现都要归功于深度神经网络。然而,许多人可能都不知道的是,近年来计算机视觉的突破性进步都是由一种特定类型的网络架构推动的,也就是所谓的残差网络(residual network,ResNet)。事实上,我们所看到的诸多先进的人工智能成果,没有残差块(residual blocks)的发明都是不可能实现的。是残差块这个如此简单而优雅的概念,使我们有了真正的“深度“网络。

本文将带大家探讨残差网络背后的基本原理,并介绍如何在PyTorch中简单实现残差网络并训练ResNets进行图像分类。

退化问题:层数越多越强大吗?

理论上来说,变量更多的深层网络应该能够更好完成图像理解这样的困难任务。然而有证据表明,按照传统方法加深层数的网络实际上更难训练,表现甚至要比浅层网络更差。这就是我们所说的退化问题(degradation problem)。

退化问题似乎在我们的直觉上并不合理。按理说,如果我们有两个层数相同的网络,并在第二个网络前面增加x层,最坏的情况应该是前x层输出了恒等映射,导致两个网络最后的性能相同。一个可能的猜想是,层数更多的网络之所以性能较差是因为恒等映射被网络忽略掉了。因此在2010年代早期,像VGG-16这样的网络一般都会被限制在10-20层左右。

残差架构

残差网络针对上面所讲的退化问题提出了一种简单明了的解决方案。残差网络能够创建一条捷径,称为跳跃连接(skip connection),将原始输入放入网络并使其经过几个堆叠层后最终与输出特征相结合。 图1 简单的残差块(source:https://arxiv.org/abs/1512.03385) 如图1所示,设要进入堆叠层的输入为x,中间的堆叠层为函数F,那么最终的输出y为

y=F(x)+x\

\

F(x)x的维度不匹配时,我们可以简单地在跳跃连接的过程中进行线性投影来改变x的维度。

我们把上面的整个pipeline称为一个残差块,我们可以结合使用多个残差块来构建一个非常深的网络并规避原有的退化问题。

计算环境

我们将使用PyTorch(包括torchvision)来搭建残差网络,下面的代码能够导入我们所需要的所有库:

"""
The following is an import of PyTorch libraries.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import random
复制代码

数据集

为了展示残差网络的强大能力,我们将使用两个数据集进行测试:较简单的MINST数据集(包括60000张0到9的手写阿拉伯数字图像)以及更复杂的CIFAR-10数据集。

下载链接如下:1. MNIST数据集 2. CIFAR-10数据集

在测试网络的过程中,我们往往需要参考多个不同数据集的训练结果。格物钛的公开数据集平台提供一站式平台获取、筛选和管理高质量的数据集,我们的工作就会变得非常方便。免费提供了许多知名的非结构化数据集。使用格物钛提供的SDK,我们甚至可以直接将这些数据集,集成到我们的代码中来进行训练和测试。

硬件要求

一般来说,我们虽然可以使用CPU来训练神经网络,但最佳选择其实是GPU,因为这样可以大幅提升训练速度。本文带领大家搭建的残差网络比较简单,CPU和GPU都能运行,但实际应用中会出现更加复杂的网络(如ResNet-152),一般用GPU为佳。我们可以用下面的代码来测试自己的机器能否用GPU来训练:

"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
复制代码

搭建残差块

这里,我们将介绍如何在卷积神经网络上创建最简单的残差块,其输入和输出的维度是相同的。下面的代码结合了PyTorch的nn.Module来创建残差块。

"""
Define an nn.Module class for a simple residual block with equal dimensions
"""
class ResBlock(nn.Module):
​
    """
    Initialize a residual block with two convolutions followed by batchnorm layers
    """
    def __init__(self, in_size:int, hidden_size:int, out_size:int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_size, hidden_size, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_size, out_size, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(hidden_size)
        self.batchnorm2 = nn.BatchNorm2d(out_size)
​
    def convblock(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = F.relu(self.batchnorm2(self.conv2(x)))
        return x
​
    """
    Combine output with the original input
    """
    def forward(self, x): return x + self.convblock(x) # skip connection
复制代码

使用既有的ResNet模型

事实上,有许多利用了残差结构的网络在使用ImageNet这样的大型数据集训练之后结果非常出色,我们完全可以利用这些既有的模型而无需重复造轮子。Torchvision在其库中就提供了ResNet-34、ResNet-50和ResNet-152等网络的检查点和架构的预建模型。我们可以通过以下代码得到上述模型:

"""
Creates a pretrained on ImageNet Resnet34
"""
resnet = torchvision.models.resnet34(pretrained=True)
复制代码

不过需要注意的是,如果我们需要用ImageNet之外的数据集对模型进行微调,务必注意要对ResNet的最后一层进行改动,因为最后的one-hot向量的维度需要等于数据集的类的数量。

结果

经过50轮训练,我们搭建的网络在MNIST数据集上可以轻松达到99%左右的准确度,而在CIFAR-10数据集上ResNet-34和ResNet-152都可以达到90%。根据何恺明等人的原始论文,我们也可以看到,残差结构在ImageNet数据集上的表现明显好于VGG,同时也比一个具有相同层数但没有残差结构的网络要更强。

结语

何恺明等人的残差结构可以说是近期计算机视觉方向的神经网络领域中最出色的发明之一。今天几乎所有的网络(甚至超越卷积网络的网络)为了在层数大量堆叠之后也能维持良好的表现,都或多或少地使用了残差块这一结构。这一简单而优雅的方法为推动机器理解人类世界创造了无数的可能性。

猜你喜欢

转载自juejin.im/post/7018109158247366686