pytorch加载部分预训练模型(解决问题的详细过程)

首先,给出代码地址https://github.com/jfzhang95/pytorch-deeplab-xception以及加载预训练模型的代码,做了些许修改

    pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)

其中,for循环寻找了model与加载的预训练模型中keys相同的部分,并将其值加载到model中,最后对model进行更新。
但是,在更换预训练模型的时候出现了错误:
RuntimeError: Error(s) in loading state_dict for XXX:
Missing key(s) in state_dict:
查看了一些资料,发现可能是keys不对齐造成的问题,于是,将两种模型的keys输出看一看:

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)                         #查看预训练模型的keys
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)                         #查看本地model的keys
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)                         #查看model更新后的keys

查看结果如下:

module.backbone.conv1.weight
module.backbone.bn1.weight


conv1.weight
bn1.weight

果然,keys不对应,看了很多解决方法,发现都是需要一个一个加载,感觉太麻烦,所以按照自己的想法修改了一下,成功加载:

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)
    print("分界线") 
    for k, v in pretrain_dict.items():
        for i, j in state_dict.items():    #加上前缀后寻找对应的keys
            m = 'module.backbone.' + i
            if k == m :
                model_dict[i] = v
                print(i)
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)
    return model

其实,也就是把model的keys加了前缀。
新手上路,记录下解决问题的过程,有好的方法,欢迎交流。

猜你喜欢

转载自blog.csdn.net/qq_36804414/article/details/106718214