PyTorch源码解读之torchvision.models PyTorch源码解读之torchvision.models

PyTorch源码解读之torchvision.models

        <div class="article-info-box">
            <div class="article-bar-top d-flex">
                                                                            <span class="time">2018年01月21日 13:28:35</span>
                <div class="float-right">
                    <span class="read-count">阅读数:9381</span>
                                                                </div>
            </div>
        </div>

    </div>
</div>
<article>
    <div id="article_content" class="article_content clearfix csdn-tracking-statistics" data-pid="blog" data-mod="popu_307" data-dsm="post">
                <div class="markdown_views">
            <p><strong>PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms</strong>。这3个子包的具体介绍可以参考官网:<a href="http://pytorch.org/docs/master/torchvision/index.html" rel="nofollow" target="_blank">http://pytorch.org/docs/master/torchvision/index.html</a>。具体代码可以参考github:<a href="https://github.com/pytorch/vision/tree/master/torchvision" rel="nofollow" target="_blank">https://github.com/pytorch/vision/tree/master/torchvision</a>。</p>

这篇博客介绍torchvision.models .torchvision.models这个包中包含alexnet,densenet,inception,resnet,squeezenet,vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

使用例子:

import torchvision
model = torchvision.models.resnet50(pretrained=True)
     
     
  • 1
  • 2

。这样就导入了resnet50的预训练模型了如果只需要网络结构,不需要用预训练模型的参数来初始化,那么就是:

model = torchvision.models.resnet50(pretrained=False)
     
     
  • 1

如果要导入densenet模型也是同样的道理,比如导入densenet169,且不需要是预训练的模型:

model = torchvision.models.densenet169(pretrained=False)
     
     
  • 1

由于预训练参数默认是假,所以等价于:

model = torchvision.models.densenet169()
     
     
  • 1

不过为了代码清晰,最好还是加上参数赋值。

接下来以导入resnet50介绍为例具体导入模型时候的源码运行。model = torchvision.models.resnet50(pretrained=True)的时候,是通过模型包下的resnet.py脚本进行的,源码如下:

首先是导入必要的库,其中model_zoo是和导入预训练模型相关的包,另外所有的变量定义了可以从外部导入的函数名或类名。这也是前面为什么可以用torchvision.models.resnet50()来调用的原因.model_urls这个字典是预训练模型的下载地址

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
     
     
  • 1
  • 2
  • 3
  • 4
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

接下来就是resnet50这个函数了,参数预训练默认是假。首先model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)的英文构建网络结构,瓶颈是另外一个构建瓶颈的类,在RESNET网络结构的构建中有很多重复的子结构,这些子结构就是通过瓶颈类来构建的,后面会介绍。然后如果参数预训练是真,那么就会通过model_zoo.py中的load_url函数根据model_urls字典下载或导入相应的预训练模型。最后通过调用模型的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作.load_state_dict方法还有一个重要的参数是严格的,该参数默认是真,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)。

def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model
     
     
  • 1
  • 2
  • 3
  • 4
  • 6
  • 7
  • 8
  • 9
  • 10

其他resnet18,resnet101等函数和resnet50基本类似,差别主要是在:1,构建网络结构的时候块的参数不一样,比如resnet18中是[2,2,2,2],resnet101中是[3,4 ,23,3] .2,调用的块类不一样,比如在resnet50,resnet101,resnet152中调用的是瓶颈类,而在resnet18和resnet34中调用的是BasicBlock类,这两个类的区别主要是在残余结果中卷积层的数量不同,这个是和网络结构相关的,后面会详细介绍0.3,如果下载预训练模型的话,model_urls字典的键不一样,对应不同的预训练模型。因此接下来分别看看如何构建网络结构和如何导入预训练模型。

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model
     
     
  • 1
  • 2
  • 3
  • 4
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

构建RESNET网络是通过RESNET这个类进行的首先还是继承PyTorch中网络的基类:torch.nn.Module,其次主要的是重写初始化__init__。和前在方法初始化__init__中主要的英文定义一些层的参数的.forward方法中主要是定义数据在层之间的流动顺序,也就是层的连接顺序。另外还可以在类中定义其他私有方法用来模块化一些操作,比如这里的_make_layer方法是用来构建RESNET网络中的4个blocks._make_layer方法的第一个输入块是瓶颈或BasicBlock类,第二个输入是该块的输出信道,第三个输入是每个块中包含多少个残余子结构,因此层这个列表就是前面resnet50的[3,4,6,3]
._make_layer方法中比较重要的两行代码是:1,layers.append(block(self.inplanes,planes,stride,downsample)),该部分是将每个块的第一个残余结构保存在层列表中.2,对于范围内的i(1,块):layers.append(块(self.inplanes,planes)),该部分是将每个块的剩余结构保存在层列表中,这样就完成了一个块的构造。这两行代码中都是通过Bottleneck这个类来完成每个残余的构建,接下来介绍瓶颈类。

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
     
     
  • 1
  • 2
  • 3
  • 4
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 三十
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

