『PyTorch』加载模型的bug

在 PyTorch 中,保存和加载模型有两种方法:

torch.save(net,'./model.pth')   # 保存整个模型及其参数
net = torch.load('./model.pth')  # 加载整个模型及其参数
# 或者
torch.save(net.state_dict(),'./model-dict.pth')# 仅仅保存模型参数
net.load_state_dict(torch.load('./model-dict.pth')) # 仅仅加载模型参数(所以需要事先定义一个模型 net)

net.load_state_dict() 和 torch.load() 的不同在于,前者需要你先定义一个模型,然后再 load_state_dict()。
torch.load() 直接加载整个模型,会把模型和模型参数一起 load 进来。完成了模型的定义和加载参数的两个过程。

需要注意的是,在保存模型之前,需要把模型进行 eval() , 即把模型从训练阶段转化为测试阶段,固定当下的模型参数,用于接下来的模型预测。如果不指定模型 eval() 模式,那么加载回来的模型并不是和原先保存的模型相同。

简单说,原先的 net 在保存之前,要 eval() 一下,load() 之后的 net 也要 eval() 一下,把所有参数 freeze 掉。才保证两个 net 完全相同(输入相同 tensor 得到完全一致的结果)。

因为在模型的训练阶段,在进行有 BN 层或者有 Dropout 层的模型训练中,获取的批次数据属性(均值、方差)会被记录下来,用于对测试数据的标准化。或者对于 Dropout 层,在训练的阶段会有一些神经元权重被置零,但是在测试阶段,这些神经元又被重新使用。如果不进行 model.eval() 的话,那么每次测试阶段这些参数的值会在前向传播的时候发生改变。导致模型不稳定。

使用 PyTorch 进行训练和测试时一定注意要把实例化的 model 指定 train/eval,eval() 时,框架会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值,不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。

  • model.train()
    启用 BatchNormalizationDropout
  • model.eval()
    不启用 BatchNormalizationDropout

另外,如果我们想在其他项目中用已经训练的模型,就个人经验而言,使用 torch.load() 并不能成功加载模型,而是会报错,所以只能通过存储 state_dict 来进行保存模型。

这里注意,模型代码要原封不动的复制到目标工程中!

因为,你要先创建模型,然后加载参数,关于 load_state_dict() 方法,一般可能出现 2 种错误:

  • Missing key(s) in state_dict
  • Unexpected key(s) in state_dict

前者表示加载的模型中应该有的参数你没有,后者表示你加载的参数存在,但是你的模型代码里并没有这些参数
至于多或者缺哪些,可以根据保存看, pth 保存参数时时,命名是多级Module名,最后是列表的下标,用 . 连接,根据此规则去找到对应模型代码位置修改即可。

猜你喜欢

转载自blog.csdn.net/dreaming_coder/article/details/109264090