pytorchは、すでに保存されているモデルファイルをロードし、それを別のネットワークの事前トレーニング済みウェイトとして使用します

pytorchは、すでに保存されているモデルファイルをロードし、それを別のネットワークの事前トレーニング済みウェイトとして使用します


問題の説明

(自分でメモをとるために)以前にトレーニングしたモデルを保存しましたが、今度はモデルをにロードします。

解決

#加载保存好的模型
Layer1pre = torch.load('./ResultData_earlystop/savemodel/checkpoint_model_layer1.pt')

#定义自己的模型
model = CNNLayer(num_classes=10, aux_logits=True)
if use_gpu:
    model = model.cuda()

#将模型权重更新到新的网络中
model.load_state_dict(Layer1pre, strict=False)

参照リンク:https ://blog.csdn.net/my_kingdom/article/details/85218478

おすすめ

転載: blog.csdn.net/qq_38703529/article/details/122208005