从前面的RESNET类可以看出,在构造RESNET网络的时候,最重要的是瓶颈这个类,因为RESNET是由残留的结构组成的,而瓶颈类就是完成剩余结构的构建。同样Bottlenect还是继承了火炬。 nn.Module类,且重写了__init__和forward方法。从前进方法可以看出,瓶颈就是我们熟悉的3个主要的卷积层,BN层和激活层,最后的out + =残余就是元素添加的操作。

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
     
     
  • 1
  • 2
  • 3
  • 4
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 三十
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

BasicBlock类和瓶颈类类似,前者主要是用来构建ResNet18和ResNet34网络,因为这两个网络的残余结构只包含两个卷积层,没有瓶颈类中的瓶颈概念。因此在该类中,第一个卷积层采用的是kernel_size = 3的卷积,如conv3x3函数所示。

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
     
     
  • 1
  • 2
  • 3
  • 4
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 三十
  • 31
  • 32
  • 33
  • 34
  • 35

介绍完如何构建网络,接下来就是如何获取预训练模型。前面提到这一行代码:if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])),主要就是通过model_zoo.py中的load_url函数根据model_urls字典导入相应的预训练模型,models_zoo.py脚本的github地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py
。load_url函数源码如下首先model_dir是下载下来的模型的保存地址,如果没有指定的话就会保存在项目如果不是os.path.exists(cached_file)语句用来判断是否指定目录下已经存在要下载存在,就直接调用torch.load接口导入模型,如果不存在,则从网上下载,下载是通过_download_url_to_file(url,cached_file,hash_prefix,progress = progress)进行的,不再细讲。重点在于模型导入是通过torch.load()接口来进行的,不管你的模型是从网上下载的还是本 地已有的。

def load_url(url, model_dir=None, map_location=None, progress=True):
    r"""Loads the Torch serialized object at the given URL.

If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-&lt;sha256&gt;.ext`` where ``&lt;sha256&gt;`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.

The default value of `model_dir` is ``$TORCH_HOME/models`` where
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
overriden with the ``$TORCH_MODEL_ZOO`` environment variable.

Args:
    url (string): URL of the object to download
    model_dir (string, optional): directory in which to save the object
    map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
    progress (bool, optional): whether or not to display a progress bar to stderr

Example:
    &gt;&gt;&gt; state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

"""</span>
<span class="hljs-keyword">if</span> model_dir <span class="hljs-keyword">is</span> <span class="hljs-keyword">None</span>:
    torch_home = os.path.expanduser(os.getenv(<span class="hljs-string">'TORCH_HOME'</span>, <span class="hljs-string">'~/.torch'</span>))
    model_dir = os.getenv(<span class="hljs-string">'TORCH_MODEL_ZOO'</span>, os.path.join(torch_home, <span class="hljs-string">'models'</span>))
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.exists(model_dir):
    os.makedirs(model_dir)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.exists(cached_file):
    sys.stderr.write(<span class="hljs-string">'Downloading: "{}" to {}\n'</span>.format(url, cached_file))
    hash_prefix = HASH_REGEX.search(filename).group(<span class="hljs-number">1</span>)
    _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
<span class="hljs-keyword">return</span> torch.load(cached_file, map_location=map_location)</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li><li style="color: rgb(153, 153, 153);">22</li><li style="color: rgb(153, 153, 153);">23</li><li style="color: rgb(153, 153, 153);">24</li><li style="color: rgb(153, 153, 153);">25</li><li style="color: rgb(153, 153, 153);">26</li><li style="color: rgb(153, 153, 153);">27</li><li style="color: rgb(153, 153, 153);">28</li><li style="color: rgb(153, 153, 153);">29</li><li style="color: rgb(153, 153, 153);">30</li><li style="color: rgb(153, 153, 153);">31</li><li style="color: rgb(153, 153, 153);">32</li><li style="color: rgb(153, 153, 153);">33</li><li style="color: rgb(153, 153, 153);">34</li><li style="color: rgb(153, 153, 153);">35</li><li style="color: rgb(153, 153, 153);">36</li></ul></pre>            </div>
        <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/markdown_views-ea0013b516.css">
            </div>
        </article>

    <div class="article-bar-bottom" style="height: 48px; overflow: hidden;">
            <div class="article-copyright">
        版权声明:本文为博主原创文章,未经博主允许不得转载。          https://blog.csdn.net/u014380165/article/details/79119664       </div>
                            <div class="tags-box">
        <span class="label">个人分类:</span>
                    <a class="tag-link" href="https://blog.csdn.net/u014380165/article/category/6829229" target="_blank">深度学习                       </a><a class="tag-link" href="https://blog.csdn.net/u014380165/article/category/7286599" target="_blank">PyTorch                        </a>
    </div>
                    <div class="tags-box">
        <span class="label">所属专栏:</span>
                    <a class="tag-link" href="https://blog.csdn.net/column/details/19413.html" target="_blank">PyTorch使用及源码解读</a>

    </div>
        <div class="article_info_click" style="left: 168px; width: auto; top: 24px;">▼查看关于本篇文章更多信息</div></div>
</div>

猜你喜欢

转载自blog.csdn.net/Jason_mmt/article/details/81910346