pytorchの事前トレーニング済みモデルvgg、resnet、alexnetなどをオフラインまたはオンラインでロードします

pytorch事前トレーニングモデルには、resnetシリーズ、vggシリーズ、alexnetなどの複数のクラシックネットワークが含まれています。事前トレーニングモデルは、ネットワークの機能を抽出する能力を向上させ、トレーニングモデルのパフォーマンスを向上させることができます。事前にトレーニングされたモデルをロードする2つの方法を次に示します。1つ
はオンラインメソッドです。つまり、コードでオンラインロードモードが使用されます。

import torch
from torchvision import models

model = models.vgg16(pretrained=True)

このように、コードがモデルに対して実行されると、モデルはpytorchでのモデルの定義に従って検出され、事前にトレーニングされたモデルがURLを介して読み込まれ、。/ cache / checkpointsに配置されます。モデルパラメータは必要に応じてロードされます。

もう1つの方法は、オフラインモードをロードすることです。これには、事前トレーニングモデルを事前にダウンロードする必要があります。事前トレーニングモデルをダウンロードすると、次のWebサイトにアクセスできます
。https //github.com/pytorch/vision/tree/master/torchvision / models
how insideネットワークの定義については、対応するネットワークのpyファイルを入力して、読み込みおよび効率モデルのURLを見つけます。
ここに画像の説明を挿入します
たとえば、開いた後、vggネットワークのvgg.pyファイルを確認できます。

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast


__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',
]


model_urls = {
    
    
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}


class VGG(nn.Module):

    def __init__(
        self,
        features: nn.Module,
        num_classes: int = 1000,
        init_weights: bool = True
    ) -> None:
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

さまざまなvggモデルのダウンロードアドレスはmodel_urlsに示されています。具体的なダウンロード方法は、URLを直接コピーして、vgg16などで開くことです。URLをコピーするhttps://download.pytorch.org/models/vgg16-397923af.pthダウンロード用のダウンロードインターフェイスが自動的に開きます。ダウンロードが完了したら、現在実行しているフォルダに配置します。ロードしたプログラムを次のように変更する必要があります。

....
model = models.vgg16(pretrained=False)#在线模式的True改为False
pre = torch.load('vgg16-397923af.pth')#进行加载
model.load_state_dict(pre)
....

上記は、事前にトレーニングされたモデルをオンラインとオフラインでロードする2つの方法です。

おすすめ

転載: blog.csdn.net/qq_44442727/article/details/112972992