【youcans动手学模型】Wide ResNet 模型

欢迎关注『youcans动手学模型』系列
本专栏内容和资源同步到 GitHub/youcans



本文用 PyTorch 实现 WideResNet 网络模型,使用 CIFAR10 数据集训练模型,进行图像分类。


1. Wide ResNet 卷积神经网络模型

Sergey Zagoruyko, Nikos Komodakis 等在 2016 年发表论文 “Wide Residual Networks”,提出一种宽残差网络( wide residual networks , WRN)架构,通过增加残差网络的宽度和减少网络深度,提高了残差网络的准确性和效率。

【论文下载地址】:【Wide Residual Networks】https://arxiv.org/abs/1605.07146

【GitHub 地址】:参考例程1参考例程2


1.1 模型简介

Wide ResNet 是一种宽残差网络( wide residual networks , WRN)架构,通过增加残差网络的宽度和减少网络深度,提高了残差网络的准确性和效率。

在这里插入图片描述


1.2 论文介绍

【论文摘要】

深度残差网络被证明能够扩展到数千层,并且仍然可以提升网络的性能。然而,每提高 1% 精度的成本大约是层数的 2倍。因此,训练非常深的残差网络存在重复使用特征的效率降低问题,这使得网络训练非常缓慢。

我们对 ResNet 架构进行了详细的实验研究,发现残差网络的能力主要由残差 block 提供,网络深度只有补充性的作用。

我们提出一种宽残差网络(WRN)架构,可以减少残差网络的深度并增加其宽度。这种网络结构远优于常用的非常深而窄的网络结构。即使是一个简单的 16层的宽残差网络,其准确性和效率也优于所有以前的深残差网络,包括千层深度的网络。在 CIFAR、SVHN、COCO 数据集上实现了最先进的结果,并在 ImageNet 上实现了显著改进。


【论文背景】

从 AlexNet、VGG、Inception 到残差网络,随着神经网络层数的不断增加,性能也不断提升。然而,由于梯度爆炸/消失和退化,训练越来越深的神经网络非常困难。近来,残差网络 ResNet 取得了巨大成功,残差链路加快了深度网络的收敛。稍早的 Highway 网络,也可以成功地训练非常深入的网络。

Residual block 提出了跳跃连接(skip connection)的概念,将层的输入直接加到输出端,构成一个残差路径(shortcut),使网络可以直接学习残差部分的变化 。其数学描述如下:

C o n v L a y e r : x l = F ( x l − 1 , W l − 1 ) R e s B l o c k : x l = F ( x l − 1 , W l − 1 ) + x l − 1 \begin{matrix} Conv Layer: &x_l = &F(x_{l-1}, W_{l-1})\\ Res Block: &x_l = &F(x_{l-1}, W_{l-1}) &+ &x_{l-1} \end{matrix} ConvLayer:ResBlock:xl=xl=F(xl1,Wl1)F(xl1,Wl1)+xl1

残差块(Residual block)有两种结构形式。

(1)基本残差块:如图(a) basic 所示,具有两个连续的 3*3 卷积,带有 BN 批量归一化和 ReLU 激活函数。

(2)瓶颈式(bottleneck)残差块:如图(b) bottleneck 所示,一个 3*3 卷积层,前后各有一个 1*1 卷积层分别进行降维和扩展。具体而言,先通过一个 1*1 卷积层对输入特征图进行降维,然后进行 3*3 卷积,最后又通过一个 1*1 卷积层进行升维扩展。3*3 卷积层较窄(薄),形成瓶颈(bottleneck),可以减少 Residual block 的计算量。

到目前为止,关于残差网络的研究主要集中在 ResNet block 内激活的顺序和残差网络的深度上。与原始架构 ResNet 相比,改进的残差块中的批量归一化、激活和卷积操作的顺序从conv-BN-ReLU 更改为 BN-ReLU-conv。后者被证明训练更快,效果更好。

我们的目标是探索一组更丰富的 ResNet block 的网络架构,并研究除了激活顺序之外的其他几个不同方面如何影响性能。

残差网络中的宽度与深度

电路复杂性理论表明,浅层电路可能比深层电路需要指数级多的组件。残差网络的作者试图使其尽可能薄(瘦),以便于增加深度和减少参数,甚至引入了一个“瓶颈”块,使 ResNet block 变得更薄。

