pytorch事前トレーニングモデルには、resnetシリーズ、vggシリーズ、alexnetなどの複数のクラシックネットワークが含まれています。事前トレーニングモデルは、ネットワークの機能を抽出する能力を向上させ、トレーニングモデルのパフォーマンスを向上させることができます。事前にトレーニングされたモデルをロードする2つの方法を次に示します。1つ
はオンラインメソッドです。つまり、コードでオンラインロードモードが使用されます。
import torch
from torchvision import models
model = models.vgg16(pretrained=True)
このように、コードがモデルに対して実行されると、モデルはpytorchでのモデルの定義に従って検出され、事前にトレーニングされたモデルがURLを介して読み込まれ、。/ cache / checkpointsに配置されます。モデルパラメータは必要に応じてロードされます。
もう1つの方法は、オフラインモードをロードすることです。これには、事前トレーニングモデルを事前にダウンロードする必要があります。事前トレーニングモデルをダウンロードすると、次のWebサイトにアクセスできます
。https ://github.com/pytorch/vision/tree/master/torchvision / models
how insideネットワークの定義については、対応するネットワークのpyファイルを入力して、読み込みおよび効率モデルのURLを見つけます。
たとえば、開いた後、vggネットワークのvgg.pyファイルを確認できます。
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',
]
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
class VGG(nn.Module):
def __init__(
self,
features: nn.Module,
num_classes: int = 1000,
init_weights: bool = True
) -> None:
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
さまざまなvggモデルのダウンロードアドレスはmodel_urlsに示されています。具体的なダウンロード方法は、URLを直接コピーして、vgg16などで開くことです。URLをコピーするhttps://download.pytorch.org/models/vgg16-397923af.pthダウンロード用のダウンロードインターフェイスが自動的に開きます。ダウンロードが完了したら、現在実行しているフォルダに配置します。ロードしたプログラムを次のように変更する必要があります。
....
model = models.vgg16(pretrained=False)#在线模式的True改为False
pre = torch.load('vgg16-397923af.pth')#进行加载
model.load_state_dict(pre)
....
上記は、事前にトレーニングされたモデルをオンラインとオフラインでロードする2つの方法です。