resnet152 = models.resnet152 (pretrained = True)
pretrained_dict = resnet152.state_dict ()
"" "After loading the pre-trained model and parameters in torchvision, extract the parameters through the state_dict () method.
You can also download directly from the official model_zoo:
pretrained_dict = model_zoo.load_url (model_urls ['resnet152']) "" "
model_dict = model.state_dict ()
Remove keys in pretrained_dict that are not part of model_dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
Update existing model_dict
model_dict.update(pretrained_dict)
Load the state_dict we really need
model.load_state_dict (model_dict)
————————————————
Copyright Notice: This article is an original article by CSDN blogger "lscelory", following the CC 4.0 BY-SA copyright agreement, please attach it for reprint Original source link and this statement.
Original link: https://blog.csdn.net/lscelory/article/details/81482586