解决AttributeError: Can‘t get attribute ‘VggNet‘ on <module ‘__main__‘ from ‘main.py‘>

问题描述

在Kaggle平台上使用GPU训练好模型,然后将模型保存,之后将模型下载到本地,然后使用 torch.load('mode.pkl') 进行读取出现 AttributeError: Can't get attribute 'VggNet' on <module '__main__' from 'main.py'>

在这里插入图片描述

Traceback (most recent call last):
  File "E:\Code\PyCharm\ImageClassificationPyQt5\logistic_regression\utils\Thread.py", line 226, in run
    model = torch.load('../checkpoints/model.pkl', map_location="cpu")
  File "D:\Anaconda\lib\site-packages\torch\serialization.py", line 713, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "D:\Anaconda\lib\site-packages\torch\serialization.py", line 930, in _legacy_load
    result = unpickler.load()
  File "D:\Anaconda\lib\site-packages\torch\serialization.py", line 746, in find_class
    return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'VGG' on <module '__main__' from 'E:\\Code\\PyCharm\\ImageClassificationPyQt5\\logistic_regression\\view\\MainWindow.py'>

原因分析:

PyTorch导入模型的时候是执行 pickle 的操作,但是因为主程序未知自定义的模型结构,所以无法解析模型。因为保存下来的模型和参数不能在没有类定义时直接使用。说白了就是现在无法获取你定义的网络模型结构,所以在加载模型的时候无法匹配参数,导致加载失败。

解决方案:

在你加载模型的位置,将你自定义的模型加入里面

 
import torch
import torch.nn as nn
 
class VggNet(nn.Module):
    def __init__(self):
        super(BertClassificationModel, self).__init__()
        
    def forward(self):
    	pass
 

path = './checkpoints/model.pkl'
 
# 加载模型
model = torch.load(path)

猜你喜欢

转载自blog.csdn.net/m0_47256162/article/details/129956074
今日推荐