【1】pytorch torchvision源码解读之Alexnet

最近开始学习一个新的深度学习框架PyTorch。

框架中有一个非常重要且好用的包:torchvision,顾名思义这个包主要是关于计算机视觉cv的。这个包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。

具体介绍可以参考官网:https://pytorch.org/docs/master/torchvision

具体代码可以参考github:https://github.com/pytorch/vision

torchvision.models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用经典的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

今天我们来解读一下Alexnet的源码实现。如果对AlexNet不是很了解 可以查看这里的论文笔记https://blog.csdn.net/sinat_33487968/article/details/83543406

如何使用呢?

import torchvision
model = torchvision.models.Alexnet(pretrained=True)

这样就可以获得网络的结构了,pretrained参数的意思是是否预训练,如果为True就会从网上下载好已经训练参数的模型。改参数默认是False。

import torch.utils.model_zoo as model_zoo

__all__ = ['AlexNet', 'alexnet']


model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}

首先是导入必要的库,其中model_zoo是和导入预训练模型相关的包,另外all变量定义了可以从外部import的函数名或类名。这也是前面为什么可以用torchvision.models.alexnet()来调用的原因。model_urls这个字典是预训练模型的下载地址。

接下来就是Alexnet这个类

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),  # inplace为True,将会改变输入的数据 ,否则不会改变原输入,只会产生新的输出
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifer = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.feature(x)
        x = x.view(x.size(0), 256 * 6 * 6)  # reshape
        x = self.classifer(x)
        return x

AlexNet网络是通过AlexNet这个类实例化的。首先还是继承PyTorch中网络的基类:torch.nn.Module,其次主要的是重写初始化__init__和forward方法。在初始化__init__中主要是定义一些层的参数。forward方法中主要是定义数据在层之间的流动顺序,也就是层的连接顺序。基本上就是五层卷积加上三层全连接(不算relu和max max pooling)。注意到ReLU的inplace为True,将会改变输入的数据 ,否则不会改变原输入,只会产生新的输出。而 x = x.view(x.size(0), 256 * 6 * 6)  的意思是reshape卷积层得到的结果,为了匹配后面的全连接层。

具体结构可以参照下图:

图片一

最后呈现上源码

import torch.nn as nn
import torch.utils.model_zoo as model_zoo

__all__ = ['Alexnet', 'alexnet']

model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),  # inplace为True,将会改变输入的数据 ,否则不会改变原输入,只会产生新的输出
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifer = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.feature(x)
        x = x.view(x.size(0), 256 * 6 * 6)  # reshape
        x = self.classifer(x)
        return x

def alexnet(pretrained = False,**kwargs):
    r"""AlexNet model architecture from the
    "One werid trick..."<https://arxiv.org/abs/1404.5997>_papper.
    Args:
        pretrained(bool):if True,returns a model pre-trained on ImagetNet
    """
    model = AlexNet(**kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
    return model


if __name__ == '__main__':
    alexnet()

猜你喜欢

转载自blog.csdn.net/sinat_33487968/article/details/83582299
今日推荐