然而,虽然残差块允许训练非常深入的网络,但当梯度流经网络时,并不能保证通过残差块权重,可能只有少数块学习有用的表示,或者许多块共享很少的信息,对最终目标的贡献很小。这个问题被称为为特征重用的效率降低。

我们试图回答深度残差网络应该有多宽的问题。实验表明,深度残差网络的性能主要来自残差块,深度的影响是补充的。

我们的研究表明,与增加残差网络的深度相比,扩大 ResNet block 可以更有效地提高残差网络的性能。我们提出了更宽的宽残差网络(Wide ResNet),其层数减少了 50倍,速度提高了 2倍以上。

在 ResNet block 中使用 Dropout

Dropout 主要应用于具有大量参数的顶层,以防止过拟合,后来经常被批量归一化(batch normalization,BN)所取代。批量归一化也可以作为正则化子,实验表明具有 BN 的网络比具有 Dropout 的网络的精度更好。

由于残差块的加宽导致参数数量的增加,我们研究了 Dropout 对正则化训练和防止过拟合的影响,认为 Dropout 应该插入到卷积层之间。在宽残差网络上的实验结果表明 Dropout 是有效的,例如具有 Dropout 的 16层的宽残差网络在 SVHN 上实现了 1.64% 的误差。


【主要创新】

有三种简单的方法来增加残差块的特征表达能力:

  • 增加层数:为每个残差块添加更多的卷积层;
  • 增加宽度:添加更多的特征图来加宽卷积层;
  • 增大卷积核:增加卷积层中的卷积核的大小。

由于 VGG 等研究证明小卷积核的有效性,因此我们不考虑使用大于 3*3 的卷积核。

我们引入两个参数,深度因子 l l l 和宽度因子 k k k,其中 l l l 是残差块中卷积层的数量, k k k 是输入特征的扩展倍数,因此图(a) basic 所示的残差块对应于 l = 2 , k = 1 l=2, k=1 l=2,k=1,图© basic-wide 所示的残差块对应于 l = 2 , k > 1 l=2, k>1 l=2,k>1

增加卷积层的宽度(增大 k k k),参数数量和计算复杂度增大。但是由于 GPU 在大张量上的并行计算效率很高,使用加宽层的计算效率更高。残差网络之前的所有架构,例如 VGG 和 Inception 都使用了很宽的卷积层。因此我们希望通过试验找到残差块数量 d 与 宽度因子 k 的最佳比率。

另外,我们通过 Dropout 进行正则化。残差网络通过批量归一化 BN 进行正则化,我们则将 Dropout 层添加到残差块中来防止过拟合,如图(d) wide-dropout 所示。


【模型结构】

我们的残差网络的一般结构如表 1所示:首先是初始卷积层 conv1,随后是 3组(每个大小为N)残差块 conv2、conv3 和 conv4,最后是平均池化层和分类层。

在这里插入图片描述

在我们的所有实验中,conv1 的大小都是固定的,而引入的宽度因子 k k k 缩放了三组 conv2~conv4 中残差块的宽度。

B ( M ) B(M) B(M) 表示残差块结构,其中 M 是残差块中卷积核的列表。例如,B(3,1) 表示一个 3*3 卷积层和一个 1*1 卷积层,B(1,3,1) 表示如图(b) bottleneck 所示的 1*1 卷积层、 3*3 卷积层和 1*1 卷积层。


【性能】

  1. 在 Cifar10 和 Cifar100上,左图没用dropout的比较,它效果最好。右图是用dropout的比较,深度28宽度10时效果最好
  2. 左图上面的没用瓶颈设计,大型数据集,网络很深的才用,这是也就用不着dropout了,如WRN-50-2-bottleneck。右图它在CIFAR-10、CIFAR-100、SVHN和COCO达到了最佳结果。

2. 在 PyTorch 中定义 WideResNet 模型类

2.1 自定义 WideResNet 模型类

PyTorch 通过 torch.nn 模块提供了高阶的 API,可以从头开始构建网络。

