carregamento de peso pytorch e congelamento de algumas configurações de peso

carregamento de peso pytorch e congelamento de algumas configurações de peso

1. Carregar pesos

Ignore pesos não correspondentes (chave(s) ausente(s) em state_dict):

model.load_state_dict(state_dict, strict=False)

2. Economize pesos

Salve todos os modelos:

# 模型保存
torch.save(model, 'model.pkl')
# 模型加载
model = torch.load('model.pkl')

Salve apenas os pesos:

# 模型参数保存
torch.save(model.state_dict(), 'model_param.pkl')
# 模型参数加载
model = ModelClass(...)
model.load_state_dict(torch.load('model_param.pkl'))

3. Congele alguns pesos para treinar

Congele todos os parâmetros para treinamento:

for p in self.parameters():
    p.requires_grad = False

Congelar alguns parâmetros para treinamento

for name, para in self.named_parameters():
    if "sampling_offsets" in name:
        para.requires_grad = True
    elif "sampling_scales" in name:
        para.requires_grad = True
    elif "sampling_angles" in name:
        para.requires_grad = True
    else:
        para.requires_grad = False

O nome de cada camada do modelo não corresponde ao arquivo de peso, modifique o nome da chave de incompatibilidade:

if self.freeze_train:
    #先加载model里的模型结构
    model_dict = self.state_dict()
    #加载权重文件里的权重,字典形式,pretrained为文件路径
    checkpoint_dict = torch.load(pretrained)
    #先将参数加载到模型中
    new_checkpoint_dict = {
    
    }
    for key in checkpoint_dict:
        if "image_encoder" in key:  #这里的权重文件里各层名称前有“image_encoder”,要去掉
            new_key = key[14:]
            new_checkpoint_dict[new_key] = checkpoint_dict[key]
            if new_checkpoint_dict[new_key].shape == (27, 64): #这里有一些层的权重维度和model里的不一样,进行一些处理
                new_checkpoint_dict[new_key] = new_checkpoint_dict[new_key][0:13, :]
    new_checkpoint_dict_1 = {
    
    k:v for k,v in new_checkpoint_dict.items() if k in model_dict.keys()}
    #strict=False设置后不会报错
    self.load_state_dict(new_checkpoint_dict_1, strict=False)
    #这里加载的全部参数我都设置为了不更新,即冻结训练
    for p in self.parameters():
        p.requires_grad = False

Supongo que te gusta

Origin blog.csdn.net/weixin_45453121/article/details/131864211
Recomendado
Clasificación