[Pytorch] torch.load() を使用してモデルをロードすると、ロードされたモデルが zip 形式であることを示すエラーが発生します。

質問

最適なモデルを pytorch バージョン 1.6.0 に保存します。モデルを他のデバイスにダウンロードすると、モデルが .pt から zip 圧縮パッケージ形式に変換されているか、他のバージョンの pytorch に直接ロードするとエラーが発生します。 。

解決

torch官网上指出
PyTorch 1.6 リリースでは、新しい zipfile ベースのファイル形式を使用するように torch.save が切り替えられました。torch.load は、古い形式でファイルをロードする機能を引き続き保持しています。何らかの理由で torch.save で古い形式を使用したい場合は、kwarg _use_new_zipfile_serialization=False を渡します。
ここに画像の説明を挿入します

# 保存模型时
torch.save(model.state_dict(), model_cp,_use_new_zipfile_serialization=False)

参考

https://www.freesion.com/article/87611412928/

おすすめ

転載: blog.csdn.net/qq_44846512/article/details/118677603