修改网络结构后加载预训练模型

修改网络结构后加载预训练模型

model = new_network()  # 修改后的网络
if pretrained: # 是否加载预训练模型
    net_dict = model.state_dict() # 修改后的网络结构
	predict_model = torch.load('xxxxx.pth') # 预训练模型加载
    # 寻找网络中公共层,并保留预训练参数
	state_dict = {
    
    k: v for k, v in predict_model.items() if k in net_dict.keys()}
	net_dict.update(state_dict)  # 将预训练参数更新到新的网络层
	model.load_state_dict(net_dict) # 加载

参考链接:
https://blog.csdn.net/biaoge6666/article/details/117043097

猜你喜欢

转载自blog.csdn.net/holly_Z_P_F/article/details/129691001