# https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet1(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet1, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=nChannels[0], kernel_size=3,
                               stride=1, padding=1, bias=False)
        # 1st block
        self.conv2 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.conv3 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.conv4 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.relu(self.bn(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

2.2 自定义 WideResNet 模型类 之二

PyTorch 通过 torch.nn 模块提供了高阶的 API,可以从头开始构建网络。

# https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py
class wide_basic(nn.Module):
    def __init__(self, ch_in, ch_out, dropout_rate, stride=1):
        super(wide_basic, self).__init__()
        self.bn1 = nn.BatchNorm2d(ch_in)
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1, bias=True)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=stride, padding=1, bias=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or ch_in != ch_out:
            self.shortcut = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x):
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.shortcut(x)
        return out

class WideResNet2(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet2, self).__init__()
        self.ch_in = 16

        assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
        n = (depth-4)/6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]
        print('| Wide-Resnet %dx%d' %(depth, k))

        # 1st conv before any network block
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=nStages[0], kernel_size=3,
                               stride=1, padding=1, bias=False)
        # Wide Res Block
        self.conv2 = self._wide_layer(wide_basic, nStages[1], n, dropRate, stride=1)
        self.conv3 = self._wide_layer(wide_basic, nStages[2], n, dropRate, stride=2)
        self.conv4 = self._wide_layer(wide_basic, nStages[3], n, dropRate, stride=2)
        # global average pooling and classifier
        self.bn = nn.BatchNorm2d(nStages[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nStages[3], num_classes)

    def _wide_layer(self, block, planes, num_blocks, dropRate, stride):
        strides = [stride] + [1]*(int(num_blocks)-1)
        layers = []
        for stride in strides:
            layers.append(block(self.ch_in, planes, dropRate, stride))
            self.ch_in = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.relu(self.bn(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

2.3 从 torchvision.model 加载预定义模型

Torchvision 自带了很多经典的网络模型,可以直接加载这些预定义模型。我们可以只使用预定义的模型类来创建实例化模型对象(不加载预训练的模型参数)用于模型训练,也可以在实例化模型对象的同时加载预训练的模型参数,还可以基于预训练模型进行模型微调或迁移学习。

torchvision.models 包和 Torch Hub 中都提供了 Wide ResNet 模型。torchvision.models 提供了 Wide ResNet模型类和预训练模型 Wide ResNet 模型 |PyTorch 可以直接使用,原始代码可以参考: 源代码。 Torch Hub 中提供了 Wide ResNet 模型 |PyTorch Hub

以下模型构建器可用于实例化 Wide ResNet 模型,无论是否具有预先训练的权重。所有的模型构建器内部都依赖于torchvision.models.resnet.resnet 基类。

wide_resnet50_2(*[, weights, progress])  # Wide ResNet-50-2 model from Wide Residual Networks.

wide_resnet101_2(*[, weights, progress])  # Wide ResNet-101-2 model from Wide Residual Networks.

Torch Hub 中提供了 Wide ResNet 模型类,也提供了在 ImageNet 数据集上训练好的预训练模型,可以直接用来进行图像分类或进行迁移学习。

# torch.hub 方式加载 load WRN-50-2
model = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet50_2', pretrained=True)

# torch.hub 方式加载 load WRN-101-2
model = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet101_2', pretrained=True)

3. 基于 WideResNet 模型的 CIFAR10 图像分类

3.1 PyTorch 建立神经网络模型的基本步骤

使用 PyTorch 建立、训练和使用神经网络模型的基本步骤如下。

  1. 准备数据集(Prepare dataset):加载数据集,对数据进行预处理。
  2. 建立模型(Design the model):实例化模型类,定义损失函数和优化器,确定模型结构和训练方法。
  3. 模型训练(Model trainning):使用训练数据集对模型进行训练,确定模型参数。
  4. 模型推理(Model inferring):使用训练好的模型进行推理,对输入数据预测输出结果。
  5. 模型保存与加载(Model saving/loading):保存训练好的模型,以便以后使用或部署。

以下按此步骤讲解 AlexNet 模型的例程。


3.2 加载 CIFAR10 数据集

通用数据集的样本结构均衡、信息高效,而且组织规范、易于处理。使用通用的数据集训练神经网络,不仅可以提高工作效率,而且便于评估模型性能。

PyTorch 提供了一些常用的图像数据集,预加载在 torchvision.datasets 类中。torchvision 模块实现神经网络所需的核心类和方法, torchvision.datasets 包含流行的数据集、模型架构和常用的图像转换方法。

CIFAR 数据集是一个经典的图像分类小型数据集,有 CIFAR10 和 CIFAR100 两个版本。CIFAR10 有 10 个类别,CIFAR100 有 100 个类别。CIFAR10 每张图像大小为 32*32,包括飞机、小汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车 10 个类别。CIFAR10 共有 60000 张图像,其中训练集 50000张,测试集 10000张。每个类别有 6000张图片,数据集平衡。

加载和使用 CIFAR 数据集的方法为:

torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()

CIFAR 数据集可以从官网下载:http://www.cs.toronto.edu/~kriz/cifar.html 后使用,也可以使用 datasets 类自动加载(如果本地路径没有该文件则自动下载)。

下载数据集时,使用预定义的 transform 方法进行数据预处理,包括调整图像尺寸、标准化处理,将数据格式转换为张量。标准化处理所使用 CIFAR10 数据集的均值和方差为 (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)。transform_train在训练过程中,增加随机性,提高泛化能力。

大型训练数据集不能一次性加载全部样本来训练,可以使用 Dataloader 类自动加载数据。Dataloader 是一个迭代器,基本功能是传入一个 Dataset 对象,根据参数 batch_size 生成一个 batch 的数据。

使用 DataLoader 类加载 CIFAR-10 数据集的例程如下。

    # (1) 将[0,1]的PILImage 转换为[-1,1]的Tensor
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.RandomRotation(10),  # 随机旋转
        transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Resize(32),  # 图像大小调整为 (w,h)=(32,32)
        transforms.ToTensor(),  # 将图像转换为张量 Tensor
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
    # 测试集不需要进行数据增强
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

    # (2) 加载 CIFAR10 数据集
    batchsize = 128
    # 加载 CIFAR10 数据集, 如果 root 路径加载失败, 则自动在线下载
    # 加载 CIFAR10 训练数据集, 50000张训练图片
    train_set = torchvision.datasets.CIFAR10(root='../dataset', train=True,
                                            download=True, transform=transform_train)
    # train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize,
                                              shuffle=True, num_workers=8)
    # 加载 CIFAR10 验证数据集, 10000张验证图片
    test_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
                                           download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000,
                                              shuffle=True, num_workers=8)
    # 创建生成器,用 next 获取一个批次的数据
    valid_data_iter = iter(test_loader)  # _SingleProcessDataLoaderIter 对象
    valid_images, valid_labels = next(valid_data_iter)  # images: [batch,3,32,32], labels: [batch]
    valid_size = valid_labels.size(0)  # 验证数据集大小,batch
    print(valid_images.shape, valid_labels.shape)

    # 定义类别名称,CIFAR10 数据集的 10个类别
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')


3.3 建立 WideResNet 网络模型

建立一个 WideResNet 网络模型进行训练,包括三个步骤:

  • 实例化 WideResNet 模型对象;
  • 设置训练的损失函数;
  • 设置训练的优化器。

torch.nn.functional 模块提供了各种内置损失函数,本例使用交叉熵损失函数 CrossEntropyLoss。

torch.optim 模块提供了各种优化方法,本例使用 Adam 优化器。注意要将 model 的参数 model.parameters() 传给优化器对象,以便优化器扫描需要优化的参数。

    # (3) 构造 WideResNet 网络模型
    model = WideResNet1(depth=28, num_classes=10, widen_factor=2, dropRate=0.1)  # 实例化 WideResNet 网络模型
    model.to(device)  # 将网络分配到指定的device中
    print(model)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()  # 定义损失函数 CrossEntropy
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)  # 定义优化器 SGD

