【pytorch】加载模型出现的bug

在模型训练完后再进行测试加载模型后出现bug,显示如下错误

据了解是由于pytorch版本导致的错误,可能与自己训练阶段保持的模型方式有关,训练阶段保存方式如下:

解决方案如下:

方法一:

generator.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(generator_1_10.pth).items()})

实际上就是将load进行的权重的有序字典里面的键值前面的的7个字符’module.'去掉。加载进行的权重有序字典如下图所示:

 键就是每层的权重或者 bias 的名称,value就是其具体的张量值。

方法二:重新新建个有序字典:

from collections import OrderedDict
    #     new_state_dict = OrderedDict()
    #     for k, v in a.items():
    #         name=k[7:]  # reduce `module.`
    #         new_state_dict[name] = v
    #     # load params
    #     # model.load_state_dict(new_state_dict)
    #     model.load_state_dict(new_state_dict)

显然方法一更简洁明了。

猜你喜欢

转载自blog.csdn.net/lyl771857509/article/details/84642555
今日推荐