pytorch loads some parameters of the pre-trained model

resnet152 = models.resnet152 (pretrained = True)
pretrained_dict = resnet152.state_dict ()
"" "After loading the pre-trained model and parameters in torchvision, extract the parameters through the state_dict () method.
You can also download directly from the official model_zoo:
pretrained_dict = model_zoo.load_url (model_urls ['resnet152']) "" "
model_dict = model.state_dict ()

Remove keys in pretrained_dict that are not part of model_dict

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

Update existing model_dict

model_dict.update(pretrained_dict)

Load the state_dict we really need

model.load_state_dict (model_dict)
————————————————
Copyright Notice: This article is an original article by CSDN blogger "lscelory", following the CC 4.0 BY-SA copyright agreement, please attach it for reprint Original source link and this statement.
Original link: https://blog.csdn.net/lscelory/article/details/81482586

Published 36 original articles · won praise 1 · views 6384

Guess you like

Origin blog.csdn.net/qq_34291583/article/details/100539275