3.4 WideResNet 模型训练

PyTorch 模型训练的基本步骤是:

  1. 前馈计算模型的输出值;
  2. 计算损失函数值;
  3. 计算权重 weight 和偏差 bias 的梯度;
  4. 根据梯度值调整模型参数;
  5. 将梯度重置为 0(用于下一循环)。

在模型训练过程中,可以使用验证集数据评价训练过程中的模型精度,以便控制训练过程。模型验证就是用验证数据进行模型推理,前向计算得到模型输出,但不反向计算模型误差,因此需要设置 torch.no_grad()。

使用 PyTorch 进行模型训练的例程如下。

    # (4) 训练 WideResNet 模型
    epoch_list = []  # 记录训练轮次
    loss_list = []  # 记录训练集的损失值
    accu_list = []  # 记录验证集的准确率
    num_epochs = 100  # 训练轮次
    for epoch in range(num_epochs):  # 训练轮次 epoch
        running_loss = 0.0  # 每个轮次的累加损失值清零
        for step, data in enumerate(train_loader, start=0):  # 迭代器加载数据
            optimizer.zero_grad()  # 损失梯度清零

            inputs, labels = data  # inputs: [batch,3,32,32] labels: [batch]
            outputs = model(inputs.to(device))  # 正向传播
            loss = criterion(outputs, labels.to(device))  # 计算损失函数
            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新

            # 累加训练损失值
            running_loss += loss.item()
            # if step%100==99:  # 每 100 个 step 打印一次训练信息
            #     print("\t epoch {}, step {}: loss = {:.4f}".format(epoch, step, loss.item()))

        # 计算每个轮次的验证集准确率
        with torch.no_grad():  # 验证过程, 不计算损失函数梯度
            outputs_valid = model(valid_images.to(device))  # 模型对验证集进行推理, [batch, 10]
        pred_labels = torch.max(outputs_valid, dim=1)[1]  # 预测类别, [batch]
        accuracy = torch.eq(pred_labels, valid_labels.to(device)).sum().item() / valid_size * 100  # 计算准确率
        print("Epoch {}: train loss={:.4f}, accuracy={:.2f}%".format(epoch, running_loss, accuracy))

        # 记录训练过程的统计数据
        epoch_list.append(epoch)  # 记录迭代次数
        loss_list.append(running_loss)  # 记录训练集的损失函数
        accu_list.append(accuracy)  # 记录验证集的准确率

