pytorch权重加载以及冻结部分权重设置

pytorch权重加载以及冻结部分权重设置

1.加载权重

忽略不匹配的权重(Missing key(s) in state_dict):

model.load_state_dict(state_dict, strict=False)

2.保存权重

保存全部模型:

# 模型保存
torch.save(model, 'model.pkl')
# 模型加载
model = torch.load('model.pkl')

仅保存权重:

# 模型参数保存
torch.save(model.state_dict(), 'model_param.pkl')
# 模型参数加载
model = ModelClass(...)
model.load_state_dict(torch.load('model_param.pkl'))

3.冻结部分权重进行训练

冻结全部参数进行训练:

for p in self.parameters():
    p.requires_grad = False

冻结部分参数进行训练

for name, para in self.named_parameters():
    if "sampling_offsets" in name:
        para.requires_grad = True
    elif "sampling_scales" in name:
        para.requires_grad = True
    elif "sampling_angles" in name:
        para.requires_grad = True
    else:
        para.requires_grad = False

模型的每一层命名和权重文件里的不匹配,修改不匹配键的名称:

if self.freeze_train:
    #先加载model里的模型结构
    model_dict = self.state_dict()
    #加载权重文件里的权重,字典形式,pretrained为文件路径
    checkpoint_dict = torch.load(pretrained)
    #先将参数加载到模型中
    new_checkpoint_dict = {
    
    }
    for key in checkpoint_dict:
        if "image_encoder" in key:  #这里的权重文件里各层名称前有“image_encoder”,要去掉
            new_key = key[14:]
            new_checkpoint_dict[new_key] = checkpoint_dict[key]
            if new_checkpoint_dict[new_key].shape == (27, 64): #这里有一些层的权重维度和model里的不一样,进行一些处理
                new_checkpoint_dict[new_key] = new_checkpoint_dict[new_key][0:13, :]
    new_checkpoint_dict_1 = {
    
    k:v for k,v in new_checkpoint_dict.items() if k in model_dict.keys()}
    #strict=False设置后不会报错
    self.load_state_dict(new_checkpoint_dict_1, strict=False)
    #这里加载的全部参数我都设置为了不更新,即冻结训练
    for p in self.parameters():
        p.requires_grad = False

猜你喜欢

转载自blog.csdn.net/weixin_45453121/article/details/131864211