pytorch预训练模型包含多个经典网络,比如resnet系列、vgg系列和alexnet等,预训练模型可以提高网络提取特征的能力,提升训练模型的性能。下面介绍一下加载预训练模型的两种方式:
第一种是在线的方法,即在代码中采用在线加载模式,
import torch
from torchvision import models
model = models.vgg16(pretrained=True)
这样当代码运行到model时,就会根据pytorch中模型的定义找到该模型,并通过url加载预训练模型放在./cache/checkpoints中,需要时就会加载模型参数。
另一种方法是离线加载方式,这需要提前下载好预训练模型,预训练模型的下载可以进入该网站:
https://github.com/pytorch/vision/tree/master/torchvision/models
里面有多个网络的定义,进入相应网络的py文件即可找到加载与效率模型的网址:
比如vgg网络的vgg.py文件打开后可以看到
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',
]
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
class VGG(nn.Module):
def __init__(
self,
features: nn.Module,
num_classes: int = 1000,
init_weights: bool = True
) -> None:
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
model_urls中注明了多种vgg模型的下载地址,具体下载方法是直接复制网址打开,比如vgg16,复制网址https://download.pytorch.org/models/vgg16-397923af.pth在浏览器中打开,就会自动打开下载界面进行下载,下载好之后放在你当前要运行的文件夹下,其中加载的程序要做一些修改,如下:
....
model = models.vgg16(pretrained=False)#在线模式的True改为False
pre = torch.load('vgg16-397923af.pth')#进行加载
model.load_state_dict(pre)
....
以上就是在线和离线加载预训练模型的两种方法。