修改网络结构后加载预训练模型
model = new_network() # 修改后的网络
if pretrained: # 是否加载预训练模型
net_dict = model.state_dict() # 修改后的网络结构
predict_model = torch.load('xxxxx.pth') # 预训练模型加载
# 寻找网络中公共层,并保留预训练参数
state_dict = {
k: v for k, v in predict_model.items() if k in net_dict.keys()}
net_dict.update(state_dict) # 将预训练参数更新到新的网络层
model.load_state_dict(net_dict) # 加载
参考链接:
https://blog.csdn.net/biaoge6666/article/details/117043097