pytorchは事前学習済みモデルのいくつかのパラメーターをロードします

resnet152 = models.resnet152(pretrained = True)
pretrained_dict = resnet152.state_dict()
"" "事前トレーニング済みのモデルとパラメーターをtorchvisionに読み込んだ後、state_dict()メソッドを使用してパラメーターを抽出します。
公式のmodel_zooから直接ダウンロードすることもできます:
pretrained_dict = model_zoo.load_url (model_urls ['resnet152']) "" "
model_dict = model.state_dict()

model_dictの一部ではないpretrained_dictのキーを削除します

pretrained_dict = {k:vはk、vはpretrained_dict.items()でkがmodel_dictの場合}

既存のmodel_dictを更新する

model_dict.update(pretrained_dict)

本当に必要なstate_dictをロードします

model.load_state_dict(model_dict)
————————————————
Copyright Notice:この記事はCSDNブロガー "lscelory"によるオリジナルの記事であり、CC 4.0 BY-SAの著作権契約に従っています。再印刷のために添付してください元のソースリンクとこのステートメント。
元のリンク:https://blog.csdn.net/lscelory/article/details/81482586

元の記事を36件公開しました 賞賛されました1 訪問6384

おすすめ

転載: blog.csdn.net/qq_34291583/article/details/100539275