pytorch权重加载以及冻结部分权重设置
1.加载权重
忽略不匹配的权重(Missing key(s) in state_dict):
model.load_state_dict(state_dict, strict=False)
2.保存权重
保存全部模型:
# 模型保存
torch.save(model, 'model.pkl')
# 模型加载
model = torch.load('model.pkl')
仅保存权重:
# 模型参数保存
torch.save(model.state_dict(), 'model_param.pkl')
# 模型参数加载
model = ModelClass(...)
model.load_state_dict(torch.load('model_param.pkl'))
3.冻结部分权重进行训练
冻结全部参数进行训练:
for p in self.parameters():
p.requires_grad = False
冻结部分参数进行训练
for name, para in self.named_parameters():
if "sampling_offsets" in name:
para.requires_grad = True
elif "sampling_scales" in name:
para.requires_grad = True
elif "sampling_angles" in name:
para.requires_grad = True
else:
para.requires_grad = False
模型的每一层命名和权重文件里的不匹配,修改不匹配键的名称:
if self.freeze_train:
#先加载model里的模型结构
model_dict = self.state_dict()
#加载权重文件里的权重,字典形式,pretrained为文件路径
checkpoint_dict = torch.load(pretrained)
#先将参数加载到模型中
new_checkpoint_dict = {
}
for key in checkpoint_dict:
if "image_encoder" in key: #这里的权重文件里各层名称前有“image_encoder”,要去掉
new_key = key[14:]
new_checkpoint_dict[new_key] = checkpoint_dict[key]
if new_checkpoint_dict[new_key].shape == (27, 64): #这里有一些层的权重维度和model里的不一样,进行一些处理
new_checkpoint_dict[new_key] = new_checkpoint_dict[new_key][0:13, :]
new_checkpoint_dict_1 = {
k:v for k,v in new_checkpoint_dict.items() if k in model_dict.keys()}
#strict=False设置后不会报错
self.load_state_dict(new_checkpoint_dict_1, strict=False)
#这里加载的全部参数我都设置为了不更新,即冻结训练
for p in self.parameters():
p.requires_grad = False