pytorch loads part of the pre-trained model (detailed process of solving the problem)

First, give the code address https://github.com/jfzhang95/pytorch-deeplab-xception and the code to load the pre-trained model, with some modifications

    pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)

Among them, the for loop looks for the same part of the key in the model and the loaded pre-training model, loads its value into the model, and finally updates the model.
However, an error occurred when replacing the pre-training model:
RuntimeError: Error(s) in loading state_dict for XXX:
Missing key(s) in state_dict: I
checked some data and found that it may be a problem caused by misaligned keys, so, Take a look at the keys output of the two models:

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)                         #查看预训练模型的keys
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)                         #查看本地model的keys
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)                         #查看model更新后的keys

Check the results as follows:

module.backbone.conv1.weight
module.backbone.bn1.weight


Conv1.weight
bn1.weight

Sure enough, the keys do not correspond. After reading many solutions, I found that they all need to be loaded one by one, which feels too troublesome, so I modified it according to my own ideas and successfully loaded:

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)
    print("分界线") 
    for k, v in pretrain_dict.items():
        for i, j in state_dict.items():    #加上前缀后寻找对应的keys
            m = 'module.backbone.' + i
            if k == m :
                model_dict[i] = v
                print(i)
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)
    return model

In fact, the model keys are prefixed.
Novice, record the process of solving the problem, there are good ways, welcome to communicate.

Guess you like

Origin blog.csdn.net/qq_36804414/article/details/106718214