First, give the code address https://github.com/jfzhang95/pytorch-deeplab-xception and the code to load the pre-trained model, with some modifications
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
model_dict = {
}
state_dict = model.state_dict()
for k, v in pretrain_dict.items()
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
model.load_state_dict(state_dict)
Among them, the for loop looks for the same part of the key in the model and the loaded pre-training model, loads its value into the model, and finally updates the model.
However, an error occurred when replacing the pre-training model:
RuntimeError: Error(s) in loading state_dict for XXX:
Missing key(s) in state_dict: I
checked some data and found that it may be a problem caused by misaligned keys, so, Take a look at the keys output of the two models:
pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
for k in pretrain_dict.keys():
print(k) #查看预训练模型的keys
model_dict = {
}
state_dict = model.state_dict()
for k in state_dict.keys():
print(k) #查看本地model的keys
for k, v in pretrain_dict.items()
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
model.load_state_dict(state_dict)
for k in model_dict.keys():
print(k) #查看model更新后的keys
Check the results as follows:
…
module.backbone.conv1.weight
module.backbone.bn1.weight
…
…
Conv1.weight
bn1.weight
…
Sure enough, the keys do not correspond. After reading many solutions, I found that they all need to be loaded one by one, which feels too troublesome, so I modified it according to my own ideas and successfully loaded:
pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
for k in pretrain_dict.keys():
print(k)
model_dict = {
}
state_dict = model.state_dict()
for k in state_dict.keys():
print(k)
print("分界线")
for k, v in pretrain_dict.items():
for i, j in state_dict.items(): #加上前缀后寻找对应的keys
m = 'module.backbone.' + i
if k == m :
model_dict[i] = v
print(i)
state_dict.update(model_dict)
model.load_state_dict(state_dict)
for k in model_dict.keys():
print(k)
return model
In fact, the model keys are prefixed.
Novice, record the process of solving the problem, there are good ways, welcome to communicate.