程序运行结果如下:

Epoch 0: train loss=717.2183, accuracy=44.00%
Epoch 1: train loss=556.0241, accuracy=55.60%
Epoch 2: train loss=473.6305, accuracy=63.00%

Epoch 97: train loss=87.7806, accuracy=90.90%
Epoch 98: train loss=90.4091, accuracy=89.60%
Epoch 99: train loss=89.8074, accuracy=91.00%

经过 20 轮左右的训练,使用验证集中的 1000 张图片进行验证,模型准确率达到 80%以上。继续训练可以进一步降低训练损失函数值,验证集的准确率达到 90% 左右。

在这里插入图片描述


3.5 WideResNet 模型的保存与加载

模型训练好以后,将模型保存起来,以便下次使用。PyTorch 中模型保存主要有两种方式,一是保存模型权值,二是保存整个模型。本例使用 model.state_dict() 方法以字典形式返回模型权值,torch.save() 方法将权值字典序列化到磁盘,将模型保存为 .pth 文件。

    # (5) 保存 WideResNet 网络模型
    save_path = "../models/WideResNet_Cifar1"
    model_cpu = model.cpu()  # 将模型移动到 CPU
    model_path = save_path + ".pth"  # 模型文件路径
    torch.save(model.state_dict(), model_path)  # 保存模型权值
    # 优化结果写入数据文件
    result_path = save_path + ".csv"  # 优化结果文件路径
    WriteDataFile(epoch_list, loss_list, accu_list, result_path)

使用训练好的模型,首先要实例化模型类,然后调用 load_state_dict() 方法加载模型的权值参数。

    # 训练结果可视化
    plt.figure(figsize=(11, 5))
    plt.suptitle("WideResNet Model in CIFAR10")
    plt.subplot(121), plt.title("Train loss")
    plt.plot(epoch_list, loss_list)
    plt.xlabel('epoch'), plt.ylabel('loss')
    plt.subplot(122), plt.title("Valid accuracy")
    plt.plot(epoch_list, accu_list)
    plt.xlabel('epoch'), plt.ylabel('accuracy')
    plt.show()

需要特别注意的是:

(1)PyTorch 中的 .pth 文件只保存了模型的权值参数,而没有模型的结构信息,因此必须先实例化模型对象,再加载模型参数。

(2)模型对象必须与模型参数严格对应,才能正常使用。注意即使都是 WideResNet 模型,模型类的具体定义也可能有细微的区别。如果从一个来源获取模型类的定义,从另一个来源获取模型参数文件,就很容易造成模型结构与参数不能匹配。

