carregamento de peso pytorch e congelamento de algumas configurações de peso
Índice
1. Carregar pesos
Ignore pesos não correspondentes (chave(s) ausente(s) em state_dict):
model.load_state_dict(state_dict, strict=False)
2. Economize pesos
Salve todos os modelos:
# 模型保存
torch.save(model, 'model.pkl')
# 模型加载
model = torch.load('model.pkl')
Salve apenas os pesos:
# 模型参数保存
torch.save(model.state_dict(), 'model_param.pkl')
# 模型参数加载
model = ModelClass(...)
model.load_state_dict(torch.load('model_param.pkl'))
3. Congele alguns pesos para treinar
Congele todos os parâmetros para treinamento:
for p in self.parameters():
p.requires_grad = False
Congelar alguns parâmetros para treinamento
for name, para in self.named_parameters():
if "sampling_offsets" in name:
para.requires_grad = True
elif "sampling_scales" in name:
para.requires_grad = True
elif "sampling_angles" in name:
para.requires_grad = True
else:
para.requires_grad = False
O nome de cada camada do modelo não corresponde ao arquivo de peso, modifique o nome da chave de incompatibilidade:
if self.freeze_train:
#先加载model里的模型结构
model_dict = self.state_dict()
#加载权重文件里的权重,字典形式,pretrained为文件路径
checkpoint_dict = torch.load(pretrained)
#先将参数加载到模型中
new_checkpoint_dict = {
}
for key in checkpoint_dict:
if "image_encoder" in key: #这里的权重文件里各层名称前有“image_encoder”,要去掉
new_key = key[14:]
new_checkpoint_dict[new_key] = checkpoint_dict[key]
if new_checkpoint_dict[new_key].shape == (27, 64): #这里有一些层的权重维度和model里的不一样,进行一些处理
new_checkpoint_dict[new_key] = new_checkpoint_dict[new_key][0:13, :]
new_checkpoint_dict_1 = {
k:v for k,v in new_checkpoint_dict.items() if k in model_dict.keys()}
#strict=False设置后不会报错
self.load_state_dict(new_checkpoint_dict_1, strict=False)
#这里加载的全部参数我都设置为了不更新,即冻结训练
for p in self.parameters():
p.requires_grad = False