pytorchは、事前にトレーニングされたモデルの一部をロードします(問題を解決するための詳細なプロセス)

まず、コードアドレスhttps://github.com/jfzhang95/pytorch-deeplab-xceptionと、事前にトレーニングされたモデルをロードするためのコードを、いくつかの変更を加えて指定します。

    pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)

その中で、forループは、モデル内のキーとロードされた事前トレーニングモデルの同じ部分を探し、その値をモデルにロードし、最後にモデルを更新します。
ただし、事前トレーニングモデルを置き換えるときにエラーが発生しました:
RuntimeError:XXXのstate_dictの読み込み中にエラーが発生しました:state_dictに
キーがありません:
いくつかのデータ確認したところ、キーの位置がずれていることが原因である可能性があることがわかりました。 2つのモデルのキー出力を見てください。

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)                         #查看预训练模型的keys
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)                         #查看本地model的keys
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)                         #查看model更新后的keys

次のような結果を確認します
...
module.backbone.conv1.weight
module.backbone.bn1.weight
...

...
Conv1.weight
bn1.weight
...
案の定、キーが対応してない多くのソリューションを読んだ後、私は1つによってロードされた一つであることが、彼らはすべての必要ことが判明し、私はそれが私自身の考えによると、正常にロードされた修正ので、それは、あまりにも面倒に感じています。。:

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)
    print("分界线") 
    for k, v in pretrain_dict.items():
        for i, j in state_dict.items():    #加上前缀后寻找对应的keys
            m = 'module.backbone.' + i
            if k == m :
                model_dict[i] = v
                print(i)
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)
    return model

実際、モデルキーには接頭辞が付いています。
初心者、問題を解決するプロセスを記録します。良い方法があります。コミュニケーションを歓迎します。

おすすめ

転載: blog.csdn.net/qq_36804414/article/details/106718214