(3)无论从 PyTorch 模型仓库加载的模型和参数,或从其它来源获取的预训练模型,或自己训练得到的模型,模型加载的方法都是相同的,也都要注意模型结构与参数的匹配问题。


3.6 模型检验

使用加载的 WideResNet 模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。

使用测试集数据进行模型推理,根据模型预测结果与图片标签进行比较,可以检验模型的准确率。模型验证集与模型检验集不能交叉使用,但为了简化例程在本程序中未做区分。

    # (7) 模型检验
    correct = 0
    total = 0
    for data in test_loader:  # 迭代器加载测试数据集
        imgs, labels = data  # torch.Size([batch,3,224,224]) torch.Size([batch])
        # print(imgs.shape, labels.shape)
        outputs = model(imgs.to(device))  # 正向传播, 模型推理, [batch, 10]
        labels_pred = torch.max(outputs, dim=1)[1]  # 模型预测的类别 [batch]
        # _, labels_pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += torch.eq(labels_pred, labels.to(device)).sum().item()
    accuracy = 100. * correct / total
    print("Test samples: {}".format(total))
    print("Test accuracy={:.2f}%".format(accuracy))

使用测试集进行模型推理,测试模型准确率为 90.03%。

Test samples: 10000
Test accuracy=90.03%


3.7 模型推理

使用加载的 WideResNet 模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。

从测试集中提取几张图片,或者读取图像文件,进行模型推理,获得图片的分类类别。在提取图片或读取文件时,要注意对图片格式和图片大小进行适当的转换。

    # (8) 提取测试集图片进行模型推理
    batch = 8  # 批次大小
    data_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
                                           download=False, transform=None)
    plt.figure(figsize=(9, 6))
    for i in range(batch):
        imgPIL = data_set[i][0]  # 提取 PIL 图片
        label = data_set[i][1]  # 提取 图片标签
        # 预处理/模型推理/后处理
        imgTrans = transform(imgPIL)  # 预处理变换, torch.Size([3, 224, 224])
        imgBatch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1, 3, 224, 224])
        outputs = model(imgBatch.to(device))  # 模型推理, 返回 [batch=1, 10]
        indexes = torch.max(outputs, dim=1)[1]  # 注意 [batch=1], device = 'device
        index = indexes[0].item()  # 预测类别,整数
        # 绘制第 i 张图片
        imgNP = np.array(imgPIL)  # PIL -> Numpy
        out_text = "label:{}/model:{}".format(classes[label], classes[index])
        plt.subplot(2, 4 ,i+1)
        plt.imshow(imgNP)
        plt.title(out_text)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

结果如下。

在这里插入图片描述

    # (9) 读取图像文件进行模型推理
    from PIL import Image
    filePath = "../images/img_plane_01.jpg"  # 数据文件的地址和文件名
    imgPIL = Image.open(filePath)  # PIL 读取图像文件, <class 'PIL.Image.Image'>

    # 预处理/模型推理/后处理
    imgTrans = transform["test"](imgPIL)  # 预处理变换, torch.Size([3, 224, 224])
    imgBatch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1, 3, 224, 224])
    outputs = model(imgBatch.to(device))  # 模型推理, 返回 [batch=1, 10]
    indexes = torch.max(outputs, dim=1)[1]  # 注意 [batch=1], device = 'device
    percentages = nn.functional.softmax(outputs, dim=1)[0] * 100
    index = indexes[0].item()  # 预测类别,整数
    percent = percentages[index].item()  # 预测类别的概率,浮点数

    # 绘制第 i 张图片
    imgNP = np.array(imgPIL)  # PIL -> Numpy
    out_text = "Prediction:{}, {}, {:.2f}%".format(index, classes[index], percent)
    print(out_text)
    plt.imshow(imgNP)
    plt.title(out_text)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

4. 使用 WideResNet 预训练模型进行图像分类

Torchvision.models 包和 Torch Hub 中不仅提供了 Wide ResNet 模型类,也提供了在 ImageNet 数据集上训练好的预训练模型,可以直接用来进行图像分类或进行迁移学习。

使用 Wide ResNet 预训练模型进行图像分类的完整例程如下。

