[pytorch、学习] - 5.7 使用重复元素的网络(VGG)

参考

5.7 使用重复元素的网络(VGG)

AlexNet在LeNet的基础上增加了3个卷积层。但AlexNet作者对它们的卷积窗口、输出通道数和构造顺序均做了大量的调整。虽然AlexNet指明了深度卷积神经网络可以取得出色的结果,但并没有提供简单的规则以指导后来的研究者如何设计新的网络。我们将在本章的后续几节里介绍几种不同的深度网络设计思路。

下面介绍VGG

5.7.1 VGG块

VGG块的组成规律是:连续使用数个相同的填充为1、窗口形状为3×3的卷积层后接上一个步幅为2、窗口形状为2×2的最大池化层。卷积层保持输入的高和宽不变,而池化层则对其减半。我们使用vgg_block函数来实现这个基础的VGG块,它可以指定卷积层的数量和输入输出通道数。

import time
import torch
from torch import nn, optim

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def vgg_block(num_convs, in_channels, out_channels):
    blk = []
    for i in range(num_convs):
        if i == 0:
            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        else:
            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        blk.append(nn.ReLU())
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 这里会使宽高减半
    return nn.Sequential(*blk)

5.7.2 VGG网络

与AlexNet和LeNet一样,VGG网络由卷积层模块后接全连接层模块构成。卷积层模块串联数个vgg_block,其超参数由变量conv_arch定义。该变量指定了每个VGG块里卷积层个数和输入输出通道数。全连接模块则跟AlexNet中的一样。

现在我们构造一个VGG网络。它有5个卷积块,前2块使用单卷积层,而后3块使用双卷积层。第一块的输入输出通道分别是1和64,之后每次对输出通道数翻倍,直到变为512。因为这个网络使用了8个卷积层和3个全连接层,所以经常被称为VGG-11。

conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))
# 经过5个 vgg_block, 宽高会减半5次, 变成 224/32 = 7
fc_features = 512 * 7 * 7
fc_hidden_units = 4096 # 任意

下面实现VGG-11

def vgg(conv_arch, fc_features, fc_hidden_units=4096):
    net = nn.Sequential()
    # 卷积层部分
    # conv_arch: ((1,1,64),(1,64,128),(2,128,256),(2,256,512),(2,512,512))
    for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):
        # 每经过一个vgg_block都会使宽高减半
        """
            (1,1,64):
                - 0: nn.Conv2d(1, 64, kernel_size=3, padding=1)  # (1, 1, 224, 224) -> (1, 64, 224, 224)
                nn.MaxPool2d(kernel_size=2, stride=2)  # (1, 64, 224, 224) -> (1, 64, 112, 112)
            (1,64,128):
                - 0: nn.Conv2d(64, 128, kernel_size=3, padding=1)  # (1, 64, 112, 112) -> (1, 128, 112, 112)
                nn.MaxPool2d(kernel_size=2, stride=2)  #  (1, 128, 112, 112) -> (1, 128, 56, 56)
            (2,128,256):
                - 0: nn.Conv2d(128, 256, kernel_size=3, padding=1)  #  (1, 128, 56, 56) -> (1, 256, 56, 56)
                - 1: nn.Conv2d(256, 256, kernel_size=3, padding=1)
                nn.MaxPool2d(kernel_size=2, stride=2)  # (1, 256, 56, 56) -> (1, 256, 28, 28)
            (2,256,512):
                - 0: nn.Conv2d(256, 512, kernel_size=3, padding=1)  # (1, 256, 28, 28) -> (1, 512, 28, 28)
                - 1: nn.Conv2d(512, 512, kernel_size=3, padding=1)
                nn.MaxPool2d(kernel_size=2, stride=2)  # (1, 512, 28, 28) -> (1, 512, 14, 14)
            (2,512,512):
                - 0: nn.Conv2d(512, 512, kernel_size=3, padding=1)
                - 1: nn.Conv2d(512, 512, kernel_size=3, padding=1)
                nn.MaxPool2d(kernel_size=2, stride=2)  # (1, 512, 14, 14) -> (1, 512, 7, 7)
        """
        net.add_module("vgg_block_" + str(i+1), vgg_block(num_convs, in_channels, out_channels))

    # 全连接层部分
    net.add_module("fc", nn.Sequential(d2l.FlattenLayer(),
                                       nn.Linear(fc_features, fc_hidden_units),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Linear(fc_hidden_units, fc_hidden_units),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Linear(fc_hidden_units, 10)
                                      ))
    return net
net = vgg(conv_arch, fc_features, fc_hidden_units)
print(net)

# X = torch.rand(1, 1, 224, 224)

# for name, blk in net.named_children():
#     X = blk(X)
#     print(name, "output shape: ", X.shape)

在这里插入图片描述

5.7.3 获取数据和训练模型

ratio = 8
small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), 
                   (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]
net = vgg(small_conv_arch, fc_features // ratio, fc_hidden_units // ratio)

print(net)

在这里插入图片描述

batch_size = 64
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/piano9425/article/details/107200015
5.7