Load pytorch pre-trained models vgg, resnet, alexnet, etc. offline or online

The pytorch pre-training model contains multiple classic networks, such as resnet series, vgg series, and alexnet. The pre-training model can improve the ability of the network to extract features and improve the performance of the training model. Here are two ways to load the pre-trained model: The
first is the online method, that is, the online loading mode is used in the code.

import torch
from torchvision import models

model = models.vgg16(pretrained=True)

In this way, when the code runs to the model, the model will be found according to the definition of the model in pytorch, and the pre-trained model will be loaded through the URL and placed in ./cache/checkpoints, and the model parameters will be loaded when needed.

Another method is to load the offline mode, which will need to download the pre-training model in advance, download the pre-training model can access the website:
https://github.com/pytorch/vision/tree/master/torchvision/models
how inside For the definition of a network, enter the py file of the corresponding network to find the URL of the loading and efficiency model: for
Insert picture description here
example, you can see the vgg.py file of the vgg network after opening

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast


__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',
]


model_urls = {
    
    
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}


class VGG(nn.Module):

    def __init__(
        self,
        features: nn.Module,
        num_classes: int = 1000,
        init_weights: bool = True
    ) -> None:
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

The download addresses of various vgg models are indicated in model_urls. The specific download method is to directly copy the URL and open it, such as vgg16. Copy the URL https://download.pytorch.org/models/vgg16-397923af.pth and open it in the browser. It will automatically open the download interface for downloading. After the download is complete, place it in the folder you are currently running. The loaded program needs to be modified, as follows:

....
model = models.vgg16(pretrained=False)#在线模式的True改为False
pre = torch.load('vgg16-397923af.pth')#进行加载
model.load_state_dict(pre)
....

The above are two methods of loading pre-trained models online and offline.

Guess you like

Origin blog.csdn.net/qq_44442727/article/details/112972992