# Begin_WideResNet_3.py
# WideResNet model for beginner with PyTorch
# 加载 WideResNet 预训练模型和参数,对图像进行分类
# Copyright: [email protected]
# Crated: Huang Shan, 2023/06/10

# _*_coding:utf-8_*_
import torch
from torchvision import models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
import numpy as np

if __name__ == '__main__':

    # (1) 加载 WideResNet/PyTorch 预训练模型
    # torch.hub 方式加载 load WRN-50-2:
    model = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet50_2', pretrained=True)
    model.eval()

    # (2) 定义输入图像的预处理变换,将 [0,1] 的 PILImage 转换为 [-1,1] 的Tensor
    transform = transforms.Compose([  # 定义图像变换组合
        transforms.Resize([256,256]),  # 图像大小调整为 (w,h)=(256,256)
        transforms.CenterCrop([224,224]),  # 图像中心裁剪为 (w,h)=(224,224)
        transforms.ToTensor(),  # 将图像转换为张量 Tensor
        transforms.Normalize(  # 对图像进行归一化
            mean=[0.485, 0.456, 0.406],  # 均值
            std=[0.229, 0.224, 0.225]  # 标准差
        )])

    # (3) 加载输入图像并进行预处理
    from PIL import Image
    filePath = "../images/img_car_01.jpg"  # 数据文件的地址和文件名
    imgPIL = Image.open(filePath)  # PIL 读取图像文件, <class 'PIL.Image.Image'>
    # 预处理/模型推理/后处理
    imgTrans = transform(imgPIL)  # 预处理变换, torch.Size([3,224,224])
    input_batch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1,3,224,224])

    # (4) 模型推理
    with torch.no_grad():
        outputs = model(input_batch)  # 返回所有类别的置信度score,torch.Size([batch, 1000])
    # _, index = torch.max(outputs, 1)  # Top-1 类别的索引,tensor([208])
    # print("index: ", index.item())  # 208 : sports car, sport car

    # (5) 模型输出后处理
    # 读取 ImageNet 文本格式类别名称文件
    with open("../dataset/imagenet_classes.txt") as f:  # 类别名称保存为 txt 文件
        categories = [line.strip() for line in f.readlines()]
    print(type(categories), len(categories))  # <class 'list'> 1000

    # 计算所有类别的概率
    probabilities = torch.nn.functional.softmax(outputs[0], dim=0) * 100  # 所有类别的概率,torch.Size([batch, 1000])
    # 查找 Top-5 类别的索引
    top5_prob, top5_idx = torch.topk(probabilities, 5)  # Top-5 类别的概率和索引, torch.Size([5])
    print("Top-5 possible categories:")
    for i in range(top5_prob.size(0)):
        print(top5_idx[i], categories[top5_idx[i]], top5_prob[i].item())

    # (6) 图像分类结果的可视化
    import cv2
    imgCV = cv2.cvtColor(np.asarray(imgPIL), cv2.COLOR_RGB2BGR)  # PIL 转换为 CV 格式
    out_text = f"{
      
      categories[top5_idx[0]]}, {
      
      top5_prob[0].item():.3f}"  # 类别标签 + 概率
    cv2.putText(imgCV, out_text, (25, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)  # 在图像上添加类别标签
    cv2.imshow('Image classification', imgCV)
    key = cv2.waitKey(0)  # delay=0, 不自动关闭
    cv2.destroyAllWindows()

结果如下。

<class 'list'> 1000
Top-5 possible categories:
tensor(436) 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', 49.345211029052734
tensor(656) 656: 'minivan', 31.259532928466797
tensor(581) 581: 'grille, radiator grille', 12.88940715789795
tensor(479) 479: 'car wheel', 2.3268837928771973
tensor(627) 627: 'limousine, limo', 2.073991060256958

在这里插入图片描述


参考文献

  1. Sergey Zagoruyko, Nikos Komodakis, Wide Residual Networks, 2016

  2. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, Deep Residual Learning for Image Recognition, 2015

  3. Wide ResNet model |PyTorch


【本节完】


版权声明:
欢迎关注『youcans动手学模型』系列
转发请注明原文链接:
【youcans动手学模型】Wide ResNet 模型
Copyright 2023 youcans, XUPT
Crated:2023-07-02


猜你喜欢

转载自blog.csdn.net/youcans/article/details/131497790
今日推荐