PyTorch がモデルをロードすると、エラーが発生します: RuntimeError: ***** の state_dict のロード中にエラーが発生しました: state_dict にキーがありません:

問題の説明:

    元の作成者のコードにはブレークポイントの継続がありません。この関数を追加し、より多くのパラメーターを導入しました。モデルを保存するときに、epoch、net.state_dict()、optimizer.state_dict()、および Scheduler.state_dict() を追加しました。情報。

モデルを保存する元のコードは次のとおりです。

torch.save(net.state_dict(), model_dir)

情報を追加した後、モデルを保存するコードは次のとおりです。

 torch.save({'epoch': i,
             'model_state_dict': net.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),},
              model_dir)

モデルをロードする元のコードは次のとおりです。

net.load_state_dict(torch.load(model_dir))

情報を追加した後、モデルをロードするコードは次のとおりです。

ckpt = torch.load(model_dir, map_location='cpu')
net.load_state_dict(ckpt['model_state_dict'])

推論をテストすると、モデルのロード時にエラーが報告されます。

net.load_state_dict(ckpt['model_state_dict'])
  File "/root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for *****:
	Missing key(s) in state_dict: 

解決:

方法 1 は正常にロードできますが、一部のパラメータがロードされず、場合によっては推論結果が間違ってしまいます。

ckpt = torch.load(model_dir)
model.load_state_dict(ckpt['model_state_dict'],strict=False)

方法 2: ディクショナリ キー値の module. を置き換えるか、元のモデルの pth ファイルのキー出力と現在のモデルのキーを比較し、モデルのパラメーターを手動でロードします。

ckpt = torch.load(args.weights, map_location='cpu')
net.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['model_state_dict'].items()})

問題の原因:

次のコードがトレーニング コードに追加されました。

net = nn.DataParallel(net)

トレーニング コードで net = nn.DataParallel(net) を見つけてコメント化し、再トレーニングします。

または、上記の方法 2 を使用してモデルをロードします。

マルチ GPU 並列コンピューティングの場合、Pytorch の nn.DataParallel を使用して同じモデルをトレーニングできます。

おすすめ

転載: blog.csdn.net/mj412828668/article/details/130014232